Python – جستجوی شبکه‌ای (Grid Search)

جستجوی شبکه چیست؟

مدل‌های یادگیری ماشین معمولاً شامل پارامترهایی هستند که می‌توان آنها را تنظیم کرد تا نحوه یادگیری مدل تغییر کند. برای مثال، مدل رگرسیون لجستیک از sklearn دارای پارامتری به نام C است که کنترل‌کننده‌ی منظم‌سازی (regularization) است و بر پیچیدگی مدل تأثیر می‌گذارد.

چگونه بهترین مقدار برای C را انتخاب کنیم؟ بهترین مقدار به داده‌هایی که برای آموزش مدل استفاده می‌شود، بستگی دارد.

نحوه عملکرد

یک روش برای انتخاب بهترین مقدار این است که مقادیر مختلف را امتحان کنیم و سپس مقداری که بهترین امتیاز را می‌دهد، انتخاب کنیم. این تکنیک به نام جستجوی شبکه (Grid Search) شناخته می‌شود. اگر بخواهیم مقادیر دو یا بیشتر پارامتر را انتخاب کنیم، تمام ترکیب‌های مقادیر را ارزیابی کرده و به این ترتیب یک شبکه از مقادیر ایجاد می‌کنیم.

قبل از ورود به مثال، خوب است که بدانیم پارامتری که تغییر می‌دهیم چه کاری انجام می‌دهد. مقادیر بالاتر C به مدل می‌گویند که داده‌های آموزشی بیشتر شبیه اطلاعات واقعی هستند و باید وزن بیشتری به داده‌های آموزشی بدهند. در حالی که مقادیر پایین‌تر C عکس این عمل را انجام می‌دهند.

استفاده از پارامترهای پیش‌فرض

اولاً بیایید ببینیم با استفاده از پارامترهای پایه بدون جستجوی شبکه چه نتایجی می‌توانیم تولید کنیم.

برای شروع، باید ابتدا مجموعه داده‌ای را که با آن کار خواهیم کرد بارگیری کنیم:

from sklearn import datasets
iris = datasets.load_iris()

برای ایجاد مدل، باید مجموعه‌ای از متغیرهای مستقل X و متغیر وابسته y داشته باشیم:

X = iris['data']
y = iris['target']

حالا مدل لجستیک را برای طبقه‌بندی گل‌های زنبق بارگیری می‌کنیم:

from sklearn.linear_model import LogisticRegression

مدل را ایجاد کرده و max_iter را به یک مقدار بالاتر تنظیم می‌کنیم تا اطمینان حاصل کنیم که مدل نتیجه‌ای پیدا کند. توجه داشته باشید که مقدار پیش‌فرض C در مدل رگرسیون لجستیک برابر با 1 است، که این مقدار را بعداً مقایسه خواهیم کرد.

در مثال زیر، داده‌های گل زنبق را بررسی کرده و سعی می‌کنیم مدلی را با مقادیر مختلف برای C در رگرسیون لجستیک آموزش دهیم:

logit = LogisticRegression(max_iter=10000)
print(logit.fit(X, y))
print(logit.score(X, y))

نتیجه:

با تنظیم پیش‌فرض C = 1، امتیازی برابر با 0.973 به دست آوردیم.

پیاده‌سازی جستجوی شبکه

ما همان مراحل قبل را دنبال خواهیم کرد، با این تفاوت که این بار یک بازه از مقادیر برای C تنظیم خواهیم کرد.

انتخاب مقادیر برای پارامترهای جستجو ترکیبی از دانش دامنه و تمرین خواهد بود.

از آنجایی که مقدار پیش‌فرض C برابر با 1 است، بازه‌ای از مقادیر اطراف آن را تنظیم خواهیم کرد:

C = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]

سپس یک حلقه for ایجاد خواهیم کرد تا مقادیر C را تغییر داده و مدل را با هر تغییر ارزیابی کنیم.

ابتدا لیستی خالی برای ذخیره امتیازات ایجاد خواهیم کرد:

scores = []

برای تغییر مقادیر C باید بر روی بازه مقادیر حلقه بزنیم و هر بار پارامتر را به‌روزرسانی کنیم:

for choice in C:
  logit.set_params(C=choice)
  logit.fit(X, y)
  scores.append(logit.score(X, y))

با ذخیره امتیازات در لیست، می‌توانیم بهترین انتخاب برای C را ارزیابی کنیم.

مثال:

from sklearn import datasets
from sklearn.linear_model import LogisticRegression

iris = datasets.load_iris()
X = iris['data']
y = iris['target']

logit = LogisticRegression(max_iter=10000)
C = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
scores = []

for choice in C:
  logit.set_params(C=choice)
  logit.fit(X, y)
  scores.append(logit.score(X, y))

print(scores)

نتایج توضیح داده شده:

  • مقادیر پایین‌تر C عملکرد بدتری نسبت به پارامتر پایه 1 داشتند. با این حال، با افزایش مقدار C به 1.75، دقت مدل افزایش یافت.
  • به نظر می‌رسد که افزایش C بیشتر از این مقدار به افزایش دقت مدل کمک نمی‌کند.

نکات در مورد بهترین شیوه‌ها

ما مدل رگرسیون لجستیک خود را با استفاده از همان داده‌هایی که برای آموزش آن استفاده شده‌اند، ارزیابی کردیم. اگر مدل بیش از حد به آن داده‌ها تطابق داشته باشد، ممکن است برای پیش‌بینی داده‌های نادیده کارایی خوبی نداشته باشد. این خطای آماری به نام overfitting شناخته می‌شود.

برای جلوگیری از گمراه شدن توسط امتیازات روی داده‌های آموزشی، می‌توانیم بخشی از داده‌های خود را کنار بگذاریم و به‌طور خاص از آن برای تست مدل استفاده کنیم. به یاد داشته باشید که به سخنرانی درباره تقسیم داده‌های آموزشی/تستی مراجعه کنید تا از گمراه شدن و overfitting جلوگیری کنید.

پست های مرتبط

مطالعه این پست ها رو از دست ندین!
Python - محدود کردن داده (MongoDB Limit)

Python – محدود کردن داده (MongoDB Limit)

Python MongoDB محدود کردن نتایج برای محدود کردن نتایج در MongoDB، از متد limit() استفاده می‌کنیم. متد limit() یک...

بیشتر بخوانید
Python - بروزرسانی (MongoDB Update)

Python – بروزرسانی (MongoDB Update)

به‌روزرسانی یک رکورد برای به‌روزرسانی یک رکورد یا سند در MongoDB، از متد update_one() استفاده می‌کنیم. پارامتر اول متد...

بیشتر بخوانید
Python - حذف کالکشن (MongoDB Drop Collection)

Python – حذف کالکشن (MongoDB Drop Collection)

حذف کالکشن شما می‌توانید یک جدول یا کالکشن در MongoDB را با استفاده از متد drop() حذف کنید. مثالحذف...

بیشتر بخوانید

نظرات

سوالات و نظراتتون رو با ما به اشتراک بذارید

برای ارسال نظر لطفا ابتدا وارد حساب کاربری خود شوید.