with torch.no_grad():注意事项

发布时间 2023-07-21 10:26:56作者: ML_WG

1。 当执行原地操作时,例如 tensor.add_(x),将会在一个张量上直接修改数据,而不会创建新的张量。由于修改了张量的数据,因此计算图会失效,即计算图中的操作和输入输出关系都会发生变化。这会导致反向传播无法正确计算梯度。因此,PyTorch 禁止在需要梯度计算的张量上执行原地操作。为了解决这个问题,可以使用 with torch.no_grad() 上下文管理器来禁用梯度计算,并对结果进行复制,从而得到一个新的张量,从而避免原地操作对计算图的破坏。

点击查看代码
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
	for param in params:
		param -= lr * param.grad / batch_size
		param.grad.zero_()
在这段代码中,我们使用 with torch.no_grad() 上下文管理器来禁用梯度计算,并使用小批量随机梯度下降算法来更新模型参数。在每次迭代中,我们遍历模型的每个参数,并根据其梯度信息更新参数的值。在更新参数时,我们使用 param -= lr * param.grad / batch_size 的原地操作来更新参数的值,并使用 param.grad.zero_() 将梯度信息清零,以便进行下一个小批量的更新。

2.具体来说,with torch.no_grad() 会创建一个上下文环境,在该环境中计算的所有张量都不会被跟踪其操作历史,并且梯度信息也不会被记录。这可用于评估模型,例如在验证集上计算模型的性能指标,或者在训练过程中定期输出模型的训练进度。由于我们不需要计算梯度信息,因此使用 with torch.no_grad() 可以显著降低计算开销并减少内存使用。

点击查看代码
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
在这段代码中,我们使用 with torch.no_grad() 上下文管理器计算了训练集的损失,并输出了当前迭代的训练进度。由于这些计算不需要梯度信息,因此可以使用 with torch.no_grad() 来禁用梯度计算,从而提高计算效率并减少内存使用。