如何在pytorch中交换x和y?

乘风 pytorch 195

原文标题How to swap x and y in pytorch?

给定输出张量为 100 * 6(xmin,ymin,xmax,ymax,conf,class),如何在 pytorch 中获得 100 * 6(ymin,xmin,ymax,xmax,conf,class) 的张量?例如,给定一个张量

x = [[1,2,3,4,5,6],
     [7,8,9,10,11,12]], 

期望的结果是

y = [[2,1,4,3,5,6],
     [8,7,10,9,11,12]]

原文链接:https://stackoverflow.com//questions/71580812/how-to-swap-x-and-y-in-pytorch

回复

我来回复
  • Ishi的头像
    Ishi 评论

    我对此并不完全确定,但试试 N6.T 还是 N6.transpose?

    2年前 0条评论
  • dtlam26的头像
    dtlam26 评论

    你应该更清楚地指出什么是xy,它可能与轴混淆。但我明白你的意思,你想交换 ymin 和 xmin 位置以及 xmax、ymax 位置。

    因此,最简单的方法是创建一个时间张量tempt_x并通过切片交换值。

    with x = [[1,2,3,4,5,6],[7,8,9,10,11,12]]
    
    >>> tempt_x = x.copy()
    >>> x[:,0] = tempt_x[:,1]
    >>> x[:,1] = tempt_x[:,0]
    >>> x[:,2] = tempt_x[:,3]
    >>> x[:,3] = tempt_x[:,2]
    >>> x
    array([[ 2,  1,  4,  3,  5,  6],
           [ 8,  7, 10,  9, 11, 12]])
    
    2年前 0条评论
  • I'mahdi的头像
    I'mahdi 评论

    您可以使用torch.reshape并根据需要更改轴,然后创建如下所示的最终张量:(因为您说您有 100 * 6。我添加了另外两行以显示此代码可以在扩展版本中使用。)

    张量版本

    import torch
    x = torch.tensor([
        [1,2,3,4,5,6],
        [7,8,9,10,11,12], 
        [13,14,15,16,17,18],
        [19,20,21,22,23,24],     
    ])
    
    
    conf_class = x[:, -2:]
    tmp_x = x[:, :-2]
    tmp_x = torch.reshape(tmp_x, (-1,2))
    tmp_x = torch.cat((tmp_x[:,1::2], tmp_x[:,::2]), 1)
    tmp_x = torch.reshape(tmp_x, (-1,4))
    res = torch.cat((tmp_x, conf_class), 1)
    print(res)
    

    输出:

    tensor([[ 2,  1,  4,  3,  5,  6],
            [ 8,  7, 10,  9, 11, 12],
            [14, 13, 16, 15, 17, 18],
            [20, 19, 22, 21, 23, 24]])
    

    麻木版本

    import numpy as np
    
    a = np.array([
        [1,2,3,4,5,6],
        [7,8,9,10,11,12],
        [13,14,15,16,17,18],
        [19,20,21,22,23,24],
    ])
    
    conf_class = a[:, -2:]
    tmp_a = a[:, :-2]
    tmp_a = tmp_a.reshape(-1,2)
    tmp_a = np.concatenate((tmp_a[:,1::2], tmp_a[:,::2]), 1)
    tmp_a = tmp_a.reshape(-1,4)
    res = np.concatenate((tmp_a, conf_class), 1)
    print(res)
    

    输出:

    [[ 2  1  4  3  5  6]
     [ 8  7 10  9 11 12]
     [14 13 16 15 17 18]
     [20 19 22 21 23 24]]
    
    2年前 0条评论