y_hat[[0, 1], y]

发布时间 2023-04-28 22:16:51作者: 要多读书的陈小派

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]

这段代码使用了两个 PyTorch 张量 yy_hat,并通过 y_hat[[0, 1], y] 代码片段返回了一个包含两个元素的一维张量。

具体来说,该代码分为以下几步:

  • y = torch.tensor([0, 2]):创建了一个名为 y 的一维 PyTorch 张量,其中包含了两个元素 0 和 2。
  • y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]):创建了一个名为 y_hat 的二维 PyTorch 张量,其中包含了两个样本和三个类别的预测结果,即第一个样本属于第一个类别的概率为 0.1,属于第二个类别的概率为 0.3,属于第三个类别的概率为 0.6;第二个样本属于第一个类别的概率为 0.3,属于第二个类别的概率为 0.2,属于第三个类别的概率为 0.5。
  • y_hat[[0, 1], y]:使用了 PyTorch 张量的切片操作,选取了 y_hat 张量中的第 1 行和第 3 行(即 [0, 2] 中的元素),以及第 1 列和第 3 列(即 y 中的元素)交叉组合得到的四个元素,即 y_hat[0, 0]y_hat[1, 2]y_hat[0, 0]y_hat[1, 2],并将它们放到一个形状为 (2,) 的一维张量中返回。

可见,这段代码的作用是选取 y_hat 张量中对应真实标签的部分预测结果,用于计算损失函数或评估模型的准确率等。