Leon's Blog

分享一点有趣的技术

0%

Transform Dialect tutorial1

image-20250408184457571

Transform dialect是mlir基础设施实现的调度dsl。通过在同一个mlir文件中,使用标准ir定义计算任务(payload ir),利用transform ir定义调度方式(schedule ir),最后借助于transform-interpreter注册和使用整个pass。

Motivation of Transform Dialect

总结许多成熟编译器,比如Halide,TVM,TC等,可以发现如下规律:

  1. 调度表示(Schedule Representation):以结构化数据描述优化流程的元信息集合
  2. 声明式规范(Declarative Specification):通过定义预期目标状态而非具体操作步骤进行配置
  3. 多版本化(Multi-Versioning):针对不同硬件/场景生成多个优化方案分支
  4. 运行时调度(Runtime Dispatch):通过动态决策机制选择最优版本执行
  5. 垂直时序控制(Vertical Sequencing):在单一功能域内进行深度优化组合

MLIR基础设施中,也希望提供用户类似TVM的计算调度分离的能力,从而方便用户对于一个计算任务自定义调度优化策略。在MLIR中,一切都是op操作,调度操作也是op,封装在transform这个dialect中,这便是transform dialect的由来。

后续教程主要来源于transform tutorial

Chapter1:利用transform op组建pipeline

transform dialect内置丰富的schedule op,计算调度分离机制使得用户需要定义payload ir以及schedule ir。在transform的官方tutorial中给的例子,是将一个矩阵乘-逐项加-relu操作,利用transform op做调度优化。

  • payload ir(详细定义计算任务):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    // Original function to optimize.
    func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
    -> tensor<512x512xf32> {
    // Matrix-matrix multiplication.
    %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
    outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>

    // Elementwise addition.
    %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

    // Elementwise max with 0 (ReLU).
    %c0f = arith.constant 0.0 : f32
    %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
    ins(%biased, %c0f : tensor<512x512xf32>, f32)
    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
    func.return %relued : tensor<512x512xf32>
    }

    payload ir的定义就使用linalg等标准mlir ir定义即可。

  • schedule ir(调度原语)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    module attributes {transform.with_named_sequence} {  
    transform.named_sequence @__transform_main(
    %arg0: !transform.any_op,
    %arg1: !transform.op<"linalg.matmul">,
    %arg2: !transform.op<"linalg.elemwise_binary">) {
    // Since the %arg2 handle is associated with both elementwise operations,
    // we need to split it into two handles so we can target only the second
    // elementwise operation.
    %add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">)
    -> (!transform.any_op, !transform.any_op)

    // The actual tiling transformation takes tile sizes as attributes. It produces a
    // handle to the loop generated during tiling.
    %tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32]
    : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

    // We can now fuse the other operations into the loop. Here, we fuse
    // operations one-by-one. This requires the operation that is being fused
    // to define the value used within the loop, so the order of such fusions
    // is important. We could also use "transform.merge_handles" to obtain
    // a single handle to all operations and give it to `fuse_into_containing_op`
    // that would take care of the ordering in this case
    %add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop
    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
    %matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2
    : (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op)

    // Tile again to get the desired size. Note that this time this tiles the
    // "add" operation and fuses matmul into the loop, but doesn't affect the
    // "max" operation. This illustrates the precise targeting with the transform
    // dialect. Otherwise, it is difficult to differentiate "add" and "max", both
    // of which having the same kind.
    %tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
    : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    %matmul_fused_2, %loop_second_2 =
    transform.structured.fuse_into_containing_op %matmul_fused into %loop_second
    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

    // // Since outlining is currently only implemented for region-holding operations
    // // such as loops, use tiling to size 1 to materialize the outer loop that is
    // // going to be outlined.
    // %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1]
    // : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    // %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third
    // : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
    // %func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
    // : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">)

    transform.yield
    }
    }

完整代码片段如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// RUN: mlir-opt %s \
// RUN: --pass-pipeline="builtin.module(transform-interpreter{ \
// RUN: debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary},\
// RUN: canonicalize,cse,symbol-dce)" |\
// RUN: FileCheck %s

// ****************************** IMPORTANT NOTE ******************************
//
// If you are changing this file, you may also need to change
// mlir/docs/Tutorials/Transform accordingly.
//
// ****************************************************************************

// Original function to optimize.
func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
-> tensor<512x512xf32> {
// Matrix-matrix multiplication.
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>

// Elementwise addition.
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>

// Elementwise max with 0 (ReLU).
%c0f = arith.constant 0.0 : f32
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
ins(%biased, %c0f : tensor<512x512xf32>, f32)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
func.return %relued : tensor<512x512xf32>
}

// CHECK: func @outlined
// CHECK: linalg.matmul
// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<add>}

// CHECK-LABEL: func @fc_relu
// CHECK: scf.forall
// CHECK: scf.forall
// CHECK: %[[SLICE4:.+]] = tensor.extract_slice
// CHECK: %[[SLICE5:.+]] = tensor.extract_slice
// CHECK: %[[SLICE6:.+]] = tensor.extract_slice
// CHECK: %[[SLICE7:.+]] = tensor.extract_slice
// CHECK: %[[SLICE8:.+]] = tensor.extract_slice
// CHECK: func.call @outlined(%[[SLICE4]], %[[SLICE5]], %[[SLICE6]], %[[SLICE7]], %[[SLICE8]])
// CHECK-NOT: linalg.matmul
// CHECK-NOT: linalg.elemwise_binary
// CHECK: scf.forall.in_parallel
// CHECK: linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>}
// CHECK: scf.forall.in_parallel

// Declaration of the "microkernel" function that we will be targeting.
func.func private @microkernel(
%lhs: tensor<4x512xf32>,
%rhs: tensor<512x4xf32>,
%bias: tensor<4x4xf32>,
%init: tensor<4x4xf32>,
%output: tensor<4x4xf32>) -> tensor<4x4xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">)
-> (!transform.any_op, !transform.any_op)

// The actual tiling transformation takes tile sizes as attributes. It produces a
// handle to the loop generated during tiling.
%tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)

// We can now fuse the other operations into the loop. Here, we fuse
// operations one-by-one. This requires the operation that is being fused
// to define the value used within the loop, so the order of such fusions
// is important. We could also use "transform.merge_handles" to obtain
// a single handle to all operations and give it to `fuse_into_containing_op`
// that would take care of the ordering in this case
%add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2
: (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op)

// Tile again to get the desired size. Note that this time this tiles the
// "add" operation and fuses matmul into the loop, but doesn't affect the
// "max" operation. This illustrates the precise targeting with the transform
// dialect. Otherwise, it is difficult to differentiate "add" and "max", both
// of which having the same kind.
%tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused_2, %loop_second_2 =
transform.structured.fuse_into_containing_op %matmul_fused into %loop_second
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

// // Since outlining is currently only implemented for region-holding operations
// // such as loops, use tiling to size 1 to materialize the outer loop that is
// // going to be outlined.
// %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1]
// : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third
// : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// %func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
// : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">)

transform.yield
}
}

Chapter2:自定义一个简单的transform operation

目的

很多时候,transform dialect中的operation无法满足我们调度优化需求。比如在chapter1中,我们最后tiling的小块可以替换成手写的算子microkernel,这个替换操作在transform中没有直接的op。可以实现自定义operation(transform.my.change_call_target),使用方法如下:

1
2
// Rewrite the call target.
transform.my.change_call_target %call, "microkernel" : !transform.any_op

即将%call这个handle(一个operation)转变成microkernel调用。

在实际编译开发中,microkernel相当于微内核算子,一般为手写,编译器可以自动发现可以替换的op做替换,以进一步提高性能。

项目结构

transform extension的写法,和mlir的一个standalone项目的结构是类似的。项目结构可以参考mlir standalone项目框架

如下是chapter2的项目结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
├── build.sh
├── CMakeLists.txt
├── include
│ ├── CMakeLists.txt
│ ├── MyExtension.h
│ └── MyExtension.td
├── lib
│ ├── CMakeLists.txt
│ └── MyExtension.cpp
├── test
│ ├── invalid.mlir
│ ├── sequence.mlir
│ └── test.sh
└── transform-opt
├── CMakeLists.txt
└── transform-opt.cpp

为transform dialect定义一个operation,主要需要思考如下几个点:

  • 利用ods系统定义一个operation,自动生成.h和.cpp文件。
  • 该operation需要重载几个interface。
    • TransformOpInterface是transform dialect的op必须实现的interface,主要实现apply方法。
    • MemoryEffectsOpInterface是内存side effect定义interface,需要实现getEffects方法。
  • 注册operation。

代码框架讲解

后续按照这个顺序来罗列代码,代码注释十分全面了。

定义operation

/include/MyExtension.td

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines Transform dialect extension operations used in the
// Chapter 2 of the Transform dialect tutorial.
//
//===----------------------------------------------------------------------===//

#ifndef MY_EXTENSION
#define MY_EXTENSION

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Define the new operation. By convention, prefix its name with the name of the dialect
// extension, "my.". The full operation name will be further prefixed with "transform.".
def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target",
// Indicate that the operation implements the required TransformOpInterface and
// MemoryEffectsOpInterface.
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
// Provide a brief and a full description. It is recommended that the latter describes
// the effects on the operands and how the operation processes various failure modes.
let summary = "Changes the callee of a call operation to the specified one";
let description = [{
For each `func.call` payload operation associated with the handle, changes its
callee to be the symbol whose name is provided as an attribute to this operation.

Generates a silenceable failure if the operand is associated with payload operations
that are not `func.call`.
Only reads the operand.
}];

// The arguments include the handle to the payload operations and the attribute that
// specifies the new callee. The handle must implement TransformHandleTypeInterface.
// We use a string attribute as the symbol may not exist in the transform IR so the
// verification may fail.
let arguments = (ins
TransformHandleTypeInterface:$call,
StrAttr:$new_target);

// The results are empty as the transformation does not produce any new payload.
let results = (outs);

// Provide nice syntax.
let assemblyFormat = "$call `,` $new_target attr-dict `:` type($call)";
}

#endif // MY_EXTENSION

上述td定义和mlir op定义是一样的,需要注意如下几点:

  • 定义interface:

    1
    2
    [DeclareOpInterfaceMethods<TransformOpInterface>,
    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]
  • operation的定义:

    1
    2
    3
    4
    5
    6
    7
    // The arguments include the handle to the payload operations and the attribute that 
    // specifies the new callee. The handle must implement TransformHandleTypeInterface.
    // We use a string attribute as the symbol may not exist in the transform IR so the
    // verification may fail.
    let arguments = (ins
    TransformHandleTypeInterface:$call,
    StrAttr:$new_target);

    重点是argument中需要传入一个handle,该handle必须实现TransformHandleTypeInterface,即这个handle可以通过transform op来操纵。

MyExtension.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines Transform dialect extension operations used in the
// Chapter 2 of the Transform dialect tutorial.
//
//===----------------------------------------------------------------------===//

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"

#define GET_OP_CLASSES
#include "MyExtension.h.inc"

// Registers our Transform dialect extension.
void registerMyExtension(::mlir::DialectRegistry &registry);

这个头文件和mlir基础操作一样,利用宏定义获取.h.inc中的op定义,并定义一个extension的注册函数。

CmakeLists.txt如下:

1
2
3
4
5
6
7
8
9
# Tell Tablegen to use MyExtension.td as input.
set(LLVM_TARGET_DEFINITIONS MyExtension.td)

# Ask Tablegen to generate op declarations and definitions from ODS.
mlir_tablegen(MyExtension.h.inc -gen-op-decls)
mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)

# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation.
add_public_tablegen_target(MyExtensionCh2IncGen)

实现operation需要重载的interface操作

/lib/MyExtension.cpp中的代码可以做拆分

  • 定义transform dialect的extension:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    // Define a new transform dialect extension. This uses the CRTP idiom to
    // identify extensions.
    class MyExtension
    : public ::mlir::transform::TransformDialectExtension<MyExtension> {
    public:
    // The TypeID of this extension.
    MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)

    // The extension must derive the base constructor.
    using Base::Base;

    // This function initializes the extension, similarly to `initialize` in
    // dialect definitions. List individual operations and dependent dialects
    // here.
    void init();
    };
  • extension的初始化:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    void MyExtension::init() {
    // Similarly to dialects, an extension can declare a dependent dialect. This
    // dialect will be loaded along with the extension and, therefore, along with
    // the Transform dialect. Only declare as dependent the dialects that contain
    // the attributes or types used by transform operations. Do NOT declare as
    // dependent the dialects produced during the transformation.
    // declareDependentDialect<MyDialect>();

    // When transformations are applied, they may produce new operations from
    // previously unloaded dialects. Typically, a pass would need to declare
    // itself dependent on the dialects containing such new operations. To avoid
    // confusion with the dialects the extension itself depends on, the Transform
    // dialects differentiates between:
    // - dependent dialects, which are used by the transform operations, and
    // - generated dialects, which contain the entities (attributes, operations,
    // types) that may be produced by applying the transformation even when
    // not present in the original payload IR.
    // In the following chapter, we will be add operations that generate function
    // calls and structured control flow operations, so let's declare the
    // corresponding dialects as generated.
    declareGeneratedDialect<::mlir::scf::SCFDialect>();
    declareGeneratedDialect<::mlir::func::FuncDialect>();

    // Finally, we register the additional transform operations with the dialect.
    // List all operations generated from ODS. This call will perform additional
    // checks that the operations implement the transform and memory effect
    // interfaces required by the dialect interpreter and assert if they do not.
    registerTransformOps<
    #define GET_OP_LIST
    #include "MyExtension.cpp.inc"
    >();
    }
    • 上述的一个重点是declareDependentDialectdeclareGeneratedDialect的区别。dependent dialects 是 transform op 需要的,generated dialects 是 transform 执行后生成的。这里scf和func均是可能生成的dialect。

    • 另一个重点是,registerTransformOps会注册ods定义的operation,同时会有检测机制check是否完成transform和sideeffect的interface。

  • 对特定op做interface重写:

    • transform interface的apply方法重写:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      ::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
      // The rewriter that should be used when modifying IR.
      ::mlir::transform::TransformRewriter &rewriter,
      // The list of payload IR entities that will be associated with the
      // transform IR values defined by this transform operation. In this case, it
      // can remain empty as there are no results.
      ::mlir::transform::TransformResults &results,
      // The transform application state. This object can be used to query the
      // current associations between transform IR values and payload IR entities.
      // It can also carry additional user-defined state.
      ::mlir::transform::TransformState &state) {

      // First, we need to obtain the list of payload operations that are associated
      // with the operand handle.
      auto payload = state.getPayloadOps(getCall());

      // Then, we iterate over the list of operands and call the actual IR-mutating
      // function. We also check the preconditions here.
      for (Operation *payloadOp : payload) {
      auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
      if (!call) {
      DiagnosedSilenceableFailure diag =
      emitSilenceableError() << "only applies to func.call payloads";
      diag.attachNote(payloadOp->getLoc()) << "offending payload";
      return diag;
      }

      updateCallee(call, getNewTarget());
      }

      // If everything went well, return success.
      return DiagnosedSilenceableFailure::success();
      }
    • side effect的interface重写:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      void mlir::transform::ChangeCallTargetOp::getEffects(
      ::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
      // Indicate that the `call` handle is only read by this operation because the
      // associated operation is not erased but rather modified in-place, so the
      // reference to it remains valid.
      onlyReadsHandle(getCallMutable(), effects);

      // Indicate that the payload is modified by this operation.
      modifiesPayload(effects);
      }
  • 注册整个myextension:

    1
    2
    3
    void registerMyExtension(::mlir::DialectRegistry &registry) {
    registry.addExtensions<MyExtension>();
    }

CmakeLists.txt如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Outside examples, this should be `add_mlir_library`.
add_mlir_dialect_library(
# Library called MyExtension.
MyExtensionCh2

# Built from the following source files.
MyExtension.cpp

# Make includes visible without top-level path.
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include

# Make sure ODS declaration and definitions are generated before compiling this.
DEPENDS
MyExtensionCh2IncGen

# Link in the transform dialect, an all generated dialects.
LINK_LIBS PRIVATE
MLIRTransformDialect
MLIRFuncDialect
MLIRSCFDialect
)

transform-opt 驱动编写

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the top-level file for the Transform dialect tutorial chapter 2.
//
//===----------------------------------------------------------------------===//

#include "MyExtension.h"

#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include <cstdlib>

namespace test {
void registerTestTransformDialectExtension(mlir::DialectRegistry &);
} // namespace test

int main(int argc, char **argv) {
// Register all "core" dialects and our transform dialect extension.
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::registerAllExtensions(registry);
registerMyExtension(registry);

// Register transform interpreter pass.
mlir::transform::registerInterpreterPass();

// Register a handful of cleanup passes that we can run to make the output IR
// look nicer.
mlir::registerCanonicalizerPass();
mlir::registerCSEPass();
mlir::registerSymbolDCEPass();

// Delegate to the MLIR utility for parsing and pass management.
return mlir::MlirOptMain(argc, argv, "transform-opt-ch2", registry)
.succeeded()
? EXIT_SUCCESS
: EXIT_FAILURE;
}

CmakeLists.txt如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(LIBS
${dialect_libs}
${conversion_libs}
MLIRIR
MLIRMlirOptMain
MLIRSideEffectInterfaces
MyExtensionCh2
)
add_llvm_executable(transform-opt transform-opt.cpp)

target_link_libraries(transform-opt PRIVATE ${LIBS})

这部分代码和mlir代码没有区别,注册我们的myextension,然后注册各种pass,重点是注册registerInterpreterPass这个pass

顶层CmakeLists.txt如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
cmake_minimum_required(VERSION 3.20.0)
project(standalone-dialect LANGUAGES CXX C)

set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)

set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to")

if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
find_package(MLIR REQUIRED CONFIG)

message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")

set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")

include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(HandleLLVMOptions)
else()
# Build via external projects mechanism
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif()

if(MLIR_ENABLE_BINDINGS_PYTHON)
include(MLIRDetectPythonEnv)
mlir_configure_python_dev_packages()
endif()

set(STANDALONE_SOURCE_DIR ${PROJECT_SOURCE_DIR})
set(STANDALONE_BINARY_DIR ${PROJECT_BINARY_DIR})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${STANDALONE_SOURCE_DIR}/include)
include_directories(${STANDALONE_BINARY_DIR}/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(transform-opt)

Transform extension的核心实现

transform extension的核心操作,就是实现transform.my.change_call_target op的apply方法和getEffect方法,这两个方法决定了tranform之后生成的代码等。

1
2
3
4
5
6
7
8
9
10
void mlir::transform::ChangeCallTargetOp::getEffects(
::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
// Indicate that the `call` handle is only read by this operation because the
// associated operation is not erased but rather modified in-place, so the
// reference to it remains valid.
onlyReadsHandle(getCallMutable(), effects);

// Indicate that the payload is modified by this operation.
modifiesPayload(effects);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Implementation of our transform dialect operation.
// This operation returns a tri-state result that can be one of:
// - success when the transformation succeeded;
// - definite failure when the transformation failed in such a way that
// following transformations are impossible or undesirable, typically it could
// have left payload IR in an invalid state; it is expected that a diagnostic
// is emitted immediately before returning the definite error;
// - silenceable failure when the transformation failed but following
// transformations are still applicable, typically this means a precondition
// for the transformation is not satisfied and the payload IR has not been
// modified. The silenceable failure additionally carries a Diagnostic that
// can be emitted to the user.
::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
// The rewriter that should be used when modifying IR.
::mlir::transform::TransformRewriter &rewriter,
// The list of payload IR entities that will be associated with the
// transform IR values defined by this transform operation. In this case, it
// can remain empty as there are no results.
::mlir::transform::TransformResults &results,
// The transform application state. This object can be used to query the
// current associations between transform IR values and payload IR entities.
// It can also carry additional user-defined state.
::mlir::transform::TransformState &state) {

// First, we need to obtain the list of payload operations that are associated
// with the operand handle.
auto payload = state.getPayloadOps(getCall());

// Then, we iterate over the list of operands and call the actual IR-mutating
// function. We also check the preconditions here.
for (Operation *payloadOp : payload) {
auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
if (!call) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "only applies to func.call payloads";
diag.attachNote(payloadOp->getLoc()) << "offending payload";
return diag;
}

updateCallee(call, getNewTarget());
}

// If everything went well, return success.
return DiagnosedSilenceableFailure::success();
}
1
2
3
static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
call.setCallee(newTarget);
}

上述两段代码注释已经非常详尽了,再次不多赘述。

Chapter3:实现更加复杂的transform operation

这一部分的tutorial完成两件事:

  • 给chapter2实现的transform operation针对的payload handle添加constraint trait,通过使用trait的方式简化遍历匹配func::call这一流程。
  • 添加一个新的op,复习整个transform的流程。实现一个callOpInterface,使得handle不局限于func.call,而是只要实现了callOpInterface的op均可。

Chapter4:利用transform op来做payload的匹配

参考资料

  1. transform dialect tutorial talk - EuroLLVM