为自定义消息丢失加速 pytorch 操作

社会演员多 pytorch 177

原文标题Speeding up pytorch operations for custom message dropout

我正在尝试在 PyTorch Geometric 中的自定义 MessagePassing 卷积中实现消息丢失。消息丢失包括随机忽略图中 p% 的边。我的想法是从输入edge_indexinforward()中随机删除其中的 p%。

edge_index是形状(2, num_edges)的张量,其中第一个维度是“from”节点 ID,第二个维度是“to”节点 ID”。所以我想我能做的就是选择range(N)的随机样本,然后用它来屏蔽其余索引:

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

但是,它太慢了,它使一个纪元变为 5′ 而不是 1′ 而没有(慢 5 倍)。在 PyTorch 中是否有更快的方法来做到这一点?

编辑:瓶颈似乎是random.sample()调用,而不是掩蔽。所以我想我应该问的是更快的替代方案。

原文链接:https://stackoverflow.com//questions/71900767/speeding-up-pytorch-operations-for-custom-message-dropout

回复

我来回复
  • Michael的头像
    Michael 评论

    我设法使用 PyTorch 的功能性 Dropout 创建了一个布尔掩码,这要快得多。现在一个纪元再次需要〜1’。比我在其他地方找到的具有排列的其他解决方案更好。

        def forward(self, x, edge_index, edge_attr=None):
            if self.message_dropout is not None:
                # message dropout -> randomly ignore p % of edges in the graph
                mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
                edge_index_to_use = edge_index[:, mask]
                edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
            else:
                edge_index_to_use = edge_index
                edge_attr_to_use = edge_attr
    
            ...
    
    2年前 0条评论