深度学习(pytorch载入onnx测试)

发布时间 2023-10-19 23:01:56作者: Dsp Tian

测试模型用之前文章训练的Alexnet模型。

首先将pth文件转为onnx文件:

import torch
import torch.nn as nn

# 自定义AlexNet模型
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=11, stride=4)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)

        self.fc1 = nn.Linear(256*6*6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 2)

    def forward(self, x):
        x = torch.relu(self.conv1(x))            # 227*227   -> 96*55*55
        x = torch.max_pool2d(x, 3, stride=2)     # 96*55*55  -> 96*27*27
        x = torch.relu(self.conv2(x))            # 96*27*27  -> 256*27*27
        x = torch.max_pool2d(x, 3, stride=2)     # 256*27*27 -> 256*13*13
        x = torch.relu(self.conv3(x))            # 256*13*13 -> 384*13*13
        x = torch.relu(self.conv4(x))            # 384*13*13 -> 384*13*13
        x = torch.relu(self.conv5(x))            # 384*13*13 -> 256*13*13
        x = torch.max_pool2d(x, 3, stride=2)     # 256*13*13 -> 256*6*6
        x = x.view(x.size(0), -1)                # 256*6*6   -> 9216
        x = torch.relu(self.fc1(x))              # 9216      -> 4096
        x = torch.relu(self.fc2(x))              # 4096      -> 4096
        x = self.fc3(x)                          # 4096      -> 2
        return x

# 加载模型参数
model = AlexNet()
model.load_state_dict(torch.load('alexnet.pth'))
model.eval()

# 创建一个虚拟的输入张量
dummy_input = torch.randn(1, 1, 227, 227)  # 假设输入图像尺寸为28*28

# 导出模型为ONNX格式
onnx_filename = 'alexnet.onnx'

torch.onnx.export(model, dummy_input, onnx_filename, verbose=False,input_names=["image"],output_names=["class"])

print(f"Model successfully exported as {onnx_filename}.")

然后用pth文件和onnx文件分别测试,输出结果类似:

import onnxruntime
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn as nn

# 自定义AlexNet模型
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=11, stride=4)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)

        self.fc1 = nn.Linear(256*6*6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 2)

    def forward(self, x):
        x = torch.relu(self.conv1(x))            # 227*227   -> 96*55*55
        x = torch.max_pool2d(x, 3, stride=2)     # 96*55*55  -> 96*27*27
        x = torch.relu(self.conv2(x))            # 96*27*27  -> 256*27*27
        x = torch.max_pool2d(x, 3, stride=2)     # 256*27*27 -> 256*13*13
        x = torch.relu(self.conv3(x))            # 256*13*13 -> 384*13*13
        x = torch.relu(self.conv4(x))            # 384*13*13 -> 384*13*13
        x = torch.relu(self.conv5(x))            # 384*13*13 -> 256*13*13
        x = torch.max_pool2d(x, 3, stride=2)     # 256*13*13 -> 256*6*6
        x = x.view(x.size(0), -1)                # 256*6*6   -> 9216
        x = torch.relu(self.fc1(x))              # 9216      -> 4096
        x = torch.relu(self.fc2(x))              # 4096      -> 4096
        x = self.fc3(x)                          # 4096      -> 2
        return x

def load_img():
    img = Image.open("./data/cat_dog/1.jpeg").convert('L').resize((227, 227))
    img = transforms.ToTensor()(img)
    img = img.unsqueeze(0)
    return img

def test_pth():
    img = load_img()
    model = AlexNet()
    model.load_state_dict(torch.load("alexnet.pth"))
    model.eval()
    outs = model(img)
    print(outs)

def test_onnx():
    img = load_img()
    session = onnxruntime.InferenceSession('alexnet.onnx')
    inputs = {session.get_inputs()[0].name : img.numpy()}
    outs = session.run(None,inputs)
    print(outs)

test_onnx()
test_pth()