ControlNet-trt优化总结2:使用TRT-API从零构建ControlNet网络

发布时间 2023-10-07 17:36:03作者: wildkid1024

ControlNet-trt优化总结2:使用TRT-API从零构建ControlNet网络

在上节讲到,可以通过手动搭建trt网络的方式来重新构造controlnet网络,这样可以避免onnx中间转换过程中的精度损失,也可避免onnx中间转化时的算子被拆解的细碎的情况,对于不支持的算子,也可通过添加插件的方式添加不支持的算子。

基础概念

tensorrt.INetworkDefinition: 网络结构定义对象,可以由解析器解析得到,或者由TensorRT API构建而成
tensorrt.Builder: 根据NetworkDefinition和相应的BuilderConfig生成CudaEngine,CudaEngine是build好的二进制计算图
tensorrt.IExecutionContext: 根据CudaEngine生成IExecutionContext,每个CudaEngine可以生成多个ExecutionContext

注意:

  1. 下面的network一般是指tensorrt.INetworkDefinition对象。
  2. x有两种情况,一种是tensorrt.ITensor对象,多见于第一次输入,另外一种是tensorrt.ILayer对象,多见于中间层输入,tensorrt.ITensor可以视为计算图的边,tensorrt.ILayer可以视为计算图的节点。
  3. 所有算子都需要传入weight_map和其参数名称,其返回值都是tensorrt.ILayer对象。

常用TRT接口函数

add_input(self: tensorrt.tensorrt.INetworkDefinition, 
          name: str, 
         dtype: tensorrt.tensorrt.DataType, 
         shape: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ITensor

功能:为网络添加一个输入层  
参数:          name  - 层的名字  
              dtype - tensor的数据类型,如trt.float32  
              shape - tensor的形状,必须小于2^30个元素  
返回值:  一个新的tensor  
add_scale(self: tensorrt.tensorrt.INetworkDefinition, 
         input: tensorrt.tensorrt.ITensor, 
          mode: tensorrt.tensorrt.ScaleMode, 
         shift: tensorrt.tensorrt.Weights , 
         scale: tensorrt.tensorrt.Weights , 
         power: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IScaleLayer
功能:控制每个元素缩放大小,计算公式为$output=(input*scale+shift)^{power}$  
参数 :         input - 输入tensor,最少有三个维度  
              mode - 缩放的模式,如trt.ScaleMode.UNIFORM,表示作用于每一个元素  
              shift - Weights变量,公式中的shift值  
              scale - Weights变量,公式中的scale值  
              power - Weights变量,公式中的power值  
如果Weights变量可以得到,那么Weights变量的shape与mode模式相关:  
        UNIFORM:形状等于1  
        CHANNEL:形状为通道的维度  
        ELEMENTWISE:形状与input的形状相同  
返回值:  一个新的layer或None  
add_slice(self: tensorrt.tensorrt.INetworkDefinition, 
         input: tensorrt.tensorrt.ITensor, 
         start: tensorrt.tensorrt.Dims, 
         shape: tensorrt.tensorrt.Dims, 
        stride: tensorrt.tensorrt.Dims) -> tensorrt.tensorrt.ISliceLayer

功能:tensor切片
参数 :       input - 输入tensor
            start - 起始index
            shape - 输出shape
            stride - 切片步长

返回值:  一个新的layer或None
add_constant(self: tensorrt.tensorrt.INetworkDefinition, 
            shape: tensorrt.tensorrt.Dims, 
          weights: tensorrt.tensorrt.Weights) → tensorrt.tensorrt.IConstantLayer

功能:添加一个常数层,可以把weight对象转变为layer进而变为tensor  
参数 :       shape - 形状  
            weights - weight对象  
返回值:  一个新的layer或None    
add_elementwise(self: tensorrt.tensorrt.INetworkDefinition, 
              input1: tensorrt.tensorrt.ITensor, 
              input2: tensorrt.tensorrt.ITensor, 
              op: tensorrt.tensorrt.ElementWiseOperation) → tensorrt.tensorrt.IElementWiseLayer

功能:二元操作
参数:  input1(input2) - 输入tensor,形状必须相等
              op - 二元操作符,在ElementWiseOperation中,如:
                    trt.ElementWiseOperation.PROD(乘积)
                    trt.ElementWiseOperation.SUM(加法)

返回值:  一个新的layer或None
add_unary(self: tensorrt.tensorrt.INetworkDefinition,
         input: tensorrt.tensorrt.ITensor, 
         op: tensorrt.tensorrt.UnaryOperation) → tensorrt.tensorrt.IUnaryLayer
功能:一元操作
参数:  input1 - 输入tensor,
              op - 一元操作符,在UnaryOperation中,如:
                    trt.UnaryOperation.EXP(自然指数)
                    trt.UnaryOperation.LOG(自然对数)

返回值:  一个新的layer或None
add_convolution(self: tensorrt.tensorrt.INetworkDefinition, 
            input: tensorrt.tensorrt.ITensor, 
            num_output_maps: int, 
            kernel_shape: tensorrt.tensorrt.DimsHW, 
            kernel: tensorrt.tensorrt.Weights, 
            bias: tensorrt.tensorrt.Weights = None)→ tensorrt.tensorrt.IConvolutionLayer
功能:添加一个2D的卷积
参数:           input - 输入Tensor,4维张量
                num_output_maps - 输出特征图数量,也即后一层的channel
                kernel_shape - 卷积核大小
                kernel - 卷积核的数据
                bias - 卷积bias的数据
返回值: 一个新的layer或None
add_activation(self: tensorrt.tensorrt.INetworkDefinition, 
            input: tensorrt.tensorrt.ITensor, 
            type: tensorrt.tensorrt.ActivationType) → tensorrt.tensorrt.IActivationLayer
功能:添加激活层,进行逐元素的激活操作,输出形状大小和输入形状大小一致
参数:           input – 输入tensor
                type – 对应的激活类型,RELU、SIGMOID、TANH、LEAKY_RELU等,参考tensorrt.ActivationType。
返回值:一个新的layer或None
add_normalization(self: tensorrt.tensorrt.INetworkDefinition,
                input: tensorrt.tensorrt.ITensor, 
                scale: tensorrt.tensorrt.ITensor, 
                bias: tensorrt.tensorrt.ITensor, 
                axesMask: int)→ tensorrt.tensorrt.INormalizationLayer
功能:添加一个归一化层,执行$Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B$,trt内部实际上是使用instancenorm来实现的,有些时候需要自己手写替换
参数:           input – 输入Tensor  
                scale – 归一化的sacle放缩参数  
                bias – 归一化的bias参数  
                axesMask – 进行mean操作的axes,以(1<<i)位压缩的方式进行传递  
返回值: 一个新的layer或None  
add_matrix_multiply(self: tensorrt.tensorrt.INetworkDefinition,
                input0: tensorrt.tensorrt.ITensor, 
                op0: tensorrt.tensorrt.MatrixOperation, 
                input1: tensorrt.tensorrt.ITensor, 
                op1: tensorrt.tensorrt.MatrixOperation) → tensorrt.tensorrt.IMatrixMultiplyLayer
功能: 添加一个一个矩阵乘积运算,分为4种情况,矩阵矩阵、矩阵向量、向量矩阵和向量向量
参数:           input0 – 第一个矩阵张量
                op0 – 处理类型,矩阵处理类型,转置或向量
                input1 – 第二个矩阵向量
                op1 – 处理类型,矩阵处理类型,转置或向量
返回值: 一个新的layer或None 
add_shuffle(self: tensorrt.tensorrt.INetworkDefinition, 
            input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.IShuffleLayer
功能:添加一个shuffle层,对应的是transpose核reshape算子
参数:  input – 每一层的输入tensor
返回值: 一个新的layer或None 
add_softmax(self: tensorrt.tensorrt.INetworkDefinition, 
            input: tensorrt.tensorrt.ITensor)→ tensorrt.tensorrt.ISoftMaxLayer
功能:添加一个softmax层,按照axes方向进行逐通道softmax操作,axes是位压缩的mask
参数: input – 输入的Tensor
返回值: 一个新的layer或None 
add_gather(self: tensorrt.tensorrt.INetworkDefinition, 
            input: tensorrt.tensorrt.ITensor, 
            indices: tensorrt.tensorrt.ITensor, 
            axis: int)→ tensorrt.tensorrt.IGatherLayer
功能:添加一个gather层,按照axis方向,在indices上取相应数据,
参数:   input – 输入张量
        indices – index序列来产生output张量
        axis – gather的方向,不能是batch方向
返回值:一个新的layer或None 
add_einsum
功能:添加一个爱因斯坦算子层,与einsum相对应,主要用于矩阵乘法
参数:   inputs – 输入张量
        equation – 爱因斯坦等式
返回值: 一个新的layer或None 

关键TRT算子

卷积算子

由于trt原生支持conv操作,所以这里调用的add_convolution函数直接计算,不过需要注意的是conv也可接受第一层的原始输入。

def conv(network, weight_map, x, ch, pre, kernel, padding, stride):
    x = network.add_convolution(
            input=x if isinstance(x, trt.ITensor) else x.get_output(0),
            num_output_maps=ch,
            kernel_shape=(kernel, kernel),
            kernel=weight_map['{}.weight'.format(pre)],
            bias=weight_map['{}.bias'.format(pre)])
    assert x
    x.padding = (padding, padding)
    x.stride = (stride, stride)
    return x

激活算子

SILU算子被拆分为了SIGMOID和PROD两个操作,实际上和onnx导出结果基本一致。

def silu(network, x):
    y = network.add_activation(x.get_output(0), trt.ActivationType.SIGMOID)
    assert y
    x = network.add_elementwise(x.get_output(0), y.get_output(0), trt.ElementWiseOperation.PROD)
    return x

归一化算子

这里groupnorm调用了plugin插件,通过PluginField定义了epsilon和bSwish两个属性参数,分别为误差及是否使用Swish激活函数。
其输入有上一层的输入、weights以及bias,输出的是groupnorm归一化后的值。

import ctypes
ctypes.CDLL('./trt/libmyplugins.so.1', mode=ctypes.RTLD_GLOBAL)

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
gn_plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")

def group_norm(network, weight_map, h, pre, epsilon=EPS, silu=False):
    ch = h.get_output(0).shape[1]
    # plugin_creator = trt.get_plugin_registry().get_plugin_creator('GroupNorm', "1")
    plugin_creator = gn_plugin_creator
    s = network.add_constant([1, ch, 1, 1], weight_map['{}.weight'.format(pre)])
    b = network.add_constant([1, ch, 1, 1], weight_map['{}.bias'.format(pre)])

    eps_attr = trt.PluginField("epsilon", np.array([epsilon], dtype=np.float32), type=trt.PluginFieldType.FLOAT32)
    silu_attr = trt.PluginField("bSwish", np.array([1 if silu else 0], dtype=np.int32), type=trt.PluginFieldType.INT32)
    field_collection = trt.PluginFieldCollection([eps_attr, silu_attr])

    plugin = plugin_creator.create_plugin(name='{}.group_norm'.format(pre), field_collection=field_collection)
    n = network.add_plugin_v2(inputs=[h.get_output(0), s.get_output(0), b.get_output(0)], plugin=plugin)
    return n

这里layer_norm执行的计算如下:
Y = (X - Mean(X, axes)) / Sqrt(Variance(X) + epsilon) * S + B
在不同axes执行的结果实际上是不一样的,这里axesMask的设置实际上是倒数第3维方向上进行归一化,对于seq人物,第一维是batch,第二维是seq长度。

def layer_norm(network, weight_map, h, pre, epsilon=EPS):
    scale_np = weight_map['{}.weight'.format(pre)]
    ch = scale_np.shape[0]
    scale = network.add_constant([1, 1, ch], scale_np)
    bias_np = weight_map['{}.bias'.format(pre)]
    bias = network.add_constant([1, 1, ch], bias_np)
    n = network.add_normalization(
        h.get_output(0),
        scale=scale.get_output(0),
        bias=bias.get_output(0),
        axesMask=1 << 2)
    assert n
    n.epsilon = epsilon

    return n    

Attention算子

因为Trt不直接支持4维矩阵的乘加运算,所以HW进行了合并。这里MHA是8个head,在计算时时合并batch进行计算的,所以就有以下的转化。
[2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]

在具体运算上,qkv的计算是由矩阵乘加得到的,这点有可优化的点,可以将3个乘积一起计算,而不是分开来进行计算,更利于并行。
而qk乘积部分则是由add_einsum计算得到的,随后softmax之后的结果与v进行乘积,需要注意的是需要将最终结果还原到[2, h * w, c]。
接下来的部分便是一个残差连接,得到并输出最终结果。

def self_attention(network, weight_map, i, ch, x):   
    heads = 8
    dim_head = ch / heads
    scale = dim_head ** -0.5

    wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_q.weight'.format(i)])
    wk = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_k.weight'.format(i)])
    wv = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_v.weight'.format(i)])

    q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                    wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
    k = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                    wk.get_output(0), trt.MatrixOperation.TRANSPOSE)
    v = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                    wv.get_output(0), trt.MatrixOperation.TRANSPOSE)

    # q [2, h * w, c] -> [2, h * w, 8, d] -> [2, 8, h * w, d] -> [16, h * w, d]
    q = network.add_shuffle(q.get_output(0))
    q.reshape_dims = (2, -1, 8, ch // 8)
    q.second_transpose = trt.Permutation([0, 2, 1, 3])
    q = network.add_shuffle(q.get_output(0))
    q.reshape_dims = (16, -1, ch // 8)

    k = network.add_shuffle(k.get_output(0))
    k.reshape_dims = (2, -1, 8, ch // 8)
    k.second_transpose = trt.Permutation([0, 2, 1, 3])
    k = network.add_shuffle(k.get_output(0))
    k.reshape_dims = (16, -1, ch // 8)

    v = network.add_shuffle(v.get_output(0))
    v.reshape_dims = (2, -1, 8, ch // 8)
    v.second_transpose = trt.Permutation([0, 2, 1, 3])
    v = network.add_shuffle(v.get_output(0))
    v.reshape_dims = (16, -1, ch // 8)

    s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
    print(s.get_output(0).shape)

    s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
                          scale=trt.Weights(np.array([scale], np.float32)))

    s = network.add_softmax(s.get_output(0))
    s.axes = 1<<2

    out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
    # [16, h * w, d] -> [2, 8, h * w, d] -> [2, h * w, 8, d] -> [2, h * w, c]
    out = network.add_shuffle(out.get_output(0))
    out.reshape_dims = (2, 8, -1, ch // 8)
    out.second_transpose = trt.Permutation([0, 2, 1, 3])
    out = network.add_shuffle(out.get_output(0))
    out.reshape_dims = (2, -1, ch)

    # to_out
    outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.weight'.format(i)])
    outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn1.to_out.0.bias'.format(i)])

    out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
                                      outw.get_output(0), trt.MatrixOperation.TRANSPOSE)

    out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)

    return out

cross attention与self attention算子类似,区别在于其kv是从context中获取,这里的context是上一层或上一次context计算的结果,而只有q是weight和上一层计算得到的结果。

def cross_attention(network, weight_map, i, ch, x, context):
    heads = 8
    dim_head = ch / heads
    scale = dim_head ** -0.5

    wq = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_q.weight'.format(i)])

    q = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                    wq.get_output(0), trt.MatrixOperation.TRANSPOSE)
    # [2, h*w, c]

    dim = ch // 8
    k = network.add_slice(context['context'],
                          trt.Dims([0, 0, 8 * context['start']]),
                          trt.Dims([2, 77, ch]),
                          trt.Dims([1, 1, 1]))
    v = network.add_slice(context['context'],
                          trt.Dims([0, 0, 8 * (context['start'] + dim)]),
                          trt.Dims([2, 77, ch]),
                          trt.Dims([1, 1, 1]))
    context['start'] += 2 * dim

    q = network.add_shuffle(q.get_output(0))
    q.reshape_dims = (2, -1, 8, ch // 8)
    q.second_transpose = trt.Permutation([0, 2, 1, 3])
    q = network.add_shuffle(q.get_output(0))
    q.reshape_dims = (16, -1, ch // 8)

    k = network.add_shuffle(k.get_output(0))
    k.reshape_dims = (2, -1, 8, ch // 8)
    k.second_transpose = trt.Permutation([0, 2, 1, 3])
    k = network.add_shuffle(k.get_output(0))
    k.reshape_dims = (16, -1, ch // 8)

    v = network.add_shuffle(v.get_output(0))
    v.reshape_dims = (2, -1, 8, ch // 8)
    v.second_transpose = trt.Permutation([0, 2, 1, 3])
    v = network.add_shuffle(v.get_output(0))
    v.reshape_dims = (16, -1, ch // 8)

    s = network.add_einsum([q.get_output(0), k.get_output(0)], 'b i d, b j d -> b i j')
    print(s.get_output(0).shape)

    # scale = network.add_constant((1, 1, 1), np.array([scale], np.float32))
    # s = network.add_elementwise(s.get_output(0), scale.get_output(0), trt.ElementWiseOperation.PROD)
    s = network.add_scale(s.get_output(0), mode=trt.ScaleMode.UNIFORM,
                          scale=trt.Weights(np.array([scale], np.float32)))

    s = network.add_softmax(s.get_output(0))
    s.axes = 1<<2

    out = network.add_einsum([s.get_output(0), v.get_output(0)], 'b i j, b j d -> b i d')
    out = network.add_shuffle(out.get_output(0))
    out.reshape_dims = (2, 8, -1, ch // 8)
    out.second_transpose = trt.Permutation([0, 2, 1, 3])

    out = network.add_shuffle(out.get_output(0))
    out.reshape_dims = (2, -1, ch)

    # to_out
    outw = network.add_constant((1, ch, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.weight'.format(i)])
    outb = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.attn2.to_out.0.bias'.format(i)])

    out = network.add_matrix_multiply(out.get_output(0), trt.MatrixOperation.NONE,
                                      outw.get_output(0), trt.MatrixOperation.TRANSPOSE)

    out = network.add_elementwise(out.get_output(0), outb.get_output(0), trt.ElementWiseOperation.SUM)

    return out

这里把ffn同样归总到attention算子中,有一次全连接和一个gelu激活函数,需要注意的是乘加结果是分开来算的。
这里add_unary是一元算子,主要进行指数运算。

def feed_forward(network, weight_map, i, ch, x):
    w1 = network.add_constant((1, ch * 8, ch), weight_map['{}.transformer_blocks.0.ff.net.0.proj.weight'.format(i)])
    b1 = network.add_constant((1, 1, ch * 8), weight_map['{}.transformer_blocks.0.ff.net.0.proj.bias'.format(i)])
    n = network.add_matrix_multiply(x.get_output(0), trt.MatrixOperation.NONE,
                                    w1.get_output(0), trt.MatrixOperation.TRANSPOSE)
    n = network.add_elementwise(n.get_output(0), b1.get_output(0), trt.ElementWiseOperation.SUM)

    hw = n.get_output(0).shape[1]
    # w = n.get_output(0).shape[3]
    n1 = network.add_slice(n.get_output(0), trt.Dims([0, 0, 0]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))
    n2 = network.add_slice(n.get_output(0), trt.Dims([0, 0, ch * 4]), trt.Dims([2, hw, ch * 4]), trt.Dims([1, 1, 1]))

    # gelu
    e = network.add_scale(n2.get_output(0), mode=trt.ScaleMode.UNIFORM, scale=trt.Weights(np.array([2 ** -0.5], np.float32)))
    e = network.add_unary(e.get_output(0), trt.UnaryOperation.ERF)
    e = network.add_scale(e.get_output(0), mode=trt.ScaleMode.UNIFORM,
                          scale=trt.Weights(np.array([0.5], np.float32)),
                          shift=trt.Weights(np.array([0.5], np.float32)))

    n = network.add_elementwise(n2.get_output(0), e.get_output(0), trt.ElementWiseOperation.PROD)
    n = network.add_elementwise(n.get_output(0), n1.get_output(0), trt.ElementWiseOperation.PROD)

    w2 = network.add_constant((1, ch, ch * 4), weight_map['{}.transformer_blocks.0.ff.net.2.weight'.format(i)])
    b2 = network.add_constant((1, 1, ch), weight_map['{}.transformer_blocks.0.ff.net.2.bias'.format(i)])
    n = network.add_matrix_multiply(n.get_output(0), trt.MatrixOperation.NONE,
                                    w2.get_output(0), trt.MatrixOperation.TRANSPOSE)
    n = network.add_elementwise(n.get_output(0), b2.get_output(0), trt.ElementWiseOperation.SUM)

    return n

关键模块

transformer模块

这里基础的transformer就不再详细探讨,标准的attn1-attn2-ffn的过程,需要注意的是trt不支持4维操作,前后要多一次reshape操作。

def basic_transformer(network, weight_map, i, ch, x, context):
    H = x.get_output(0).shape[2]
    W = x.get_output(0).shape[3]

    # n c h w -> b (h w) c
    x = network.add_shuffle(x.get_output(0))
    x.first_transpose = trt.Permutation([0, 2, 3, 1])
    x.reshape_dims = (2, -1, ch)

    # attn1
    n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm1'.format(i))
    
    attn1 = self_attention(network, weight_map, i, ch, n)
    x = network.add_elementwise(attn1.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)

    # attn2
    n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm2'.format(i))
    attn2 = cross_attention(network, weight_map, i, ch, n, context)
    x = network.add_elementwise(attn2.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)

    # ff
    n = layer_norm(network, weight_map, x, '{}.transformer_blocks.0.norm3'.format(i))
    ff = feed_forward(network, weight_map, i, ch, n)
    
    x = network.add_elementwise(ff.get_output(0), x.get_output(0), trt.ElementWiseOperation.SUM)

    # n (h w) c -> n c h w
    x = network.add_shuffle(x.get_output(0))
    x.first_transpose = trt.Permutation([0, 2, 1])
    x.reshape_dims = (2, ch, H, W)
    return x

spatial_transformer是在basic_transformer基础上加了两次conv投影。

def spatial_transformer(network, weight_map, i, ch, h, context):
    # return h
    # norm
    n = group_norm(network, weight_map, h, '{}.norm'.format(i), 1e-6)
    # proj_in
    n = conv(network, weight_map, n, ch, '{}.proj_in'.format(i), 1, 0, 1)

    # BasicTransformerBlock
    n = basic_transformer(network, weight_map, i, ch, n, context)

    # proj_out
    n = conv(network, weight_map, n, ch, '{}.proj_out'.format(i), 1, 0, 1)

    h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
    return h

采样模块

下采样则是卷积操作,上采样则是线性插值操作,zero_convs则是不改变原有特征图大小。

def input_first(network, weight_map, pre, h):
    h = conv(network, weight_map, h, 320, '{}.input_blocks.0.0'.format(pre), 3, 1, 1)
    return h

def downsample(network, weight_map, i, ch, x):
    x = conv(network, weight_map, x, ch, '{}.op'.format(i), 3, 1, 2)
    return x

def upsample(network, weight_map, i, ch, x):
    x = network.add_resize(x.get_output(0))
    x.scales = [1, 1, 2, 2]
    x.resize_mode = trt.ResizeMode.NEAREST

    x = conv(network, weight_map, x, ch, '{}.conv'.format(i), 3, 1, 1)

    return x

def zero_convs(network, weight_map, x, i):
    ch = x.get_output(0).shape[1]
    x = conv(network, weight_map, x, ch, 'control_model.zero_convs.{}.0'.format(i), 1, 0, 1)
    return x

block模块

resblock 是由倒瓶颈结构的卷积块组成的残差连接模块。

def resblock(network, weight_map, embed_weight, i, ch, h, emb):
    print('resblock: ', h.get_output(0).shape, '{}.in_layers.0'.format(i))
    ## in_layers
    # group_norm
    n = group_norm(network, weight_map, h, '{}.in_layers.0'.format(i), silu=True)
    # silu
    # n = silu(network, n)
    # conv_nd
    n = conv(network, weight_map, n, ch, '{}.in_layers.2'.format(i), 3, 1, 1)

    print('in_layers: ', n.get_output(0).shape)

    ## emb_layers
    m = network.add_constant([20, ch, 1, 1], embed_weight.pop(0))
    m = network.add_gather(m.get_output(0), emb, axis=0)
    print('emb_layers: ', m.get_output(0).shape)

    n = network.add_elementwise(n.get_output(0), m.get_output(0), trt.ElementWiseOperation.SUM)

    ## out_layers
    n = group_norm(network, weight_map, n, '{}.out_layers.0'.format(i), silu=True)
    # n = silu(network, n)
    n = conv(network, weight_map, n, ch, '{}.out_layers.3'.format(i), 3, 1, 1)

    print('out_layers: ', n.get_output(0).shape)

    in_ch = h.get_output(0).shape[1]
    if in_ch != ch:
        # skip_connection
        h = conv(network, weight_map, h, ch, '{}.skip_connection'.format(i), 1, 0, 1)

    h = network.add_elementwise(n.get_output(0), h.get_output(0), trt.ElementWiseOperation.SUM)
    return h

input_block则是由不同level、不同大小channel的resblock以及spatial_transformer组成的。
middle_block则是resblock和spatial_transformer的组合。
output_blocks与input_block类似,只不过由input_block中的下采样变成了output_blocks中的上采样。
这三个block是unet中的重要组成部分,对应了Unet先下采样到特征状态再上采样到对应图像的过程。

def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
    hs = []
    h = input_first(network, weight_map, 'control_model', h)
    h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)

    h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
    h.mode = trt.SliceMode.WRAP
    hs.append(zero_convs(network, weight_map, h, 0))
    # h [2, 320, 32, 48]

    channel_mult = [1, 2, 4, 4]
    num_res_blocks = [2] * 4

    model_channels = 320
    index = 1
    for level, mult in enumerate(channel_mult):
        ch = model_channels * mult
        for nr in range(num_res_blocks[level]):
            pre = 'control_model.input_blocks.{}'.format(index)
            h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
            print('resblock: ', h.get_output(0).shape)
            if level != len(channel_mult) -1:
                h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
            hs.append(zero_convs(network, weight_map, h, index))

            # ch = mult * model_channels
            index = index + 1

        if level != len(channel_mult) - 1:
            pre = 'control_model.input_blocks.{}'.format(index)
            out_ch = ch
            h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
            hs.append(zero_convs(network, weight_map, h, index))
            index = index + 1
        
        # if index == 10:
    return hs, h

def input_block(network, weight_map, embed_weight, h, emb, context, model_name):
    hs = []
    h = input_first(network, weight_map, model_name, h)
    h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
    h.mode = trt.SliceMode.WRAP

    #return h
    hs.append(h)

    channel_mult = [1, 2, 4, 4]
    num_res_blocks = [2] * 4

    model_channels = 320
    index = 1
    for level, mult in enumerate(channel_mult):
        ch = model_channels * mult
        for nr in range(num_res_blocks[level]):
            pre = '{}.input_blocks.{}'.format(model_name, index)
            h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
            print('resblock: ', h.get_output(0).shape)
            if level != len(channel_mult) -1:
                h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
            hs.append(h)

            # ch = mult * model_channels
            index = index + 1

        if level != len(channel_mult) - 1:
            pre = '{}.input_blocks.{}'.format(model_name, index)
            out_ch = ch
            h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
            hs.append(h)
            index = index + 1
        
        # if index == 10:
    return hs, h

def middle_block(network, weight_map, embed_weight, h, emb, context, model_name):
    pre = '{}.middle_block'.format(model_name)
    h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), 1280, h, emb)
    h = spatial_transformer(network, weight_map, '{}.1'.format(pre), 1280, h, context)
    h = resblock(network, weight_map, embed_weight, '{}.2'.format(pre), 1280, h, emb)
    return h

def output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs):
    channel_mult = [1, 2, 4, 4]
    num_res_blocks = [2] * 4

    model_channels = 320
    index = 0
    for level, mult in list(enumerate(channel_mult))[::-1]:
        ch = model_channels * mult
        for i in range(num_res_blocks[level] + 1):
            print(control[-1].shape, hs[-1].shape, len(hs), h.get_output(0).shape)
            c = network.add_elementwise(control.pop(), hs.pop(), trt.ElementWiseOperation.SUM)
            h = network.add_concatenation([h.get_output(0), c.get_output(0)])
            print('output: ', index, h.get_output(0).shape)
            pre = 'model.diffusion_model.output_blocks.{}'.format(index)
            h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
            print('resblock: ', h.get_output(0).shape)
            if level != len(channel_mult) -1:
                h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
            
            if level and i == num_res_blocks[level]:
                h = upsample(network, weight_map,
                             '{}.{}'.format(pre, 1 if level == len(channel_mult) - 1 else 2), ch, h)
            index = index + 1
    print(h.get_output(0).shape, len(hs), len(control), index)
    return h

input_block_control是control_net的上半部分,在结构参数上与Unet一样,但是在每一层都添加了zero_convs层学习参数。

def input_block_control(network, weight_map, embed_weight, h, emb, context, hint):
    hs = []
    h = input_first(network, weight_map, 'control_model', h)
    h = network.add_elementwise(h.get_output(0), hint, trt.ElementWiseOperation.SUM)

    h = network.add_slice(h.get_output(0), trt.Dims([0, 0, 0, 0]), trt.Dims([2, 320, 32, 48]), trt.Dims([1, 1, 1, 1]))
    h.mode = trt.SliceMode.WRAP
    hs.append(zero_convs(network, weight_map, h, 0))
    # h [2, 320, 32, 48]

    channel_mult = [1, 2, 4, 4]
    num_res_blocks = [2] * 4

    model_channels = 320
    index = 1
    for level, mult in enumerate(channel_mult):
        ch = model_channels * mult
        for nr in range(num_res_blocks[level]):
            pre = 'control_model.input_blocks.{}'.format(index)
            h = resblock(network, weight_map, embed_weight, '{}.0'.format(pre), ch, h, emb)
            print('resblock: ', h.get_output(0).shape)
            if level != len(channel_mult) -1:
                h = spatial_transformer(network, weight_map, '{}.1'.format(pre), ch, h, context)
            hs.append(zero_convs(network, weight_map, h, index))

            # ch = mult * model_channels
            index = index + 1

        if level != len(channel_mult) - 1:
            pre = 'control_model.input_blocks.{}'.format(index)
            out_ch = ch
            h = downsample(network, weight_map, '{}.0'.format(pre), out_ch, h)
            hs.append(zero_convs(network, weight_map, h, index))
            index = index + 1
        
        # if index == 10:
    return hs, h

网络构建模块

controlnet

这里h, hint, emb经过input_block_control得到control和h的特征,h经过middle_block得到不同尺度特征的control特征。

def control_net(network, weight_map, embed_weight, h, hint, emb, context):
    # #####################
    # # time_embed
    # #####################

    #####################
    # input_blocks
    #####################
    control, h = input_block_control(network, weight_map, embed_weight, h, emb, context, hint)
    print(h.get_output(0).shape)

    #####################
    # middle_blocks
    #####################   
    h = middle_block(network, weight_map, embed_weight, h, emb, context, 'control_model')
    h = conv(network, weight_map, h, 1280, 'control_model.middle_block_out.0', 1, 0, 1)

    control.append(h)
    return control

Unet

Unet的组成相对简单,经过input_block、middle_block和output_blocks得到最终结果,并返回最终状态。

def unet(network, weight_map, embed_weight, h, emb, context, control):
    # #####################
    # # time_embed
    # #####################


    #####################
    # input_blocks
    #####################
    hs, h = input_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
    print(h.get_output(0).shape)

    #####################
    # middle_blocks
    #####################   
    h = middle_block(network, weight_map, embed_weight, h, emb, context, 'model.diffusion_model')
    print(h.get_output(0).shape)

    h = network.add_elementwise(h.get_output(0), control.pop().get_output(0), trt.ElementWiseOperation.SUM)

    #####################
    # output_blocks
    #####################
    h = output_blocks(network, weight_map, embed_weight, h, emb, context, control, hs)

    # out
    # group_norm
    # h = group_norm_sile(network, weight_map, h)
    h = group_norm(network, weight_map, h, 'model.diffusion_model.out.0', silu=True)
    # silu
    # h = silu(network, h)
    # conv_nd
    h = conv(network, weight_map, h, 4, 'model.diffusion_model.out.2', 3, 1, 1)

    return h

参考

  1. nvidia python api: https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/
  2. xiatwhu: https://github.com/deeplearning/xiatwhu/trt2023