Pytorch/Python中item()的用法

前言

在使用Pytorch训练模型时,用到python中的item()函数,如:

train_loss += loss.item()

现对item()函数用法做出总结。item()函数的作用是从包含单个元素的张量中取出该元素值,并保持该元素的类型不变。,即:该元素为整形,则返回整形,该元素为浮点型,则返回浮点型。官网解释如下:

Pytorch官网:https://pytorch.org/docs/stable/tensors.html?highlight=item#torch.Tensor.item

实验

做个测试:

import torch

x = torch.randn(2, 2)

print(x)
print(x[0,0])
print(x[0,0].item())

Output:

tensor([[-0.1405,  2.4767],
        [-0.6847,  0.0057]])
tensor(-0.1405)
-0.14052967727184296

总结

  1. 计算loss或者accuracy时,经常使用item()函数,而不是直接取对应的元素x[i,j]。
  2. item()函数取值时,保持该元素的类型不变。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年7月13日
下一篇 2023年7月13日

相关推荐