Max_Pool模块实现

发布时间 2023-07-04 11:43:02作者: 李白的白

Max_Pool模块实现

  • - Max_Pool模块是一种**池化(pooling)**操作,用于对输入的特征图(feature map)进行降采样(downsampling),从而减少参数数量,提高计算效率,防止过拟合。
    - Max_Pool模块的原理是在输入的特征图上滑动一个固定大小的窗口(kernel),每次取窗口内的最大值作为输出的一个元素。窗口每次移动的距离称为步长(stride)。
    - Max_Pool模块可以有不同的参数设置,如窗口大小、步长、填充(padding)等,影响输出的特征图的形状。
    - Max_Pool模块通常接在卷积层(convolutional layer)后面,用于提取输入特征图中的最显著的特征,忽略一些细节信息。
    
  • 最大值池化:一种降低图像数据维度的方法,可以提高计算效率和抽象能力

    • 原理:在一个固定大小的窗口内,选取窗口内的最大值作为输出
    • 参数:窗口大小为(2x2)、步长(2)、无填充
    • 示例:
      • 输入特征图:$$\begin{bmatrix}1 & 2 & 3 & 4 \ 5 & 6 & 7 & 8 \ 9 & 10 & 11 & 12 \ 13 & 14 & 15 & 16 \end{bmatrix}$$
      • 输出特征图:$$\begin{bmatrix}6 & 8 \ 14 & 16 \end{bmatrix}$$
    • max_pool_example.png
  • 常规方案:先构造一个2*2的矩阵窗口,再分别比较窗口内的四个值,最后得到最大值

      • 需要调用FIFO,存储图像数据,同步输出同列数据
      • 需要定义两个寄存器,构造2*2的矩阵
      • FIFO的深度至少是图像一行的数据量,例如416
    • 优点:逻辑简单,易于理解

    • 缺点:需要调用FIFO存储图像数据,FIFO深度至少为一行图像数据量,占用较多存储资源!image

  • 优化方案:先比较相邻两个数据的最大值,再将结果写入FIFO,然后再比较FIFO读出的数据和当前数据的最大值,得到最终结果

    先比较相邻两个数据的最大值,再比较两个最大值的最大值,不需要构造矩阵窗口

    • 优点:存储资源占用较少,FIFO深度为图像一行数据量的一半,不需要额外的寄存器
    • 缺点:时序要求更高,需要在每个数据到来时进行比较和写入操作
    • 在图像数据过来的时候,先在内部定义一个计算器,对数据做一个打拍
    • 先比较一行里面相邻的两个数据的最大值,然后把结果写到FIFO里面
    • 当第二行数据过来的时候,再比较一行里面相邻的两个数据的最大值,然后从FIFO里面读出第一行的最大值,和当前的最大值再做一次比较,得到最终的池化结果
    • 这样可以节省FIFO的深度一半,只需要图像一行数据量的一半,例如第一层是208
    • 时序上面要求稍微高一点,需要在每个数据到来之后做比较和写入或读出操作
    • 示例
    • 假设输入图像数据如下:
    00 01 02 03
    10 11 12 13
    20 21 22 23
    30 31 32 33
    • 那么输出图像数据如下:
    11 13
    31 33
    • 具体过程如下:
      • 当第一行数据00,01,02,03进入模块时,先用一个寄存器R1对数据进行缓存和打拍(Shift Register)(即延迟一个时钟周期)为00,01,02,03
      • 比较器(Comparator)比较相邻两个数据00和01,得到最大值01,并通过fifo_wr_data写入FIFO
      • 比较器比较相邻两个数据02和03,得到最大值03,并通过fifo_wr_data写入FIFO
      • 此时FIFO中有两个数据01和03,深度为一行的一半(208)
      • 当第二行数据10,11,12,13进入模块时,R1同样进行缓存、打拍和比较为10,11,12,13
      • 比较器比较相邻两个数据10和11,得到最大值11,并暂存为Max_Data
      • FIFO读出第一行的第一个最大值01,并和Max_Data比较,得到更大的值\(11\),并输出为Out_Datapool_data
      • 比较器比较相邻两个数据12和13,得到最大值13,并暂存为Max_Data
      • FIFO读出第一行的第二个最大值03,并和Max_Data比较,得到更大的值\(13\),并输出为Out_Datapool_data
      • 此时输出的第一行数据为\(11\)\(13\)
      • 以此类推,可以得到输出的第二行数据为\(31\)\(33\)
    • 时序图:

    image

    • 最大值池化的代码实现

      • 定义输入输出端口和内部信号
      module max_pool(
          input clk,
          input rst_n,
          input data_in_valid,
          input [7:0] data_in,
          output reg data_out_valid,
          output reg [7:0] data_out
      );
      
      reg [7:0] data_in_r1;
      reg [7:0] fifo_write_data;
      reg fifo_write_en;
      reg [7:0] max_data;
      reg [7:0] fifo_read_data;
      reg fifo_read_en;
      reg row_even_flag;
      reg data_in_value_r1;
      
      • 对输入数据做打拍和比较,得到fifo_write_data和fifo_write_en
      always @(posedge clk or negedge rst_n) begin
          if (~rst_n) begin
              data_in_r1 <= 8'h00;
          end else begin
              if (data_in_valid) begin
                  data_in_r1 <= data_in;
              end else begin
                  data_in_r1 <= 8'h00;
              end 
          end 
      end 
      
      always @(*) begin
          fifo_write_data = 8'h00;
          if (data_in_r1 >= data_in) begin
              fifo_write_data = data_in_r1;
          end else begin
              fifo_write_data = data_in;
          end 
      end 
      
      assign fifo_write_en = ~row_even_flag & data_in_value_r1 & data_in_valid;
      
      • 调用FIFO模块,存储每行相邻两个数据的最大值,并读出第一行的最大值
      fifo_256x8 fifo_inst(
          .clk(clk),
          .rst_n(rst_n),
          .din(fifo_write_data),
          .wr_en(fifo_write_en),
          .rd_en(fifo_read_en),
          .dout(fifo_read_data)
      );
      
      • 比较第一行和第二行的最大值,得到data_out和data_out_valid
      always @(*) begin
          data_out = 8'h00;
          if (max_data >= fifo_read_data) begin
              data_out = max_data;
          end else begin
              data_out = fifo_read_data;
          end 
      end 
      
      assign data_out_valid = fifo_read_en;
      
      • 定义奇偶数行的标志信号和上升沿和下降沿的标志信号,控制时序逻辑
      always @(posedge clk or negedge rst_n) begin
          if (~rst_n) begin
              row_even_flag <= 1'b0;
          end else begin
              if (data_in_value_r1 & ~data_in_valid) begin
                  row_even_flag <= ~row_even_flag;
              end 
          end 
      end 
      
      always @(posedge clk or negedge rst_n) begin
          if (~rst_n) begin
              data_in_value_r1 <= 1'b0;
          end else begin
              if (data_in_valid) begin
                  data_in_value_r1 <= data_in_valid;
              end else begin
                  data_in_value_r1 <= 1'b0;
              end 
          end 
      end 
      

    max_pool_ch0

    单通道最大值池化模块主要代码:

    assign  max_data        =       fifo_wr_data;
    
    always  @(posedge sclk) begin
            data_in_r1      <=      data_in;
    end
    
    always  @(posedge sclk or negedge s_rst_n) begin
            if(s_rst_n == 1'b0)
                    fifo_wr_data    <=      'd0;
            else if(data_in_r1 >= data_in)
                    fifo_wr_data    <=      data_in_r1;
            else
                    fifo_wr_data    <=      data_in;
    end
    
    
    always  @(posedge sclk or negedge s_rst_n) begin
            if(s_rst_n == 1'b0)
                    data_out        <=      'd0;
            else if(max_data >= fifo_rd_data)
                    data_out        <=      max_data;
            else
                    data_out        <=      fifo_rd_data;
    end
    
    pool_fifo_ip    pool_fifo_ip (
            .clk                    (sclk                   ),      // input wire clk
            .srst                   (~s_rst_n               ),      // input wire srst
            .din                    (fifo_wr_data           ),      // input wire [7 : 0] din
            .wr_en                  (fifo_wr_en             ),      // input wire wr_en
            .rd_en                  (fifo_rd_en             ),      // input wire rd_en
            .dout                   (fifo_rd_data           ),      // output wire [7 : 0] dout
            .full                   (                       ),      // output wire full
            .empty                  (                       ),      // output wire empty
            .data_count             (                       )       // output wire [8 : 0] data_count
    );
    

    max_pool_8ch

    8通道最大值池化模块主要代码:

    • 将单通道模块复制8份(max_pool_ch0 ~ max_pool_ch7)
    • 将输入输出数据合并为8位宽(data_in, data_out)
    • 将FIFO读写使能信号统一在外部定义(fifo_wr_en, fifo_rd_en)
    always  @(posedge sclk) begin
            data_in_vld_r1  <=      data_in_vld;
            data_out_vld    <=      fifo_rd_en;
    end
    
    always  @(posedge sclk or negedge s_rst_n) begin
            if(s_rst_n == 1'b0)
                    col_even_flag   <=      1'b0;
            else if(data_in_vld == 1'b1)
                    col_even_flag   <=      ~col_even_flag;
            else
                    col_even_flag   <=      1'b0;
    end
    
    always  @(posedge sclk or negedge s_rst_n) begin
            if(s_rst_n == 1'b0)
                    fifo_wr_en      <=      1'b0;
            else if(row_even_flag == 1'b1 && col_even_flag == 1'b1)
                    fifo_wr_en      <=      1'b1;
            else
                    fifo_wr_en      <=      1'b0;
    end
    
    always  @(posedge sclk or negedge s_rst_n) begin
            if(s_rst_n == 1'b0)
                    fifo_rd_en      <=      1'b0;
            else if(row_even_flag == 1'b0 && col_even_flag == 1'b1)
                    fifo_rd_en      <=      1'b1;
            else
                    fifo_rd_en      <=      1'b0;
    end
    
    always  @(posedge sclk or negedge s_rst_n) begin
            if(s_rst_n == 1'b0)
                    row_even_flag   <=      1'b0;
            else if(data_in_vld == 1'b1 && data_in_vld_r1 == 1'b0)
                    row_even_flag   <=      ~row_even_flag;
    end
    
    max_pool        ch0_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch0_data_in            ),
            .data_out               (ch0_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch1_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch1_data_in            ),
            .data_out               (ch1_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch2_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch2_data_in            ),
            .data_out               (ch2_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch3_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch3_data_in            ),
            .data_out               (ch3_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch4_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch4_data_in            ),
            .data_out               (ch4_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch5_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch5_data_in            ),
            .data_out               (ch5_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch6_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch6_data_in            ),
            .data_out               (ch6_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );
    
    max_pool        ch7_max_pool_inst(
            // system signals
            .sclk                   (sclk                   ),
            .s_rst_n                (s_rst_n                ),
            // 
            .data_in                (ch7_data_in            ),
            .data_out               (ch7_data_out           ),
            .fifo_wr_en             (fifo_wr_en             ),
            .fifo_rd_en             (fifo_rd_en             )
    );