教师学生双网络

发布时间 2023-07-17 16:04:23作者: helloWorldhelloWorld

1. 声明教师,学生网络

backbone_model = Net(gps=opt.gps, blocks=opt.blocks)
backbone_model = backbone_model.to(device)
ema_model = Net(gps=opt.gps, blocks=opt.blocks)
ema_model = ema_model.to(device)

2. 教师网络不进行梯度更新

for param in backbone_model.parameters():
    param.requires_grad = True
for param in ema_model.parameters():
    param.requires_grad = False

3. 教师网络不进行梯度更新将input放入到教师网络中

with torch.no_grad():
    real_out = ema_model(real_hazy_img)

4. 将学生网络的参数传递到教师网络中

if opt.ema:
    state_dict_backbone = backbone_model.state_dict()
    state_dict_ema_model = ema_model.state_dict()
    for (k_backbone, v_backbone), (k_ema, v_ema) in zip(state_dict_backbone.items(), state_dict_ema_model.items()):
        assert k_backbone == k_ema
        assert v_backbone.shape == v_ema.shape
        if 'num_batches_tracked' in k_ema:
            v_ema.copy_(v_backbone)
        else:
            v_ema.copy_(v_ema * opt.momentum + (1. - opt.momentum) * v_backbone)  # momentum=0.999

5. 测试时进行两个教师、学生两个模型的测试

ssim_eval_1, psnr_eval_1, ssim_eval_2, psnr_eval_2 = test(backbone_model, ema_model, test_loader)
pred = backbone_model(input)
ema_pred = ema_model(input)