Max_Pool模块完善

发布时间 2023-08-12 21:48:13作者: 李白的白

Max_Pool模块完善

什么是最大池化层(Max Pooling Layer)?

- 最大池化层是一种常用的池化层(Pooling Layer),它的作用是对输入的特征图(Feature Map)进行降维压缩,以加快运算速度,减少参数数量,防止过拟合,提高模型的尺度不变性和旋转不变性 。
- 最大池化层的原理是:在前向传播过程中,对每个特征图的区域(通常是2x2或3x3的窗口),选择其中的最大值作为该区域池化后的值;在反向传播过程中,梯度只通过前向传播时的最大值反向传播,其他位置的梯度为0 。
- 最大池化层可以分为重叠池化(Overlapping Pooling)和非重叠池化(Non-overlapping Pooling),区别在于窗口移动的步长(Stride)是否等于窗口大小 。
  - 重叠池化的步长小于窗口大小,例如AlexNet/GoogLeNet系列中采用的3x3窗口,步长为2的重叠池化。
  - 非重叠池化的步长等于窗口大小,例如VGG系列中采用的2x2窗口,步长为2的非重叠池化。
- 最大池化层的优点在于它能学习到图像的边缘和纹理结构,同时保留了一定的空间信息 。
最大池化层有以下几个优点:

- 减少模型的参数数量和计算量,从而提高模型的效率和速度。
- 增强模型对输入图像中特征位置变化的鲁棒性,从而提高模型的泛化能力。
- 提取更高层次或更抽象的特征,从而增强模型的表达能力。

目的

  • 为了实现YOLO tiny网络中的最大值池化层,需要对Max_Pool模块进行完善,增加对步长等于1的最大值池化的支持。
  • 步长等于1的最大值池化是在网络结构中的第11层,步长为1的最大池化层是在每个2x2的窗口内选择最大值作为输出,同时在输入特征图的两条边上填充0,使得输出特征图的尺寸与输入相同。

stride=1最大值池化的工作原理

  • stride等于1的最大池化是一种特殊的最大池化层,它的作用是对输入的特征图进行填充(Padding)和平滑(Smoothing),而不改变特征图的尺寸。
  • stride等于1的最大池化的原理是:在前向传播过程中,在输入特征图的两条相邻边(通常是上边和左边)填充0,使得特征图尺寸增加一个单位(例如13x13变成14x14);然后对每个特征图区域(通常是2x2窗口),选择其中的最大值作为该区域池化后的值;在反向传播过程中,梯度只通过前向传播时的最大值反向传播,其他位置的梯度为0。
  • stride等于1的最大池化可以分为四种情况,根据填充边和移动方向不同而有所区别:
    • 上左填充,右下移动:这是Yolo网络中使用的情况,在layer 11处采用了stride等于1,窗口大小为2x2的最大池化。
    • 上右填充,左下移动:这种情况与上一种情况类似,只是填充边和移动方向相反。
    • 下左填充,右上移动:这种情况与上一种情况类似,只是填充边和移动方向相反。
    • 下右填充,左上移动:这种情况与上一种情况类似,只是填充边和移动方向相反。
  • stride等于1的最大池化的优点在于它能平滑特征图,减少噪声,增强特征的鲁棒性。
  • 例如,如果输入图像为:
D0 D1 D2
D3 D4 D5
D6 D7 D8

则填充后的图像为:

0 0 0 0
0 D0 D1 D2
0 D3 D4 D5
0 D6 D7 D8

则输出图像为:

max(0,D0) max(D0,D1) max(D1,D2)
max(0,D3) max(D3,D4) max(D4,D5)
max(0,D6) max(D6,D7) max(D7,D8)

实现

时序:
image

  • 输入数据和输出数据的时序控制
    • 输入数据由 data_in 和 data_in_valid 信号组成,data_in_valid 信号为高时表示 data_in 信号有效。
    • 输出数据由 data_out 和 data_out_valid 信号组成,data_out_valid 信号为高时表示 data_out 信号有效。
    • 输入数据和输出数据都是按行顺序传输,每行有 13 个数据。
    • 输入数据和输出数据之间有一定延迟,因为需要进行比较和 FIFO 的读写操作。
  • 数据打拍和比较逻辑
    • 数据打拍是指将输入数据分成两个寄存器,一个寄存器存储当前数据,另一个寄存器存储上一个数据,以便进行相邻两个数据的比较。
    • 比较逻辑是指将两个寄存器中的数据进行比较,取最大值作为输出数据。
    • 数据打拍和比较逻辑需要根据 stride 的值进行不同的处理。
    • 当 stride 等于二时,数据打拍和比较逻辑只在 data_in_valid 信号为高时进行,即每两个数据进行一次打拍和比较。
    • 当 stride 等于一时,数据打拍和比较逻辑需要在每个数据到来时进行,即每个数据进行一次打拍和比较。此外,还需要在每行的第一个数据到来之前,将上一个寄存器的值赋为零,以实现填充数据的效果。
  • FIFO 的读写控制
    • FIFO 是指先进先出的存储器,用于存储上一行的最大值结果,以便与下一行的最大值结果进行比较。
    • FIFO 的读写控制需要根据 stride 的值进行不同的处理。
    • 当 stride 等于二时,FIFO 的写使能信号由 row_even_flag 和 col_even_flag 信号控制,即每四个数据写入一次 FIFO。FIFO 的读使能信号由 row_even_flag 信号控制,即从第二行开始每两行读取一次 FIFO。
    • 当 stride 等于一时,FIFO 的写使能信号由 data_in_valid 信号控制,即每个数据写入一次 FIFO。FIFO 的读使能信号由 row_cnt 信号控制,即从第二行开始每行读取一次 FIFO。此外,还需要在每次池化开始之前,将 FIFO 的数据清空,以避免干扰。
  • 行计数器和列计数器
    • 行计数器和列计数器用于记录当前处理的是第几行和第几列的数据,以便进行 FIFO 的读写控制和输出数据的时序控制。
    • 行计数器和列计数器需要根据 stride 的值进行不同的处理。
    • 当 stride 等于二时,行计数器和列计数器都是在 data_in_valid 信号的下降沿进行加一操作,即每两个数据加一次。
    • 当 stride 等于一时,行计数器和列计数器都是在 data_in_valid 信号的上升沿进行加一操作,即每个数据加一次。
  • stride 的判断和传递
    • stride 的判断是指根据配置表格中的 bit13 位来判断当前层的池化步长是多少,bit13 位为零表示步长为二,bit13 位为一表示步长为一。
    • stride 的传递是指将配置表格中的 bit13 位作为一个端口输入到池化模块中,并在池化模块中根据该端口来进行不同的处理。
  • 填充数据的生成
    • 填充数据的生成是指在 stride 等于一的情况下,在输入图像的两条边上添加零像素,以保证输出图像的尺寸和输入图像相同。
    • 填充数据的生成可以通过在每行的第一个数据到来之前,将上一个寄存器的值赋为零来实现。这样就相当于在输入图像的上边和左边各添加了一行或一列零像素。

Example:

  • 假设有一个\(3*3\)大小的输入图像,其数值如下
1 2 3
4 5 6
7 8 9

对该输入图像进行\(stride=1\)最大值池化,首先需要对其进行填充,在其上边和左边添加一行或一列0,得到一个\(4*4\)大小的填充后图像:

0 0 0 0
0 1 2 3
0 4 5 6
0 7 8 9

然后使用一个\(2*2\)大小的滑动窗口,在每个子区域内选择最大值作为输出,滑动窗口从左上角开始,每次向右或向下移动一个像素,直到覆盖整个填充后图像。我们可以得到一个\(3*3\)大小的输出图像,其数值如下:

max(0,1) = 1 max(1,2) = 2 max(2,3) = 3
max(0,4) = 4 max(4,5) = 5 max(5,6) = 6
max(0,7) = 7 max(7,8) = 8 max(8,9) = 9

可以看到,输出图像与输入图像尺寸相同,且每个位置的值都是对应子区域的最大值。

代码框架:

module max_pool_stride_1(
    input clk, rst, // 时钟和复位信号
    input data_in_valid, // 输入数据的有效标志
    input [7:0] data_in, // 输入数据,8位宽,每次输入一个数据
    output reg data_out_valid, // 输出数据的有效标志
    output reg [7:0] data_out // 输出数据,8位宽,每次输出一个数据
);

// 定义一些参数和信号
parameter data_width = 8; // 输入数据的位宽
parameter channel_num = 8; // 输入数据的通道数
parameter row_num = 13; // 输入数据的行数
parameter col_num = 13; // 输入数据的列数
parameter window_size = 2; // 窗口的大小
parameter stride = 1; // 窗口的步长
parameter fifo_depth = 14; // FIFO 缓存的深度

reg [7:0] data_in_ie; // 数据打拍后的输出
reg [7:0] max_data; // 比较器的输出,即相邻两个数据的最大值
reg fifo_write_en; // FIFO 缓存的写使能信号
reg fifo_read_en; // FIFO 缓存的读使能信号
wire [7:0] fifo_read_data; // FIFO 缓存的读数据信号
reg [3:0] row_cnt; // 行计数器,用于记录当前输入数据的行数
reg [3:0] col_cnt; // 列计数器,用于记录当前输入数据的列数

// 实例化 FIFO 缓存模块,使用 Xilinx 提供的 IP 核
FIFO_SRL #(
    .DATA_WIDTH(data_width), // 数据位宽
    .FIFO_DEPTH(fifo_depth), // FIFO 深度
    .ALMOST_EMPTY_OFFSET(1), // 几乎空偏移量,用于产生几乎空标志信号
    .ALMOST_FULL_OFFSET(1) // 几乎满偏移量,用于产生几乎满标志信号
) fifo (
    .clk(clk), // 时钟信号
    .srst(rst), // 同步复位信号
    .din(max_data), // 写数据信号,即比较器的输出
    .wr_en(fifo_write_en), // 写使能信号
    .rd_en(fifo_read_en), // 读使能信号
    .dout(fifo_read_data), // 读数据信号,即 FIFO 缓存中存储的上一行相邻两个数据的最大值
    .full(), // 满标志信号,本例中不使用
    .empty(), // 空标志信号,本例中不使用
    .prog_full(), // 几乎满标志信号,本例中不使用
    .prog_empty() // 几乎空标志信号,本例中不使用
);

// 使用时序逻辑来控制数据打拍和比较器的工作状态

always @(posedge clk) begin

    if (rst) begin

        data_in_ie <= 0; // 复位时将数据打拍后的输出清零

        max_data <= 0; // 复位时将比较器的输出清零

        row_cnt <= 0; // 复位时将行计数器清零

        col_cnt <= 0; // 复位时将列计数器清零

    end else begin

        // 当输入数据有效时,进行数据打拍和比较

        if (data_in_valid) begin

            // 数据打拍,将当前输入数据和上一个输入数据同时输出

            data_in_ie <= data_in;

            // 比较器,比较相邻两个数据的大小,并输出较大的数据

            if (data_in > data_in_ie) begin

                max_data <= data_in;

            end else begin

                max_data <= data_in_ie;

            end

            // 行计数器,根据输入数据的有效标志的下降沿来计数

            if (data_in_valid == 0 && data_in_ie == 1) begin

                row_cnt <= row_cnt + 1;

            end

            // 列计数器,根据输入数据的有效标志来计数

            col_cnt <= col_cnt + 1;

        end else begin

            // 当输入数据无效时,将数据打拍后的输出清零,保证第一行和第一列的数据只与 0 比较

            data_in_ie <= 0;

        end
    end
end
graph LR subgraph 输入特征图 A((1)) -- 2x2 窗口 --> B((0)) B -- 2x2 窗口 --> C((3)) D((4)) -- 2x2 窗口 --> E((6)) E -- 2x2 窗口 --> F((8)) G((9)) -- 2x2 窗口 --> H((5)) H -- 2x2 窗口 --> I((7)) end subgraph 输出特征图 J((4)) --> K((6)) --> L((8)) M((9)) --> N((9)) --> O((8)) P((9)) --> Q((7)) --> R((7)) end A -.-> J B -.-> K C -.-> L D -.-> M E -.-> N F -.-> O G -.-> P H -.-> Q I -.-> R style B fill:#f9f,stroke:#333,stroke-width:4px; style C fill:#f9f,stroke:#333,stroke-width:4px; style E fill:#f9f,stroke:#333,stroke-width:4px; style F fill:#f9f,stroke:#333,stroke-width:4px; style K fill:#ff6,stroke:#333,stroke-width:4