本文接着慢慢磨pytorch基础。本来是想记录一下心得,结果码着码着又讲成了story。
gather函数:原始数据矩阵 根据索引矩阵 取到对应值矩阵
我们在学习Softmax回归从零实现的时候,需要定义一个交叉熵损失函数。我们会使用torch.gather函数的方法取原始数据矩阵中对应位置的值,接着再取log等处理。
先看的例子
可以先囫囵吞枣地看一下gather函数的例子:
# 变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。
import torch
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
# y是一个索引矩阵,包含了每个样本的正确类别的位置。
# 比如第1个样本是第1类,第2个样本是第3类。0即第1个,2即第2个。
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))
# 输出:
tensor([[0.1000],
[0.5000]])
可以看到y是一个索引矩阵,包含了每个样本的正确类别的位置。比如第1个样本是第1类,第2个样本是第3类(0是第1个,2是第3个)。
原始数据矩阵y_hat使用gather传入y后,成功取到了两个样本的预测概率分布 各自对应正确类别下的 概率。
所以这里已经可以理解:gather函数是 原始数据矩阵 根据 索引矩阵 取到 对应值矩阵 的一个过程。
这允许我们定义交叉熵损失函数:
def cross_entropy(y_hat, y):
return - torch.log(y_hat.gather(1, y.view(-1, 1)))
官方文档解释
官方文档中对gather函数的描述非常简洁,即按设定的维度方向,按该维度上的索引值取值。
给出的例子也很简单。但是直观上看,这是一个非常符合逻辑的缠绕函数,很难想象在维度方向上根据索引取值。
我的思考是为什么我们要用这么别扭的方法取值呢?定义交叉熵损失函数为啥要用gather呢?
从交叉熵损失函数定义的角度理解gather函数的使用
- 在Softmax回归中,我们知道要使用交叉熵损失函数来计算两个概率分布之间的差异。
还是一开始的例子,我们有 2个样本3类别的预测概率分布 y_hat 和实际的概率分布 y ,如图: - 根据交叉熵损失函数:
对于softmax回归来说,真实的概率分布只有在正确的类别值为1,所以损失函数最后可以化简为: 。即在预测概率分布中取到第p个值,而p就是索引值,对应正确的类别。 - 对于一个batch的数据来说,每一行是一个样本的各类别概率分布,每一行都会有一个正确分类,位于第
列。这个第 列也是真实分布 y 每一行中1的位置。所以我们只要知道每个样本正确的类别所在列的标号就行了。
即对于实际概率分布 y 矩阵: 我们只需要把 y 矩阵转变为索引矩阵。
故对 y_hat 的每个样本取索引值就得到了对应类别的预测概率。之后就可以进行取log等操作来定义损失函数了。 - 这样,根据一个索引矩阵对一个原始数据矩阵的行(或列)取索引对应的值,就是gather函数的具体做的事了。而交叉熵损失函数的定义正好符合了这样的数据特征,所以就能正好使用gather函数了。
这样做有什么好处?个人认为:
- 将真实概率矩阵(抑或叫One-Hot编码矩阵)缩减为索引矩阵可以大大减少内存开销。
- 索引值矩阵的维数没有变化,仍然保留了并行矩阵运算效率高的优势。
torch.gather()
讲了半天交叉熵损失函数,最后讲讲gather函数的使用。
对于初学者来说,最重要的参数就是dim。在开头定义交叉熵损失函数的时候y_hat.gather(1, y.view(-1, 1))
,参数dim=1,以行为方向进行索引。
直接用二维Tensor举例:
import torch
src = torch.arange(1, 16).reshape(5, 3)
"""
src:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[13, 14, 15]])
"""
# 定义两个索引矩阵
index1 = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
index2 = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]).t()
"""
index1:
tensor([[0, 1, 2],
[2, 3, 4],
[0, 2, 4]])
index2:
tensor([[1, 1],
[2, 2],
[0, 1],
[2, 0],
[1, 0]])
"""
# axis=0时,
output1 = src.gather(dim=0, index=index1)
print(output1)
"""
输出:
tensor([[ 1, 5, 9],
[ 7, 11, 15],
[ 1, 8, 15]])
"""
# axis=1
output2 = src.gather(dim=1, index=index2)
print(output2)
"""
输出:
tensor([[ 2, 2],
[ 6, 6],
[ 7, 8],
[12, 10],
[14, 13]])
"""
当一个数据矩阵使用索引矩阵取某个维度方向的值时,该方向的所有值都按照索引矩阵中该方向的索引值取值。
我们直接看下图就更加清晰明了(原图中Dim以1、2来说明,在pytorch中即为0、1):
指定了一个方向以后,index以指定方向一刀切下去,对这个方向的数字取值。
index和src的维度必须要一致,且除去切的那一维,其余维度shape也一致(扩展到三维也是这样)。
文章出处登录后可见!