ControlNet-trt优化总结4:onnx图修改与重建

发布时间 2023-10-09 14:23:57作者: wildkid1024

ControlNet-trt优化总结4:onnx图修改与重建

在这一节中,主要总结网络层面的优化,针对于算子插件优化,主要聚焦于以下几点:

  • 修改onnx图,添加不支持的算子插件
  • 增加前后处理部分,前后处理导出为onnx图

onnx图surgeon

原有的graph中存在大量的GN操作,正常fp32的时候没有问题,但是当使用fp16时,由于GN中存在pow、exp等操作就会精度溢出,使得计算结果不准确。
一种方式就是手动改写添加GN算子,第一步就是要对onnx图进行surgeon操作,在原有的onnx图中插入GN算子,不过由于onnx的opset会把GN转化为IN+MM的方式处理,所以整个过程要分为两步,第一步是将IN分解为mean-sub-pow的形式,第二步则是将对应的算子模式重新捏回去为GN算子。

示例代码如下,这里分解代码分为3步,第一步是将原节点的输入数据和属性数据取出来,第二步是建立新的节点列表,代替原有的算子运算,第三步是断开原有节点的前后连接,并将连接新节点的前后连接,有点类似于链表操作。

def decompose_instancenorms(graph):
    nRemoveInstanceNorm = 0
    for node in graph.nodes:
        if node.op == "InstanceNormalization":
            name = node.name + "/"
            input_tensor = node.inputs[0]
            output_tensor = node.outputs[0]
            mean_out = gs.Variable(name=name + "mean_out")
            mean_node = gs.Node(op="ReduceMean", name=name + "mean_node", attrs={"axes": [-1]}, inputs=[input_tensor], outputs=[mean_out])
            sub_out = gs.Variable(name=name + "sub_out")
            sub_node = gs.Node(op="Sub", name=name + "sub_node", attrs={}, inputs=[input_tensor, mean_out], outputs=[sub_out])
            pow_out = gs.Variable(name=name + "pow_out")
            pow_const = gs.Constant(name=name + "pow_const", values=np.array([2.0], dtype=np.float32))
            pow_node = gs.Node(op="Pow", name=name + "pow_node", attrs={}, inputs=[sub_out, pow_const], outputs=[pow_out])
            mean2_out = gs.Variable(name=name + "mean2_out")
            mean2_node = gs.Node(op="ReduceMean", name=name + "mean2_node", attrs={"axes": [-1]}, inputs=[pow_out], outputs=[mean2_out])
            epsilon_out = gs.Variable(name=name + "epsilon_out")
            epsilon_const = gs.Constant(name=name + "epsilon_const", values=np.array([node.attrs["epsilon"]], dtype=np.float32))
            epsilon_node = gs.Node(op="Add", name=name + "epsilon_node", attrs={}, inputs=[mean2_out, epsilon_const], outputs=[epsilon_out])
            sqrt_out = gs.Variable(name=name + "sqrt_out")
            sqrt_node = gs.Node(op="Sqrt", name=name + "sqrt_node", attrs={}, inputs=[epsilon_out], outputs=[sqrt_out])
            div_out = gs.Variable(name=name + "div_out")
            div_node = gs.Node(op="Div", name=name + "div_node", attrs={}, inputs=[sub_out, sqrt_out], outputs=[div_out])
            constantScale = gs.Constant("InstanceNormScaleV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[1].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
            constantBias = gs.Constant("InstanceBiasV-" + str(nRemoveInstanceNorm), np.ascontiguousarray(node.inputs[2].inputs[0].attrs["value"].values.reshape(1, 32, 1)))
            mul_out = gs.Variable(name=name + "mul_out")
            mul_node = gs.Node(op="Mul", name=name + "mul_node", attrs={}, inputs=[div_out, constantScale], outputs=[mul_out])
            add_node = gs.Node(op="Add", name=name + "add_node", attrs={}, inputs=[mul_out, constantBias], outputs=[output_tensor])
            graph.nodes.extend([mean_node, sub_node, pow_node, mean2_node, epsilon_node, sqrt_node, div_node, mul_node, add_node])
            node.inputs = []
            node.outputs = []
            nRemoveInstanceNorm += 1

    graph.cleanup().toposort()
    print("remove IN")
    print(nRemoveInstanceNorm)
    return graph

捏算子的过程与分解算子的过程类似,只不过是反回来的,这里需要注意的是要和cuda算子插件的属性、输入输出参数保持一致,否则构建时将找不到对应插件。

def insert_groupnorm_plugin(graph):
    nGroupNormPlugin = 0
    for node in graph.nodes:
        if node.op == "Reshape" and node.outputs != [] and \
            node.o().op == "ReduceMean" and node.o(1).op == "Sub" and node.o().o() == node.o(1) and \
            node.o().o().o().o().o().o().o().o().o().o().o().op == "Mul" and \
            node.o().o().o().o().o().o().o().o().o().o().o().o().op == "Add" and \
            len(node.o().o().o().o().o().o().o().o().inputs[1].values.shape) == 3 :

            assert len(node.outputs) == 1
            inputTensor = node.inputs[0]

            gammaNode = node.o().o().o().o().o().o().o().o().o().o().o()
            index = [type(i) == gs.ir.tensor.Constant for i in gammaNode.inputs].index(True)
            gamma = np.array(deepcopy(gammaNode.inputs[index].values.tolist()), dtype=np.float32)
            constantGamma = gs.Constant("groupNormGamma-" + str(nGroupNormPlugin), np.ascontiguousarray(gamma.reshape(-1)))  # MUST use np.ascontiguousarray, or TRT will regard the shape of this Constant as (0) !!!

            betaNode = gammaNode.o()
            index = [type(i) == gs.ir.tensor.Constant for i in betaNode.inputs].index(True)
            beta = np.array(deepcopy(betaNode.inputs[index].values.tolist()), dtype=np.float32)
            constantBeta = gs.Constant("groupNormBeta-" + str(nGroupNormPlugin), np.ascontiguousarray(beta.reshape(-1)))

            epsilon = node.o().o().o().o().o().inputs[1].values.tolist()[0]

            if betaNode.o().op == "Sigmoid":  # need Swish
                bSwish = True
                lastNode = betaNode.o().o()  # Mul node of Swish
            else:
                bSwish = False
                lastNode = betaNode  # Cast node after Group Norm

            if lastNode.o().op == "Cast":
                lastNode = lastNode.o()
            inputList = [inputTensor, constantGamma, constantBeta]
            groupNormV = gs.Variable("GroupNormV-" + str(nGroupNormPlugin), np.dtype(np.float16), inputTensor.shape)
            groupNormN = gs.Node("GroupNorm", "GroupNormN-" + str(nGroupNormPlugin), inputs=inputList, outputs=[groupNormV], attrs=OrderedDict([('epsilon', epsilon), ('bSwish', int(bSwish))]))
            graph.nodes.append(groupNormN)

            for subNode in graph.nodes:
                if lastNode.outputs[0] in subNode.inputs:
                    index = subNode.inputs.index(lastNode.outputs[0])
                    subNode.inputs[index] = groupNormV
            
            lastNode.outputs = []
            nGroupNormPlugin += 1

    graph.cleanup().toposort()
    print("GroupNorm")
    print(nGroupNormPlugin)
    return graph

对于fp16溢出的另外一种处理方式是,将对应算子的前一层和当前层都使用高精度表示,示例代码中是对softmax的精度溢出进行处理,将前一层和当前层使用fp32来运算。

for i, i_next in pairwise(indices):
    layer = trt_network.get_layer(i)
    next_layer = trt_network.get_layer(i_next)
    layer = trt_network.get_layer(i)
    if not all([
        layer.get_output(i).is_execution_tensor
        for i in range(layer.num_outputs)
    ]):
        continue
    if layer.get_output_type(0) != trt.float32:
        continue
    if next_layer.type == trt.LayerType.SOFTMAX:
        layer.precision = trt.DataType.FLOAT
        next_layer.precision = trt.DataType.FLOAT

还有一种溢出情况是,一些算子的属性过大过小导致的溢出,这时需要将对应算子的属性由原有的inf调整为一个较小的数,在示例代码中便是将-np.inf调整为-1e4:

# change onnx -inf to -1e4
for node in new_onnx_model.graph.node:
    if node.op_type == "ConstantOfShape":
        attr = node.attribute[0]
        if attr.name == "value" and attr.t.data_type == onnx.TensorProto.FLOAT:
            np_array = np.frombuffer(attr.t.raw_data, dtype=np.float32).copy()
            np_array[np_array == -np.inf] = -100000  # 将所有负无穷的值改为-100000
            attr.t.raw_data = np_array.tobytes() 

前后处理onnx图

这个不算是特别大的加速,但是是一种让人眼前一新的trick。主要的点在于DDIM过程中,controlnet之后会有一段后处理,把这段前后处理部分由原本的torch计算换成onnx图,这样便也可以通过trt进行加速,即后处理部分转化为了一个postnet的图。这里有个问题是,由于迭代的次数不一样,所以对应的参数也不一样,好的做法combine一个更大的图,避免额外的参数。

class PostNet(nn.Module):
    def __init__(self):
        super().__init__()

        # step = 20
        # self.alphas = torch.from_numpy(np.array([0.9983, 0.9505, 0.8930, 0.8264, 0.7521, 0.6722, 0.5888, 0.5048, 0.4229,0.3456, 0.2750, 
        #     0.2128, 0.1598, 0.1163, 0.0819, 0.0557, 0.0365, 0.0231,0.0140, 0.0082]))
        # self.alphas_prev = torch.from_numpy(np.array([0.99914998,0.99829602, 0.95052433, 0.89298052, 0.82639927, 0.75214338,
        #                     0.67215145, 0.58881873, 0.50481856, 0.42288151, 0.34555823, 0.27499905,
        #                     0.21278252, 0.15981644, 0.11632485, 0.08191671, 0.05571903, 0.03654652,
        #                     0.02307699, 0.0140049 ]))
        # self.sqrt_one_minus_alphas = torch.from_numpy(np.array([0.0413, 0.2224, 0.3271, 0.4167, 0.4979, 0.5726, 0.6412, 0.7037, 0.7597,
        #                                 0.8090, 0.8515, 0.8873, 0.9166, 0.9400, 0.9582, 0.9717, 0.9816, 0.9884,
        #                                 0.9930, 0.9959]))
        # self.sigmas = torch.from_numpy(np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))

        # self.time_range = [951, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101,51, 1]

        # step = 10
        self.alphas = torch.from_numpy(np.array([0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365,0.0140]))
        self.alphas_prev = torch.from_numpy(np.array([0.99914998, 0.99829602, 0.89298052, 0.75214338, 0.58881873, 0.42288151,0.27499905,  0.15981644, 0.08191671, 0.03654652]))
        self.sqrt_one_minus_alphas = torch.from_numpy(np.array([0.0413, 0.3271, 0.4979, 0.6412, 0.7597, 0.8515, 0.9166, 0.9582, 0.9816,
                0.9930]))
        self.sigmas = torch.from_numpy(np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))

    def forward(self,x,image,unconditional_guidance_scale,index):
        e_t = image[1].unsqueeze(0) + unconditional_guidance_scale * (image[0].unsqueeze(0) - image[1].unsqueeze(0))

        a_t = self.alphas[index]
        a_prev =  self.alphas_prev[index]
        sqrt_one_minus_at = self.sqrt_one_minus_alphas[index]
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        dir_xt = (1. - a_prev).sqrt() * e_t
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt
        return x_prev,  pred_x0

参考

  1. Dataxu: https://github.com/TRT2022/ControlNet_TensorRT
  2. Tlntin: https://tianchi.aliyun.com/forum/post/574634