RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the

慈云数据 2024-03-12 技术支持 138 0

bug:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the
(图片来源网络,侵删)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

源代码如下:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the
(图片来源网络,侵删)
if __name__ == "__main__":
    from torchsummary import summary
    model = UNet()
    print(model)
    summary(model, input_size=(1, 480, 480))

使用torchsummary可视化模型时候报错,报这个错误是因为类型不匹配,根据报错内容可以看出Input type为torch.FloatTensor(CPU数据类型),而weight type(即网络权重参数这些)为torch.cuda.FloatTensor(GPU数据类型)。

我们将model传到GPU上便可。将代码如下修改便可正常运行

if __name__ == "__main__":
    from torchsummary import summary
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)	# modify
    print(model)
    summary(model, input_size=(1, 480, 480))
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon