原因
跑模型的时候,用的是多卡加载torch.nn.DataParallel(self.model),测试是用的单卡模糊加载保存的模型权重,很多模型参数都没有加载成功,自然会导致测试效果很差。
解决方法
`
如果你想要用nn.DataParallel来加载模型
state_dict = torch.load('model.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'module.' + k # 添加'module.'前缀
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
如果你想要不用nn.DataParallel来加载模型
state_dict = torch.load('model.pth')
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # 删除'module.'前缀
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
`
解决效果
改动前:
改动后: