torch保存模型

发布时间 2023-11-24 11:06:05作者: 黑逍逍

保存模型有两种方式,方式不同,在调用模型的时候也不同

我更建议用torch.jit。。。这样不需要在写模型的参数

torch.save

保存模型:
import torch
import torch.nn as nn

# 假设 model 是你的 PyTorch 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

model = SimpleModel()

# 保存模型到文件
torch.save(model.state_dict(), 'model.pth')
解释:
model.state_dict() 返回模型的参数字典,torch.save 将这个字典保存到名为 model.pth 的文件中。

  

调用模型:
import torch
import torch.nn as nn

# 假设 model 是你的 PyTorch 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

model = SimpleModel()

# 加载模型参数
model.load_state_dict(torch.load('model.pth'))

# 将模型设为评估模式(如果是测试模型)
model.eval()
outputs = model(data.float())

  

torch.jit.script

TorchScript — PyTorch 2.1 documentation

torch.jit 模块是 PyTorch 中的即时(just-in-time)编译模块,提供了一种将 PyTorch 模型转换为脚本(script)或 Torch 脚本(TorchScript)的方法。Torch 脚本是一种中间表示形式,可以在不依赖 Python 解释器的情况下在 PyTorch 中运行。

可以将整个模型保存为一个 Torch 脚本文件,而不仅仅是模型的参数。这样做可以更轻松地保存和加载整个模型。

保存模型:
import torch
import torch.jit

# model 是我的 PyTorch 模型
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        return x + 1

model = SimpleModel()

# 将模型转换为 Torch 脚本
scripted_model = torch.jit.script(model)
# 保存 Torch 脚本到文件
scripted_model.save("scripted_model.pt")

 

# 调用模型 
loaded_model = torch.jit.load("scripted_model.pt")

# 将模型设为评估模式(如果是测试模型)
model.eval()
outputs = model(data.float())