RuntimeError: expected scalar type float but found __int64

问题描述

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-d9bacc2c4126> in <module>
     44 
     45 gat = GATConv(dataset.num_features, 16)
---> 46 gat(data.x, data.edge_index).shape

D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-30-d9bacc2c4126> in forward(self, x, edge_index)
     31 
     32         adj = to_dense_adj(edge_index)
---> 33         attention = torch.where(adj > 0, e, 0)
     34 
     35         attention = F.softmax(attention, dim=1)

RuntimeError: expected scalar type float but found __int64

原因分析:

调用torch.where()时传入了int类型整数,但是函数的输入参数要求传入float类型数据,所以修改下类型即可。

解决方案:

attention = torch.where(adj > 0, e, torch.tensor(0, dtype=torch.float32))

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年7月13日
下一篇 2023年7月13日

相关推荐