رگرسیون خطی (Linear Regression)
رگرسیون (Regression) یعنی پیدا کردن رابطه بین متغیرها (Variables). در یادگیری ماشین، این رابطه برای پیش بینی آینده استفاده می شود. رگرسیون خطی یعنی یک خط صاف از میان نقطه ها رد می کنیم تا مقدارهای آینده حدس زده شوند.
رگرسیون خطی چیست؟
در رگرسیون خطی (Linear Regression) با استفاده از رابطه نقطه ها یک خط رسم می کنیم. سپس با همان خط، مقدارهای جدید پیش بینی می شوند. این دقیقاً مثل حدس زدن نمره با توجه به روند قبلی است.
1) رسم نمودار پراکندگی
اول داده ها را نقطه ای رسم می کنیم. محور x سن خودرو است. محور y سرعت آن است. این کار الگو را به ما نشان می دهد.
import matplotlib.pyplot as plt
x = [5, 7, 8, 7, 2, 17, 2, 9, 4, 11, 12, 9, 6]
y = [99, 86, 87, 88, 111, 86, 103, 87, 94, 78, 77, 85, 86]
plt.scatter(x, y)
plt.show()
2) کشیدن خط رگرسیون
با SciPy شیب (Slope) و عرض از مبدأ (Intercept) را حساب می کنیم. سپس همان خط را روی نمودار می کشیم.
import matplotlib.pyplot as plt
from scipy import stats
x = [5, 7, 8, 7, 2, 17, 2, 9, 4, 11, 12, 9, 6]
y = [99, 86, 87, 88, 111, 86, 103, 87, 94, 78, 77, 85, 86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(v):
return slope * v + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
3) بررسی ضریب رابطه r
ضریب همبستگی r از -1 تا 1 است. صفر یعنی بی ربط. نزدیک 1 یا -1 یعنی رابطه قوی. اگر r ضعیف باشد، پیش بینی با خط خوب نمی شود.
from scipy import stats
x = [5, 7, 8, 7, 2, 17, 2, 9, 4, 11, 12, 9, 6]
y = [99, 86, 87, 88, 111, 86, 103, 87, 94, 78, 77, 85, 86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
print(r)
نکته: مقدار حدود -0.76 یعنی رابطه وجود دارد، اما کامل نیست. با این حال، می توان برای پیش بینی از خط استفاده کرد.
4) پیش بینی مقدار جدید
حالا سرعت یک خودرو 10 ساله را حدس بزنیم. کافی است x را 10 بگذاریم و تابع را صدا بزنیم.
from scipy import stats
x = [5, 7, 8, 7, 2, 17, 2, 9, 4, 11, 12, 9, 6]
y = [99, 86, 87, 88, 111, 86, 103, 87, 94, 78, 77, 85, 86]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(v):
return slope * v + intercept
speed = myfunc(10)
print(speed)
نکته: نتیجه حدود 85.6 است. روی نمودار هم قابل مشاهده است.
5) وقتی داده ها خطی نیستند
اگر داده ها الگوی خطی نداشته باشند، خط بد می افتد. در این حالت r خیلی کوچک می شود.
import matplotlib.pyplot as plt
from scipy import stats
x = [89, 43, 36, 36, 95, 10, 66, 34, 38, 20, 26, 29, 48, 64, 6, 5, 36, 66, 72, 40]
y = [21, 46, 3, 35, 67, 95, 53, 72, 58, 10, 26, 34, 90, 33, 38, 20, 56, 2, 47, 15]
slope, intercept, r, p, std_err = stats.linregress(x, y)
def myfunc(v):
return slope * v + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
بیایید r را هم چاپ کنیم. اگر تقریباً صفر باشد، خط مناسب نیست.
from scipy import stats
x = [89, 43, 36, 36, 95, 10, 66, 34, 38, 20, 26, 29, 48, 64, 6, 5, 36, 66, 72, 40]
y = [21, 46, 3, 35, 67, 95, 53, 72, 58, 10, 26, 34, 90, 33, 38, 20, 56, 2, 47, 15]
slope, intercept, r, p, std_err = stats.linregress(x, y)
print(r)
نکته: مقدار حدود 0.013 یعنی رابطه بسیار ضعیف است. پس این داده برای رگرسیون خطی مناسب نیست.
جمع بندی سریع
- رگرسیون خطی یعنی کشیدن یک خط برای پیش بینی.
- اول scatter بکش و الگو را ببین.
- با linregress شیب و عرض را بگیر.
- اگر r نزدیک صفر است، روش مناسب نیست.
- برای x جدید، y را حدس بزن.