模型超参数基本都没改,测试时加载模型报模型结构不匹配,设置模糊加载模型即:model.load_state_dict(torch.load(model_path), strict=Fasle),但效果出奇的差

发布时间 2023-08-15 17:35:03作者: Kurie

原因

跑模型的时候,用的是多卡加载torch.nn.DataParallel(self.model),测试是用的单卡模糊加载保存的模型权重,很多模型参数都没有加载成功,自然会导致测试效果很差。

解决方法

`

如果你想要用nn.DataParallel来加载模型

state_dict = torch.load('model.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'module.' + k # 添加'module.'前缀
new_state_dict[name] = v
model.load_state_dict(new_state_dict)

如果你想要不用nn.DataParallel来加载模型

state_dict = torch.load('model.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # 删除'module.'前缀
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
`

解决效果

改动前:

改动后: