loss.item()用法和注意事项详解

慈云数据 2024-03-13 技术支持 130 0

.item()方法是,取一个元素张量里面的具体元素值并返回该值,可以将一个零维张量转换成int型或者float型,在计算loss,accuracy时常用到。

loss.item()用法和注意事项详解
(图片来源网络,侵删)

作用

1.item()取出张量具体位置的元素元素值

loss.item()用法和注意事项详解
(图片来源网络,侵删)

2.并且返回的是该位置元素值的高精度值

3.保持原元素类型不变;必须指定位置

4.节省内存(不会计入计算图)

import torch
loss = torch.randn(2, 2)
print(loss)
print(loss[1,1])
print(loss[1,1].item())

输出结果

tensor([[-2.0274, -1.5974],

        [-1.4775,  1.9320]])

tensor(1.9320)

1.9319512844085693



其它:

loss = criterion(out, label)
    loss_sum += loss     # 
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon