pytorch中的矩阵乘法操作:torch.mm(), torch.bmm(), torch.mul()和*, torch.dot(), torch.mv(), @, torch.matmul()

😄 无聊整理下torch里的张量的各种乘法相关操作。

文章目录

  • 0、简单提一下广播法则的定义:
  • 1、torch.mm()
  • 2、torch.bmm()
  • 3、torch.mul()和*
  • 4、torch.dot()
  • 5、torch.mv()
  • 6、@
  • 7、torch.matmul()

0、简单提一下广播法则的定义:

  • 1、让所有输入张量都向其中shape最长的矩阵看齐,shape不足的部分在前面加1补齐。
  • 2、两个张量的维度要么在某一个维度一致,若不一致其中一个维度为1也可广播。否则不能广播。【如两个维度:(4, 1, 4)和(2, 1)可以广播,因为他们不相等的维度其中一个为1就可以广播了。】

1、torch.mm()

- 只适合于二维张量的矩阵乘法。
- m x n, n x p -> m x p

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 4)
out = torch.mm(mat1, mat2)
out.shape
# torch.Size([2, 4])

2、torch.bmm()

- 只适合于三维张量的矩阵乘法,与torch.mm类似,但多了一个batch_size维度。
- b x m x n, b x n x p -> b x m x p

mat1 = torch.randn(8, 2, 3)
mat2 = torch.randn(8, 3, 4)
out = torch.bmm(mat1, mat2)
out.shape
# torch.Size([8, 2, 4])

3、torch.mul()和*

  • - ⭐ torch.mul()和*等价。
    - 张量对应位置元素相乘。
    - 将输入张量input的每个元素与另一个向量or标量other相乘,返回一个新的张量out,两者维度需满足广播规则
# 方式1:张量 和 标量相乘
input = torch.randn(3)
other = 100
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([3])

# 方式2:张量 和 张量(需满足广播规则)
input = torch.randn(4, 1, 4)
other = torch.randn(2, 1)
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([4, 2, 4])

# 方式3:元素对应项相乘
input = torch.randn(3, 2)
other = torch.randn(3, 2)
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([3, 2])

4、torch.dot()

向量点积:两向量对应位置相乘然后全部相加。只能支持两个一维向量。

5、torch.mv()

矩阵和向量的乘法

  • 第一个参数只能是二维的,第二个参数是一维的,则在其维数末尾追加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(2,6)和(6)运算过程:(2,6)和(6,1) -> (2,1) -> (2)
mat1 = torch.randn(6,8)
mat2 = torch.randn(8)
out = torch.mv(mat1, mat2)
out.shape
# torch.Size([6])

6、@

torch中的@操作是可以实现前面的某几个函数,是一种强大的操作。

  • 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
  • 若mat1是二维向量,mat2是一维向量,那么对应操作就是torch.mv()
  • 若mat1和mat2都是两个二维向量,那么对应操作就是torch.mm()

7、torch.matmul()

torch.matmul()与@操作类似,但是torch.matmul()不止局限于一维和二维,可以进行高维张量的乘法。两个张量的矩阵乘积。其行为取决于张量的维数如下:

  • 1、如果两个张量都是一维的,则返回点积(标量)。

  • 2、如果两个参数都是二维的,则返回矩阵-矩阵乘积。

  • 3、如果第一个参数是二维的,第二个参数是一维的,则在其维数末尾追加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(2,6)和(6)运算过程:(2,6)和(6,1) -> (2,1) -> (2)

  • 4、如果第一个参数是一维的(则在其维数前加一个1,),第二个参数是二维的,则返回矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(6)和(6,2)运算过程:(1,6)和(6,2) -> (1,2) -> (2)

  • 5、对3和4的总结。如果两个参数至少是一个参数是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)

  • 两个参数都是N维(>2),只有非矩阵的维度才是可以广播的,最后两维需满足矩阵乘法即m x n, n x p -> m x p。如bx1xnxm, kxmxp -> jxkxnxp

  >>> # vector x vector
  >>> tensor1 = torch.randn(3)
  >>> tensor2 = torch.randn(3)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([])
  >>> # matrix x vector
  >>> tensor1 = torch.randn(3, 4)
  >>> tensor2 = torch.randn(4)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([3])
  >>> # batched matrix x broadcasted vector
  >>> tensor1 = torch.randn(10, 3, 4)
  >>> tensor2 = torch.randn(4)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([10, 3])
  >>> # batched matrix x batched matrix
  >>> tensor1 = torch.randn(10, 3, 4)
  >>> tensor2 = torch.randn(10, 4, 5)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([10, 3, 5])
  >>> # batched matrix x broadcasted matrix
  >>> tensor1 = torch.randn(10, 3, 4)
  >>> tensor2 = torch.randn(4, 5)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([10, 3, 5])

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年11月2日
下一篇 2023年11月2日

相关推荐