损失函数InfoNCE loss和cross entropy loss以及温度系数

还是基础知识的搬运哦

(1)对比学习常用的损失函数InfoNCE loss和cross entropy loss是否有联系?

(2)对比损失InfoNCE loss中有一个温度系数,其作用是什么?温度系数的设置对效果如何产生影响?

个人认为,这两个问题可以作为对比学习相关项目面试的考点,本文我们就一起盘一盘这两个问题。

1. InfoNCE loss公式

对比学习损失函数有多种,其中比较常用的一种是InfoNCE loss,InfoNCE loss其实跟交叉熵损失有着千丝万缕的关系,下面我们借用恺明大佬在他的论文MoCo里定义的InfoNCE loss公式来说明。论文MoCo提出,我们可以把对比学习看成是一个字典查询的任务,即训练一个编码器从而去做字典查询的任务。假设已经有一个编码好的query  (一个特征), 以及一系列编码好的样本 , 那么  可以看作是字典里的key。假设字典里只有一个  即  (称 为  positive) 是跟  是匹配的,那么  和  +就互为正样本对, 其余的  为  的负样本。一旦定义 好了正负样本对, 就需要一个对比学习的损失函数来指导模型来进行学习。这个损失函数需要满足 这些要求, 即当query  和唯一的正样本  相似, 并且和其他所有负样本key都不相似的时候, 这 个loss的值应该比较低。反之, 如果  和  不相似, 或者  和其他负样本的key相似了, 那么loss就 应该大, 从而惩罚模型, 促使模型进行参数更新。

损失函数InfoNCE loss和cross entropy loss以及温度系数
损失函数InfoNCE loss和cross entropy loss以及温度系数

2. InfoNCE loss和交叉熵损失有什么关系?

我们先从softmax说起,下面是softmax公式:

损失函数InfoNCE loss和cross entropy loss以及温度系数
损失函数InfoNCE loss和cross entropy loss以及温度系数

上式中的  在有监督学习里指的是这个数据集一共有多少类别, 比如CV的ImageNet数据集有 1000 类, k就是1000。

对于对比学习来说,理论上也是可以用上式去计算loss,但是实际上是行不通的。为什么呢?

还是拿CV领域的ImageNet数据集来举例,该数据集一共有128万张图片,我们使用数据增强手段(例如,随机裁剪、随机颜色失真、随机高斯模糊)来产生对比学习正样本对,每张图片就是单独一类,那k就是128万类,而不是1000类了,有多少张图就有多少类。但是softmax操作在如此多类别上进行计算是非常耗时的,再加上有指数运算的操作,当向量的维度是几百万的时候,计算复杂度是相当高的。所以对比学习用上式去计算loss是行不通的。

怎么办呢?NCE loss可以解决这个问题。

NCE(noise contrastive estimation)核心思想是将多分类问题转化成二分类问题,一个类是数据类别 data sample,另一个类是噪声类别 noisy sample,通过学习数据样本和噪声样本之间的区别,将数据样本去和噪声样本做对比,也就是“噪声对比(noise contrastive)”,从而发现数据中的一些特性。但是,如果把整个数据集剩下的数据都当作负样本(即噪声样本),虽然解决了类别多的问题,计算复杂度还是没有降下来,解决办法就是做负样本采样来计算loss,这就是estimation的含义,也就是说它只是估计和近似。一般来说,负样本选取的越多,就越接近整个数据集,效果自然会更好。

NCE loss常用在NLP模型中,公式如下:

损失函数InfoNCE loss和cross entropy loss以及温度系数
损失函数InfoNCE loss和cross entropy loss以及温度系数

上述公式细节详见:NCE loss(https://arxiv.org/pdf/1410.8251.pdf)

有了NCE loss,为什么还要用Info NCE loss呢?

Info NCE loss是NCE的一个简单变体,它认为如果你只把问题看作是一个二分类,只有数据样本和噪声样本的话,可能对模型学习不友好,因为很多噪声样本可能本就不是一个类,因此还是把它看成一个多分类问题比较合理(但这里的多分类kk指代的是负采样之后负样本的数量,下面会解释)。于是就有了InfoNCE loss,公式如下:

 

损失函数InfoNCE loss和cross entropy loss以及温度系数

上式中,  是模型出来的logits, 相当于上文  oftmax公式中的  是一个温度超参数, 是个标 量, 假设我们忽略 , 那么infoNCE loss其实就是cross entropy loss。唯一的区别是, 在cross entropy loss里,  指代的是数据集里类别的数量, 而在对比学习InfoNCE loss里, 这个k指的是负样本的数量。上式分母中的sum是在 1 个正样本和  个负样本上做的, 从0到 , 所以共  个样本, 也就是字典里所有的key。恺明大佬在MoCo里提到, InfoNCE loss其实就是一个cross entropy loss, 做的是一个  类的分类任务, 目的就是想把  这个图片分到  这个类。

另外,我们看下图中MoCo的伪代码,MoCo这个loss的实现就是基于cross entropy loss。

损失函数InfoNCE loss和cross entropy loss以及温度系数
损失函数InfoNCE loss和cross entropy loss以及温度系数

3. 温度系数的作用

温度系数  虽然只是一个超参数, 但它的设置是非常讲究的, 直接影响了模型的效果。上式Info NCE loss中的  相当于是logits, 温度系数可以用来控制logits的分布形状。对于既定的logits分 布的形状, 当  值变大, 则  就变小,  则会使得原来logits分布里的数值都变小, 且经过指数运算之后, 就变得更小了, 导致原来的logits分布变得更平滑。相反, 如果  取得值小,  就 变大, 原来的logits分布里的数值就相应的变大, 经过指数运算之后, 就变得更大, 使得这个分布变得更集中, 更peak。

如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

总之,温度系数的作用就是它控制了模型对负样本的区分度。

  whaosoft aiot http://143ai.com  

 

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年3月4日 下午5:01
下一篇 2023年3月4日 下午5:02

相关推荐