OneAPI 矩阵乘法实践

发布时间 2023-12-02 00:33:56作者: XZY3031

OneAPI 矩阵乘法

OneAPI 是一个由英特尔(Intel)推动的跨架构编程模型和开发工具的倡议。该倡议的目标是使开发人员能够在不同类型的处理器架构上编写性能高效的代码,包括 CPU、GPU、FPGA 等。OneAPI 的设计理念是实现统一的编程模型,以便开发人员能够更容易地利用异构计算资源,而不必为每个硬件架构单独编写不同的代码。

在SYCL编程模型中,parallel_for 函数用于指示SYCL运行时在指定的工作组内的可用工作项之间分发计算。每个工作组内的每个工作项都独立且并发地执行 lambda 表达式。

Work Group:

  • 概念: 将工作组视为并行的工作单元集合。工作组中的每个工作单元(工作项)都可以执行独立的计算。
  • 实际执行: 工作组内的工作项可以并行执行,每个工作项独立完成其任务。

parallel_for:

  • 概念: 类似于告诉整个工作组同时执行特定任务。工作组内的每个工作项执行其任务的一部分,但它们在整个工作组范围内协同工作。
  • 实际执行: 每个工作组内的工作项独立执行,并共同完成任务。

nd_range 和 work_group 在代码中的使用及参数解释:

  • nd_range: 定义了并行计算的全局和局部维度。它指定了全局工作项的范围和局部工作组的大小。
    • 在代码片段中使用了 nd_range<2>{{M, N}, {1, tile_size}},其中 {M, N} 定义了二维全局工作项范围,{1, tile_size} 指定了局部工作组的大小。
  • work_group: 是并行性的概念单元,将多个工作项组合在一起。它代表一组可以在本地上下文中协作和共享数据的工作项。
    • 工作组的大小在 nd_range 中定义。在提供的代码中,{1, tile_size} 指定了工作组在第二维上的大小。

全局 ND-range 和 局部 ND-range

  1. 全局 ND-range:
    • 全局 ND-range 表示整个问题的分布范围,即在全局范围内有多少个工作项。
    • 通常,全局 ND-range 用于划分任务的总体规模,例如,表示计算矩阵乘法的总大小,其中包含了所有需要计算的元素。
    • 在二维矩阵乘法的例子中,全局 ND-range 的第一个维度可能表示矩阵的行数,第二个维度表示列数。
  2. 局部 ND-range:
    • 局部 ND-range 表示每个工作组(work-group)的分布范围,即在每个工作组内有多少个工作项。
    • 工作组是由多个工作项组成的,这些工作项可以在同一设备上的计算单元(例如处理器核心)上并行执行。
    • 局部 ND-range 的设置通常与硬件架构和任务的性质有关。例如,可以选择一定数量的工作项组成一个工作组,以适应硬件上的并行性。
    • 局部 ND-range 通常用于定义每个工作组的工作项的分布,以便它们可以协同工作并共享局部内存

代码部分:

#include <chrono>
#include <iostream>
#include <sycl/sycl.hpp>

#define random_float() (rand() / double(RAND_MAX))

using namespace sycl;

const int tileX = 8;
const int tileY = 8;


double matrix_multiply_parallel(float *A, float *B, float *C, 
                  int M, int N, int K, 
                  int BLOCK, sycl::queue &q) {

  auto grid_rows = M / tileY;
  auto grid_cols = N / tileX;
  auto local_ndrange  = range<2>(BLOCK, BLOCK);
  auto global_ndrange = range<2>(grid_rows, grid_cols);

  double duration = 0.0f;

  auto e = q.submit([&](sycl::handler &h) {
      h.parallel_for(
          sycl::nd_range<2>(global_ndrange, local_ndrange), [=](sycl::nd_item<2> index) {

              int row = tileY * index.get_global_id(0);
              int col = tileX * index.get_global_id(1);

              float sum[tileY][tileX] = {0.0f};
              float subA[tileY] = {0.0f};
              float subB[tileX] = {0.0f};

              for (int k = 0; k < N; k++) {

                for(int m = 0; m < tileY; m++) {
                    subA[m] = A[(row + m) * N + k];
                } 

                for(int p = 0; p < tileX; p++) {
                    subB[p] = B[k * N + p + col];
                } 

                for (int m = 0; m < tileY; m++) {
                  for (int p = 0; p < tileX; p++) {
                    sum[m][p] += subA[m] * subB[p];
                  }
                }
              }

              for (int m = 0; m < tileY; m++) {
                for (int p = 0; p < tileX; p++) {
                  C[(row + m) * N + col + p] = sum[m][p];
                }
              }

          });
    });
    e.wait();

    duration += (e.get_profiling_info<info::event_profiling::command_end>() -
    e.get_profiling_info<info::event_profiling::command_start>()) /1000.0f/1000.0f;

    return(duration);
}


double matrix_multiply_normal(float *cA, float *cB, float *cC, int M, int N, int K) {
    
    double duration = 0.0;
    std::chrono::high_resolution_clock::time_point s, e;
    s = std::chrono::high_resolution_clock::now();
    for(int i = 0; i < M; i++) {
        for(int j = 0; j < N; j++) {
            float sum = 0.0f;
            for(int k = 0; k < K; k++) {
                sum +=  cA[i * K + k] * cB[k * N  + j];
            }
            cC[i * N + j] = sum;
        }
    }
    e = std::chrono::high_resolution_clock::now();
    duration = std::chrono::duration<float, std::milli>(e - s).count();
    return(duration);
}

int verify(float *normal_res, float *parallel_res, int length){
    int err = 0;
    for(int i = 0; i < length; i++) {
       if( fabs(normal_res[i] - parallel_res[i]) > 1e-3) {
          err++;
          printf("\n%lf, %lf, %d %lf", normal_res[i], parallel_res[i], i, fabs(normal_res[i]-parallel_res[i]));
       } 
    }
    return(err);
}

int gemm(const int M, 
         const int N, 
         const int K, 
         const int block_size,
         const int iterations, 
         sycl::queue &q) {

  std::cout << "Problem size: c(" << M << "," <<  N << ") ="
       << " a(" << M << "," << K << ") *" 
       << " b(" << K << "," << N << ")\n";

  auto A = malloc_shared<float>(M * K, q);
  auto B = malloc_shared<float>(K * N, q);
  auto C = malloc_shared<float>(M * N, q);
  auto C_host = malloc_host<float>(M * N, q);

  for(int i=0; i < M * K; i++) {
      A[i] = random_float();
  }

  for(int i=0; i < K * N; i++) {
      B[i] = random_float();
  }

  for(int i=0; i < M * N; i++) {
      C[i] = 0.0f;
      C_host[i] = 0.0f;
  }

  double flopsPerMatrixMul
      = 2.0 * static_cast<double>(M) * static_cast<double>(N) * static_cast<double>(K);

  double duration_parallel = 0.0f;
  double duration_normal = 0.0f;

  int warmup = 10;
  for (int run = 0; run < iterations + warmup; run++) {
    float duration = matrix_multiply_parallel(A, B, C, M, N, K, block_size, q);
    if(run >= warmup) duration_parallel += duration;
  }
  duration_parallel = duration_parallel / iterations;

  warmup = 2;
  for(int run = 0; run < iterations/2 + warmup; run++) {
      float duration = matrix_multiply_normal(A, B, C_host, M, N, K);
      if(run >= warmup) duration_normal += duration;
  }
  duration_normal = duration_normal / iterations/2;

  int errCode = 0;
  if(errCode > 0) printf("\nThere are %d errors\n", errCode);

  printf("\nGEMM size M = %d, N = %d, K = %d", M, N, K);
  printf("\nWork-Group size = %d * %d, tile_X = %d, tile_Y = %d", block_size, block_size, tileX, tileY);
  printf("\nPerformance Flops = %lf, \n" 
          "Parallel Computation Time = %lf (ms); \n"
          "Normal Computaiton Time = %lf (ms); \n"
          "Speedup = %lf\n", 
          flopsPerMatrixMul, duration_parallel, duration_normal, duration_normal/duration_parallel);

  free(A, q);
  free(B, q);
  free(C, q);
  free(C_host, q);

  return(errCode);
}

int main() {
    
  auto propList = sycl::property_list {sycl::property::queue::enable_profiling()};
  queue my_queue(default_selector_v , propList);

  int errCode = gemm(512, 512, 512, 
                     4,             
                     10,               
                     my_queue);

  return(errCode);
}

参考:

  1. https://github.com/pengzhao-intel/oneAPI_course/blob/main/code/gemm_basic.cpp
  2. Reinders 等 - 2021 - Data Parallel C++ Mastering DPC++ for Programming