Representation Learning with Contrastive Predictive Coding 论文阅读笔记

foreword

这是DeepMind在2019年发的一篇paper,也是无监督学习领域非常经典的一篇论文。它介绍了一种叫做Contrastive Predictive Coding的技术用于对序列形式的数据(e.g. 文本,声音信号等等)进行表征(representation)学习。

CPC 模型架构

首先我们先来看一下论文里的这张cpc模型架构图。

对于一组序列数据x,先通过一个编码器g_enc对序列的每个位置xt进行编码,转化为表征zt。然后这系列表征作为输入,传入一个自回归模型g_ar。该自回归模型在t时刻的输出为表征ct,其中包含了序列到t时刻为止的信息。

到目前为止,光从模型的架构上来看,并没有什么特别之处。不过真正有趣的地方才刚刚开始。首先,基于上述的模型架构,非常直觉的,我们可能会想到用t时刻为止模型学到的序列表征ct来预测之后的序列数据。但是谷歌的大佬们并没有这么做,他们并不是直接对p(x|c)进行建模。因为在篇paper中,他们希望模型学到的是一种high-level的表征,包含更多全局的信息,即文中提到的‘slow feature’,他们希望c就是这样的特征,但是用这种high-level的表征去预测高维数据是非常困难的。因此这里他们采用了另一种模型优化逻辑,对于x(future) 和 c(present), 他们希望最大化这两者的互信息(mutual information)。

其实这也非常符合直觉,真实的相邻点之间必然存在高的信息重叠,c的引入必然对于x的不确定性有大的减弱。

Denstiy Ratio

基于上述分析,作者定义了一个密度比来描述互信息。

在原文中,作者使用了一种简单的log-bilinear model 来对density ratio 进行参数化。

我们来看下这个表达式,里面的zt+k为xt+k经过编码器编码后得到的表征,ct是前面提到的在step t的high-level 表征。而中间的Wk为模型需要训练的参数,这里需要注意的一点是对于每个间隔k,都对应一个独立的参数矩阵。

非常容易理解,当编码器得到的表征zt+k和ct的乘积越大,那么也对应着更大的相似性以及互信息,所以这里的density ratio满足正比于互信息的条件。

InfoNCE Loss

我们已经来到了论文的关键部分,损失函数的定义。以下是原论文中提出的损失函数表达式。为什么损失函数是这种形式?

直觉理解

我们先来直观地理解一下,最小化损失函数等价于让期望表达式里的分子尽可能大,那么这里的分子代表什么呢?不就是前面定义的用来描述xt+k(future) 和 ct(present) 互信息的density ratio嘛!所以模型的优化目标就是希望让真实的future 和present有更高的互信息,让虚假的future和present的互信息尽可能小。
(注意,这部分描述并不准确,因为这里的density ratio并不完全等价于互信息)

这里我们引入对比学习的相关概念再来解释一下上面所谓的真实和虚假。其实上面的ct可以理解为anchor(锚点), xt+k(真实的future)可以理解为正例样本,而xj(虚假的future)就是负例样本。在一个包含N个样本的batch X中,我们有1个正例样本和N-1个负例样本。我们可以通过下面这张图来加深一下理解。

严谨推导

好了,这部分我们来严谨地推导下上述的损失函数为什么makes sense。我们先把这个式子放一边,来仔细思考一下我们的优化目标。还是这个N个样本,其中1个正例样本,N-1个负例样本,假设我们正例样本的index=i,即第i个样本为正,其余均为负。当我们把它看作是一个识别正样本的分类问题的时候,我们希望我们的模型去最大化p(idx=i|X,ct)。而这个式子由可以进一步拆解为如下的等式

第一个等式的分子代表idx=i的样本恰为正样本,从p(x|c) 中抽样,其余均为从p(x) 中抽样的负样本。分母的话就是所有情况的和,即idx=1,2,3,…N的样本是正样本,而其余样本为负样本。第二个等式需要一点小trick。当我们给定X,那么所有N个独立样本均从p(x) 中抽样得到的概率的乘积应该是一个定值,因此第一个等式右边分子的 N-1个负样本的概率连乘就可以转化为定值/p(正样本xi)。同理,分母的N-1个负样本的概率连乘也可以转化为定值/p(正样本xj)。这样就从第一个等式到了第二个等式。

我们仔细观察第二个等式,会发现分子分母不就是我们开始提到的互信息表达式吗!我们把p(xi|ct)/p(xi) 替换为 p(xt+k|ct)/p(xt+k),然后再用density ratio的表达式做个替换,就得到了我们的CPC loss中log里的表达式,最大化概率p(idx=i|X,ct) 就是最小化我们的CPC loss。

写在最后

到这里,CPC的核心部分就差不多讲完了,还有部分关于为什么最小化loss等价于最大化正例xt+k与context ct的互信息的证明这里就不再展开了,感兴趣的可以去看一下原文附录中的证明。总结一下,CPC主要利用对比学习的思想,将最大化正例样本与锚点表征的互信息作为模型的训练目标,进行模型的无监督训练,最终的目标还是将获取到的high-level表征用于下游任务。

参考

1.Representation Learning with Contrastive Predictive Coding
2.Contrastive Self-Supervised Learning[0][1]

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年4月28日
下一篇 2022年4月28日

相关推荐