foreword
本文介绍多项式回归。
1. 一个例子
线性回归可以很好地拟合线性分布的数据,但是对于非线性数据就没有用了,比如下面的数据:
这是一个简单的例子,它只有一个特征,输出变量只是这一特征的函数,但线性回归无法拟合。之前学过一个局部加权的线性回归算法,自然很适合这样的数据分布,但是那个算法每次预测都需要重新训练参数,拟合的非常好,但是有个很大的缺点就是时间成本太高高的。
观察这样的数据分布,发现一个线性函数无法拟合,那么非线性函数呢?例如,二次函数可以拟合吗?如果二次函数拟合不好,三次函数能拟合得更好吗?二次函数是对称函数。上图不是不对称的,所以二次函数的拟合效果可能不好。但是从数据分布可以看出,二次函数的拟合应该比线性函数(Linear function)的拟合好。次数增加后,拟合效果会更好,那么三次函数呢?熟悉三次函数的人可能会看到,只要正确选择三次函数的四个系数,三次函数就更接近于数据分布。根据泰勒公式,我们可以知道,如果一个函数是阶连续可微的,那么我们可以用阶多项式来逼近这个函数。这也启发了我们使用多项式来拟合非线性数据分布!
2. 多项式回归模型
如果想对输出变量 和特征向量 作回归,根据上面的启发,我们想用一个多项式函数来预测 ,假设我们使用 2 阶多项式来预测 ,也就是说我们想寻找参数 使:
我们只有特征,如何求特征的平方?这很简单!只需将训练集中的每个特征都发挥到幂。获得的方法是将列乘以列。也就是说,我们有 训练样本,矩阵 是:
我们只需要将矩阵除列之外的所有列加入到这个矩阵中所有列的平方和乘积形成的新列中,就会形成一个新的矩阵,然后把这个矩阵作为原矩阵使用普通的线性回归训练参数。正规方程解法为:
也可以使用梯度下降等解决方案,与普通线性回归完全一样。这是多项式回归。
上面介绍了二阶多项式回归的方法,高阶多项式回归也是如此。如果数据可以用相对低阶的多项式更好地拟合,那么多项式回归是更好的方法,但是如果低阶多项式不能很好地拟合数据,则需要更高阶的多项式来拟合,那么,新构建的矩阵特征的维度会太大,不利于训练。
3. 代码实现
代码使用线性函数(普通线性回归)、二次函数和三次函数来拟合上面的例子:
import numpy as np
import matplotlib.pyplot as plt
# 创造数据
def CreateData():
X = np.arange(0,10,0.3)
y = np.empty(X.shape[0])
for i in range(X.shape[0]):
y[i] = 1.1*X[i]**3 - 10*X[i]**2 + X[i] + np.random.uniform(-10,10)
return X[:,np.newaxis], y
X, y = CreateData()
X = np.insert(X, 0, 1, axis = 1)
# 数据可视化
plt.scatter(X[:,1], y, marker = 'x')
# 使用普通线性回归预测(一次函数)
theta = np.dot(np.linalg.inv(np.dot(X.T, X)), np.dot(X.T, y))
# 可视化回归曲线
t = np.linspace(-1, 11, 100)
plt.plot(t, theta[0] + theta[1] * t, c = 'blue')
# 使用二次函数回归
col_new = X[:,1]**2 # 新增加一列
X = np.hstack([X, col_new[:,np.newaxis]])
theta = np.dot(np.linalg.inv(np.dot(X.T, X)), np.dot(X.T, y))
# 可视化回归曲线
t = np.linspace(-1, 11, 100)
plt.plot(t, theta[0] + theta[1] * t + theta[2] * t**2, c = 'yellow')
# 使用三次函数回归
col_new = X[:,1]**3 # 新增加一列
X = np.hstack([X, col_new[:,np.newaxis]])
theta = np.dot(np.linalg.inv(np.dot(X.T, X)), np.dot(X.T, y))
# 可视化回归曲线
t = np.linspace(-1, 11, 100)
plt.plot(t, theta[0] + theta[1] * t + theta[2] * t**2 + theta[3] * t**3, c = 'red')
# 加标注
plt.legend([r"$y=\theta_0+\theta_1x$",
r"$y=\theta_0+\theta_1x+\theta_2x^2$",
r"$y=\theta_0+\theta_1x+\theta_2x^2+\theta_3x^3$"])
plt.show()
拟合结果如下:
很明显,可以看出线性函数完全没用,二次函数拟合比线性函数好,但也可以明显欠拟合,三次函数的拟合效果明显更好。这里不使用高阶多项式拟合。由于高阶多项式的逼近能力会越来越高,所以对于高阶多项式,必须考虑过拟合的问题,在使用该算法时选择合适的阶数很重要。
文章出处登录后可见!