y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
这段代码使用了两个 PyTorch 张量 y
和 y_hat
,并通过 y_hat[[0, 1], y]
代码片段返回了一个包含两个元素的一维张量。
具体来说,该代码分为以下几步:
y = torch.tensor([0, 2])
:创建了一个名为y
的一维 PyTorch 张量,其中包含了两个元素 0 和 2。y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
:创建了一个名为y_hat
的二维 PyTorch 张量,其中包含了两个样本和三个类别的预测结果,即第一个样本属于第一个类别的概率为 0.1,属于第二个类别的概率为 0.3,属于第三个类别的概率为 0.6;第二个样本属于第一个类别的概率为 0.3,属于第二个类别的概率为 0.2,属于第三个类别的概率为 0.5。y_hat[[0, 1], y]
:使用了 PyTorch 张量的切片操作,选取了y_hat
张量中的第 1 行和第 3 行(即[0, 2]
中的元素),以及第 1 列和第 3 列(即y
中的元素)交叉组合得到的四个元素,即y_hat[0, 0]
、y_hat[1, 2]
、y_hat[0, 0]
、y_hat[1, 2]
,并将它们放到一个形状为(2,)
的一维张量中返回。
可见,这段代码的作用是选取 y_hat
张量中对应真实标签的部分预测结果,用于计算损失函数或评估模型的准确率等。