[MLIR] 转换流程详解(以Toy接入为例)

发布时间 2023-04-23 20:56:50作者: 多一些不为什么的坚持

参考资料:

[MLIR] 转换流程详解(以Toy接入为例) - 知乎 (zhihu.com)

在本文中我们使用 toy 语言接入 MLIR,最终转化为 LLVM IR (或目标代码)为例,来讲解 MLIR 的转换流程。具体的流程如下:

.toy 源文件 → AST → MLIRGen(遍历AST生成MLIR表达式) → Transformation(变形消除冗余) → Lowering → LLVM IR / JIT 编译引擎

1. Toy接入MLIR

本节对应 Chapter 2: Emitting Basic MLIR - MLIR (llvm.org)

1.1 Toy源码和AST

def multiply_transpose(a, b){
    return transpose(a) * transpose(b);
}
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<2, 3> = [1, 2, 3, 4, 5, 6];
  var c = multiply_transpose(a, b);
  print(c);
}
原toy教程的第一节生成ast的指令如下:
cd llvm-project/build/bin
./toyc-ch1 ../../mlir/test/Examples/Toy/Ch1/ast.toy --emit=ast

编译得到的AST如下

Module:
  Function 
    Proto 'multiply_transpose' @test/Examples/Toy/Ch1/ast.toy:4:1'
    Params: [a, b]
    Block {
      Return
        BinOp: * @test/Examples/Toy/Ch1/ast.toy:5:25
          Call 'transpose' [ @test/Examples/Toy/Ch1/ast.toy:5:10
            var: a @test/Examples/Toy/Ch1/ast.toy:5:20
          ]
          Call 'transpose' [ @test/Examples/Toy/Ch1/ast.toy:5:25
            var: b @test/Examples/Toy/Ch1/ast.toy:5:35
          ]
    } // Block
    ... // main函数的ast未写出

1.2 生成(未优化)MLIR表达式

MLIRGen 模块会遍历 AST ,递归调用子函数,构建 operation。operation 是 dialect 中重要的组成元素,用来表示 dialect 中的某个操作,一个 dialect 中可以有很多的 operation。

mlir::Value mlirGen(CallExperAST &call)
{
    llvm::StringRef callee = call.getCallee();
    auto location = loc(call.loc()); 

    SmallVector<mlir::Value, 4> operands;
    for(auto &expr:call.getArgs()){
        auto arg = mlirGen(*expr); // 递归调用
        if(!arg)
            return nullptr;
        operands.push_back(arg);
    }

    if(callee == "transpose"){
        if(call.getArgs().size() != 1){
            emitError(location, "MLIR codegen encountered an error: toy.transpose does not accept multiple arguments");
            return nullptr;
        }
        return bulider.creater<TransposeOp>(location, operands[0]);
    }
    ...
}

创建好的节点 operation 还没有输入参数等定义,Toy Dialect 模块负责定义各种操作和分析。(Toy Dialect 继承自 mlir::Dialect,并注册了属性、操作和数据类型等)

 Toy Dialect 模块的创建 见 MLIR初识 —— Dialect及Operation详解的 "3. 创建新的dialect"

// TransposeOp
void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value value){
    state.addTypes(UnrankedTensorType::get(bulider.getF64Type()));
    state.addOperands(value);
}

根据 ast 中的节点,生成的一系列 operations 最终组成 MLIR 表达式。(去除了loc的信息)

原toy教程的第一节生成MLIR 表达式的指令如下:
cd llvm-project/build/bin
./toyc-ch2 ../../mlir/test/Examples/Toy/Ch2/codegen.toy -emit=mlir -mlir-print-debuginfo

# 由toy ast 生成 MLIR 表达式
module{
  func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
    %0 = "toy.transpose"(%arg0): (tensor<*xf64>) -> tensor<*xf64>
    %1 = "toy.transpose"(%arg1): (tensor<*xf64>) -> tensor<*xf64>
    %2 = "toy.mul"(%0, %1): (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
    "toy.return"(%2): (tensor<*xf64>) -> ()
  }
  func @main(){
    %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> 
    %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> 
    %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> 
    %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> 
    %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> 
    "toy.print"(%4) : (tensor<*x64>) -> ()
    "toy.return"() : () -> ()
  }
}

2. MLIR 表达式变形

本节的 2.1、2.2部分对应 Chapter 3: High-level Language-Specific Analysis and Transformation - MLIR (llvm.org)
本节的2.3部分对应 Chapter 4: Enabling Generic Transformation with Interfaces - MLIR (llvm.org) 

我们发现生成的 MLIR 表达式往往存在冗余的操作,为了提升程序性能就需要对表达式进行转换变形(Transformation 后的 MLIR表达式又可称为 Toy Dialect IR。)。MLIR 提供以下两种方式进行模式匹配转换:

其一,使用 C++ 手动编写代码进行表达式的匹配与重写

其二,使用基于规则的模式匹配和重写的声明式重写规则(DRR)进行,但该方法要求使用ODS定义操作。

2.1 手动编写代码进行表达式的匹配与重写

对于同一个变量,连续进行多次转置操作,必然存在冗余操作。本节以 "消除两个具有相互抵消效果的转置序列" 为例,说明第一种模式匹配转换方法。(Optimize Transpose using C++ style pattern-match and rewrite)

// toy 代码
def transpose_transpose(x) {
  return transpose(transpose(x));
}
// 未引入优化生成的 MLIR 表达式
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
  %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
  %1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
  toy.return %1 : tensor<*xf64>
}

1.第一步:直接使用 C++ 写出匹配和重写的代码

 下面这段代码位于在 ToyCombine.cpp 中,默认位置在 llvm-project/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp

// Fold transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
  // 匹配该IR中的所有 toy.transpose
  /// mlir使用"benefit"对patterns进行排序,并按profitability顺序处理
  SimplifyRedundantTranspose(mlir::MLIRContext *context)
      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}

  // 尝试匹配并重写
  mlir::LogicalResult
  matchAndRewrite(TransposeOp op,
                  mlir::PatternRewriter &rewriter) const override {
    // 获取当前Op(一个TransposeOp)的操作数
    mlir::Value transposeInput = op.getOperand();
    // 获取当前Op的操作数对应的Op
    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
    // 如果当前Op的操作数对应的Op不是Transpose,重写失败
    if (!transposeInputOp)
      return failure();

    // 反之,当前Op就是TransposeOp
    // transposeInputOp.getOperand()就是x
    rewriter.replaceOp(op, {transposeInputOp.getOperand()});
    return success();
  }
};

 

2. 第二步:将自定义的匹配和重写模式登记为 canonicalization 模式,使得后续可以使用它 

下面这段代码位于 toyc.cpp 中,默认位置为 llvm-project/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp

void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
  // SimplifyRedundantTranspose 就是第一步中定义的结构体(类)
  results.insert<SimplifyRedundantTranspose>(context);
}

 

3. 第三步:在Ops.td中设置相应选项

下面这段代码位于 Ops.td 中,默认位置为llvm-project/mlir/examples/toy/Ch3/include/toy/Ops.td

def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
  // MLIR 在优化代码时较为保守,可能会保留一些无效操作
  // 设置[NoSideEffect] 可解决这一问题
...
  // 确保启用规范化框架,应用 canonicalization pass
  let hasCanonicalizer = 1;
...
}

 

 4. 第四步:更新主文件以添加 optimization pipeline

 下面这段代码位于 toyc.cpp 中,默认位置在 llvm-project/mlir/examples/toy/Ch3/toyc.cpp

if (enableOpt) {// enableOpt 是从命令行输入的编译选项
  // 使用 PassManger 模块添加优化一道优化工艺
  mlir::PassManager pm(&context);
  applyPassManagerCLOptions(pm);
  // createCanonicalizerPass 创建并使用规范化框架
  pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
  // 运行定义好的 canonicalizer 来优化 MLIR 表达式
  if (mlir::failed(pm.run(*module)))
      return 4;
}

5. 最后执行 toyc-ch3 ../../test/Examples/Toy/Ch3/transpose_transpose.toy -emit=mlir -opt,得到优化后的 Toy Dialect IR (MLIR表达式)如下

toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
  %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
  toy.return %arg0 : tensor<*xf64>
}

 

2.2 采用 DDR 自动生成匹配和重写函数

目前这个地方我认为太过于复杂

2.3 通用的转换接口

目前这个地方我认为太过于复杂

3. Lowering 过程

本节的3.1部分对应 Chapter 5: Partial Lowering to Lower-Level Dialects for Optimization - MLIR (llvm.org)
本节的3.2部分对应 Chapter 6: Lowering to LLVM and CodeGeneration - MLIR

在编译器一系列转换程序的过程中,越来越多的高层次的简明信息被打散,转换为低层次的细碎指令,这个过程被称为代码表示递降 lowerinng ,与之相反的过程被称为代码表示递升raising 。raising远比lowering困难,因为需要在庞杂的细节中找出宏观脉络。

lowering 过程中越晚执行的转换越有结构劣势,因为缺乏高层次信息。

lowering 主要是为了更贴近硬件做代码生成和做硬件相关的优化。

每次转换遍历(pass) 都需要保持原子性,在其内部可能会临时违反源程序语义,但在每个转换遍历之后,中间表示应该是正确的。编译器依赖每个遍历之后的中间表示验证 (validation) 来保证正确性。 在保证转换的正确性之后,才可进行优化。

3.1 从 MLIR 表达式进行部分 Lowering

MLIR 中有许多不同的 Dialect,lowering 过程其实就是在各种 Dialect 之间转化,而 MLIR 提供了一套统一的 DialectConversion 框架来实现不同 Dialect 之间的转化。

 1. 要使用 DialectConversion 框架需要 Three Components(组件)

  1. ConversionTarget ConversionTarget是DialectConversion框架的一个重要组件,它定义了需要转换的dialect,以及dialect之间的转换规则。通常,ConversionTarget会在编译器的初始化阶段进行创建,并在整个编译过程中使用。它包含了Dialect之间的转换规则,以及一些用于验证和调试的工具,如类型检查器、断言等。
  2. RewritePattern RewritePattern是一个用于匹配和重写操作的规则。它定义了在源dialect中匹配的操作,以及在目标dialect中的重写规则。每个RewritePattern通常由一个PatternMatcher和一个PatternRewriter组成。PatternMatcher定义了匹配操作的模式,而PatternRewriter定义了将匹配的操作重写为目标dialect中的操作的规则。
  3. TypeConverter TypeConverter是DialectConversion框架的另一个重要组件,它负责将源dialect中的类型映射到目标dialect中的类型。TypeConverter通常包含一组类型转换规则,用于将源dialect中的类型转换为目标dialect中的类型。在类型转换过程中,TypeConverter还可以执行一些其他的操作,如创建新类型、插入类型转换指令等。 这三个组件共同构成了DialectConversion框架,可以用于将不同dialect之间的操作进行转换。使用DialectConversion框架可以使编译器更加灵活,能够处理各种类型的dialect,并支持自定义的dialect之间的转换规则。

2. DialectConversion 框架的转换有 Tow Modes
(1)Partial: Not all input operations have to be legalized to the target 当前 Dialect 中某些 operation 在 lowering 中先进行保留(保留部分之前的信息)
(2)Full: All input operations have to be legalized to the target 当前 Dialect 中全部 operation 在 lowering 中全部去除(类似转换到 LLVM IR)

本节标题的部分lowering 意味着:从一个高抽象级别的 Dialect 到一个低抽象级别的 Dialect 过程中,可以只 lowering 其中一部分 operation,剩下的 operation 只需要升级与其他 operation 共存。现在以对 transformation 后的 MLIR 表达式进行 lowering为例:

// toy 源码
def multiply_transpose(a, b){
    return transpose(a) * transpose(b);
}
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<2, 3> = [1, 2, 3, 4, 5, 6];
  var c = multiply_transpose(a, b);
  print(c);
}
// transformation 后的 MLIR 表达式
toy.func @main() {
  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
  %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
  %2 = toy.mul %1, %1 : tensor<3x2xf64>
  toy.print %2 : tensor<3x2xf64>
  toy.return
}

 

 后面的太复杂,稍后看.....

3.2 混合 Dialect 表达式 Lowering 到 LLVM IR

  后面的太复杂,稍后看.....

总结

 本文介绍的 Toy 接入 MLIR 流程本质上还是高级语言的转换流程,但目前 MLIR 在人工智能领域应用较热,二者的转换前端区别较大,一个是抽象语法树(AST),一个是计算图IR(Computation Graph IR)。下图是以 Tensorflow 为例的转换流程。具体的流程为可参考 Codegen Dialect Overview - MLIR - LLVM Discussion Forums

 

 

 

 

 

 

 

 

 

 

 

1111111111111111