对已有YOLO加速模块进行Layer2仿真

发布时间 2023-07-12 17:07:04作者: 李白的白

对已有YOLO加速模块进行Layer2仿真

Layer2的仿真流程和Layer0类似,只是在数据发送和接收方面有些不同

目的是验证Layer2的前8个输出通道结果

验证YOLO网络中第二层卷积层(Layer2)的前8个输出通道的结果

  • 加速模块每次只计算8个输出通道,需要分4批
  • 如果前8个通道结果正确,可以推断其他3x3卷积层也正确

Layer2的特点:

  • 输入通道数:16(CH0-CH15)
  • 输出通道数:32(CH0-CH31)
  • 卷积核大小:3x3
  • 步长:1
  • 填充:1
  • 激活函数:Leaky ReLU

加速模块的设计方案:

  • 每次只计算8个输出通道,需要分4批完成32个输出通道的计算

  • 每次只发送8个输入通道,需要分2批完成16个输入通道的发送

  • 每次只发送19行输入数据,需要多次发送完成208行输入数据的发送

  • 使用FIFO IP作为输入数据的缓存,深度为4096

  • 使用矩阵构造模块将输入数据转换为3x3的矩阵,方便卷积计算

  • 使用卷积计算模块进行卷积运算,并将结果累加到输出缓存中

  • 使用DMA将输出缓存中的数据传输回PS端

  • 如果前8个输出通道结果正确,可以推断其他3x3卷积层也正确

  • 为了简化验证过程,只验证输出通道0到7的结果

  • 如果输出通道0到7的结果正确,可以推断其他输出通道也正确

  • 如果Layer2的结果正确,可以推断其他3x3的卷积层也正确

  • Layer2的输入数据是Layer1的输出结果,大小为208x208

  • Layer2的输出结果需要分四批计算,每次计算8个通道

  • Layer2的每个输出通道需要用到16个3x3的卷积核和一个偏置参数

  • Layer2的卷积计算需要进行填充和激活操作

仿真数据

  • 使用Matlab生成仿真所需的参数和输入数据

  • 将数据转换成TXT文件,存放在layer2_matlab文件夹中

  • 文件包括:

    • layer2_image2txt.m:将Layer1的输出结果作为Layer2的输入数据,生成16个TXT文件(CH0到CH15)
    • layer2_param2txt.m:将Layer2的权重、偏置和激活参数转换成TXT文件,生成8个权重文件(CH0到CH7),1个偏置文件和1个激活文件
    • layer2_out.mat:Layer2的期望输出结果,用于与仿真结果进行比较
  • Layer2 的输入和输出

    • 输入数据

      • 来自 Layer1 的输出结果,保存在 layer1_out.mat 文件中

      • 每个通道的大小为 208x208,共有 16 个通道

      • 使用 layer2_image2txt.m 文件将输入数据转换为 txt 格式,生成 CH0.txtCH15.txt 文件

        • layer2_image2txt.m:作用是将Layer1的输出结果(208x208x16)转换为16个TXT文件(CH0~CH15),每个文件包含一个通道的数据,按行存储。具体步骤如下:
          • 读取Layer1的输出结果,这是一个四维矩阵(1x208x208x16),每个元素是一个8位无符号整数(uint8)
          • 对于每个通道(0~15),创建一个空的TXT文件,命名为CH0.txt, CH1.txt, …, CH15.txt
          • 对于每个通道,将其对应的三维矩阵(1x208x208)转换为二维矩阵(208x208),并按行写入TXT文件中,每个元素占用两个字符,中间没有空格或换行符
          • 关闭所有TXT文件
      • 每个 txt 文件中的数据按照行优先的顺序排列

    • 输出

      • 使用paramel_to_txt.m文件将Layer2的权重、偏置和激活参数转换为8个TXT文件,每个文件包含4个输出通道的参数

        • layer2_param2txt.m:作用是将Layer2的权重(3x3x16x32)、偏置(32)和激活(32)参数转换为TXT文件,每个文件包含8个输出通道的参数,按输出通道、输入通道、行、列的顺序存储。具体步骤如下:
          • 读取Layer2的权重、偏置和激活参数,这些都是一维数组,每个元素是一个16位有符号整数(int16)
          • 对于每批8个输出通道(0~7, 8~15, 16~23, 24~31),创建一个空的TXT文件,命名为param0.txt, param1.txt, …, param3.txt
          • 对于每批8个输出通道,将其对应的权重参数(3x3x16x8)转换为一维数组(1152),并按输出通道、输入通道、行、列的顺序写入TXT文件中,每个元素占用四个字符,中间没有空格或换行符
          • 对于每批8个输出通道,将其对应的偏置参数(8)和激活参数(8)分别写入TXT文件中,每个元素占用四个字符,中间没有空格或换行符
          • 关闭所有TXT文件

layer2_activation.m:计算Layer2的激活函数(leaky ReLU)的查找表,并保存为TXT文件,包含256个数据,表示从-128到127的整数输入对应的输出值

权重txt文件存储结构如下表所示:

文件名 含义 存储内容
CH0.txt 第0个输出通道的权重 依次存储输出通道0,8,16,24对应的16个3x3卷积核
CH1.txt 第1个输出通道的权重 依次存储输出通道1,9,17,25对应的16个3x3卷积核
CH2.txt 第2个输出通道的权重 依次存储输出通道2,10,18,26对应的16个3x3卷积核
CH3.txt 第3个输出通道的权重 依次存储输出通道3,11,19,27对应的16个3x3卷积核
CH4.txt 第4个输出通道的权重 依次存储输出通道4,12,20,28对应的16个3x3卷积核
CH5.txt 第5个输出通道的权重 依次存储输出通道5,13,21,29对应的16个3x3卷积核
CH6.txt 第6个输出通道的权重 依次存储输出通道6,14,22,30对应的16个3x3卷积核
CH7.txt 第7个输出通道的权重 依次存储输出通道7,15,23,31对应的16个3x3卷积核

每个txt文件中,每行代表一个卷积核,每行有9个数值,分别代表卷积核中从左到右,从上到下的9个元素。例如,CH0.txt中第一行代表输出通道0对应输入通道0的卷积核,第二行代表输出通道0对应输入通道1的卷积核,以此类推,直到第16行代表输出通道0对应输入通道15的卷积核。然后第17行代表输出通道8对应输入通道0的卷积核,以此类推,直到第144行代表输出通道31对应输入通道15的卷积核。其他文件同理。

  • 权重txt文件有8个,分别是CH0.txt, CH1.txt, ..., CH7.txt,每个文件对应4个输出通道的权重参数。
  • 每个权重txt文件里,按照输出通道的顺序,依次存储了16个3x3的卷积核,每个卷积核对应一个输入通道。
  • 每个卷积核里,按照行优先的顺序,依次存储了9个权重值,每个值占一行。
  • 例如,CH0.txt文件里,第一个3x3的卷积核是输出通道0和输入通道0的权重参数,第二个3x3的卷积核是输出通道0和输入通道1的权重参数,以此类推,直到第16个3x3的卷积核是输出通道0和输入通道15的权重参数。然后是输出通道8和输入通道0的权重参数,一直到输出通道8和输入通道15的权重参数。最后是输出通道16和输入通道0的权重参数,一直到输出通道16和输入通道15的权重参数。最后最后是输出通道24和输入通道0的权重参数,一直到输出通道24和输入通道15的权重参数。
  • 其他的权重txt文件也是类似的,只是对应不同的输出通道。例如,CH1.txt文件里,第一个3x3的卷积核是输出通道1和输入通道0的权重参数,第二个3x3的卷积核是输出通道1和输入通道1的权重参数,以此类推。然后是输出通道9和输入通道0的权重参数,一直到输出通道9和输入通道15的权重参数。最后是输出通道17和输入通道0的权重参数,一直到输出通道17和输入通道15的权重参数。最后最后是输出通道25和输入通道0的权重参数,一直到输出通道25和输入通道15的权重参数。

仿真流程

  • 在PL端发送数据和命令给PL端
  • 在PL端接收数据和命令,并进行卷积计算
  • 在PL端将卷积结果返回给PS端
  • 在PS端读取卷积结果并与Matlab结果比较

仿真细节

  • Layer2仿真输入数据发送说明

    指在对YOLO加速模块进行Layer2仿真时,如何将输入数据从PS端发送到PL端,并在PL端进行卷积计算和结果返回的过程

    • Layer2的输入数据有16个通道,每个通道的大小是208*208,但是每次只能发送8个通道的数据,因为加速模块的FIFO IP只能存储4096个数据,所以每次发送的行数不能超过19行。
    • Layer2的输出数据有32个通道,每个通道的大小是208*208,但是PL端的加速模块每次只能计算8个通道的结果,所以需要分四批进行计算。
    • 在发送数据时,需要给PL端发送相应的命令,包括写数据、卷积计算、读数据等,命令由24位二进制数表示,每一位代表不同的含义,如数据类型、卷积类型、填充类型、行数类型、批次类型、列数量、行数量等。
    • 在发送数据时,需要考虑到卷积计算的窗口大小和步长,因为3*3的卷积窗口需要相邻三行的数据,所以在发送第二次及以后的数据时,需要从前一次发送的最后两行开始,即18行开始,这样才能保证卷积计算的连续性。
    • 在发送数据时,需要等待PL端返回task finish信号,表示接收或计算完成,然后再发送下一次的命令或数据。
    • 在发送完所有输入数据后,需要从PL端读取所有输出数据,并将其存储在PS端。
    • 发送完一批数据后,需要等待TASK_FINISH信号,然后给出卷积计算的命令,让加速模块进行卷积运算,并把结果存储在BUFFER里面。
    • 发送完两批数据后,需要给出读DMA的命令,让加速模块把卷积结果通过DMA读出来,并返回给PS端。
    • 为了计算下一行的卷积结果,需要把前两行数据重新发送一遍,因为前两行数据已经被清空了。所以每次发送的行数会有重叠,比如第一次发送1-19行,第二次发送18-36行,依次类推,直到最后一次发送205-208行。
  • 为什么是18-36行数据

    这是因为在卷积计算的时候,需要用到相邻的三行数据,比如第19行的卷积结果需要用到第17、18、19行的数据,第20行的卷积结果需要用到第18、19、20行的数据,依此类推。但是在第一次发送数据后,前19行的数据已经从FIFO里面读出来了,没有缓存起来,所以为了计算下一行的卷积结果,就需要把丢失的前两行数据重新发一遍。所以第二次发送数据的时候,就从第18行开始发,一直到第36行,这样就可以保证每一行的卷积结果都有相邻的三行数据参与计算。

  • 为什么需要等待task finish信号?

    因为卷积计算是异步的,需要等待加速模块完成计算后才能发送下一批数据。task finish信号是加速模块给出的一个标志,表示卷积计算已经完成。

  • 发送数据和命令的流程如下:

次数 发送内容 命令 说明
第一次 8个输入通道的1-19行数据写入FIFO 32`h24_4181 发送第一批输入数据:写数据,特征图,3x3卷积,填充延迟,不包含第一行和最后一行,第一批,208列,19行
等待task finish PL端接收完数据后返回信号
卷积计算 32`h24_4184 开始对前8个通道的1-19行数据进行卷积计算
等待task finish PL端接收完数据后返回信号
第二次 8个输入通道的1-19行数据写入FIFO 32`h24_5181 发送第二批输入数据:写数据,特征图,3x3卷积,填充延迟,不包含第一行和最后一行,第二批,208列,19行
等待task finish
卷积计算 32`h24_5184 开始对后8个通道的1-19行数据进行卷积计算
等待task finish
PS端发送读DMA命令 将卷积结果返回给PS端 32`h020000 开始读取输出数据
第三次 8个输入通道的18-36行(重复前两行) 32`h24_4381 发送第三批输入数据:写数据,特征图,3x3卷积,有填充,中间位置,第一批,列数为208,行数为19
等待task finish
卷积计算 32`h24_4384
等待task finish
第四次 8个输入通道的18-36行(重复前两行) 32`h24_5381
等待task finish
卷积计算 32`h24_5384
等待task finish
... ... ...
第十三次(最后一次) 8个输入通道的205-208行(只有4行) 32`h64581 发送最后一批输入数据
等待task finish
卷积计算 32`h64584
等待task finish
第十四次(最后一次) 8个输入通道的205-208行(只有4行) 32`h65581 发送最后一批输入数据
等待task finish
卷积计算 32`h65584
等待task finish
PS端发送读DMA命令 将卷积结果返回给PS端 32`h020000 开始读取输出数据

流程图:

graph TD A[开始] --> B[发送前8个通道的1-19行] B --> C[命令字:244181] C --> D[等待task finish] D --> E[卷积计算] E --> F[命令字:244184] F --> G[等待task finish] G --> H{是否最后一行?} H -- 是 --> I[读DMA] I --> J[命令字:020000] J --> K[结束] H -- 否 --> L[发送前8个通道的18-36行] L --> M[命令字:244381] M --> N[等待task finish] N --> O[卷积计算] O --> P[命令字:244384] P --> Q[等待task finish] Q --> H I --> R[发送后8个通道的1-19行] R --> S[命令字:245181] S --> T[等待task finish] T --> U[卷积计算] U --> V[命令字:245184] V --> W[等待task finish] W --> X{是否最后一行?} X -- 是 --> J X -- 否 --> Y[发送后8个通道的18-36行] Y --> Z[命令字:245381] Z --> AA[等待task finish] AA --> AB[卷积计算] AB --> AC[命令字:245384] AC --> AD[等待task finish] AD --> X
  • 发送数据和命令的流程说明
    • 命令的格式为6位16进制数,每位对应一个功能
      • 第一位:写数据或读数据,0为写,1为读
      • 第二位:选择输入通道,0为前8个,1为后8个
      • 第三位:选择输出通道,0为前8个,1为后8个
      • 第四位:选择输入行数,0为19行,1为36行,2为4行
      • 第五位:选择卷积计算或不计算,0为不计算,4为计算
      • 第六位:保留位,一般为0
    • 例如:244181的含义是
      • 写数据
      • 前8个输入通道
      • 前8个输出通道
      • 19行输入数据
      • 不计算卷积
      • 保留位为0
  • 在sim中运行仿真,并观察状态机和数据变化
    • 发送数据
      • 首先发送偏置、激活和权重参数,命令格式为244181(二进制)
      • 然后分批发送输入数据,每次发送8个输入通道的19行数据,命令格式为244381(二进制)
      • 每次发送完数据后,等待task_finish信号,然后发送卷积计算命令,格式为244184(二进制)
      • 当输入数据发送完16个通道后,发送读DMA命令,格式为020000(二进制)
    • 接收数据
      • 等待task_finish信号,然后从FIFO中读取卷积结果,并保存到buffer中
      • 当卷积结果读取完32个输出通道后,通过DMA将结果返回给PS端
  • 比较仿真结果和期望结果
    • 期望结果是用Matlab计算的Layer2的输出结果,保存在layer2_output.mat文件中
    • 仿真结果是加速模块返回的卷积结果,保存在layer2_output_sim.txt文件中
    • 使用Matlab的read_layer2_output_sim.m脚本将仿真结果转换为mat文件,并与期望结果进行比较
    • 比较的方法是计算两者之间的均方误差(MSE),即$$\frac{1}{n}\sum_{i=1}^n (x_i-y_i)^2$$,其中\(x_i\)是期望结果,\(y_i\)是仿真结果,\(n\)是元素个数
    • 如果MSE小于一个阈值(例如0.01),则认为两者相等,否则认为不相等
  • 验证的结论是Layer2的前8个输出通道结果正确
    • MSE的值为0.0007,小于阈值0.01
    • 可以认为加速模块的3x3卷积层实现正确
    • 下一步可以验证Layer2的后24个输出通道结果

仿真结果
image

image

image

仿真代码清单:

layer2_sim

module  layer2_sim(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        // 
        output  reg     [63:0]  m_axis_mm2s_tdata       ,       
        output  wire    [ 7:0]  m_axis_mm2s_tkeep       ,       
        output  reg             m_axis_mm2s_tvalid      ,       
        input                   m_axis_mm2s_tready      ,       
        output  reg             m_axis_mm2s_tlast       ,       
        // Lite-Reg
        output  reg     [31:0]  slave_lite_reg0         ,
        output  reg     [31:0]  slave_lite_reg1         ,
        output  reg     [31:0]  slave_lite_reg2         ,
        output  reg     [31:0]  slave_lite_reg3         ,
        //
        input                   task_finish                    
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/
localparam      S_IDLE          =       8'h01                   ;
localparam      S_BIAS_TX       =       8'h02                   ;
localparam      S_LEAKYRELU_TX  =       8'h04                   ;
localparam      S_WEIGHT_TX     =       8'h08                   ;
localparam      S_FEATURE_TX    =       8'h10                   ;
localparam      S_CONV_CAL      =       8'h20                   ;
localparam      S_DMA_RX        =       8'h40                   ;
localparam      S_FINISH        =       8'h80                   ;

localparam      BATCH_END       =       'd2                     ;
localparam      TX_END          =       'd12                    ;
// localparam      TX_END          =       'd2                    ;

reg     [ 7:0]                  state                           ;       

wire    [63:0]                  bias_data                       ;       
wire                            bias_valid                      ;       
wire                            bias_last                       ;       

wire    [63:0]                  leakyrelu_data                  ;       
wire                            leakyrelu_valid                 ;       
wire                            leakyrelu_last                  ;       

wire    [63:0]                  weight_data                     ;       
wire                            weight_valid                    ;       
wire                            weight_last                     ;       

wire    [63:0]                  feature_data                    ;       
wire                            feature_valid                   ;       
wire                            feature_last                    ;    

reg     [ 7:0]                  batch_cnt                       ;
reg     [ 7:0]                  tx_cnt                          ;   
               
reg     [127:0]                  state_param                     ;
//=============================================================================
//**************    Main Code   **************
//=============================================================================

assign  m_axis_mm2s_tkeep       =       8'hFF;

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0) begin
                state           <=      S_IDLE;
                slave_lite_reg0 <=      32'h0;
                slave_lite_reg1 <=      32'h0;
                slave_lite_reg2 <=      32'h0;
                slave_lite_reg3 <=      32'h0;
        end
        else case(state)
                S_IDLE: begin
                        state           <=      S_BIAS_TX;
                        slave_lite_reg0 <=      32'h21;
                        slave_lite_reg1 <=      {16'd30363, 16'h0};
                        slave_lite_reg2 <=      {8'd8,8'h0,8'd86,8'd12};
                end
                S_BIAS_TX: begin
                        if(task_finish == 1'b1) begin
                                state           <=      S_LEAKYRELU_TX;
                                slave_lite_reg0 <=      32'h31;
                        end 
                        else begin
                                state           <=      S_BIAS_TX;
                                slave_lite_reg0 <=      32'h20;
                        end 
                end
                S_LEAKYRELU_TX: begin
                        if(task_finish == 1'b1) begin 
                                state           <=      S_WEIGHT_TX;
                                slave_lite_reg0 <=      32'h11;
                        end 
                        else begin
                                state   <=      S_LEAKYRELU_TX;
                                slave_lite_reg0 <=      32'h30;
                        end 
                end
                S_WEIGHT_TX: begin
                        if(task_finish == 1'b1) begin
                                state           <=      S_FEATURE_TX;
                                slave_lite_reg0 <=      32'h24_4181;    //发送第一批数据,包含第一行
                        end 
                        else begin
                                state           <=      S_WEIGHT_TX;
                                slave_lite_reg0 <=      32'h10;
                        end 
                end
                S_FEATURE_TX: begin
                        if(task_finish == 1'b1) begin
                                state           <=      S_CONV_CAL;
                                slave_lite_reg0 <=      {slave_lite_reg0[31:4], 4'h4};
                                slave_lite_reg1 <=      {16'd30363, 8'h0, batch_cnt};
                        end 
                        else begin
                                state           <=      S_FEATURE_TX;
                                slave_lite_reg0 <=      {slave_lite_reg0[31:4], 4'h0};
                        end 
                end
                S_CONV_CAL: begin
                        if(task_finish == 1'b1 && batch_cnt == 'd1) begin
                                state           <=      S_DMA_RX;
                                slave_lite_reg0 <=      {slave_lite_reg0[31:4], 4'h2};  // Read Start
                        end 
                        else if(task_finish == 1'b1 && batch_cnt == 'd0) begin
                                state           <=      S_FEATURE_TX;
                                if(tx_cnt == 'd0)               // 包含第一行数据,batch_type=1
                                        slave_lite_reg0 <=      32'h24_5181;
                                else if(tx_cnt == TX_END)       // 当前层中的最后位置的第一批数据,包含最后一行
                                        slave_lite_reg0 <=      32'h06_5581;
                                else 
                                        slave_lite_reg0 <=      32'h24_5381;
                        end 
                        else begin
                                state           <=      S_CONV_CAL;
                                slave_lite_reg0 <=      {slave_lite_reg0[31:4], 4'h0};
                        end 
                end
                S_DMA_RX: begin
                        if(task_finish == 1'b1 && tx_cnt == TX_END)
                                state           <=      S_FINISH;
                        else if(task_finish == 1'b1 && tx_cnt < (TX_END-1)) begin
                                state           <=      S_FEATURE_TX;
                                slave_lite_reg0 <=      32'h24_4381;
                        end 
                        else if(task_finish == 1'b1 && tx_cnt == (TX_END-1)) begin
                                state           <=      S_FEATURE_TX;
                                slave_lite_reg0 <=      32'h06_4581;
                        end 
                        else begin
                                state           <=      S_DMA_RX;
                                slave_lite_reg0 <=      {slave_lite_reg0[31:4], 4'h0};
                        end 
                end 
                S_FINISH:
                        state   <=      S_FINISH;
                default: begin
                        state           <=      S_IDLE;
                        slave_lite_reg0 <=      32'h0;
                        slave_lite_reg0 <=      32'h0;
                        slave_lite_reg0 <=      32'h0;
                        slave_lite_reg0 <=      32'h0;
                end
        endcase
end

always  @(*) begin
        case(state)
                S_BIAS_TX: begin
                        m_axis_mm2s_tdata       =       bias_data;
                        m_axis_mm2s_tvalid      =       bias_valid;
                        m_axis_mm2s_tlast       =       bias_last;
                end
                S_LEAKYRELU_TX: begin
                        m_axis_mm2s_tdata       =       leakyrelu_data;
                        m_axis_mm2s_tvalid      =       leakyrelu_valid;
                        m_axis_mm2s_tlast       =       leakyrelu_last;
                end
                S_WEIGHT_TX: begin
                        m_axis_mm2s_tdata       =       weight_data;
                        m_axis_mm2s_tvalid      =       weight_valid;
                        m_axis_mm2s_tlast       =       weight_last;
                end
                S_FEATURE_TX: begin
                        m_axis_mm2s_tdata       =       feature_data;
                        m_axis_mm2s_tvalid      =       feature_valid;
                        m_axis_mm2s_tlast       =       feature_last;
                end
                default: begin
                        m_axis_mm2s_tdata       =       64'h0;
                        m_axis_mm2s_tvalid      =       1'b0;
                        m_axis_mm2s_tlast       =       1'b0;
                end
        endcase
end

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                batch_cnt       <=      'd0;
        else if(state == S_CONV_CAL && task_finish == 1'b1 && batch_cnt == BATCH_END-1)
                batch_cnt       <=      'd0;
        else if(state == S_CONV_CAL && task_finish == 1'b1)
                batch_cnt       <=      batch_cnt + 1'b1;
end

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                tx_cnt  <=      'd0;
        else if(state == S_FINISH)
                tx_cnt  <=      'd0;
        else if(state == S_DMA_RX && task_finish == 1'b1)
                tx_cnt  <=      tx_cnt + 1'b1;
end

layer2_bias_tx  layer2_bias_tx_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n & state[1]     ),
        //
        .bias_data              (bias_data              ),
        .bias_valid             (bias_valid             ),
        .bias_last              (bias_last              ),
        .ready                  (m_axis_mm2s_tready     )
);

layer2_leakyrelu_tx     layer2_leakyrelu_tx_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n & state[2]     ),
        //
        .leakyrelu_data         (leakyrelu_data         ),
        .leakyrelu_valid        (leakyrelu_valid        ),
        .leakyrelu_last         (leakyrelu_last         ),
        .ready                  (m_axis_mm2s_tready     )
);

layer2_weight_tx        layer2_weight_tx_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n & state[3]     ),
        //
        .weight_data            (weight_data            ),
        .weight_valid           (weight_valid           ),
        .weight_last            (weight_last            ),
        .ready                  (m_axis_mm2s_tready     )
);

layer2_feature_tx       layer2_feature_tx_inst(
        // system signals
        .sclk                   (sclk                   ),
        .s_rst_n                (s_rst_n                ),
        //
        .feature_data           (feature_data           ),
        .feature_valid          (feature_valid          ),
        .feature_last           (feature_last           ),
        .ready                  (m_axis_mm2s_tready     ),
        //
        .batch_cnt              (batch_cnt              ),
        .tx_cnt                 (tx_cnt                 ),
        .state                  (state                  )
);

always  @(*) begin
        case(state)
                S_IDLE         : state_param     =       "S_IDLE";
                S_BIAS_TX      : state_param     =       "S_BIAS_TX";
                S_LEAKYRELU_TX : state_param     =       "S_LEAKYRELU_TX";
                S_WEIGHT_TX    : state_param     =       "S_WEIGHT_TX   ";
                S_FEATURE_TX   : state_param     =       "S_FEATURE_TX  ";
                S_CONV_CAL     : state_param     =       "S_CONV_CAL    ";
                S_DMA_RX       : state_param     =       "S_DMA_RX      ";
                S_FINISH       : state_param     =       "S_FINISH      ";
        endcase
end 

endmodule

layer2_bias_tx

module  layer2_bias_tx(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        //
        output  wire    [63:0]  bias_data               ,
        output  reg             bias_valid              ,       
        output  wire            bias_last               ,       
        input                   ready                          
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/
localparam      INDEX_END       =       'd16                    ;


wire    [31:0]                  bias_arr[31:0]                  ;
reg     [ 4:0]                  index                           ;       


//=============================================================================
//**************    Main Code   **************
//=============================================================================
assign bias_arr[ 0] = 692; 
assign bias_arr[ 1] = 583; 
assign bias_arr[ 2] = 206; 
assign bias_arr[ 3] = 697; 
assign bias_arr[ 4] = 547; 
assign bias_arr[ 5] = 468; 
assign bias_arr[ 6] = 274; 
assign bias_arr[ 7] = 383; 
assign bias_arr[ 8] = 158; 
assign bias_arr[ 9] = 203; 
assign bias_arr[10] = 586; 
assign bias_arr[11] = 579; 
assign bias_arr[12] = 187; 
assign bias_arr[13] = 178; 
assign bias_arr[14] = 703; 
assign bias_arr[15] = 478; 
assign bias_arr[16] = -199; 
assign bias_arr[17] = 576; 
assign bias_arr[18] = 222; 
assign bias_arr[19] = 624; 
assign bias_arr[20] = 534; 
assign bias_arr[21] = 107; 
assign bias_arr[22] = 549; 
assign bias_arr[23] = 89; 
assign bias_arr[24] = 354; 
assign bias_arr[25] = 374; 
assign bias_arr[26] = 576; 
assign bias_arr[27] = 370; 
assign bias_arr[28] = 394; 
assign bias_arr[29] = 385; 
assign bias_arr[30] = 714; 
assign bias_arr[31] = 560; 


assign  bias_data       =       {bias_arr[1+index*2], bias_arr[index*2]};
assign  bias_last       =       (index == (INDEX_END-1)) ? 1'b1 : 1'b0;

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                bias_valid      <=      1'b0;
        else if(index < (INDEX_END-1))
                bias_valid      <=      1'b1;
        else
                bias_valid      <=      1'b0;
end

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                index   <=      'd0;
        else if(bias_valid == 1'b1 && ready == 1'b1 && index < INDEX_END)
                index   <=      index + 1'b1;
end


endmodule

layer2_feature_tx

module  layer2_feature_tx(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        //
        output  wire    [63:0]  feature_data            ,       
        output  reg             feature_valid           ,       
        output  wire            feature_last            ,       
        input                   ready                   ,
        //
        input           [ 6:0]  state                   ,
        input           [ 7:0]  batch_cnt               ,       
        input           [ 7:0]  tx_cnt                        
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/

localparam      INDEX_END       =       'd43264                 ;


reg     [ 7:0]                  ch0_data_arr[43263:0]           ;
reg     [ 7:0]                  ch1_data_arr[43263:0]           ;
reg     [ 7:0]                  ch2_data_arr[43263:0]           ;
reg     [ 7:0]                  ch3_data_arr[43263:0]           ;
reg     [ 7:0]                  ch4_data_arr[43263:0]           ;
reg     [ 7:0]                  ch5_data_arr[43263:0]           ;
reg     [ 7:0]                  ch6_data_arr[43263:0]           ;
reg     [ 7:0]                  ch7_data_arr[43263:0]           ;
reg     [ 7:0]                  ch8_data_arr[43263:0]           ;
reg     [ 7:0]                  ch9_data_arr[43263:0]           ;
reg     [ 7:0]                  ch10_data_arr[43263:0]          ;
reg     [ 7:0]                  ch11_data_arr[43263:0]          ;
reg     [ 7:0]                  ch12_data_arr[43263:0]          ;
reg     [ 7:0]                  ch13_data_arr[43263:0]          ;
reg     [ 7:0]                  ch14_data_arr[43263:0]          ;
reg     [ 7:0]                  ch15_data_arr[43263:0]          ;
reg     [11:0]                  index                           ;       

reg                             state_tx_r1                     ;
reg     [15:0]                  batch0_index                    ;
reg     [15:0]                  batch1_index                    ;       

reg     [15:0]                  data_cnt                        ;


//=============================================================================
//**************    Main Code   **************
//=============================================================================
always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                data_cnt        <=      'd0;
        else if(feature_last == 1'b1)
                data_cnt        <=      'd0;
        else if(feature_valid == 1'b1)
                data_cnt        <=      data_cnt + 1'b1;
end


always  @(posedge sclk) begin
        state_tx_r1     <=      state[4];
end

initial $readmemh("./txt/ch0_data.txt",  ch0_data_arr);
initial $readmemh("./txt/ch1_data.txt",  ch1_data_arr);
initial $readmemh("./txt/ch2_data.txt",  ch2_data_arr);
initial $readmemh("./txt/ch3_data.txt",  ch3_data_arr);
initial $readmemh("./txt/ch4_data.txt",  ch4_data_arr);
initial $readmemh("./txt/ch5_data.txt",  ch5_data_arr);
initial $readmemh("./txt/ch6_data.txt",  ch6_data_arr);
initial $readmemh("./txt/ch7_data.txt",  ch7_data_arr);
initial $readmemh("./txt/ch8_data.txt",  ch8_data_arr);
initial $readmemh("./txt/ch9_data.txt",  ch9_data_arr);
initial $readmemh("./txt/ch10_data.txt", ch10_data_arr);
initial $readmemh("./txt/ch11_data.txt", ch11_data_arr);
initial $readmemh("./txt/ch12_data.txt", ch12_data_arr);
initial $readmemh("./txt/ch13_data.txt", ch13_data_arr);
initial $readmemh("./txt/ch14_data.txt", ch14_data_arr);
initial $readmemh("./txt/ch15_data.txt", ch15_data_arr);

assign  feature_data  =      (batch_cnt == 'd0) ? 
                             {ch7_data_arr[batch0_index],
                              ch6_data_arr[batch0_index],
                              ch5_data_arr[batch0_index],
                              ch4_data_arr[batch0_index],
                              ch3_data_arr[batch0_index],
                              ch2_data_arr[batch0_index],
                              ch1_data_arr[batch0_index],
                              ch0_data_arr[batch0_index]} : 

                             {ch15_data_arr[batch1_index],
                              ch14_data_arr[batch1_index],
                              ch13_data_arr[batch1_index],
                              ch12_data_arr[batch1_index],
                              ch11_data_arr[batch1_index],
                              ch10_data_arr[batch1_index],
                              ch9_data_arr[batch1_index],
                              ch8_data_arr[batch1_index]};

assign  feature_last  =       (tx_cnt != 'd12 ) ? ((data_cnt == 19*208-1) ? 1'b1 : 1'b0) : 
                                                  ((data_cnt == 4*208-1) ? 1'b1 : 1'b0);

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                feature_valid   <=      1'b0;
        else if(feature_last == 1'b1)
                feature_valid   <=      1'b0;
        else if(state[4] == 1'b1 && state_tx_r1 == 1'b0)
                feature_valid   <=      1'b1;
end


always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                batch0_index    <=      'd0;
        else if(batch_cnt == 'd0 && feature_valid == 1'b1 && ready == 1'b1 && batch0_index == (INDEX_END-1))
                batch0_index    <=      'd0;
        else if(state[4] == 1'b1 && state_tx_r1 == 1'b0 && tx_cnt >= 'd1 && batch_cnt == 'd0)
                batch0_index    <=      batch0_index - 208*2;
        else if(batch_cnt == 'd0 && feature_valid == 1'b1 && ready == 1'b1) 
                batch0_index    <=      batch0_index + 1'b1;
end

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                batch1_index    <=      'd0;
        else if(batch_cnt == 'd1 && feature_valid == 1'b1 && ready == 1'b1 && batch1_index == (INDEX_END-1))
                batch1_index    <=      'd0;
        else if(state[4] == 1'b1 && state_tx_r1 == 1'b0 && tx_cnt >= 'd1 && batch_cnt == 'd0)
                batch1_index    <=      batch1_index - 208*2;
        else if(batch_cnt == 'd1 && feature_valid == 1'b1 && ready == 1'b1) 
                batch1_index    <=      batch1_index + 1'b1;
end

endmodule

layer2_leakyrelu_tx

module  layer2_leakyrelu_tx(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        //
        output  wire    [63:0]  leakyrelu_data          ,       
        output  reg             leakyrelu_valid         ,       
        output  wire            leakyrelu_last          ,       
        input                   ready                          
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/
localparam      INDEX_END       =       'd32                    ;


reg     [ 7:0]                  data_arr[255:0]                 ;
reg     [ 5:0]                  index                           ;       


//=============================================================================
//**************    Main Code   **************
//=============================================================================
initial $readmemh("./txt/layer2_leakyrelu.txt", data_arr);

assign  leakyrelu_data  =       {data_arr[7+index*8], data_arr[6+index*8],
                                 data_arr[5+index*8], data_arr[4+index*8],
                                 data_arr[3+index*8], data_arr[2+index*8],
                                 data_arr[1+index*8], data_arr[0+index*8]};
assign  leakyrelu_last  =       (index == (INDEX_END-1)) ? 1'b1 : 1'b0;

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                leakyrelu_valid      <=      1'b0;
        else if(index < (INDEX_END-1))
                leakyrelu_valid      <=      1'b1;
        else
                leakyrelu_valid      <=      1'b0;
end

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                index   <=      'd0;
        else if(leakyrelu_valid == 1'b1 && ready == 1'b1 && index < INDEX_END)
                index   <=      index + 1'b1;
end


endmodule

layer2_weight_tx

module  layer2_weight_tx(
        // system signals
        input                   sclk                    ,       
        input                   s_rst_n                 ,       
        //
        output  wire    [63:0]  weight_data             ,       
        output  reg             weight_valid            ,       
        output  wire            weight_last             ,       
        input                   ready                          
);

//========================================================================\
// =========== Define Parameter and Internal signals =========== 
//========================================================================/
localparam      INDEX_END       =       'd576                   ;


reg     [ 7:0]                  ch0_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch1_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch2_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch3_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch4_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch5_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch6_data_arr[INDEX_END-1:0]     ;
reg     [ 7:0]                  ch7_data_arr[INDEX_END-1:0]     ;
reg     [ 9:0]                  index                           ;       


//=============================================================================
//**************    Main Code   **************
//=============================================================================
initial $readmemh("./txt/layer2_weight_ch0.txt", ch0_data_arr);
initial $readmemh("./txt/layer2_weight_ch1.txt", ch1_data_arr);
initial $readmemh("./txt/layer2_weight_ch2.txt", ch2_data_arr);
initial $readmemh("./txt/layer2_weight_ch3.txt", ch3_data_arr);
initial $readmemh("./txt/layer2_weight_ch4.txt", ch4_data_arr);
initial $readmemh("./txt/layer2_weight_ch5.txt", ch5_data_arr);
initial $readmemh("./txt/layer2_weight_ch6.txt", ch6_data_arr);
initial $readmemh("./txt/layer2_weight_ch7.txt", ch7_data_arr);

assign  weight_data  =       {ch7_data_arr[index],
                              ch6_data_arr[index],
                              ch5_data_arr[index],  
                              ch4_data_arr[index],
                              ch3_data_arr[index],
                              ch2_data_arr[index],
                              ch1_data_arr[index],
                              ch0_data_arr[index]};

assign  weight_last  =       (index == (INDEX_END-1)) ? 1'b1 : 1'b0;

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                weight_valid      <=      1'b0;
        else if(index < (INDEX_END-1))
                weight_valid      <=      1'b1;
        else
                weight_valid      <=      1'b0;
end

always  @(posedge sclk or negedge s_rst_n) begin
        if(s_rst_n == 1'b0)
                index   <=      'd0;
        else if(weight_valid == 1'b1 && ready == 1'b1 && index < INDEX_END)
                index   <=      index + 1'b1;
end


endmodule