Python报错 | RuntimeError: expected scalar type Long but found Float

发布时间 2023-07-07 13:04:57作者: 张Zong在修行

报错信息

在执行nlp自定义模型的训练函数的时候,报如下错误:

RuntimeError: expected scalar type Float but found Long

错误原因

错误信息指出了问题所在:模型期望的数据类型是 float,但实际上传递给模型的数据类型是 long。

这个错误通常是由于张量数据类型不匹配引起的。在 PyTorch 中,张量数据类型非常重要,因为它们指定了张量中存储的数值的精度和类型。如果您在模型的前向传递中使用了错误的数据类型,就会出现这个错误。

例如:

import torch
import torch.nn as nn

v = torch.tensor([0])
m = nn.Linear(1, 10)
m(v)

运行结果:

因为input也就是我们的v是torch.long类型的而weight是torch.float类型。所以在做矩阵乘法的时候这两种类型的不一致导致了报错。

解决方案

把v的dtype显示地设置成torch.float代码就成功运行了

import torch
import torch.nn as nn
# dtype=torch.float必不可少
v = torch.tensor([0], dtype=torch.float)
m = nn.Linear(1, 10)
m(v)

运行结果:

tensor([-0.6189, -0.9843, -0.7568,  0.9157,  0.5192, -0.6109, -0.5627, -0.7755,
        -0.9522,  0.7771], grad_fn=<AddBackward0>)