总结学习Pytorch时的一些API
内容
1 torch.sort()
1 torch.sort()
torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)
1.1 作用
根据给定维度按升序或降序对输入张量进行排序。
1.2 参数
input: 需要是一个torch.Tensor类型的张量
dim: 给定一个张量的维度(int型),按照这个维度上的数值进行排序。如果不指定,默认按照张量的最后一个维度进行排序。
descending:传入一个布尔类型的数据(Ture、False),True代表降值排序,False代表升值排序。如果不指定,默认升值排序。
stable:传入一个布尔类型的数据(Ture、False),当一个张量中存在多个相同数字时,例如[2, 2, 1, 1],传入True不会打乱同一个数字的先后顺序(第一个1会排在第一个,第二个1会排在第二个)。如果不指定,默认False。
out:(Tensor, LongTensor) 的输出元组,可以选择用作输出缓冲区。如果不指定,默认None。
1.3 举例
首先只传入张量,其他参数默认:
import torch
tensor_a = torch.tensor([[2, 1],
[3, 4],
[6, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)
#---------输出---------#
tensor([[1, 2],
[3, 4],
[5, 6]])
tensor([[1, 0],
[0, 1],
[1, 0]])
#----------------------#
dim = 0 的情况:
import torch
tensor_a = torch.tensor([[6, 1],
[1, 4],
[2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0)
print(sorted_tensor_a, '\n', indices)
#---------输出---------#
tensor([[1, 1],
[2, 4],
[6, 5]])
tensor([[1, 0],
[2, 1],
[0, 2]])
#----------------------#
descending = True 的情况:
import tensor
tensor_a = torch.tensor([[6, 1],
[1, 4],
[2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0, descending=True)
print(sorted_tensor_a, '\n', indices)
#---------输出---------#
tensor([[6, 5],
[2, 4],
[1, 1]])
tensor([[0, 2],
[2, 1],
[1, 0]])
#----------------------#
stable = True 的情况:
import torch
tensor_a = torch.tensor([0, 1] * 9)
sorted_tensor_a, indices = torch.sort(tensor_a, stable=True)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])
#----------------------------------------------------------#
sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])
#----------------------------------------------------------#
可以看到,在我的运行结果中,stable无论是True还是False,好像结果都是一样的,但是以下是官方教程中的例子:
不知道为什么我的和官方输出的不一样。 . .
版权声明:本文为博主Balaboo原创文章,版权归属原作者,如果侵权,请联系我们删除!
原文链接:https://blog.csdn.net/fuss1207/article/details/123044790