模型权重保存、加载、冻结(pytorch)

发布时间 2023-07-31 21:45:53作者: 湾仔码农

1. 保存整个网络

torch.save(net, PATH) 
model = torch.load(PATH)

2. 保存网络中的参数(速度快,占空间小)

torch.save(net.state_dict(),PATH)
model_dict = model.load_state_dict(torch.load(PATH))

model.state_dict函数会以有序字典OrderedDict形式返回模型训练过程中学习的权重weight和偏置bias参数,只有带有可学习参数的层(卷积层、全连接层等),以及注册的缓存(batchnorm的运行平均值)在state_dict 中才有记录。以下面的LeNet为例:

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)  # output(16, 14, 14)
        x = F.relu(self.conv2(x))  # output(32, 10, 10)
        x = self.pool2(x)  # output(32, 5, 5)
        x = x.view(-1, 32 * 5 * 5)  # output(32*5*5)
        x = F.relu(self.fc1(x))  # output(120)
        x = F.relu(self.fc2(x))  # output(84)
        x = self.fc3(x)  # output(10)
        return x


net = LeNet()
# 打印可学习层的参数
print(net.state_dict().keys())

上面的模型中,只有卷积层和全连接层具有可学习参数,所以net.state_dict()只会保存这两层的参数,而激活函数层的参数则不会保存。层的名字是上面实例化时确定的,如果是利用nn.Sequential定义多个层时,用层的位置索引表示每个层,如下所示:

示例:用nn.Sequential搭建模型时的state_dict

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3, 16, 5),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.MaxPool2d(2, 2))

        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.feature(x)  # input(3, 32, 32)
        x = x.view(-1, 32 * 5 * 5)  # output(32*5*5)
        x = F.relu(self.fc1(x))  # output(120)
        x = F.relu(self.fc2(x))  # output(84)
        x = self.fc3(x)  # output(10)
        return x


net = LeNet()
# 打印可学习层的参数
print(net.state_dict().keys()) 

 

★模型加载

  • 当我们对网络模型结构进行优化改进时,如果改进的部分不包含可学习的层,那么可以直接加载预训练权重。如:如果我们对上述lenet模型进行改进,将激活函数层改为nn.Hardswish(),因为不包含可学习的参数,所以改进的模型的state_dict()没有改变,仍然可以直接加载lenet模型的权重文件。
  • 当我们改进的部分改变了可学习的参数时,如果直接加载预训练权重就会发生不匹配的错误,比如:卷积的维度改变后会报错 size mismatch for conv.weight...(2)新增一些层后会出现 Unexpected key(s) in state_dict等

解决方案:遍历预训练文件的每一层参数,将能够匹配成功的参数提取出来,再进行加载。

import torch
import torch.nn as nn
import torch.nn.functional as F


class LeNet_new(nn.Module):
    def __init__(self):
        super(LeNet_new, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)


    def forward(self, x):
        x = F.hardswish(self.conv1(x))  # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)  # output(16, 14, 14)
        x = F.hardswish(self.conv2(x))  # output(32, 10, 10)
        x = self.pool2(x)  # output(32, 5, 5)
        return x


def intersect_dicts(da, db):
    return {k: v for k, v in da.items() if k in db and v.shape == db[k].shape}


net = LeNet_new()
state_dict = torch.load("Lenet.pth")  # 加载预训练权重
print(state_dict.keys())
state_dict = intersect_dicts(state_dict, net.state_dict())  # 筛选权重参数
print(state_dict.keys())
net.load_state_dict(state_dict, strict=False)  # 模型加载预训练权重中可用的权重

3. 保存网络参数,同时保存优化器参数、损失值等(方便追加训练)

如果还想保存某一次训练采用的优化器、epochs等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来

# 保存
save_file = {"model": model.state_dict(),
                  "optimizer": optimizer.state_dict(),
                  "lr_scheduler": lr_scheduler.state_dict(),
                  "epoch": epoch,
                  "args": args}
torch.save(save_file, "save_weights/model_{}.pth".format(epoch))

# 加载
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

4. 冻结训练

在加载预训练权重后,可能需要固定一部分模型的参数,只更新另一部分参数。有两种思路实现这个目标,一个是设置不要更新参数的网络层为requires_grad = False,另一个就是在定义优化器时只传入要更新的参数。最优写法时:将不更新的参数的requires_grad设置为False,同时不将该参数传入optimizer

示例:LeNet网络+MNIST手写识别+预训练模型加载+冻结训练

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST(root='../dataset', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_data = datasets.MNIST(root='../dataset', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(1, 16, 5),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5),
            nn.MaxPool2d(2, 2))
        self.fc1 = nn.Linear(32 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.feature(x)
        x = x.view(-1, 32 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def train(epoch):
    loss_runtime = 0.0
    for batch, data in enumerate(tqdm(train_loader, 0)):
        x, y = data
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss_runtime += loss.item()
        loss_runtime /= x.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("after %s epochs, loss is %.8f" % (epoch + 1, loss_runtime))
    save_file = {"model": model.state_dict(),
                 "optimizer": optimizer.state_dict(),
                 "epoch": epoch}
    torch.save(save_file, "model_{}.pth".format(epoch))


def test():
    correct, total = 0, 0
    with torch.no_grad():
        for (x, y) in test_loader:
            x = x.to(device)
            y = y.to(device)
            y_pred = model(x)
            _, prediction = torch.max(y_pred.data, dim=1)
            correct += (prediction == y).sum().item()
            total += y.size(0)
            acc = correct / total
    print("accuracy on test set is :%5f" % acc)


if __name__ == '__main__':
    start_epoch = 0
    freeze_epoch = 0
    resume = "model_5.pth"
    freeze = True

    model = LeNet()
    device = ("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    # 加载预训练权重
    if resume:
        checkpoint = torch.load(resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

        # 冻结训练
        if freeze:
            freeze_epoch = 5
            print("冻结前置特征提取网络权重,训练后面的全连接层")
            for param in model.feature.parameters():
                param.requires_grad = False  # 将不更新的参数的requires_grad设置为False,节省了计算这部分参数梯度的时间
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.5)
            for epoch in range(start_epoch, start_epoch + freeze_epoch):
                train(epoch)
                test()
            print("解冻前置特征提取网络权重,接着训练整个网络权重")
            for param in model.feature.parameters():
                param.requires_grad = True
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.5)

    for epoch in range(start_epoch + freeze_epoch, 100):
        train(epoch)
        test()

  

 

 

 

参考:

1. 加载预训练权重