UNet pytorch模型转ONNX模型完整code

发布时间 2023-11-14 16:52:47作者: 猪大大BiuBiuBiu
 1 import os
 2 import torch
 3 import numpy as np
 4 from Unet import UNET
 5 os.environ["CUDA_VISIBLE_DEVICE"] = ""
 6 
 7 def main():
 8     demo = Demo(model_path="/xxx.pth.tar", output="pathto/xxx.onnx")
 9     demo.inference()
10     check_onnx(onnx_pth="path to xxx.onnx")
11 
12 
13 
14 #检查onnx模型
15 def check_onnx(onnx_pth):
16     import onnx
17     #load the ONNX model
18     model = onnx.load(onnx_pth)
19     #check the IR is well formed
20     onnx.checker.check_model(model)
21     #print a human readable representation of graph
22     print(onnx.helper.printable_graph(model.graph))
23 
24 class WrappedModel(torch.nn.Module):
25     def __init__(self,model):
26         super().__init__()
27         self.model =model
28 
29     def forward(self,x):31         outs=self.model(x)
32         new_outs=torch.sigmoid(outs)
33         return new_outs
34 
35 
36 class Demo():
37     def __init__(self,model_path,output):
38         self.model_path =model_path
39         self.output_path = output
40 
41     def init_torch_tensor(self):
42         self.device = 'cpu'#torch.device('cpu')
43         torch.set_default_tensor_type('torch.FloatTensor')
44         #use gpu or not
45         # if torch.cuda.is_available():
46         #     self.device = torch.device('cuda')
47         #     torch.set_default_tensor_type('torch.FloatTensor')
48         # else:
49         #     self.device = torch.device('cpu')
50         #     torch.set_default_tensor_type('torch.FloatTensor')
51     
52     def init_model(self,in_channels,out_channels):
53         model = UNET(in_channels=in_channels, out_channels=out_channels).to(self.device)#to('cuda')
54         return model
55     
56     def resume(self, model, path):
57         if not os.path.exists(path):
58             print("Checkpoint not found:" + path)
59             return
60         states = torch.load(path, map_location=self.device)#
61         model.load_state_dict(states["state_dict"],strict=False)#states有两个key_value"state_dict","optimizer"
62         
63         model_sig = WrappedModel(model)
64         print("Resume from " + path)
65         return model_sig
66 
67     def inference(self):
68         #use gpu or cpu
69         self.init_torch_tensor()
70         #加载网络模型
71         model = self.init_model(in_channels=3,out_channels=2)
72         model_sig=self.resume(model, self.model_path)
73         #设置model的模式
74         model_sig.eval()
75         #设置输入
76         img = np.random.randint(0,255, size=(512,512,3),dtype=np.uint8)
77         img = img.astype(np.float32)
78         img = img / 255#(img / 255. - 0.5)/0.5
79         img = img.transpose((2,0,1)) #C H W 
80         img = torch.from_numpy(img).unsqueeze(0).float()
81         #img = torch.randn(1,3,512,512)
82         '''
83         设置动态可变维度
84         KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。
85         VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。
86         '''
87         dynamic_axes = {'input':{0: 'batch_size', 2: 'height', 3: 'width'},
88                         'output': {0:'batch_size', 2: 'height', 3: 'width'}}
89         with torch.no_grad():
90             img = img.to(self.device)
91             torch.onnx.export(model_sig, img, self.output_path, input_names=['input'],
92                                 output_names=['output'], dynamic_axes=dynamic_axes, keep_initializers_as_inputs=False,export_params=True,
93                                 verbose=True, opset_version=11)
94 
95 if __name__ == '__main__':
96     main()