【pytorch】torch.cdist使用说明

使用说明

torch.cdist的使用介绍如官网所示,

它是批量计算两个向量集合的距离。

其中, x1和x2是输入的两个向量集合。

p 默认为2,为欧几里德距离。

它的功能上等同于 scipy.spatial.distance.cdist(input,’minkowski’, p=p)

如果x1的shape是 [B,P,M], x2的shape是[B,R,M],则cdist的结果shape是 [B,P,R]

进一步的解释

x1一般是输入矢量,而x2一般是码本。

x2中所有的元素分别与x1中的每一个元素求欧几里德距离(当p默认为2时)

如下面示例

import torch

x1 = torch.FloatTensor([0.1, 0.2, 0, 0.5]).view(4, 1)

x2 = torch.FloatTensor([0.2, 0.3]).view(2, 1)

print(torch.cdist(x1,x2))

x2中的所有元素分别与x1中的每一个元素求欧几里德距离,即有如下步骤

【pytorch】torch.cdist使用说明

所以运行结果为

扩张到2维的情况

如下面示例

import torch

x1 = torch.FloatTensor([0.1, 0.2, 0.1, 0.5, 0.2, -0.9, 0.8, 0.4]).view(4, 2)

x2 = torch.FloatTensor([0.2, 0.3, 0, 0.1]).view(2, 2)

print(torch.cdist(x1,x2))

x1和x2数据是二维的,

x2中的所有元素分别与x1中的每一个元素求欧几里德距离,即有如下步骤

【pytorch】torch.cdist使用说明

所以结果如下

p=2的欧几里德距离也是L2范式,如果p=1即是L1范式
上面的例子修改一下p参数

import torch

x1 = torch.FloatTensor([0.1, 0.2, 0.1, 0.5, 0.2, -0.9, 0.8, 0.4]).view(4, 2)

x2 = torch.FloatTensor([0.2, 0.3, 0, 0.1]).view(2, 2)

print(torch.cdist(x1,x2,p=1))

结果如下,这里就不一个一个运算了。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年6月4日
下一篇 2023年6月4日

相关推荐