YOLO Int8量化模块实现

发布时间 2023-06-21 18:44:10作者: 李白的白
什么是量化?
  • 量化是一种将浮点数转换为整数的方法,可以减少计算量和存储空间,提高模型的运行效率和部署能力。
  • 量化的过程可以表示为:

\[Q(x) = round(\frac{x}{s}) + z \]

  • 其中,\(x\)是浮点数,\(s\)是缩放因子(scale factor),\(z\)是零点(zero point),\(Q(x)\)是量化后的整数。
  • 缩放因子\(s\)和零点\(z\)可以根据不同的量化方法和范围来确定,例如对称量化(symmetric quantization)或非对称量化(asymmetric quantization),无符号整数(unsigned integer)或有符号整数(signed integer)等。
为什么要对YOLO进行Int8量化?
  • YOLO是一种流行的目标检测算法,它可以快速地在图像中定位和识别多个目标。
  • YOLO的模型通常使用浮点数来表示参数和特征,这样可以保证模型的精度和表达能力,但也带来了较高的计算量和存储空间的需求。
  • 为了在移动设备或嵌入式系统上部署YOLO模型,需要对模型进行压缩和优化,以适应有限的资源和性能要求。
  • Int8量化是一种常用的压缩和优化方法,它可以将浮点数转换为8位整数,从而减少模型的大小和运行时间,同时尽量保持模型的精度和效果。
如何对YOLO进行Int8量化?
  • 对YOLO进行Int8量化的主要步骤如下:

    • 确定量化方法和范围,例如使用非对称量化和无符号整数。
    • 计算每一层的缩放因子\(s\)和零点\(z\),根据输入数据和权重的分布和范围来确定。
    • 对每一层的输入数据和权重进行量化,即将浮点数转换为整数,根据公式\(Q(x) = round(\frac{x}{s}) + z\)来计算。
    • 对每一层的输出数据进行反量化,即将整数转换为浮点数,根据公式\(x = s(Q(x) - z)\)来计算。
    • 对每一层的卷积操作进行优化,使用整数乘法和移位代替浮点乘法和除法,减少计算量和提高精度。
    • 对每一层的激活函数进行调整,使用整数表示激活输出的范围和零点,避免数据溢出或损失。
  • 下面以第0层卷积层为例,详细说明如何对YOLO进行Int8量化。

第0层卷积层的Int8量化
  • 第0层卷积层的输入数据和权重的缩放因子\(s\)和零点\(z\)分别为:

    • 输入数据:\(s_1 = 0.00784314, z_1 = 127\)
    • 权重:\(s_2 = 0.00248109, z_2 = 0\)
  • 第0层卷积层的输入数据和权重的量化公式分别为:

    • 输入数据:\(Q_1(x) = round(\frac{x}{s_1}) + z_1\)
    • 权重:\(Q_2(x) = round(\frac{x}{s_2}) + z_2\)
  • 第0层卷积层的输出数据的缩放因子\(s\)和零点\(z\)分别为:

    • 输出数据:\(s_3 = s_1 \times s_2, z_3 = 86\)
  • 第0层卷积层的输出数据的反量化公式为:

    • 输出数据:\(x = s_3(Q_3(x) - z_3)\)
  • 第0层卷积层的卷积操作的优化方法为:

    • 使用整数乘法和移位代替浮点乘法和除法,即:

      \[Q_3(x) = round(\frac{Q_1(x) \times Q_2(x)}{s_3}) + z_3 \]

      等价于:

      \[Q_3(x) = \frac{(Q_1(x) \times Q_2(x)) \times M}{2^{15}} >> S + z_3 \]

    • 其中,\(M\)是一个整数,用来调整缩放因子\(s_3\)的大小,使其接近\(2^{15}\),以提高精度和避免溢出,例如:

      \[M = round(\frac{2^{15}}{s_3}) \]

    • \(S\)是一个整数,用来表示移位的次数,根据缩放因子\(s_3\)的大小来确定,例如:

      \[S = -log_2(s_3) \]

    • 第0层卷积层的\(M\)\(S\)分别为:

      • \(M = 19290\)
      • \(S = -4\)
  • 第0层卷积层的激活函数的调整方法为:

    • 使用整数表示激活输出的范围和零点,例如使用ReLU6作为激活函数,那么激活输出的范围为\([0, 6]\),零点为\(0\),则可以用整数表示为:

      \[Q_a(x) = min(max(Q_3(x), 0), \frac{6}{s_a}) + z_a \]

    • 其中,\(s_a\)是激活输出的缩放因子,可以根据激活输出的范围来确定,例如:

      \[s_a = \frac{6}{255} \]

    • \(z_a\)是激活输出的零点,可以根据激活函数的类型来确定,例如对于ReLU6来说,零点为\(0\)

  • 第0层卷积层的激活函数没有进行调整,因为它还没有用到。但是在后面的层中,它会用到激活输出的缩放因子和零点。

YOLO Int8量化模块的实现过程如下:

  • 首先,定义一个quant_int8模块,用于将24位有符号整数数据输入转换为8位有符号整数数据输出。该模块的输入参数包括mult(乘法因子),shift(右移位数),zero_point(零点偏移量)。该模块的内部逻辑如下:
    • 使用mult_gen_0模块对data_in和mult进行乘法运算,得到39位有符号整数结果mult_rslt。
    • 使用always语句在时钟上升沿或复位信号下降沿对mult_rslt进行右移运算,并加上zero_point,得到8位有符号整数结果data_out。
  • 然后,定义一个quant_int8_8ch模块,用于将8个通道的24位有符号整数数据输入转换为8个通道的8位有符号整数数据输出。该模块的输入参数与quant_int8模块相同,但是对每个通道都使用了一个quant_int8模块实例来进行转换。该模块还使用了一个shift_reg模块来延迟data_in_vld信号,使其与data_out_vld信号对齐。

YOLO Int8量化模块的时序分析如下:

  • quant_int8模块的时钟周期为1ns,因此其乘法运算需要1ns完成,右移运算需要1ns完成,加法运算需要1ns完成。因此,quant_int8模块的总延迟为3ns。
  • quant_int8_8ch模块的时钟周期也为1ns,因此其对每个通道的转换需要3ns完成。由于各个通道是并行处理的,因此quant_int8_8ch模块的总延迟也为3ns。
  • shift_reg模块的时钟周期为1ns,因此其对data_in_vld信号的延迟为5ns。
  • 因此,YOLO Int8量化模块的总延迟为3ns + 5ns = 8ns。

YOLO Int8量化模块可能遇到的问题和解决方法如下:

  • 问题一:量化过程可能导致精度损失和信息丢失,影响YOLO网络的性能和准确度。
    • 解决方法一:选择合适的量化参数(mult,shift,zero_point),使得量化后的数据分布尽可能接近原始数据分布。
    • 解决方法二:在训练YOLO网络时,使用量化感知训练(Quantization-aware training)方法,使得网络能够适应量化后的数据,并减少精度损失。
  • 问题二:量化过程可能导致溢出或欠流现象,使得数据超出范围或变为零。
    • 解决方法一:在量化前,对数据进行范围检测和裁剪,使得数据不超过8位有符号整数的最大值或最小值。
    • 解决方法二:在量化后,对数据进行饱和运算(Saturation arithmetic),使得数据不超过8位有符号整数的最大值或最小值。

代码清单:

quant_int8

module  quant_int8(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        //
        input   signed  [23:0]  data_in                 ,
        //
        input           [14:0]  mult                    ,       
        input           [ 7:0]  shift                   ,       
        input           [ 7:0]  zero_point              ,       
        //
        output  reg     [ 7:0]  data_out                       
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/

wire    signed  [38:0]          mult_rslt                       ;

reg     [23:0]                  shift_rslt                      ;


//=============================================================================
//**************    Main Code   **************
//=============================================================================
mult_gen_0      mult_gen_0_inst (
        .CLK                    (sclk                   ),      // input wire CLK
        .A                      (data_in                ),      // input wire [23 : 0] A
        .B                      (mult                   ),      // input wire [14 : 0] B
        .P                      (mult_rslt              )       // output wire [38 : 0] P
);


always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0) begin
                shift_rslt      <=      'd0;
                data_out        <=      'd0;
        end
        else begin
                shift_rslt      <=      mult_rslt[38:15] >> shift;
                data_out        <=      shift_rslt + zero_point;
        end
end




endmodule

quant_int8_8ch

module  quant_int8_8ch(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        //
        input   signed  [23:0]  ch0_data_in             ,       
        input   signed  [23:0]  ch1_data_in             ,       
        input   signed  [23:0]  ch2_data_in             ,       
        input   signed  [23:0]  ch3_data_in             ,       
        input   signed  [23:0]  ch4_data_in             ,       
        input   signed  [23:0]  ch5_data_in             ,       
        input   signed  [23:0]  ch6_data_in             ,       
        input   signed  [23:0]  ch7_data_in             ,       
        input                   data_in_vld             ,       
        //
        input           [14:0]  mult                    ,       
        input           [ 7:0]  shift                   ,       
        input           [ 7:0]  zero_point              ,       
        //
        output  wire    [ 7:0]  ch0_data_out            ,       
        output  wire    [ 7:0]  ch1_data_out            ,       
        output  wire    [ 7:0]  ch2_data_out            ,       
        output  wire    [ 7:0]  ch3_data_out            ,       
        output  wire    [ 7:0]  ch4_data_out            ,       
        output  wire    [ 7:0]  ch5_data_out            ,       
        output  wire    [ 7:0]  ch6_data_out            ,       
        output  wire    [ 7:0]  ch7_data_out            ,       
        output  wire            data_out_vld                   
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/



//=============================================================================
//**************    Main Code   **************
//=============================================================================
shift_reg #(
        .DLY_CNT                (5                      )
)shift_reg_inst(
        .sclk                   (sclk                   ),
        .data_in                (data_in_vld            ),
        .data_out               (data_out_vld           )
);


quant_int8      ch0_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch0_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch0_data_out           )
);

quant_int8      ch1_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch1_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch1_data_out           )
);

quant_int8      ch2_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch2_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch2_data_out           )
);

quant_int8      ch3_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch3_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch3_data_out           )
);

quant_int8      ch4_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch4_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch4_data_out           )
);

quant_int8      ch5_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch5_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch5_data_out           )
);

quant_int8      ch6_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch6_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch6_data_out           )
);

quant_int8      ch7_quant_int8_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .data_in                (ch7_data_in            ),
        //
        .mult                   (mult                   ),
        .shift                  (shift                  ),
        .zero_point             (zero_point             ),
        //
        .data_out               (ch7_data_out           )
);

endmodule