解决 clamp 函数会阻断梯度传播

发布时间 2024-01-10 17:33:54作者: 倒地

开端

若在网络的 forward 过程中使用 clamp 函数对数据进行截断,可能会阻断梯度传播。即,梯度变成零。

不妨先做一个实验。定义一个全连接网络 fc,通过输入 input_t 获得结果 pred,其值为 \(0.02\)

from torch.nn import functional as F
import torch.nn as nn
import torch

fc = nn.Linear(in_features=1, out_features=1, bias=True)
fc.weight.data = torch.tensor([[0.01]])
fc.bias.data = torch.tensor([[0.01]])

input_t = torch.tensor([[1.0]], dtype=torch.float32)
pred = fc(input_t)
print(pred)  # pred = 0.02

pred 进行反向传播,可以看到网络的权重都有梯度:

pred.backward()
print(fc.weight.grad)  # grad = 1.0
print(fc.bias.grad)  # grad = 1.0

如果使用 torch.clamp()pred 结果截断在 \((0.1,0.9)\) 范围内,会发现梯度消失了:

fc.zero_grad()
pred = fc(input_t)
pred = torch.clamp(pred, min=0.1, max=0.9)  # pred = 0.1
pred.backward()
print(fc.weight.grad)  # grad = 0.0
print(fc.bias.grad)  # grad = 0.0

解决方法

我们需要跳过 torch.clamp() 的梯度运算。

若是用 TensorFlow,使用 tf.stop_gradient 可以很方便地让 clamp 不参与梯度计算。在 PyTorch 就有点绕了。

最后我总结出了一个相对优雅的解决方案:

def nclamp(input, min, max):
    return input.clamp(min=min, max=max).detach() + input - input.detach()

参考来源