5.5 读写文件

发布时间 2023-07-26 10:09:43作者: Ann-

1. 加载和保存张量

我们可以使用torch提供的函数torch.save以及torch.load对张量进行保存和加载。

 torch.save的第一个参数是要保存的张量,第二个参数是要保存成为的文件名。注意,上面的代码中我们先创建张量x = [1,2,3,4],将其保存为文件x-file,此时x-file已存在。而我们再将新的x = [3,4,5,6]保存为x-file,便将x-file的内容覆盖了,再读取x-file并打印它,可以看到它的值是[3,4,5,6]了。

 torch.save还可以保存张量构成的列表:

 torch.save同样可以保存“以字符串为键,以张量为值”的字典。

 

2. 加载和保存模型参数

pytorch提供了保存整个模型参数的函数,而并不能保存模型结构(如,3层的多层感知机还是4层?)。网络模型的state_dict()函数返回模型各层的名字和参数构成的字典,我们可以用torch.save来保存网络的state_dict()。为了从保存的参数文件中恢复模型,我们需要创建原网络的一个备份。看一个例子,首先我们定义一个多层感知机:

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))

我们将net的参数保存到一个名字为mlp.params的文件中,使用torch.save(net.state_dict(),'mlp.params'):

torch.save(net.state_dict(),'mlp.params')

加载这个模型需要创建与原net相同结构的一个网络,并将参数加载到这个网络中,我们创建一个Clone,并使用Clone.load_state_dict()加载参数:

Clone = MLP()
Clone.load_state_dict(torch.load('mlp.params'))

此时,net的结构和参数与Clone的结构和参数应该是完全相同的了。那么,它们对任意输入X的输出也应该是相同的,我们验证一下:

Y = net(X)
Y2 = Clone(X)
print(Y==Y2)

输出为:

 

总结:

保存模型:用torch.save(net.state_dict,'filename')来保存参数

加载模型:首先创建与要加载的模型相同结构的网络Clone,然后用Clone.load_state_dict()加载“用torch.load('filename')加载到的”参数。