网络优化方法–正则化

正则化

1.1 正则化介绍

  正则化也叫作规范化,通常用得比较多的方式是 L1 正则化L2 正则化。L1 和 L2 正则 化的使用实际上就是在普通的代价函数(例如均方差代价函数或交叉熵代价函数)后面加上一 个正则项,例如加上了 L1 正则项的交叉熵为:
网络优化方法--正则化
  加上L2正则项的交叉熵为:
网络优化方法--正则化
  L2正则项的交叉熵也可写为:
网络优化方法--正则化
  其中网络优化方法--正则化是原始的代价函数,网络优化方法--正则化是正则项的系数,网络优化方法--正则化是一个大于 0 的数,网络优化方法--正则化的值越大那么正则 项的影响就越大,网络优化方法--正则化的值越小正则项的影响也就越小,当网络优化方法--正则化为 0 时,相当于正则项不存在。N 表 示样本个数。w 代表所有的权值参数和偏置值。

  我们训练模型的过程中实际上就是使用梯度下降法来最小化代价函数的过程,交叉熵代价 函数中的 t 和 y 的值越接近,那么代价函数的值就越接近于 0。观察带有正则项的代价函数表 达式我们可以知道,最小化代价函数的过程中不仅要使得 t 的值接近于 y,还要使得神经网络 的权值参数 w 的值趋近于 0。因为不管是对于 L1 正则项 网络优化方法--正则化还是对于 L2 正则项 网络优化方法--正则化, 正则项的值都是大于 0 的,所以最小化正则项的值,实际上就是让 w 的值接近于 0。

1.2 L1正则项与L2正则项的区别

  L1 正则项会使得神经网络中的很多权值参数变为 0,如果神经网络中很多的权值都是 0 的 话那么可以认为网络的复杂度降低了,拟合能力也降低了,因此不容易出现过拟合的情况。

  L2 正则项会使得神经网络的权值衰减,权值参数变为接近于 0 的值,注意这里的接近于 0 不是等于零,L2 正则化很少会使权值参数等于 0。L2 正则项之所以有效是因为权值参数 w 变 得很小之后 WX+b 的计算也是会变成一个接近于 0 的值。我们知道在使用 sigmoid(x)函数或 者 tanh(x)函数时,当 x 的取值在 0 附近时,函数的曲线是非常接近于一条直线的,如图 所示。

image-20220508232507951

  所以神经网络中增加了很多线性特征减少了很多非线性的特征,网络的复杂度降低了,因 此不容易出现过拟合

1.3 正则化程序

  这里我们将正则化应用在MNIST数据集识别中。

  代码使用Jupyter Notebook调试。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Flatten
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt
import numpy as np
# 使用l1或l2正则化
from tensorflow.keras.regularizers import l1,l2
# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)

# 模型定义,model1使用l2正则化
# l2(0.0003)表示使用l2正则化,正则化系数为0.0003
model1 = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(units=200,activation='tanh',kernel_regularizer=l2(0.0003)),
        Dense(units=100,activation='tanh',kernel_regularizer=l2(0.0003)),
        Dense(units=10,activation='softmax',kernel_regularizer=l2(0.0003))
        ])

# 在定义一个一模一样的模型用于对比测试,model2不使用正则化
model2 = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(units=200,activation='tanh'),
        Dense(units=100,activation='tanh'),
        Dense(units=10,activation='softmax')
        ])

# sgd定义随机梯度下降法优化器
# loss='categorical_crossentropy'定义交叉熵代价函数
# metrics=['accuracy']模型在训练的过程中同时计算准确率
sgd = SGD(0.2)
model1.compile(optimizer=sgd,
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model2.compile(optimizer=sgd,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 传入训练集数据和标签训练模型
# 周期大小为30(把所有训练集数据训练一次称为训练一个周期)
epochs = 30
# 批次大小为32(每次训练模型传入32个数据进行训练)
batch_size=32
# 先训练model1
history1 = model1.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
# 再训练model2
history2 = model2.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))

  训练过程:

Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 6s 101us/sample - loss: 0.4063 - accuracy: 0.9205 - val_loss: 0.2799 - val_accuracy: 0.9560
Epoch 2/30
60000/60000 [==============================] - 5s 78us/sample - loss: 0.2611 - accuracy: 0.9605 - val_loss: 0.2427 - val_accuracy: 0.9634
Epoch 3/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.2177 - accuracy: 0.9692 - val_loss: 0.2148 - val_accuracy: 0.9661
Epoch 4/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.1936 - accuracy: 0.9734 - val_loss: 0.1896 - val_accuracy: 0.9737
Epoch 5/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.1779 - accuracy: 0.9761 - val_loss: 0.1812 - val_accuracy: 0.9747
Epoch 6/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.1682 - accuracy: 0.9776 - val_loss: 0.1730 - val_accuracy: 0.9749
Epoch 7/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.1616 - accuracy: 0.9785 - val_loss: 0.1727 - val_accuracy: 0.9740
Epoch 8/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.1573 - accuracy: 0.9794 - val_loss: 0.1744 - val_accuracy: 0.9741
Epoch 9/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.1534 - accuracy: 0.9806 - val_loss: 0.1730 - val_accuracy: 0.9744
Epoch 10/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.1510 - accuracy: 0.9808 - val_loss: 0.1802 - val_accuracy: 0.9700
Epoch 11/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.1486 - accuracy: 0.9818 - val_loss: 0.1590 - val_accuracy: 0.9778
Epoch 12/30
60000/60000 [==============================] - 5s 77us/sample - loss: 0.1465 - accuracy: 0.9821 - val_loss: 0.1578 - val_accuracy: 0.9791
Epoch 13/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.1462 - accuracy: 0.9819 - val_loss: 0.1564 - val_accuracy: 0.9772
Epoch 14/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.1442 - accuracy: 0.9822 - val_loss: 0.1582 - val_accuracy: 0.9777
Epoch 15/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.1437 - accuracy: 0.9829 - val_loss: 0.1649 - val_accuracy: 0.9745
Epoch 16/30
60000/60000 [==============================] - 4s 75us/sample - loss: 0.1408 - accuracy: 0.9833 - val_loss: 0.1548 - val_accuracy: 0.9792
Epoch 17/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.1418 - accuracy: 0.9831 - val_loss: 0.1546 - val_accuracy: 0.9783
Epoch 18/30
60000/60000 [==============================] - 4s 75us/sample - loss: 0.1417 - accuracy: 0.9833 - val_loss: 0.1552 - val_accuracy: 0.9782
Epoch 19/30
60000/60000 [==============================] - 4s 75us/sample - loss: 0.1421 - accuracy: 0.9831 - val_loss: 0.1559 - val_accuracy: 0.9777
Epoch 20/30
60000/60000 [==============================] - 4s 75us/sample - loss: 0.1393 - accuracy: 0.9840 - val_loss: 0.1682 - val_accuracy: 0.9725
Epoch 21/30
60000/60000 [==============================] - 6s 92us/sample - loss: 0.1389 - accuracy: 0.9839 - val_loss: 0.1545 - val_accuracy: 0.9772
Epoch 22/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1395 - accuracy: 0.9837 - val_loss: 0.1518 - val_accuracy: 0.9802
Epoch 23/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1396 - accuracy: 0.9834 - val_loss: 0.1484 - val_accuracy: 0.9792
Epoch 24/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1385 - accuracy: 0.9842 - val_loss: 0.1595 - val_accuracy: 0.9759
Epoch 25/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1379 - accuracy: 0.9842 - val_loss: 0.1694 - val_accuracy: 0.9737
Epoch 26/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1373 - accuracy: 0.9846 - val_loss: 0.1588 - val_accuracy: 0.9767
Epoch 27/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1377 - accuracy: 0.9839 - val_loss: 0.1512 - val_accuracy: 0.9797
Epoch 28/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1367 - accuracy: 0.9844 - val_loss: 0.1461 - val_accuracy: 0.9810
Epoch 29/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.1385 - accuracy: 0.9837 - val_loss: 0.1554 - val_accuracy: 0.9765
Epoch 30/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1371 - accuracy: 0.9843 - val_loss: 0.1631 - val_accuracy: 0.9751
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 6s 100us/sample - loss: 0.2524 - accuracy: 0.9245 - val_loss: 0.1453 - val_accuracy: 0.9544
Epoch 2/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1168 - accuracy: 0.9643 - val_loss: 0.1153 - val_accuracy: 0.9634
Epoch 3/30
60000/60000 [==============================] - 6s 93us/sample - loss: 0.0800 - accuracy: 0.9753 - val_loss: 0.0893 - val_accuracy: 0.9705
Epoch 4/30
60000/60000 [==============================] - 6s 92us/sample - loss: 0.0599 - accuracy: 0.9812 - val_loss: 0.0928 - val_accuracy: 0.9694
Epoch 5/30
60000/60000 [==============================] - 6s 92us/sample - loss: 0.0450 - accuracy: 0.9858 - val_loss: 0.0725 - val_accuracy: 0.9774
Epoch 6/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0337 - accuracy: 0.9897 - val_loss: 0.0804 - val_accuracy: 0.9744
Epoch 7/30
60000/60000 [==============================] - 5s 90us/sample - loss: 0.0253 - accuracy: 0.9925 - val_loss: 0.0749 - val_accuracy: 0.9784
Epoch 8/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0185 - accuracy: 0.9945 - val_loss: 0.0756 - val_accuracy: 0.9773
Epoch 9/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0129 - accuracy: 0.9968 - val_loss: 0.0712 - val_accuracy: 0.9788
Epoch 10/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0090 - accuracy: 0.9981 - val_loss: 0.0686 - val_accuracy: 0.9789
Epoch 11/30
60000/60000 [==============================] - 5s 90us/sample - loss: 0.0057 - accuracy: 0.9993 - val_loss: 0.0692 - val_accuracy: 0.9798
Epoch 12/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0040 - accuracy: 0.9996 - val_loss: 0.0688 - val_accuracy: 0.9789
Epoch 13/30
60000/60000 [==============================] - 6s 92us/sample - loss: 0.0032 - accuracy: 0.9997 - val_loss: 0.0690 - val_accuracy: 0.9793
Epoch 14/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0024 - accuracy: 0.9999 - val_loss: 0.0681 - val_accuracy: 0.9798
Epoch 15/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0018 - accuracy: 0.9999 - val_loss: 0.0690 - val_accuracy: 0.9800
Epoch 16/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0694 - val_accuracy: 0.9800
Epoch 17/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0014 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9803
Epoch 18/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9801
Epoch 19/30
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0723 - val_accuracy: 0.9798
Epoch 20/30
60000/60000 [==============================] - 5s 90us/sample - loss: 9.8140e-04 - accuracy: 1.0000 - val_loss: 0.0718 - val_accuracy: 0.9801
Epoch 21/30
60000/60000 [==============================] - 5s 91us/sample - loss: 8.9510e-04 - accuracy: 1.0000 - val_loss: 0.0718 - val_accuracy: 0.9803
Epoch 22/30
60000/60000 [==============================] - 5s 91us/sample - loss: 8.2853e-04 - accuracy: 1.0000 - val_loss: 0.0733 - val_accuracy: 0.9797
Epoch 23/30
60000/60000 [==============================] - 5s 91us/sample - loss: 7.6028e-04 - accuracy: 1.0000 - val_loss: 0.0732 - val_accuracy: 0.9807
Epoch 24/30
60000/60000 [==============================] - 5s 92us/sample - loss: 7.1290e-04 - accuracy: 1.0000 - val_loss: 0.0738 - val_accuracy: 0.9803
Epoch 25/30
60000/60000 [==============================] - 6s 99us/sample - loss: 6.6676e-04 - accuracy: 1.0000 - val_loss: 0.0734 - val_accuracy: 0.9797
Epoch 26/30
60000/60000 [==============================] - 6s 98us/sample - loss: 6.2845e-04 - accuracy: 1.0000 - val_loss: 0.0738 - val_accuracy: 0.9803
Epoch 27/30
60000/60000 [==============================] - 6s 98us/sample - loss: 5.9281e-04 - accuracy: 1.0000 - val_loss: 0.0747 - val_accuracy: 0.9802
Epoch 28/30
60000/60000 [==============================] - 6s 98us/sample - loss: 5.6025e-04 - accuracy: 1.0000 - val_loss: 0.0745 - val_accuracy: 0.9797
Epoch 29/30
60000/60000 [==============================] - 6s 98us/sample - loss: 5.3286e-04 - accuracy: 1.0000 - val_loss: 0.0752 - val_accuracy: 0.9803
Epoch 30/30
60000/60000 [==============================] - 6s 98us/sample - loss: 5.0925e-04 - accuracy: 1.0000 - val_loss: 0.0746 - val_accuracy: 0.9801
# 画出model1验证集准确率曲线图
plt.plot(np.arange(epochs),history1.history['val_accuracy'],c='b',label='L2 Regularization')
# 画出model2验证集准确率曲线图
plt.plot(np.arange(epochs),history2.history['val_accuracy'],c='y',label='FC')
# 图例
plt.legend()
# x坐标描述
plt.xlabel('epochs')
# y坐标描述
plt.ylabel('accuracy')
# 显示图像
plt.show()

image-20220508232716450

  前 1-30 周期是使用 L2 正则化的 model1 的结果,后 1-30 周期是不使用正则化的 model2 的结果。从结果上看,使用正则化后 model1 的训练集准确率和验证集准确率相差不大,说明正则化确实是可以起到抵抗过拟合的作用。但是使用正则化之后验证集准确率的结果并不是非常理想,说明正则化并不是适用于所有场景。在神经网络结构比较复杂,训练数据量比较少的时候,使用正则化效果会比较好。如果网络不算太复杂的话,任务比较简单的时候,使用正则化可能准确率反而会下降。对于 Dropout 来说也有类似的情况。所以 Dropout 和正则化需要根据实际使用情况的好坏来决定是否使用。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2022年5月10日
下一篇 2022年5月10日

相关推荐