快速使用 Pytorch 的混合精度进行训练

发布时间 2023-12-29 01:07:34作者: 倒地

使用混合精度的代码示例

以非常基础的训练代码片段为例:

for epoch in range(epochs):
    model.train()
    for i, (images, labels) in enumerate(loader_train):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

···

按照这样进行增删,即可使用 cuda 的混合精度。

+from torch.cuda.amp import GradScaler, autocast
+scaler = GradScaler()
 for epoch in range(epochs):
     model.train()
 
     for i, (images, labels) in enumerate(loader_train):
         images = images.to(device)
         labels = labels.to(device)
 
         optimizer.zero_grad()
 
-        output = model(images)
-        loss = criterion(output, labels)
-        loss.backward()
-        optimizer.step()
+        with autocast():
+            output = model(images)
+            loss = criterion(output, labels)
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()

···