CrossEntropyLoss

发布时间 2023-10-10 23:14:03作者: oneDonkey

输入x多一个维度,输出差距怎么那么大?

import torch
import numpy as np

x = torch.randn((64,224,224))

y = torch.rand((64,224,224))

y = (y > 0.5).float()
y = torch.tensor(y)
fun = torch.nn.CrossEntropyLoss()
print(fun(x,y))   //输出:tensor(661.5663)
import torch
import numpy as np

x = torch.randn((64,2,224,224))

y = torch.rand((64,224,224))

y = (y > 0.5).float()
y = torch.tensor(y).to(torch.long)
fun = torch.nn.CrossEntropyLoss()
print(fun(x,y))  //输出:tensor(0.9030)