torch梯度计算相关

发布时间 2023-03-27 20:53:10作者: LinXiaoshu

torch梯度计算图

计算图中,默认只有叶子结点的梯度能够保留,如果要访问非叶子结点p的梯度数据,需要执行p.retain_grad().

torch计算图中requires_graddetach的区别

requires_gradtorch.Tensor中的属性,表示该张量是否需要计算梯度.而detach()则是方法,将此张量从当前计算图中脱离.这两者的区别在于:调用detach()后,默认将会把requires_grad设置为False.但脱离计算图只是阻断了此张量的梯度向后传播,脱离计算图仍然可以被计算梯度.

比如,在torch.utils.checkpoint中,有

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
    if isinstance(inputs, tuple):
        out = []
        for inp in inputs:
            if not isinstance(inp, torch.Tensor):
                out.append(inp)
                continue

            x = inp.detach()
            x.requires_grad = inp.requires_grad
            out.append(x)
        return tuple(out)
    else:
        raise RuntimeError(
            "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)

这里将inp从计算图中脱离,却仍然指定其需要梯度,就是要求此梯度被计算,但不传播.