站点图标 AI技术聚合

Pytorch/Python中item()的用法

Table of Contents

前言

在使用Pytorch训练模型时,用到python中的item()函数,如:

train_loss += loss.item()

现对item()函数用法做出总结。item()函数的作用是从包含单个元素的张量中取出该元素值,并保持该元素的类型不变。,即:该元素为整形,则返回整形,该元素为浮点型,则返回浮点型。官网解释如下:

Pytorch官网:https://pytorch.org/docs/stable/tensors.html?highlight=item#torch.Tensor.item

实验

做个测试:

import torch

x = torch.randn(2, 2)

print(x)
print(x[0,0])
print(x[0,0].item())

Output:

tensor([[-0.1405,  2.4767],
        [-0.6847,  0.0057]])
tensor(-0.1405)
-0.14052967727184296

总结

  1. 计算loss或者accuracy时,经常使用item()函数,而不是直接取对应的元素x[i,j]。
  2. item()函数取值时,保持该元素的类型不变。

文章出处登录后可见!

已经登录?立即刷新
退出移动版