机器学习编译(二):张量程序抽象

发布时间 2023-07-22 12:22:51作者: machine_gun_lin

元张量函数 (primitive tensor function)

一个模型的执行包含 tensor 和 primitive tensor function,后者是定义 tensor 之间的计算步骤的函数(通常也叫 op,不过这里的范围更广,还包括 Module 等)。

../_images/primitive_tensor_func.png

上面的 linear、add、relu、softmax 都是元张量函数。

框架通常都会实现常见的计算操作,这些操作通常也会有 C++ 实现(更加高效):

../_images/tensor_func_abstractions.png

机器学习框架对模型编译的时候会把一些计算操作做一些优化,比如下面的例子,对于 0 ~ 127 的计算可以由单核处理器完成,但是考虑到计算之间的独立性,可以由 4 核处理器并行计算,这样大概能得到 4x 的计算速度的提升:

../_images/tensor_func_transformation.png

张量程序抽象

为了更有效地变换元张量函数,需要对这些函数做一些抽象,这些抽象包括:

  • 存储数据的多维数组:存储输入、输出、中间结果
  • 驱动张量计算的循环嵌套
  • 计算部分本身的语句:在循环的对应位置执行相应的计算

../_images/tensor_func_elements.png

使用张量程序抽象,可以把下面的左图变为右图:

../_images/tensor_func_seq_transform.png

这里只需要明确 tensor、循环、计算操作,就从一个单核的计算操作抽象成了等价但更高效的 4 核并行计算操作。