
Transform dialect是mlir基础设施实现的调度dsl。通过在同一个mlir文件中,使用标准ir定义计算任务(payload ir),利用transform ir定义调度方式(schedule ir),最后借助于transform-interpreter注册和使用整个pass。
Motivation of Transform Dialect
总结许多成熟编译器,比如Halide,TVM,TC等,可以发现如下规律:
- 调度表示(Schedule Representation):以结构化数据描述优化流程的元信息集合
- 声明式规范(Declarative Specification):通过定义预期目标状态而非具体操作步骤进行配置
- 多版本化(Multi-Versioning):针对不同硬件/场景生成多个优化方案分支
- 运行时调度(Runtime Dispatch):通过动态决策机制选择最优版本执行
- 垂直时序控制(Vertical Sequencing):在单一功能域内进行深度优化组合
MLIR基础设施中,也希望提供用户类似TVM的计算调度分离的能力,从而方便用户对于一个计算任务自定义调度优化策略。在MLIR中,一切都是op操作,调度操作也是op,封装在transform这个dialect中,这便是transform dialect的由来。
后续教程主要来源于transform tutorial
Chapter1:利用transform op组建pipeline
transform dialect内置丰富的schedule op,计算调度分离机制使得用户需要定义payload ir以及schedule ir。在transform的官方tutorial中给的例子,是将一个矩阵乘-逐项加-relu操作,利用transform op做调度优化。
payload ir(详细定义计算任务):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20// Original function to optimize.
func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
-> tensor<512x512xf32> {
// Matrix-matrix multiplication.
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
// Elementwise addition.
%biased = linalg.elemwise_binary { fun =
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
// Elementwise max with 0 (ReLU).
%c0f = arith.constant 0.0 : f32
%relued = linalg.elemwise_binary { fun =
ins(%biased, %c0f : tensor<512x512xf32>, f32)
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
func.return %relued : tensor<512x512xf32>
}payload ir的定义就使用linalg等标准mlir ir定义即可。
schedule ir(调度原语)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">)
-> (!transform.any_op, !transform.any_op)
// The actual tiling transformation takes tile sizes as attributes. It produces a
// handle to the loop generated during tiling.
%tiled, %loop = transform.structured.tile_using_forall %max tile_sizes [8, 32]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// We can now fuse the other operations into the loop. Here, we fuse
// operations one-by-one. This requires the operation that is being fused
// to define the value used within the loop, so the order of such fusions
// is important. We could also use "transform.merge_handles" to obtain
// a single handle to all operations and give it to `fuse_into_containing_op`
// that would take care of the ordering in this case
%add_fused, %loop2 = transform.structured.fuse_into_containing_op %add into %loop
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused, %loop3 = transform.structured.fuse_into_containing_op %arg1 into %loop2
: (!transform.op<"linalg.matmul">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// Tile again to get the desired size. Note that this time this tiles the
// "add" operation and fuses matmul into the loop, but doesn't affect the
// "max" operation. This illustrates the precise targeting with the transform
// dialect. Otherwise, it is difficult to differentiate "add" and "max", both
// of which having the same kind.
%tiled_second, %loop_second = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused_2, %loop_second_2 =
transform.structured.fuse_into_containing_op %matmul_fused into %loop_second
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// // Since outlining is currently only implemented for region-holding operations
// // such as loops, use tiling to size 1 to materialize the outer loop that is
// // going to be outlined.
// %_0, %loop_third = transform.structured.tile_using_forall %tiled_second tile_sizes [1]
// : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// %_1, %outline_target = transform.structured.fuse_into_containing_op %matmul_fused_2 into %loop_third
// : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
// %func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
// : (!transform.any_op) -> (!transform.any_op, !transform.op<"func.call">)
transform.yield
}
}
完整代码片段如下。
1 | // RUN: mlir-opt %s \ |
Chapter2:自定义一个简单的transform operation
目的
很多时候,transform dialect中的operation无法满足我们调度优化需求。比如在chapter1中,我们最后tiling的小块可以替换成手写的算子microkernel,这个替换操作在transform中没有直接的op。可以实现自定义operation(transform.my.change_call_target),使用方法如下:
1 | // Rewrite the call target. |
即将%call这个handle(一个operation)转变成microkernel调用。
在实际编译开发中,microkernel相当于微内核算子,一般为手写,编译器可以自动发现可以替换的op做替换,以进一步提高性能。
项目结构
transform extension的写法,和mlir的一个standalone项目的结构是类似的。项目结构可以参考mlir standalone项目框架。
如下是chapter2的项目结构:
1 | ├── build.sh |
为transform dialect定义一个operation,主要需要思考如下几个点:
- 利用ods系统定义一个operation,自动生成.h和.cpp文件。
- 该operation需要重载几个interface。
- TransformOpInterface是transform dialect的op必须实现的interface,主要实现
apply方法。 - MemoryEffectsOpInterface是内存side effect定义interface,需要实现
getEffects方法。
- TransformOpInterface是transform dialect的op必须实现的interface,主要实现
- 注册operation。
代码框架讲解
后续按照这个顺序来罗列代码,代码注释十分全面了。
定义operation
/include/MyExtension.td
1 | //===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===// |
上述td定义和mlir op定义是一样的,需要注意如下几点:
定义interface:
1
2[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]operation的定义:
1
2
3
4
5
6
7// The arguments include the handle to the payload operations and the attribute that
// specifies the new callee. The handle must implement TransformHandleTypeInterface.
// We use a string attribute as the symbol may not exist in the transform IR so the
// verification may fail.
let arguments = (ins
TransformHandleTypeInterface:$call,
StrAttr:$new_target);重点是argument中需要传入一个handle,该handle必须实现TransformHandleTypeInterface,即这个handle可以通过transform op来操纵。
MyExtension.h
1 | //===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===// |
这个头文件和mlir基础操作一样,利用宏定义获取.h.inc中的op定义,并定义一个extension的注册函数。
CmakeLists.txt如下:
1 | # Tell Tablegen to use MyExtension.td as input. |
实现operation需要重载的interface操作
/lib/MyExtension.cpp中的代码可以做拆分
定义transform dialect的extension:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16// Define a new transform dialect extension. This uses the CRTP idiom to
// identify extensions.
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
// The extension must derive the base constructor.
using Base::Base;
// This function initializes the extension, similarly to `initialize` in
// dialect definitions. List individual operations and dependent dialects
// here.
void init();
};extension的初始化:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32void MyExtension::init() {
// Similarly to dialects, an extension can declare a dependent dialect. This
// dialect will be loaded along with the extension and, therefore, along with
// the Transform dialect. Only declare as dependent the dialects that contain
// the attributes or types used by transform operations. Do NOT declare as
// dependent the dialects produced during the transformation.
// declareDependentDialect<MyDialect>();
// When transformations are applied, they may produce new operations from
// previously unloaded dialects. Typically, a pass would need to declare
// itself dependent on the dialects containing such new operations. To avoid
// confusion with the dialects the extension itself depends on, the Transform
// dialects differentiates between:
// - dependent dialects, which are used by the transform operations, and
// - generated dialects, which contain the entities (attributes, operations,
// types) that may be produced by applying the transformation even when
// not present in the original payload IR.
// In the following chapter, we will be add operations that generate function
// calls and structured control flow operations, so let's declare the
// corresponding dialects as generated.
declareGeneratedDialect<::mlir::scf::SCFDialect>();
declareGeneratedDialect<::mlir::func::FuncDialect>();
// Finally, we register the additional transform operations with the dialect.
// List all operations generated from ODS. This call will perform additional
// checks that the operations implement the transform and memory effect
// interfaces required by the dialect interpreter and assert if they do not.
registerTransformOps<
>();
}上述的一个重点是
declareDependentDialect和declareGeneratedDialect的区别。dependent dialects是 transform op 需要的,generated dialects是 transform 执行后生成的。这里scf和func均是可能生成的dialect。另一个重点是,
registerTransformOps会注册ods定义的operation,同时会有检测机制check是否完成transform和sideeffect的interface。
对特定op做interface重写:
transform interface的apply方法重写:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33::mlir::DiagnosedSilenceableFailure mlir::transform::ChangeCallTargetOp::apply(
// The rewriter that should be used when modifying IR.
::mlir::transform::TransformRewriter &rewriter,
// The list of payload IR entities that will be associated with the
// transform IR values defined by this transform operation. In this case, it
// can remain empty as there are no results.
::mlir::transform::TransformResults &results,
// The transform application state. This object can be used to query the
// current associations between transform IR values and payload IR entities.
// It can also carry additional user-defined state.
::mlir::transform::TransformState &state) {
// First, we need to obtain the list of payload operations that are associated
// with the operand handle.
auto payload = state.getPayloadOps(getCall());
// Then, we iterate over the list of operands and call the actual IR-mutating
// function. We also check the preconditions here.
for (Operation *payloadOp : payload) {
auto call = dyn_cast<::mlir::func::CallOp>(payloadOp);
if (!call) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "only applies to func.call payloads";
diag.attachNote(payloadOp->getLoc()) << "offending payload";
return diag;
}
updateCallee(call, getNewTarget());
}
// If everything went well, return success.
return DiagnosedSilenceableFailure::success();
}side effect的interface重写:
1
2
3
4
5
6
7
8
9
10void mlir::transform::ChangeCallTargetOp::getEffects(
::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
// Indicate that the `call` handle is only read by this operation because the
// associated operation is not erased but rather modified in-place, so the
// reference to it remains valid.
onlyReadsHandle(getCallMutable(), effects);
// Indicate that the payload is modified by this operation.
modifiesPayload(effects);
}
注册整个myextension:
1
2
3void registerMyExtension(::mlir::DialectRegistry ®istry) {
registry.addExtensions<MyExtension>();
}
CmakeLists.txt如下:
1 | # Outside examples, this should be `add_mlir_library`. |
transform-opt 驱动编写
1 | //===-- transform-opt.cpp - Transform dialect tutorial entry point --------===// |
CmakeLists.txt如下:
1 | get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) |
这部分代码和mlir代码没有区别,注册我们的myextension,然后注册各种pass,重点是注册registerInterpreterPass这个pass。
顶层CmakeLists.txt如下:
1 | cmake_minimum_required(VERSION 3.20.0) |
Transform extension的核心实现
transform extension的核心操作,就是实现transform.my.change_call_target op的apply方法和getEffect方法,这两个方法决定了tranform之后生成的代码等。
1 | void mlir::transform::ChangeCallTargetOp::getEffects( |
1 | // Implementation of our transform dialect operation. |
1 | static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) { |
上述两段代码注释已经非常详尽了,再次不多赘述。
Chapter3:实现更加复杂的transform operation
这一部分的tutorial完成两件事:
- 给chapter2实现的transform operation针对的payload handle添加constraint trait,通过使用trait的方式简化遍历匹配
func::call这一流程。 - 添加一个新的op,复习整个transform的流程。实现一个callOpInterface,使得handle不局限于func.call,而是只要实现了callOpInterface的op均可。