torch.argmax()函数-截至2023年8月28日

发布时间 2023-08-28 09:58:48作者: 你的-昵称

argmax函数参数dim=0表示从列获取最大值索引,dim=1从行获取最大值索引,dim=-1从最后一个维度获取最大值索引[1]

举例

import torch
d = torch.tensor([[9,7,6],
				[4,8,2],
                 [5,10,0]])
print(torch.argmax(d , dim=0))#结果应为9,10,6的所在列的索引==》0,2,0
print(torch.argmax(d , dim=1))#结果应为9,8,10所在行的索引==》0,1,1
print(torch.argmax(d , dim=-1))#结果应为9,8,10所在行的索引==》0,1,1

运行结果

image-20230826153122554


  1. https://blog.csdn.net/weixin_42494287/article/details/92797061 ↩︎