onnx子图修改与动态静态转化

发布时间 2023-10-18 14:15:09作者: Wangtn

子图修改

import onnx
import onnx_graphsurgeon as gs
import onnxruntime as ort
import numpy as np

def cut_subgraph(origin_graph_path, input_node_name_list, output_node_name_list, sub_graph_path):
    graph = gs.import_onnx(onnx.load(origin_graph_path))
    tensors = graph.tensors()
    graph.inputs = []
    graph.outputs = []
    for input_node_name in input_node_name_list:
        graph.inputs.append(tensors[input_node_name])
    for output_node_name in output_node_name_list:
        graph.outputs.append(tensors[output_node_name])
    graph.cleanup()
    onnx.save(gs.export_onnx(graph), sub_graph_path)

 

动态转静态(静态转动态同理),这方法是改batch的

import onnx
import onnxruntime as ort
import numpy as np
import struct

def rebatch(infile_path, outfile_path, batch_size):
    model = onnx.load(infile_path)
    graph = model.graph
    for tensor in list(graph.input) + list(graph.value_info) + list(graph.output):
        tensor.type.tensor_type.shape.dim[0].dim_param = str(batch_size)
    for node in graph.node:
        if node.op_type != 'Reshape':
            continue
        for init in graph.initializer:
            if init.name != node.input[1]:
                continue
            if len(init.int64_data) > 0:
                init.int54_data[0] = -1
            elif len(init.raw_data) > 0:
                shape = bytearray(init.raw_data)
                struct.pack_into('q', shape, 0, -1)
                init.raw_data = bytes(shape)
    onnx.save(model, outfile_path)

 

调用

origin_graph_path = './del_maxmin/new_wdl.onnx'
input_node_name_list = ['Sum:0', 'Embedding_layer/embedding_lookup_3:0', 'Embedding_layer/embedding_lookup_1:0', 'Embedding_layer/embedding_lookup:0']
output_node_name_list = ['Softmax:0']
cut_subgraph(origin_graph_path, input_node_name_list, output_node_name_list, './simple_wdl.onnx')
rebatch('./simple_wdl.onnx', 'tt.onnx', 32)

 

裁剪前的模型图,红圈是我希望作为新的输入节点的位置

 他们的name可以和代码中的对应起来

ReduceSum的output

 gather的output

子图的输出节点

 

 裁剪完之后的模型

 

参考

https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon#examples

https://github.com/onnx/onnx/issues/2182#issuecomment-881752539