Pytorch学习笔记

总结学习Pytorch时的一些API

内容

1 torch.sort()

1 torch.sort()

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

1.1 作用

根据给定维度按升序或降序对输入张量进行排序。

1.2 参数

input: 需要是一个torch.Tensor类型的张量

dim: 给定一个张量的维度(int型),按照这个维度上的数值进行排序。如果不指定,默认按照张量的最后一个维度进行排序。

descending:传入一个布尔类型的数据(Ture、False),True代表降值排序,False代表升值排序。如果不指定,默认升值排序。

stable:传入一个布尔类型的数据(Ture、False),当一个张量中存在多个相同数字时,例如[2, 2, 1, 1],传入True不会打乱同一个数字的先后顺序(第一个1会排在第一个,第二个1会排在第二个)。如果不指定,默认False。

out:(Tensor, LongTensor) 的输出元组,可以选择用作输出缓冲区。如果不指定,默认None。

1.3 举例

首先只传入张量,其他参数默认:

import torch

tensor_a = torch.tensor([[2, 1],
                         [3, 4],
                         [6, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)

#---------输出---------#
tensor([[1, 2],
        [3, 4],
        [5, 6]]) 
 tensor([[1, 0],
        [0, 1],
        [1, 0]])
#----------------------#

dim = 0 的情况:

import torch

tensor_a = torch.tensor([[6, 1],
                         [1, 4],
                         [2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0)
print(sorted_tensor_a, '\n', indices)

#---------输出---------#
tensor([[1, 1],
        [2, 4],
        [6, 5]]) 
 tensor([[1, 0],
        [2, 1],
        [0, 2]])
#----------------------#

descending = True 的情况:

import tensor

tensor_a = torch.tensor([[6, 1],
                         [1, 4],
                         [2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0, descending=True)
print(sorted_tensor_a, '\n', indices)

#---------输出---------#
tensor([[6, 5],
        [2, 4],
        [1, 1]]) 
 tensor([[0, 2],
        [2, 1],
        [1, 0]])
#----------------------#

stable =  True 的情况:

import torch

tensor_a = torch.tensor([0, 1] * 9)

sorted_tensor_a, indices = torch.sort(tensor_a, stable=True)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) 
 tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17])
#----------------------------------------------------------#

sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) 
 tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17])
#----------------------------------------------------------#

可以看到,在我的运行结果中,stable无论是True还是False,好像结果都是一样的,但是以下是官方教程中的例子:

Pytorch学习笔记 不知道为什么我的和官方输出的不一样。 . .

版权声明:本文为博主Balaboo原创文章,版权归属原作者,如果侵权,请联系我们删除!

原文链接:https://blog.csdn.net/fuss1207/article/details/123044790

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年2月22日 下午4:03
下一篇 2022年2月22日 下午4:24

相关推荐