SGD简介

SGD(Stochastic Gradient Descent),译为随机梯度下降,是深度学习中的常用的函数优化方法。

1.引例

在介绍SGD简介之前首先来引入一个例子,有三个人在山顶上正在思考如何快速的下山,老大,老二和老三分别提出了三个不同的观点。

  • 老大说:从山顶出发,每走一段路程,就寻找附近所有的山路,挑选最陡峭的山路继续前进,顾名思义,老大总是挑最陡峭的山路来走。

  • 老二说:从山顶出发,每走一段路程,就随机地寻找附近部分的山路,挑选最陡峭的山路继续前进,顾名思义,老二随机的寻找部分山路,然后走最陡峭的。

  • 老三说:从山顶出发,直接随机的挑选山路走,直到到达山脚。

老大的走法虽然每条路都是最优,但是在寻找最陡的山路的过程中会耗费大量的时间。

老三的走法较为随意,每次走的路有可能最优,可能最劣。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Vp2WZCwN-1676107686927)(/image_editor_upload/20220730/20220730050337_83545.png)]

那么你认为最先到达山脚呢?在学完SGD简介之后,你就会得到答案。

2.SGD介绍

2.1引入问题

给你一个SGD简介坐标系,上面有一些点,给你过原点的一条直线SGD简介,如何用最快的方法来拟合这些点?
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OOYBhgNT-1676107686929)(/image_editor_upload/20220730/20220730050742_20671.png)]

为了解决这个问题,我们要对问题定义一个目标,即让所有的点离直线的偏差最小。我们常用的误差函数为均方误差,对于一个点SGD简介来说,它与直线的均方误差可以定义为SGD简介
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-c26Kyylh-1676107686933)(/image_editor_upload/20220730/20220730050907_67003.png)]

SGD简介
完全平方展开:
SGD简介SGD简介
同理,点SGD简介SGD简介SGD简介SGD简介都是如此:
SGD简介SGD简介SGD简介
而我们最终的误差SGD简介
通过合并同类项:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TEdFVq6M-1676107686938)(/image_editor_upload/20220730/20220730033337_23748.png)]

因为SGD简介,所以SGD简介,所以SGD简介是一个向上的抛物线。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KEOyQ3GV-1676107686945)(/image_editor_upload/20220730/20220730035607_63923.png)]

2.2SGD的计算步骤

回到刚刚爬山那个问题,通过大量数据实验得知,老二的SGD简介方法能最快到达山脚。

3.SGD的代码实现

from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
X_scaler = StandardScaler()
y_scaler = StandardScaler()
X = [[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],[250],[300]]
y = [[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],[310],[330]]
#plt.show()
X = X_scaler.fit_transform(X) #用什么方法标准化数据?
y = y_scaler.fit_transform(y)
X_test = [[40],[400]] # 用来做最终效果测试
X_test = X_scaler.transform(X_test) 
model = SGDRegressor()
model.fit(X, y.ravel())
y_result = model.predict(X_test)
plt.title('single variable')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.plot(X, y, 'k.')
plt.plot(X_test, y_result, 'g-')
plt.show()

结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DvmfgnmY-1676107686960)(/image_editor_upload/20220730/20220730050240_27787.png)]

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年6月25日
下一篇 2023年6月25日

相关推荐