许多机器学习方法中, 尤其是深度学习中的神经网络, 都存在几个问题:
- 模型容易过拟合.
- 模型在受到微小扰动(噪声)后, 预测结果会受相当程度的影响.
为了减少过拟合现象, 典型的监督学习中会添加一个新的损失项. 在半监督学习中, 同样存在一种正则化方法, 即一致性正则化(Consistency Regularization).
一致性正则化
具体来说, 基于平滑假设和聚类假设, 具有不同标签的数据点在低密度区域分离, 并且相似的数据点具有相似的输出. 那么, 如果对一个未标记的数据应用实际的扰动, 其预测结果不应该发生显著变化, 也就是输出具有一致性.
由于这种方法一般基于模型输出的预测向量, 不需要具体的标签, 所以其刚好能能应用于半监督学习. 通过在未标记数据上构造添加扰动后的预测结果 与正常预测结果 之间的无监督正则化损失项, 提高模型的泛化能力.
数学化表达如下:
其中, 为度量函数, 一般采用 KL 散度或 JS 散度, 当然也可使用交叉熵或平方误差等. 是数据增强函数, 会添加一些噪声扰动, 为模型参数, 多个模型的参数可以共用, 或者通过一些转换建立联系(Mean-Teacher), 如 EMA , 也可以相互独立(Daul-Student).
的类型在近几年的论文有以下几种:
- 常规的数据增强, 平移旋转, 或随机 dropout 等, 如 -model.
- 时序移动平均, 如 Temporal Ensembling, Mean-Teacher, SWA.
- 对抗样本扰动, 如 VAT, Adversarial Dropout.
- 高级数据增强, 如 UDA, Self-supervised SSL.
- 线性混合, 如 MixMatch.
blog 文章传送门:
1. Ladder Networks
Ladder Networks 文献暂时没看, 内容以后更新.
2. -Model
-Model 模型如下图所示:
单独对于有标记样本来说, 其前向传播一次, 注入扰动后在预测值与真实值之间进行交叉熵计算. 最后将两个损失函数进行加权求和即为损失函数 .
-Model 算法伪代码如下:
3. Temporal Ensembling
Temporal Ensembling 模型如下图所示:
Temporal Ensembling 算法伪代码如下:
4. Mean-Teacher
针对 Temporal Ensembling 的一些缺点做出了改进. 此前的 Temporal Ensembling 在每个 epoch 只进行一次 EMA, 无法满足大型数据集的学习, 且无法实现模型的在线训练. 为了克服这个问题, Mean Teacher 能在每个 epoch 中的每个 step 进行模型权重的更新. 也就是将原来计算输出向量 的过程变成了计算整个网络的参数 .
Mean Teacher 算法模型如下:
5. Dual-Students
与 Mean-Teacher 相比, Dual Student 用另一个学生代替老师. 两名 Student 共享具有不同初始状态的相同网络架构并分别更新, 避免 Mean-Teacher 中的权重耦合问题带来的局限性, 因为 Teacher 本质上是 Student 的 EMA.
Dual Student 通过同时训练两个独立的模型来获得松耦合的目标. Dual Student 模型如下图所示:
- 问题1. 如何定义和获取模型的可靠知识.
- 问题2. 如何相互交换知识.
为解决问题1, 引入稳定样本(Stable Sample)概念. 定义如下: 给定一个常数 , 一个满足平滑假设的数据集 和一个模型 , 对所有 满足 , 且:
- , 在 附近, 它们的预测标签是相同的.
- 满足不等式:
则 是关于 的稳定样本. 如下图中, 只有 , 满足.
Student 的最终约束是三个部分的组合: 分类约束、每个模型中的一致性约束和模型之间的稳定性约束. 如下:
Daul-Student 训练过程如下:
6. Fast-SWA
Fast-SWA 文献暂时没看, 内容以后更新.
7. Virtual Adversarial Training(VAT)
(PS: 这篇论文中还有些东西没有明白, 后面会继续深入理解)
VAT 是一种基于熵最小化的正则方法, 并提出一种对于给定输入评估模型输出条件分布局部光滑性的虚拟对抗损失, 虚拟对抗性损失可被定义为每个输入数据点周围的条件标签分布对局部扰动的鲁棒性. VAT 主要思想为: 为提升模型的鲁棒性, 对输入样本加入对抗性扰动, 将其与原始样本做一致性正则化.
在对抗训练中, 对抗方向定义为在输入数据点处, 可以最大程度地降低模型正确分类的概率, 或者是可以最大程度地"偏离"模型预测与正确标签的方向. 基于此, VAT 中引入虚拟对抗方向, 与传统对抗训练不同的是, 即使在没有标签信息的情况下, 也可以在未标记的数据点上定义虚拟对抗方向, 就好像有一个"虚拟"标签一样.
VAT 中的损失函数如下:
实际上我们没有关于 的直接信息, 因此采取策略用 替换 . 如果带标签的样本比较多时, 会逼近 , 即用 生成的虚拟标签代替不知道的标签, 并根据虚拟标签计算对抗方向, 因此 用上一步的 替代, 损失函数更新如下:
综上, 给整个目标函数加上损失函数, 这里损失函数取平均, 最终得到完整的目标函数, 如下所示:
其中 为带标签数据的负对数似然函数. 整个正则化过程中, 只有两个超参数: 正则化系数: , 对抗方向的范数约束: .
下图显示了 VAT 如何在二维合成数据集上进行半监督学习:
8. Adversarial Dropout(AdD)
是 VAT 的变种, VAT 是在 input data 上加对抗扰动, AdD 则是在网络中间层进行对抗性 dropout.
结合对抗 dropout 的附加损失函数的描述如下:
其中 , 为 adversarial dropout mask, 为随机采样 dropout mask, 为 dropout layer 的维度.
引入边界条件 . 如果没有这个约束, 具有对抗 dropout 的网络可能会变成没有连接的神经网络层.
在对抗性训练的一般形式中, 关键点是线性扰动 的存在. 可以将具有对抗性扰动的输入解释为对抗性噪声输入 . 从这个角度来看, 对抗训练的作者将对抗方向限制在加性高斯噪声 的空间上, 其中 是输入层上的采样高斯噪声. 相比之下, 对抗性 dropout 可以被认为是通过 masking 隐藏单元产生的噪声空间, , 是对抗性选择的 dropout 状态. 如果假设对抗性训练是输入上的高斯加性扰动, 则扰动本质上是线性的, 但如果对抗性 dropout 施加在多个层上, 则对抗性 dropout 可能是非线性扰动.
学习的完整目标函数由下式给出:
其中 是 中 对应的 的负对数似然.
9. Interpolation Consistency Training(ICT)
有研究表明: 对抗性扰动训练会损害泛化性能. 为了克服这个问题, 便提出了插值一致性训练(ICT), 简单来说, ICT 通过在未标记点 , 的插值 上的一致性预测 来规范半监督学习.
为什么这种插值的策略有效呢? 因为只有在决策边界的数据点才更加有效, 因为对于这些点加入扰动之后, 可能把这个点变为另外一类, 可以提高模型的决策难度, 提升泛化能力.
根据 mixup 式子:
ICT 训练分类器 以在未标记点的插值中提供一致性预测:
其中 时 的滑动平均.
在监督学习环境中, mixup 是实现大边距决策边界的一种方法. 在 mixup 中, 通过强制预测模型在样本之间线性变化, 将决策边界推离类别边界, 通过训练模型 来预测 的"假标签" 来将 mixup 扩展到半监督学习.
ICT 模型如下图所示:
10. Unsupervised Data Augmentation(UDA)
为了加强一致性, 现有方法通常采用简单的噪声注入方法, 例如添加高斯噪声. 相比之下, UDA 中使用监督学习的更强的数据增强作用于半监督的一致性训练框架中, 研究表明此方法可以带来更卓越的性能.
UDA 模型如下:
- 1.带标签样本的的有监督交叉熵损失 CE loss.
- 2.在未标记数据上, 对数据进行增强处理后的预测结果 与增强前的预测结果 之间的一致性惩罚 loss.
UDA 针对不同任务有不同的数据增强策略
-
RandAugment for Image Classification. 使用一种名为 RandAugment 的数据增强方法, 该方法受到 AutoAugment 的启发. AutoAugment 使用一种搜索方法将 Python 图像库(PIL)中的所有图像处理转换结合起来, 以找到一个好的增强策略. 在 RandAugment 中, 不使用搜索, 而是从 PIL 中的同一组增强变换中统一采样. 换句话说, RandAugment 更简单, 不需要标记数据, 因为不需要搜索最优策略.
-
Back-translation for Text Classification. 当用作扩充方法时, 回译(Back-translation)是指将语言 A 中的现有示例 翻译成另一种语言 B, 然后将其翻译回 A 以获得扩充示例 的过程. 回译可以生成不同的释义, 同时保留原始句子的语义, 从而显着提高问答的性能. 如下图所示. 论文的案例中, 使用 Back-translation 来解释文本分类任务的训练数据.
-
Word replacing with TF-IDF for Text Classification. 虽然回译擅长维护句子的全局语义, 但对保留哪些单词几乎没有控制. 这个要求对于主题分类任务很重要, 因此, 用低 TF-IDF 值替换.
文章出处登录后可见!