求一个字典的所有value中的最大值

发布时间 2023-09-18 09:07:00作者: 海_纳百川

已知一个字典中有多个类别key,每个类别value是一个torch.tensor(多个浮点型),求这个字典所有value中的最大值

import torch

my_dict = {
    'category1': torch.tensor([1.0, 2.0, 3.0]),
    'category2': torch.tensor([4.0, 5.0, 6.0]),
    'category3': torch.tensor([7.0, 8.0, 9.0])
}

# 使用列表推导和 torch.max() 找到最大值
max_value = max(torch.max(tensor).item() for tensor in my_dict.values())

print("最大值:", max_value)