A NAIVE CON2D CPU IMPL

发布时间 2023-04-07 14:34:31作者: ijpq
/******************************************************************************

Welcome to GDB Online.
GDB online is an online compiler and debugger tool for C, C++, Python, Java, PHP, Ruby, Perl,
C#, OCaml, VB, Swift, Pascal, Fortran, Haskell, Objective-C, Assembly, HTML, CSS, JS, SQLite, Prolog.
Code, Compile, Run and Debug online from anywhere in world.

*******************************************************************************/
#include <iostream>
#include <cstdlib>

using namespace std;

size_t N = 1, IC = 3, IH = 2, IW = 2;
size_t OC = 2, FH = 2, FW = 2;
size_t stride = 1, pad = 1;

size_t OH = (IH + 2 * pad - FH) / stride + 1;
size_t OW = OH;


void assign(float *ptr, int a, int b, int c, int d) {
    int v = 1;
    for (int i = 0; i < a; i++) {
        for (int j = 0; j < b; j++) {
            for (int p = 0; p < c; p++) {
                for (int q = 0; q < d; q++) {
                    *(ptr + i*(b*c*d) + j*(c*d) + p *(d) + q) = (v)%4;
                    v++;
                }
            }
        }
    }
}

void printout(float *ptr, int a, int b, int c, int d) {
    for (int i = 0; i < a; i++) {
        for (int j = 0; j < b; j++) {
            for (int p = 0; p < c; p++) {
                for (int q = 0; q < d; q++) {
                    printf("%.2f ", *(ptr+i*b*c*d+j*c*d+p*d+q));
                }
                printf("\n");
            }
        }
    }
    printf("\n");
}

// NxICxIHxIW, OCxICxFHxFW, NxOCxOHxOW
void
direct_conv (float *feat, float *weight, float *out)
{
    for (int n = 0;  n < N ; n++) {
        for (int oc =0; oc < OC; oc++) {
            for (int oh = 0 ; oh < OH; oh++) {
                for (int ow = 0;ow < OW; ow++) {
                    float value = 0;
                    auto ih_start = oh*stride-pad;
                    auto iw_start = ow*stride-pad;
                    for (int ic = 0; ic < IC; ic++) {
                        for (int fh = 0; fh < FH; fh++) {
                            for (int fw = 0; fw <  FW; fw++) {
                                
                                float f = 0;
                                if (ih_start+fh >= 0 && iw_start +fw>=0 && ih_start+fh < IH && iw_start+fw < IW) {
                                    f = *(feat + n *(IC*(IH)*(IW)) + ic*((IH) * (IW)) + (ih_start + fh)*(IW) + iw_start+fw);
                                }
                                auto w = *(weight + oc*(IC*FH*FW) +ic*(FH*FW) + fh*FW+fw);
                                value += f*w;
                                printf("feat@{%d,%d,%d,%d}, val:%f\n", n,ic,ih_start+fh,iw_start+fw,f);
                                printf("weight@{%d,%d,%d,%d}, val:%f\n", oc,ic,fh,fw,w);
                            }
                        }
                    }
                    *(out + n*(OC*OH*OW) + oc*(OH*OW) + oh*OW + ow) = value;
                    
                }
            }
        }
    }
}

int
main ()
{

  float *feat = (float *) malloc (sizeof (float) * N * IC * (IH) * (IW));
  assign(feat,N,IC,IH,IW);
  printf("feat\n");
  printout(feat,N,IC,IH,IW);
  printf("\n");
  printf("weight\n");
  float *weight = (float *) malloc (sizeof (float) * OC * IC * FH * FW);
  assign(weight, OC,IC,FH,FW);
  printout(weight,OC,IC,FH,FW);
  printf("\n");
  float *out = (float *) malloc (sizeof (float) * N * OC * OH * OW);
  direct_conv (feat, weight, out);
  printf("out\n");
  printout(out,N,OC,OH,OW);
  printf("\n");
  return 0;
}