PyTorch中计算KL散度详解

PyTorch计算KL散度详解

最近在进行方法设计时,需要度量分布之间的差异,由于样本间分布具有相似性,首先想到了便于实现的KL-Divergence,使用PyTorch中的内置方法时,踩了不少坑,在这里详细记录一下。

简介

首先简单介绍一下KL散度(具体的可以在各种技术博客看到讲解,我这里不做重点讨论)。
从名称可以看出来,它并不是严格意义上的距离(所以才叫做散度~),原因是它并不满足距离的对称性,为了弥补这种缺陷,出现了JS散度(这就是另一个故事了…)
我们先来看一下KL散度的形式:
PyTorch中计算KL散度详解

手动代码实现

可以看到,KL散度形式上还是比较直观的,我们先手撸一个试试:
这里我们随机设定两个随机变量P和Q

import torch
P = torch.tensor([0.4, 0.6])
Q = torch.tensor([0.3, 0.7])

快速算一下答案:
PyTorch中计算KL散度详解

数值计算实现版:

def DKL(_p, _q):
		"""calculate the KL divergence between _p and _q
		"""
    return  torch.sum(_p * (_p.log() - _q.log()), dim=-1)

divergence = DKL(P, Q)
print(divergence)
# tensor(0.0226)

上面的代码中,之所以求和时dim=-1是因为我在使用的过程中,考虑到有时是对batch中feature进行计算,所以这里只对特征维度进行求和。
接下来,就到了今天介绍的主角~

torch代码实现

torch中提供有两种不同的api用于计算KL散度,分别是torch.nn.functional.kl_div()torch.nn.KLDivLoss(),两者计算效果类似,区别无非是直接计算和作为损失函数类。

先介绍一下torch.nn.functional.kl_div()

注意,该方法的inputtargetPyTorch中计算KL散度详解PyTorch中计算KL散度详解PyTorch中计算KL散度详解的位置正好相反,从参数名称就可以看出来(target为目标分布PyTorch中计算KL散度详解input为待度量分布PyTorch中计算KL散度详解)。为了防止指代混乱,我后面统一用PyTorch中计算KL散度详解PyTorch中计算KL散度详解指代targetinput
在这里插入图片描述

reduction:该参数是结果应该以什么规约形式进行呈现,sum即为我们定义式中的效果,batchmean:按照batch大小求平均,mean:按照元素个数进行求平均

再看看log_target的效果:

if not log_target: # default
    loss_pointwise = target * (target.log() - input)
else:
    loss_pointwise = target.exp() * (target - input)

也就是说,如果log_target=False,此时计算方式为
PyTorch中计算KL散度详解
这和我们熟悉的定义式的计算方式是不同的,如果想要和定义式的效果一致,需要对input取对数操作(在官方文档中也有提及,建议将input映射到对数空间,防止数值下溢):

import torch.nn.Functional as F

print(F.kl_div(Q.log(), P, reduction='sum'))
#tensor(0.0226)

而当log_target=True时,此时的计算方式变为
PyTorch中计算KL散度详解
也就是说,此时我们对PyTorch中计算KL散度详解取对数操作即可得到定义式的效果:

print(F.kl_div(Q.log(), P.log(), 
	  log_target=True, reduction='sum'))
#tensor(0.0226)

这样设计的目的也是为了防止数值下溢。

torch.nn.KLDivLoss()的参数列表与torch.nn.functional.kl_div()类似,这里就不过多赘述。

总结

总的来说,当需要计算KL散度时,默认情况下需要对input取对数,并设置reduction='sum'方能得到与定义式相同的结果:

divergence = F.kl_div(Q.log(), P, reduction='sum')

由于我们度量的是两个分布的差异,因此通常需要对输入进行softmax归一化(如果已经归一化则无需此操作):

divergence = F.kl_div(Q.softmax(-1).log(), P.softmax(-1), reduction='sum')

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年2月25日 下午3:20
下一篇 2023年2月25日

相关推荐