CrossEntropyLoss: RuntimeError: expected scalar type Float but found Long neural network

发布时间 2023-11-27 22:54:01作者: 哆啦小火车

错误分析

  这个错误通常指的是期望接受的参数类型是Float, 但是程序员传入的是Int 。 通常会需要我们去检查传入的 inputtarget 的数据类型有没有匹配。在传入的数据中,通常 input 希望是 Float 类型,targetInt 类型。
  但是通常也许会发现传入的参数是符合要求的,但是仍然会报这样的错误,那么这个时候就需要注意查看 CrossEntropyLoss 中传入的参数 weight 的类型,传入的参数weight 也必须是一个浮点数,即,如果你设置成 [1, 2] 也必须写成 [1.0, 2.0] 的形式。

样例

   CorssEntropyLoss 的参数使用的样例代码如下:

class_weights = torch.tensor([1.0, 2.0], device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

通常, 这个参数是我们在做分类任务时,当我们期待对少数类样本投以更多关注时就可以开始设置,在异常检测的领域比较常见。