问题描述
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-30-d9bacc2c4126> in <module>
44
45 gat = GATConv(dataset.num_features, 16)
---> 46 gat(data.x, data.edge_index).shape
D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
<ipython-input-30-d9bacc2c4126> in forward(self, x, edge_index)
31
32 adj = to_dense_adj(edge_index)
---> 33 attention = torch.where(adj > 0, e, 0)
34
35 attention = F.softmax(attention, dim=1)
RuntimeError: expected scalar type float but found __int64
原因分析:
调用torch.where()
时传入了int类型整数,但是函数的输入参数要求传入float类型数据,所以修改下类型即可。
解决方案:
attention = torch.where(adj > 0, e, torch.tensor(0, dtype=torch.float32))
文章出处登录后可见!
已经登录?立即刷新