深度学习中的max函数

首先生成一个3*3的Tensor

输入

a = torch.randn(3,3)
print(a)

输出

tensor([[-1.2324,  1.5269, -1.0409],
        [ 1.3819,  0.1164,  0.8842],
        [ 0.2303,  0.0043,  0.4198]])

输入

# 输出的是张量a中每一列中的最大值   以及最大值所处的索引
print(torch.max(a, 0))

输出

values=tensor([1.3819, 1.5269, 0.8842]),
indices=tensor([1, 0, 1]))

输入

# 输出的是张量a中每一列中的最大值
print(torch.max(a, 0)[0])

输出

tensor([1.3819, 1.5269, 0.8842])

输入

# 输出的是张量a中每一列中最大值所对应的索引
print(torch.max(a, 0)[1])

输出

tensor([1, 0, 1])

输入

# 输出的是张量a中每一行中的最大值   以及最大值所处的索引
print(torch.max(a, 1))

输出

values=tensor([1.5269, 1.3819, 0.4198]),
indices=tensor([1, 0, 2]))

输入

# 输出的是张量a中每一行的最大值
print(torch.max(a, 1)[0])

输出

tensor([1.5269, 1.3819, 0.4198])

输入

# 输出是张量a中每一行的最大值所对应的索引
print(torch.max(a, 1)[1])

输出

tensor([1, 0, 2])

Summary:
torch.max(张量,0) 输出的是张量每一列中的最大值,还有最大值所对应的索引,输出类型为张量
torch.max(张量,1) 输出的是张量每一行中的最大值,还有最大值所对应的索引,输出类型都为张量
torch.max()[0] 输出的仅仅是每一行或每一列的最大值
torch.max()[1] 输出的仅仅是每一行或每一列最大值所对应的索引

下面这种形式是在深度学习中用的最多的

# 输出的是张量a中每一列中的最大值   以及最大值所处的索引  行数为1 列数保持不变
print(a.max(0, keepdim=True))
# 输出的是张量a中每一列中的最大值
print(a.max(0, keepdim=True)[0])
print(a.max(0, keepdim=True)[0].size())
# 输出的是张量a中每一列最大值所对应的索引
print(a.max(0, keepdim=True)[1])
print(a.max(0, keepdim=True)[1].size())


# 输出的是张量a中每一行中的最大值   以及最大值所处的索引  行数为保持不变 列数为1
print(a.max(1, keepdim=True))
# 输出的是张量a中每一行中的最大值   
print(a.max(1, keepdim=True)[0])
print(a.max(1, keepdim=True)[0].size())
# 输出的是张量a中每一行中最大值所对应的索引  
print(a.max(1, keepdim=True)[1])
print(a.max(1, keepdim=True)[1].size())

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年5月31日
下一篇 2022年5月31日

相关推荐