前言
近期开始做关于预测模型和最优化解模型的大创,以前做数模也了解过一些,觉得是时候系统的学习学习了~身为一名才大二的超级小白,温故而知新是很有必要的,于是想在CSDN上记录学习成果!欢迎大家来和我一起讨论!
建筑模型
从最简单的计划模型开始。下面给出一组数据。
xs=np.array([0.5,0.6,0.8,1.1,1.4]) ys=np.array([5.0,5.5,6.0,6.8,7.0])
最简单的预测模型:y=w0+w1x。根据已知的模型找到w0,w1,尽可能精确的描述输入和输出的关系。
单样本误差:求出y’,单样本误差为
总样本误差:
损失函数(loss):即为总样本误差。找到w0和w1使得loss最小。这是一个三维数学模型
下面给出了相应的求解代码~并且画出了w0,w1,loss的三维图形
import numpy as np
import matplotlib.pyplot as mp
from mpl_toolkits.mplot3d import Axes3D
xs=np.array([0.5,0.6,0.8,1.1,1.4])
ys=np.array([5.0,5.5,6.0,6.8,7.0])
n=500 #把w0,w1分割500
w0_grid,w1_grid=np.meshgrid(np.linspace(-3,10,n),
np.linspace(-3,10,n))
for x,y in zip(xs,ys):
loss=(w0_grid+w1_grid*x-y)**2/2
#画图 画出每个w0,w1对应的loss,三维图
mp.figure('Loss Function',facecolor='lightgray')
ax3d = mp.gca(projection="3d")
ax3d.set_xlabel('w0')
ax3d.set_ylabel('w1')
ax3d.set_zlabel('loss')
ax3d.plot_surface(w0_grid,w1_grid,loss,
cstride=30,rstride=30,cmap='jet')
mp.show()
生成的3D图形如下
模型求解
得到了w0,w1,loss的关系,需要找到最小的loss对应的w0,w1输出。这里采用梯度下降公式。
我们先用二维求解(w0,y)举例:
梯度下降。先随便找一个点(w0,y),之后向最低点前进。
w1=w0±Lrate*y’ 。Lrate是学习率,y’是导数,w1是w的二次取值。当导数较大时,说明离最低点较远,于是较大,当导数较小时较小。Lrate自己取,可调。取较高的迭代次数可以向最低值无限接近。
对于三维,上述导数变为偏导数。可以推导出对应偏导数的公式(附代码)
#找到loss最低点对应的w0,w1,即为所求w0,w1
train_x=np.array([0.5,0.6,0.8,1.1,1.4])
train_y=np.array([5.0,5.5,6.0,6.8,7.0])
w0,w1=1,1 #随便赋初值
times=1000 #迭代1000次
lrate=0.01
for i in range(1,times+1):
#求偏导
d0=(w0+w1*train_x-train_y).sum()
d1=(w1*train_x**2+train_x*w0-train_x*train_y).sum()
#根据梯度下降公式
w0=w0-lrate*d0
w1=w1-lrate*d1
print('w0:',w0)
print('w1:',w1)
最后我们得到了w0,w1
w0: 4.065692318299849
w1: 2.2634176028710415
查看最终模型
最后,我们绘制散点图和回归线,我们只需要在进行预测时带入。
#通过w0,w1,模型参数画出回归线
linex=np.linspace(
train_x.min(),train_x.max(),100)
liney=w1*linex+w0
#画出x,y散点分布图及其回归线
mp.figure('Linear Regression',facecolor='lightgray')
mp.title('Linear Regression',fontsize=18)
mp.grid(linestyle=":")
mp.scatter(train_x,train_y,s=80,marker='o',
color='dodgerblue',label='Samples')
mp.plot(linex,liney,color='orangered',
linewidth=2,label='Regression Line')
mp.legend()
mp.show()
结语
在做DAISO的时候,输入输出数据不只是二维的,而且有上百组数据,所以需要一个更准确更合适的模型,加油!
文章出处登录后可见!
已经登录?立即刷新