主要问题:矩阵乘法时类型不一致。先看下面代码:
A = torch.arange(20,dtype=torch.float32).reshape(5,4)
a = torch.arange(4)
A.shape,a.shape,torch.mv(A, a)
出现该问题的主要原因是在A矩阵中已经设置了类型dtype=torch.float32,但在a向量中未设置,从下图可以看到,a的默认类型为int64,从而出现类型不一致的问题。
A = torch.arange(20,dtype=torch.float32).reshape(5,4)
a = torch.arange(4,dtype=torch.float32)
A.shape,a.shape,torch.mv(A, a)
文章出处登录后可见!
已经登录?立即刷新