Leon's Blog

分享一点有趣的技术

0%

GPU 后端优化(二)

image-20250512225158784

这是本系列的第二篇文章,主要结合印度理工学院搭建的面向GPU tensor core代码生成的论文,以及IREE开源项目,解读如何基于MLIR系统完成GPU的代码生成。

IREE  Tensor Core代码生成

IREE是一个端到端的开源机器学习编译器,支持在边缘设备进行高效部署。本章节主要讲解IREE项目是如何做面向Nvidia GPU的代码生成的。

IREE 后端代码生成管线概览

详解matmul算子生成

注:这一部分解读主要参考IREE Codegen博客

结合前一章节的论文解读,我们重点关注IREE的matmul代码生成部分。在compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp代码中,有如下代码生成dispatch实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// 根据translation info,选择对应的executable生成的dispatch lowering pipeline
switch (translationInfo.getDispatchLoweringPassPipeline()) {
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDefault:
addGPUDefaultPassPipeline(pipeline, pipelineOptions);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUBaseLowering:
addGPUBaseLoweringPassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUDistribute:
addGPUSimpleDistributePassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorize:
addGPUVectorizationPassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize:
addGPUWinogradVectorizePassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt:
addGPUMatmulSimtPassPipeline(pipeline, pipelineOptions);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore: {
// 这个pass pipeline比较有意思,可以debug一下
// tensor core生成逻辑
// Tile and distribute operations to workgroups

// 这个函数自动获取pipeline元信息
FailureOr<int64_t> maybeDepth =
getSoftwarePipelineDepth(translationInfo.getConfiguration());
if (failed(maybeDepth)) {
funcOp.emitOpError(
"invalid matmul configuration without software pipelining config");
return signalPassFailure();
}
addGPUMatmulTensorCorePassPipeline(pipeline, pipelineOptions, *maybeDepth);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::
LLVMGPUMatmulTensorCoreMmaSync: {
FailureOr<int64_t> maybeDepth =
getSoftwarePipelineDepth(translationInfo.getConfiguration());
if (failed(maybeDepth)) {
funcOp.emitOpError(
"invalid matmul configuration without software pipelining config");
return signalPassFailure();
}
addGPUMatmulTensorCoreMmaSyncPassPipeline(pipeline, pipelineOptions,
*maybeDepth);
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTransposeSharedMem:
addGPUTransposePassPipeline(pipeline, pipelineOptions);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUVectorDistribute:
addGPUVectorDistributePassPipeline(pipeline, pipelineOptions,
/*usePadToModelSharedMemcpy=*/false);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::
LLVMGPUPadAndVectorDistribute:
addGPUVectorDistributePassPipeline(pipeline, pipelineOptions,
/*usePadToModelSharedMemcpy=*/true);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWarpReduction:
addGPUWarpReductionPassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUPackUnPack:
addGPUPackUnPackPasses(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse:
addGPUTileAndFusePassPipeline(pipeline, pipelineOptions);
break;
// no pipeline specified, nothing to do.
case IREE::Codegen::DispatchLoweringPassPipeline::None:
return;
default:
funcOp.emitOpError("unsupported pipeline on GPU target.");
return signalPassFailure();
}

可以看到,结合不同的translation info,IREE针对每个可执行变种(IREE中叫execute.variant)派发不同的代码生成逻辑。这里我们重点关注LLVMGPUMatmulTensorCore,该管线是面向Nvidia gpu的tensor core硬件的适配编译流水线

测试输入

1
2
3
4
5
6
7
8
9
#compilation0 = #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1]
,
{ pipeline_depth = 3, store_stage = 1}>>
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1(%lhs: tensor<512x128xf16>, %rhs: tensor<128x512xf16>, %acc: tensor<512x512xf16>) -> tensor<512x512xf16> {
%result = linalg.matmul {compilation_info = #compilation0} ins(%lhs, %rhs: tensor<512x128xf16>, tensor<128x512xf16>) outs(%acc: tensor<512x512xf16>) -> tensor<512x512xf16>
return %result: tensor<512x512xf16>
}

这个测试例子是一个简单的矩阵乘运算,维度是<512x128xf16> x <128x512xf16> = <512x512xf16>。这段测试核心点事compiler配置attribute:

1
2
3
4
5
#compilation0 = #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1]
,
{ pipeline_depth = 3, store_stage = 1}>>

其中,定义了如下元数据:

  • workgroup(即threadBlock)维度是[64, 2, 1]。
  • tile分块的维度是[32, 32, 16]。
  • 流水线深度为3,注意3是GPU最小流水线深度(存,取,执行)。
  • storage_stage是1,这个具体含义尚不明晰。

针对这个测试,我们的测试脚本如下:

1
2
3
4
5
6
7
iree-compile --iree-hal-target-backends=cuda \
--iree-cuda-target=sm_86 \
--mlir-disable-threading \
--mlir-elide-elementsattrs-if-larger=10 \
--mlir-print-ir-after-all \
matmul.mlir -o test.vmfb \
2>&1 | tee output.dump

Tensor Core verifier

现在我们有了输入和测试脚本,IREE在根据我们的代码和配置做tensor core转换之前,会先检查一下参数是否适配。具体地参考compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp中关于compilation_infowork_group等参数的约束判断逻辑。这里通过几张表格的方式简单解释下各个参数背后的具体意义和相应约束:

核心概念

概念 代码对应变量/参数 含义与作用 约束来源
Workgroup Size workgroup_size 线程块(Thread Block)的维度配置(如[64,2,1]),总线程数 ≤ 1024 CUDA硬件限制
Tile Size tile_sizes 矩阵分块尺寸(如[32,32,16]),决定每个线程块处理的数据块大小 算法优化需求
Thread Block Shape threadBlockShape 线程块内各维度的分块尺寸(如[32,32,16]),与Tile Size直接相关 分块策略与硬件匹配
Warp数量 numWarps 线程块内各维度的Warp数量(如[2,2,1]),由Workgroup Size除以Warp Size(32)计算 SM硬件架构
Warp Shape warpShape 单个Warp处理的分块子矩阵尺寸(如[16,16,16] 线程调度与指令级并行
Instruction Shape instructionShape Tensor Core硬件指令支持的矩阵尺寸(如FP16→16x16x16FP32→16x16x8 Tensor Core架构规范

TensorCore约束检测逻辑

验证项 验证条件 失败后果
Workgroup总线程数 workgroupSize[X] * Y * Z ≤ 1024 超过GPU线程块容量限制,无法执行
Z维度线程数 workgroupSize[Z] == 1 Tensor Core设计为二维计算,Z维度扩展会破坏数据局部性
X维度线程数 workgroupSize[X] % 32 == 0 Warp调度需要X维度为32的整数倍(每个Warp含32线程)
验证项 验证条件 失败后果
矩阵尺寸对齐 matmulShape[M/N/K] % threadBlockShape[M/N/K] == 0 分块无法均匀覆盖原矩阵,导致计算错误或性能下降
Warp分块对齐 warpShape[M/N/K] % instructionShape[M/N/K] == 0 Tensor Core指令无法覆盖Warp分块,硬件资源利用率不足

这里有两个关键点:

  1. tensor core的最小执行单元是一个warp,因此需要验证Warp是否分块对齐。
  2. workgroup就是threadblock。tiesize表示一个threadblock完成的计算任务量,而workgroup则表示一个threadblock有多少个线程,并可以根据warpsize(默认32)计算一个workgroup可以有多少warps。这两个概念要分辨清楚

整个tensor core verifier的逻辑比较复杂,需要详细阅读下面代码的注释以获得全面的了理解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Number of warps in x, y, and z dim.
// 计算一个workgroup的warp数目
// warp一般为32,如果一个workgroup是[64,2,1],则这里得到[2,2,1]维度的warp
SmallVector<int64_t> numWarps{workgroupSize[kDimX] / kWarpSize,
workgroupSize[kDimY], workgroupSize[kDimZ]};

// Matrix-multiply problem shape in number of elements in M, N, and K dim.
// matmulshape为[512, 512, 128]
// 获取矩阵最本元的MNK参数。
SmallVector<int64_t> matmulShape{lhsShape[0], rhsShape[1], lhsShape[1]};

// Warp tile shape in number of elements in M, N, and K dim.
// Note that num warp in (x, y, z) dim are mapped to problem (M, N, K) dim as:
// DimY -> ProblemDimM, DimX -> ProblemDimN, DimZ -> ProblemDimK.
/*
将线程块(thread block)形状均匀划分给各个 warp。
注意:注释中说明了 warp 在 (x, y, z) 维度的分布映射到矩阵问题的 (M, N, K) 上,其中:
y 维度 warp 对应问题的 M;
x 维度 warp 对应问题的 N;
z 维度 warp 对应问题的 K;
*/
// [32/2, 32/2, 16/1] = [16, 16, 16]表征每个warp所要干的工作
SmallVector<int64_t> warpShape{threadBlockShape[kM] / numWarps[kDimY],
threadBlockShape[kN] / numWarps[kDimX],
threadBlockShape[kK] / numWarps[kDimZ]};

// Instruction shape in number of elements in M, N, and K dim.
// 获取tensor core指令形状
SmallVector<int64_t> instructionShape;
// f16 和 bf16 类型对应的指令形状为 {16, 16, 16},而 f32 类型对应 {16, 16, 8}
if (failed(getInstructionShape(
op, pipeline, llvm::cast<ShapedType>(lhsType).getElementType(),
instructionShape))) {
return failure();
}

// Verify that matmul problem shape can be tiled with the thread block shape.
// TODO: This check should be relaxed as we allow unaligned matmul shapes.
// 要求矩阵问题的每个维度(M、N、K)能够被tile设定的维度整除,来决策是否可以做tiling运算
// 检测[512%32, 512%32, 128%16]
if (matmulShape[kM] % threadBlockShape[kM] != 0 ||
matmulShape[kN] % threadBlockShape[kN] != 0 ||
matmulShape[kK] % threadBlockShape[kK] != 0) {
return op->emitError("Thread block shape ")
<< threadBlockShape << " cannot be tiled on matmul shape "
<< matmulShape << " with compilation pipeline " << pipelineName;
}

// Verify that if warp shape can be tiled using warp-level Tensor core
// instruction shape.
// 确保每个 warp tile(即 warpShape)在 M、N、K 维度上均可以被对应的 Tensor Core 指令形状整除。
// 若不满足,则说明硬件的计算单元(Tensor Core)无法完美地覆盖 warp tile
// 以f16为例,为[16,16,16],即每个warp的计算,是否可以生成Tensor Core指令
if (warpShape[kM] % instructionShape[kM] != 0 ||
warpShape[kN] % instructionShape[kN] != 0 ||
warpShape[kK] % instructionShape[kK] != 0) {
return op->emitError("Tensor Core instruction shape ")
<< instructionShape << " cannot be tiled on warp shape " << warpShape
<< " with compilation pipeline " << pipelineName;
}

矩阵乘tensor core生成流程概览

讲解完Tensor Core管线的verifier流程后,我们逐渐接触tensor core生成的主体管线流程,代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options,
unsigned pipelineDepth) {
tileAndBufferize(funcPassManager);

// Distribute linalg onto warps within the workgroup.
funcPassManager.addPass(
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
if (pipelineDepth > 1) {
funcPassManager.addPass(createGPUMultiBufferingPass(
GPUMultiBufferingPassOptions{pipelineDepth}));
}
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createRemoveSingleIterationLoopPass());

ReorderWorkgroupsStrategy reorderStrategy =
getReorderWorkgroupsStrategy(options.reorderStrategy);
funcPassManager.addPass(
createReorderWorkgroups(reorderStrategy, canReorderWorkgroups));

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Linalg -> vector
funcPassManager.addPass(
createLLVMGPUTensorCoreVectorizationPass(GPUTensorCoreType::WMMA));
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createOptimizeVectorTransferPass());
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());

// Distribute shared memory copies.
funcPassManager.addPass(createMemrefCopyToLinalgPass());
funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
if (options.enableReduceSharedMemoryBankConflicts) {
funcPassManager.addPass(createGPUReduceBankConflictsPass());
}

// Vector -> MMA ops
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(
createLLVMGPUVectorToGPUPass(GPUTensorCoreType::WMMA));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Hoist loop invariant code to avoid pipelining it.
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
// Pipeline memory operations.
GPUPipeliningPassOptions pipelieningOptions = {};
pipelieningOptions.epiloguePeeling = false;
pipelieningOptions.depth = pipelineDepth;
pipelieningOptions.scheduleIndex =
llvm::to_underlying(PipeliningSchedulingStrategy::loadGlobalStage0);
funcPassManager.addPass(createGPUPipeliningPass(pipelieningOptions));
// Optimize shared memory usage.
funcPassManager.addPass(createLLVMGPUPackSharedMemoryAllocPass());
}

上述管线中,我们比较关心如下7个pass的流程:

  • LLVMGPUTileAndDistribute
  • GPUMultiBufferingPass
  • LLVMGPUTensorCoreVectorizationPass
  • GPUDistributeSharedMemoryCopyPass
  • GPUReduceBankConflictsPass
  • LLVMGPUVectorToGPU
  • GPUPipeliningPass
  • LLVMGPUPackSharedMemoryAlloc

我们接下来结合代码,一个一个pass的解读。

LLVMGPUTileAndDistribute Pass

LLVMGPUTileAndDistribute 这个 pass 主要是根据 lower_config 中的 tile_sizes 和 compilation info 中的 workgroup size 进行 tiling,并且在此过程中,使用了 shared memory 用作缓存进行访存的优化。

我们的workload可以表征为如下:C<512x512> = A<512x128> x B<128x512>

具体的,可以拆解为下面几步:

  1. 初步tile,将计算workload分块为<32x128> x <128x32>
  2. Promote Memory1:将C矩阵放入shared_memory
  3. 持续tile化,化为指定的<32x16> x <16x32>
  4. Promote Memory2:如果workgroup比warp大,那么需要将一部分的A和B也放入sharedmemory
  5. 根据warp size继续tile化。我们的thread block是[64,2,1],warp size是32,所以可以拆成[2,2,1]个warp。针对此,我们的矩阵<32x16> x <16x32>可以进一步变成<32/2 x 16> x <16 x 32/2>

接下来分别在debug过程中dump每个步骤生成的中间表示代码。

Step1 i,j维度分块
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1_dispatch_0_matmul_512x512x128_f16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1], {pipeline_depth = 3 : i64, store_stage = 1 : i64}>} {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%subview = memref.subview %0[%3, 0] [32, 128] [1, 1] : memref<512x128xf16, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview_0 = memref.subview %1[0, %4] [128, 32] [1, 1] : memref<128x512xf16, #hal.descriptor_type<storage_buffer>> to memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_1 = memref.subview %2[%3, %4] [32, 32] [1, 1] : memref<512x512xf16, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%subview, %subview_0 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_1 : memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
return
}

这段代码在 GPU 上实现了高效的 512×512 半精度矩阵乘法,并进行了如下优化:

  1. 使用 Tensor Core 加速计算tile_sizes = [[32, 32, 16]])。
  2. 分块计算(Tile-based Computation)
    • 将矩阵 A、B、C 分割成 32×128、128×32、32×32 子矩阵。
    • 使用 workgroup_id_x/y 进行索引计算,每个 workgroup 处理固定区域。
  3. 存储优化
    • 存储对齐(alignment 64):提高 GPU 访存性能。
    • 使用 Subview 提取矩阵块,减少数据移动,提升数据局部性。
Step2(做promotion优化)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
//========================================== LLVM GPU Tile And Distibute ========================================
// Step1
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1_dispatch_0_matmul_512x512x128_f16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1], {pipeline_depth = 3 : i64, store_stage = 1 : i64}>} {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<32x32xf16, #gpu.address_space<workgroup>>
// 获取interface的binding,这是数组A
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
// 这是数组B
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
// 这是数组C
// 注意返回的是memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>

// 获取workgroup中的thread id,是二维的
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%subview = memref.subview %0[%3, 0] [32, 128] [1, 1] : memref<512x128xf16, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview_0 = memref.subview %1[0, %4] [128, 32] [1, 1] : memref<128x512xf16, #hal.descriptor_type<storage_buffer>> to memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
// C的subview,这个是external memory
%subview_1 = memref.subview %2[%3, %4] [32, 32] [1, 1] : memref<512x512xf16, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_2 = memref.subview %alloc[0, 0] [%c32, %c32] [1, 1] : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<?x?xf16, strided<[32, 1]>, #gpu.address_space<workgroup>>
// 将C的subview拷贝到workgroup memory中,subview_2是workgroup memory
memref.copy %subview_1, %subview_2 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<?x?xf16, strided<[32, 1]>, #gpu.address_space<workgroup>>
// 计算matmul,核心计算逻辑
linalg.matmul {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%subview, %subview_0 : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_2 : memref<?x?xf16, strided<[32, 1]>, #gpu.address_space<workgroup>>)
// 将C的结果拷贝回到external memory中
memref.copy %subview_2, %subview_1 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<?x?xf16, strided<[32, 1]>, #gpu.address_space<workgroup>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}

重点是:将C矩阵存储在workgroup memory中。

image-20250513230338216

在上图的代码比对中,可以看到如下变化:

  • 显示的memref.alloc操作,并且address_space设置为workgroup
  • 由于shared memory而导致的memref.copy(从shared mem显示加载)和memref.subview操作。

为什么这一步仅仅提升C矩阵,而不提升A,B矩阵,具体原因参考MLIR GPU代码生成论文,这里截取一下论文的原文辅助理解:

image-20250513182833242

Step3(继续tiling,将K从128降维成16)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1_dispatch_0_matmul_512x512x128_f16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1], {pipeline_depth = 3 : i64, store_stage = 1 : i64}>} {
%c16 = arith.constant 16 : index
%c128 = arith.constant 128 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<32x32xf16, #gpu.address_space<workgroup>>
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%subview = memref.subview %0[%3, 0] [32, 128] [1, 1] : memref<512x128xf16, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview_0 = memref.subview %1[0, %4] [128, 32] [1, 1] : memref<128x512xf16, #hal.descriptor_type<storage_buffer>> to memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_1 = memref.subview %2[%3, %4] [32, 32] [1, 1] : memref<512x512xf16, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.copy %subview_1, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, #gpu.address_space<workgroup>>
scf.for %arg0 = %c0 to %c128 step %c16 {
%subview_2 = memref.subview %subview[0, %arg0] [32, 16] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %subview_0[%arg0, 0] [16, 32] [1, 1] : memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.matmul {__internal_linalg_transform__ = "workgroup_k_tiled", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%subview_2, %subview_3 : memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloc : memref<32x32xf16, #gpu.address_space<workgroup>>)
}
memref.copy %alloc, %subview_1 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}

继续完成k维度的tiling,从128变成16。

image-20250513230638654

如上图所示,reduction维度变成显示的循环,以及删除冗余的memref.copy操作。

这个处理有趣的点是将linalg.matmul的workgroup_memory变成了workgroup_k_tiled,具体阅读源码中的tileReductionLoops函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
/// Tiles to workgroup level. Workgroup tiling is done at the flow level but we
/// may have extra tiling for the reduction dimension. Therefore we tile again
/// without distributing.
static LogicalResult tileReductionLoops(mlir::FunctionOpInterface funcOp) {
auto tileSizesFn = [](OpBuilder &builder,
Operation *op) -> SmallVector<OpFoldResult> {
auto interfaceOp = cast<PartitionableLoopsInterface>(*op); // 获取可分块的循环组

// Returns the loop IDs (0 being outermost) that are partitionable.
//
// If `maxNumPartitionedLoops` is passed the size of the vector returned
// is always lesser than the value.
auto partitionedLoops =
interfaceOp.getPartitionableLoops(kNumMaxParallelDims); // 获取已在workgroup中分块的循环
SmallVector<OpFoldResult> tileSizes =
getAsIndexOpFoldResult(op->getContext(), getTileSizes(op, 0));
auto zeroAttr = builder.getIndexAttr(0);
for (unsigned depth : partitionedLoops) {
if (depth < tileSizes.size()) { // 将已分配到工作组的并行循环平铺尺寸设为0(跳过)
tileSizes[depth] = zeroAttr;
}
}

int numLoops = cast<TilingInterface>(op).getLoopIteratorTypes().size(); // 确保分块尺寸向量(tileSizes)的长度与操作的循环维度数量一致,并为未明确指定分块大小的维度设置默认不分块(即分块大小为0)
tileSizes.resize(numLoops, zeroAttr);
return tileSizes;
};

auto tilingOptions =
scf::SCFTilingOptions().setTileSizeComputationFunction(tileSizesFn);

MLIRContext *context = funcOp.getContext();
LinalgTransformationFilter filter( // 处理源操作中带有workgroup memory标记的操作
ArrayRef<StringAttr>{
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker())); // 处理好后的打上tilemarker
filter.setMatchByDefault();

return tileLinalgOpsWithFilter(funcOp, tilingOptions, filter);
}

image-20250514205936633

上图大体总结了一下tileToSerialLoops的调用链。这段代码比较重要的一个函数是tileSCF(),该函数底层调用tileUsingSCF()函数实现,感兴趣读者可以自行阅读mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp中的源码实现。

Step4(当workgroup中的线程数大于warp,则A和B也可以做提升)

这一步骤主要完成如下操作:

  • 额外在 shared memory 中申请了两块儿 buffer,用来缓存 A ,B 矩阵
  • 在所有涉及到 shared memory 读写的地方(memref.copy)前后加上 gpu.barrier

接下来结合具体的代码逻辑进行解读:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
SmallVector<int64_t> workgroupSize = maybeWorkgroupSize.value();
int64_t flatWorkgroupSize =
workgroupSize[0] * workgroupSize[1] * workgroupSize[2];
// Only promote to workgroup size if there are multiple warps.
if (flatWorkgroupSize > kWarpSize) { // 以[64,2,1]为例,flatWorkgroupSize=128时,表示有2x2个warp,所以切分成warp是有必要的
RewritePatternSet promotionPatterns(&getContext());

populateContractPromotionPatterns(promotionPatterns, {0, 1}); // 0和1表示对A/B矩阵进行提升

if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(promotionPatterns)))) {
return signalPassFailure();
}
// Insert barriers before and after copies to workgroup memory.
insertBarriersAroundSharedMemoryCopy(funcOp);
}

这段代码的判断条件是workgroup的[0],[1],[2]相乘是否大于warp,大于说明workgroup内部也需要做warp的tiling。这也导致A和B会有重用的可能性,因此最好也存入shared memory中。这一步骤的逻辑具体参考populateContractPromotionPatterns()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/// 这段代码是 MLIR(Multi-Level Intermediate Representation)中的一个 重写模式(Rewrite Pattern),
/// 它用于优化 linalg::MatmulOp、linalg::BatchMatmulOp 和 linalg::GenericOp,
/// 具体是对某些操作数(operands)进行优化提升(promotion),使其在 workgroup memory(工作组共享内存) 中进行计算,
/// 以提高计算效率(主要面向 GPU 计算)
/// 整个pass都是基于mlir提供的基于linalg的优化而来,需要系统学习linalg的优化。
void populateContractPromotionPatterns(RewritePatternSet &patterns,
ArrayRef<int64_t> operandsToPromote) {
// 一个整数数组,表示要提升(promote)到工作组共享内存的操作数索引(通常是 A, B 矩阵)
MLIRContext *context = patterns.getContext();

// 将linalg::MatmulOp,linalg::BatchMatmulOp,linalg::GenericOp这三种op的operandsToPromote进行优化
// insert写法具体看PatternMatch.h的写法。
// LinalgPromotionPattern提供普适的linalg pattern的优化
patterns.insert<LinalgPromotionPattern<linalg::MatmulOp>,
LinalgPromotionPattern<linalg::BatchMatmulOp>,
LinalgPromotionPattern<linalg::GenericOp>>(
context,
// 设定优化参数
linalg::LinalgPromotionOptions()
// 设置内存释放函数
.setAllocationDeallocationFns(allocateWorkgroupMemory,
deallocateWorkgroupMemory)
// 设置拷贝函数
.setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
// 指定要提升的操作数
.setOperandsToPromote(operandsToPromote)
// 设置提升的维度
.setUseFullTileBuffers({false, false}),
LinalgTransformationFilter(
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()))
.setMatchByDefault()
.addFilter(contractOpFilter));
}

这段代码很好地体现MLIR基础设施的可复用性的强大之处以及Google团队的良好封装。代码中最重要的点是LinalgPromotionPattern,这是IREE对于MLIR的RewritePattern的封装和扩展。具体的依赖关系如下所示:

image-20250514214336495

我们先来解读第一层封装:LinalgBasePromotionPatter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
//===----------------------------------------------------------------------===//
// Transformations exposed as patterns, moved from upstream MLIR as IREE still
// heavily relies on patterns that compose through filters.
// TODO: Deprecate all the code below.
//===----------------------------------------------------------------------===//
///
/// Linalg promotion patterns.
///
/// Apply the `promoteSubViews` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViews` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
/// Entry point to match any LinalgOp
/// OpInterface. MatchAnyOpTag-based constructor
/// with a mandatory `filter`.
LinalgBasePromotionPattern(
MLIRContext *context, LinalgTransformationFilter f,
linalg::LinalgPromotionOptions options = linalg::LinalgPromotionOptions(),
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
filter(std::move(f)), options(std::move(options)) {}
/// Entry point to match a specific Linalg op.
LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context,
linalg::LinalgPromotionOptions options,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
options(std::move(options)) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
if (failed(promoteSubviewsPrecondition(op, options)))
return failure();

// TODO: We cannot use root update here. This
// pattern is creating other ops, so if the
// promotion fails, those need to be cleaned
// up, which doesnt seem to be happening here.
// So to fail properly, we should be cloning
// the op and deleting the previous op. This
// needs more investigation.
rewriter.startOpModification(op);
std::optional<linalg::LinalgOp> promotedOp =
promoteSubViews(rewriter, cast<linalg::LinalgOp>(op), options);
if (!promotedOp) {
rewriter.cancelOpModification(op);
return op->emitError("subview promotion failed");
}
rewriter.finalizeOpModification(op);
filter.replaceLinalgTransformationFilter(rewriter, op);
return success();
}

private:
/// LinalgTransformMarker handles special
/// attribute manipulations.
LinalgTransformationFilter filter;
/// Promotion options.
linalg::LinalgPromotionOptions options;
};

可以看到,在类的构造中,最重要的是LinalgTransformationFilterLinalgPromotionOptions。这两个具体内容在patterns.insert中定义好了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
patterns.insert<LinalgPromotionPattern<linalg::MatmulOp>,
LinalgPromotionPattern<linalg::BatchMatmulOp>,
LinalgPromotionPattern<linalg::GenericOp>>(
context,
// 设定优化参数
linalg::LinalgPromotionOptions()
// 设置内存释放函数
.setAllocationDeallocationFns(allocateWorkgroupMemory,
deallocateWorkgroupMemory)
// 设置拷贝函数
.setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
// 指定要提升的操作数
.setOperandsToPromote(operandsToPromote)
// 设置提升的维度
.setUseFullTileBuffers({false, false}),
LinalgTransformationFilter(
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()))
.setMatchByDefault()
.addFilter(contractOpFilter));

其中,LinalgTransformationFilter和我们之前遇到的作用一样,将WorkgroupKTiledMark变成WorkgroupMemoryMark。而LinalgPromotionOptions的定义特别有趣,显示指定了如何做内存释放,如何设置拷贝函数,要提升的操作数是哪个(0,1对应lianalg.xxOp的operand 0和1,在我们的case中是A和B)以及提升的维度。追溯LinalgPromotionOptions的定义,可以看到该类是LLVM公共class,通过回调的方式,允许用户自定义这些操作具体的执行逻辑。

讲解完这个类的构造,另一个重点就是matchAndRewrite实际的改写逻辑。阅读代码注释即可明白原理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
// 步骤1: 检查过滤器和提升前提条件
if (failed(filter.checkAndNotify(rewriter, op)) ||
failed(promoteSubviewsPrecondition(op, options)))
return failure();

// 步骤2: 尝试执行子视图提升
rewriter.startOpModification(op);
std::optional<linalg::LinalgOp> promotedOp =
promoteSubViews(rewriter, cast<linalg::LinalgOp>(op), options);

// 步骤3: 处理提升结果
if (!promotedOp) {
rewriter.cancelOpModification(op); // 失败时回滚
return op->emitError("subview promotion failed");
}
rewriter.finalizeOpModification(op);
filter.replaceLinalgTransformationFilter(rewriter, op); // 更新属性
return success();
}

这段代码特别好的点是,这里的所有诸如promoteSubviewsPrecondition()函数均是Linalg Dialect的transform原生支持的。因此这段IREE的封装是很好的处理memory promotion问题的范例,可以直接移植入用在别的项目。

解读好第一层封装,我们来到最上层封装:LinalgPromotionPattern

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
template <typename OpTy>
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
/// SFINAE: This constructor can only trigger for
/// concrete ops that have a static
/// `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgPromotionPattern(
MLIRContext *context, linalg::LinalgPromotionOptions options,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
f, benefit) {}
/// This constructor is available to anyone.
LinalgPromotionPattern(
StringRef opName, MLIRContext *context,
linalg::LinalgPromotionOptions options,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
};

这段代码唯一的trick是,第一个构造函数的模板参数typename ConcreateOpTy = OpTy配合OpTy::getOperationName()形成**SFINAE约束**。当OpTy不包含静态getOperationName()方法时,该构造函数被丢弃,不会导致编译错误,起到保护机制。

这一阶段生成的最终IR如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1_dispatch_0_matmul_512x512x128_f16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1], {pipeline_depth = 3 : i64, store_stage = 1 : i64}>} {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c16 = arith.constant 16 : index
%alloc = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_1 = memref.alloc() : memref<32x32xf16, #gpu.address_space<workgroup>>
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%subview = memref.subview %0[%3, 0] [32, 128] [1, 1] : memref<512x128xf16, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview_2 = memref.subview %1[0, %4] [128, 32] [1, 1] : memref<128x512xf16, #hal.descriptor_type<storage_buffer>> to memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %2[%3, %4] [32, 32] [1, 1] : memref<512x512xf16, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
memref.copy %subview_3, %alloc_1 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.for %arg0 = %c0 to %c128 step %c16 {
%subview_4 = memref.subview %subview[0, %arg0] [32, 16] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %subview_2[%arg0, 0] [16, 32] [1, 1] : memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
memref.copy %subview_4, %alloc_0 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, #gpu.address_space<workgroup>>
memref.copy %subview_5, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, #gpu.address_space<workgroup>>
gpu.barrier
linalg.matmul {__internal_linalg_transform__ = "workgroup_memory", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%alloc_0, %alloc : memref<32x16xf16, #gpu.address_space<workgroup>>, memref<16x32xf16, #gpu.address_space<workgroup>>) outs(%alloc_1 : memref<32x32xf16, #gpu.address_space<workgroup>>)
}
gpu.barrier
memref.copy %alloc_1, %subview_3 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
return
}

和上一阶段相比,其不同是显而易见的,具体为A和B 提升到共享内存,更多的显示copy,以及gpu.barrier同步shared memory拷贝。

image-20250514220954451

image-20250514221013687

Step5(针对warp进一步分块)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1_dispatch_0_matmul_512x512x128_f16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1], {pipeline_depth = 3 : i64, store_stage = 1 : i64}>} {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c16 = arith.constant 16 : index
%alloc = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_1 = memref.alloc() : memref<32x32xf16, #gpu.address_space<workgroup>>
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%subview = memref.subview %0[%3, 0] [32, 128] [1, 1] : memref<512x128xf16, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview_2 = memref.subview %1[0, %4] [128, 32] [1, 1] : memref<128x512xf16, #hal.descriptor_type<storage_buffer>> to memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %2[%3, %4] [32, 32] [1, 1] : memref<512x512xf16, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
memref.copy %subview_3, %alloc_1 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.for %arg0 = %c0 to %c128 step %c16 {
%subview_4 = memref.subview %subview[0, %arg0] [32, 16] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %subview_2[%arg0, 0] [16, 32] [1, 1] : memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
memref.copy %subview_4, %alloc_0 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, #gpu.address_space<workgroup>>
memref.copy %subview_5, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, #gpu.address_space<workgroup>>
gpu.barrier
%c0_6 = arith.constant 0 : index
%c16_7 = arith.constant 16 : index
%c16_8 = arith.constant 16 : index
%thread_id_x = gpu.thread_id x
%5 = affine.apply affine_map<(d0) -> (d0 floordiv 32)>(%thread_id_x)
%c2 = arith.constant 2 : index
%thread_id_y = gpu.thread_id y
%c2_9 = arith.constant 2 : index
%c0_10 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16_11 = arith.constant 16 : index
%c0_12 = arith.constant 0 : index
%c32_13 = arith.constant 32 : index
%c16_14 = arith.constant 16 : index
%6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%thread_id_y, %c16_11]
%7 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%6, %c0_10]
%8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%c2_9, %c16_11]
scf.for %arg1 = %7 to %c32 step %8 {
%9 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%5, %c16_14]
%10 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%9, %c0_12]
%11 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%c2, %c16_14]
scf.for %arg2 = %10 to %c32_13 step %11 {
%subview_15 = memref.subview %alloc_0[%arg1, 0] [16, 16] [1, 1] : memref<32x16xf16, #gpu.address_space<workgroup>> to memref<16x16xf16, strided<[16, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_16 = memref.subview %alloc[0, %arg2] [16, 16] [1, 1] : memref<16x32xf16, #gpu.address_space<workgroup>> to memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_17 = memref.subview %alloc_1[%arg1, %arg2] [16, 16] [1, 1] : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>
linalg.matmul {__internal_linalg_transform__ = "vectorize", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%subview_15, %subview_16 : memref<16x16xf16, strided<[16, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%subview_17 : memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>)
}
}
}
gpu.barrier
memref.copy %alloc_1, %subview_3 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
return
}

这一步是为后续将linalg.matmul算子变成WMMA API做准备。workgroup size(表示有多少个线程) 以 warp 粒度(32个线程一个warp)对 thread block([32,32,16]) 进行进一步的 tiling分块。我们的workgroup size是[64,2,1],而warp是32,因此应该有[2,2,1]个warp。对thread blcok需要做[2,2,1]的tile拆分,为[32/2, 32/2, 16/1] = [16,16,16]。由于我们的代码是f16,根据compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp中的getInstructionShape()方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/// Returns the shape of the math instruction for the given pipeline and input
/// element type.
/// 例如,对于 Tensor Core 管线,f16 和 bf16 类型对应的指令形状为 {16, 16, 16},而 f32 类型对应 {16, 16, 8}。
/// Tensor Core的计算能力是硬件定死的,因此我们需要根据输入的数据类型来选择合适的指令形状。
static LogicalResult
getInstructionShape(Operation *op, CodeGenPipeline pipeline,
Type inputElementType,
SmallVector<int64_t> &instructionShape) {
switch (pipeline) {
case CodeGenPipeline::LLVMGPUMatmulSimt:
// SIMT Pipeline / CUDA Cores
instructionShape = {1, 1, 1};
break;

// 我们的pipeline是LLVMGPUMatmulTensorCore
case CodeGenPipeline::LLVMGPUMatmulTensorCore:
// Tensor Core Pipeline / WMMA API
// 对于F16和BF16,tensorcore是16x16x16
if (inputElementType.isF16() || inputElementType.isBF16()) {
instructionShape = {16, 16, 16};
// 对于F32,tensorcore是16x16x8
} else if (inputElementType.isF32()) {
instructionShape = {16, 16, 8};
} else {
return op->emitError(
"Expected f16, bf16 or f32 for Tensor Core (WMMA) pipeline");
}
break;
case CodeGenPipeline::LLVMGPUMatmulTensorCoreMmaSync:
// Tensor Core Pipeline / MMA.SYNC
if (inputElementType.isF16() || inputElementType.isBF16()) {
instructionShape = {16, 8, 16};
} else if (inputElementType.isF32()) {
instructionShape = {16, 8, 8};
} else {
return op->emitError(
"Expected f16, bf16 or f32 for Tensor Core (MMA.SYNC) pipeline");
}
break;
default:
return op->emitError(
"Expected matmul SIMT, TensorCore(WMMA), or TensorCore(MMA.SYNC), "
"compilation pipeline");
}
return success();
}

tensor core中的f16是[16,16,16] matmul矩阵运算。因此我们的尺寸正好匹配。这个warp tile步骤在源码中是tileToInvocation()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/// Patterns for thread level tiling.
static LogicalResult tileToInvocation(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<int64_t> &workgroupSize) {
linalg::TileSizeComputationFunction getInnerTileSizeFn =
[&](OpBuilder &builder, Operation *operation) {
return calculateDistributedTileSize(workgroupSize, builder, operation); // 计算tile size: [64/32, 2/1, 1/1] = [2, 2, 1]
};
auto getThreadProcInfoFn = // 显示计算生成thread id,不断向GPU的kernel表达靠近
[&workgroupSize](OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
return getGPUThreadIdsAndCounts(builder, loc, parallelLoopRanges.size(),
workgroupSize);
};
linalg::LinalgLoopDistributionOptions invocationDistributionOptions; // 设定thread的分布选项,设定procInfo
invocationDistributionOptions.procInfo = getThreadProcInfoFn;

auto tilingOptions =
linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::Loops)
.setTileSizeComputationFunction(getInnerTileSizeFn)
.setDistributionOptions(invocationDistributionOptions);

MLIRContext *context = funcOp.getContext();
LinalgTransformationFilter f(
{StringAttr::get(context, getWorkgroupKTiledMarker()),
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getVectorizeMarker()));
f.addFilter([](Operation *op) {
// FFT doesn't support second level of tiling yet.
return success(!isa<IREE::LinalgExt::FftOp>(op));
}).setMatchByDefault();

return distributeLinalgOpsWithFilter(funcOp, tilingOptions, f);
}

这一块搭配注释阅读源码已经很清晰了。相比前一个k维度tile,整体结构都是一样的(再一次感叹Google团队和MLIR社区的良好封装,大大简化冗余代码)。

这一步骤得到的代码变化如下:

image-20250514222617381

最重要的点完成了分块,并分配好了每个线程完成的任务(通过thread_id_x和thread_id_y):

1
2
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%thread_id_y]  // y方向起始行
%6 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%thread_id_x) // x方向起始列

一个thread,通过(d0 floordiv 32)找到其属于的warp组,并通过*16找到其warp组负责的列。可以看到x方向总共可以分为两个warp,一个warp加载连续的16个列,符合coalesced memory优化。至于y,由于只有2个线程,难以构成warp,同时y方向也没法利用coalesced warp,直接每个线程处理连续16行即可。对这一部分具体细节不明白的,可以参考calculateDistributedTileSize源码或是深入理解coalesced memory优化calculateDistributedTileSize计算每个thread或是warp对应的tile大小,具体函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/// Return the tile size associated to one thread or warp based on the number of
/// element in the group.
static SmallVector<Value>
calculateDistributedTileSize(ArrayRef<int64_t> numElements, OpBuilder &builder,
Operation *operation) {
SmallVector<int64_t> blockTileSize = getTileSizes(operation, 0);
SmallVector<Value> tileSizesVal;
// Use partitionedLoop to know what loop needs to be distributed.
auto interfaceOp = cast<PartitionableLoopsInterface>(operation);
auto partitionedLoops =
interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
if (partitionedLoops.empty()) {
return tileSizesVal;
}
auto zero = builder.create<arith::ConstantIndexOp>(operation->getLoc(), 0);
tileSizesVal.resize(
cast<TilingInterface>(operation).getLoopIteratorTypes().size(), zero);

// partitionedLoops contains the dimensions we want to distribute.
// We are distributing them in order onto the different workgroup
// dimensions.
SmallVector<int64_t> distributedDim(numElements.begin(), numElements.end());
distributedDim.resize(partitionedLoops.size());
unsigned idIdx = 0;
std::reverse(distributedDim.begin(), distributedDim.end());
for (unsigned depth : partitionedLoops) {
if (depth >= blockTileSize.size())
continue;
tileSizesVal[depth] = builder.create<arith::ConstantIndexOp>(
operation->getLoc(),
llvm::divideCeil(blockTileSize[depth], distributedDim[idIdx++]));
if (idIdx == kNumMaxParallelDims)
break;
}
return tileSizesVal;
}

这里有个注意的地方,每一处tile优化,IREE都提供了规则化善后:

1
2
3
4
5
6
7
8
9
10
11
{
// Apply canonicalization patterns.
RewritePatternSet threadTilingCanonicalizationPatterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
populateAffineMinSCFCanonicalizationPattern(
threadTilingCanonicalizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(threadTilingCanonicalizationPatterns)))) {
return signalPassFailure();
}
}

这段代码值得借鉴。

至此便是完整的tileAndDistribute流程。可以看到整个流程的逻辑其实是比较简单的,但在这里做了极其详细的分析,因为这个pass设计大量的tiling优化,其思想和具体实现是值得参考借鉴的。在前面的博客中已经详细讲过tiling优化,在各种异构架构(GPU/NPU)中,都有多层级缓存架构,tiling大有用武之地。

GPUMultiBufferingPass

在完成分块划分和面向GPU的并行负载分配工作后,IREE针对GPU的流水线特性做GPUMultiBufferingPass。这个pass代码量只有小一百行,但是要想真正弄懂这个pass,需要对GPU的流水线技术有一定了解。本博客只针对GPU软件流水线做简要介绍。在IREE GPUMultiBufferingPass解读博客中有这样一段话:

通过广泛使用软件流水线技术,CUTLASS 可以最大程度地利用 GPU 的计算资源,并提高矩阵乘法的性能。软件流水线阶段包括:

  1. 从全局内存加载数据
  2. 将数据拷贝至共享内存 / 进行必要的预处理
  3. 执行计算
  4. 将计算结果写回共享内存或全局内存

在软件流水线结构中,GPU 在使用共享内存的数据执行当前计算的同时,还应从全局内存中读取下一次计算所需数据。因此,从存储层次结构的角度看,在共享内存级别应使用多缓冲模式,保证在上游流水线阶段将数据写入共享内存的同时,下游流水线阶段可以从共享内存中加载数据到寄存器。换言之,循环中的不同迭代使用不同缓冲区,从而消除迭代间的数据依赖。多缓冲模式中应使用的缓冲数量取决于流水线深度(也称流水线阶段数量)。

如下图所示,是GPU软件流水线的示意图:

image-20250517000503961

可以看到,GPU的流水线和CPU多级流水线本质原理一样,但是由于GPU的分支逻辑简单,流水线设计相比CPU已经是大大简化了。

相信到这里,大家对于GPU软件流水的作用有了基础的认知。接下来解读GPUMultiBufferingPass,这个pass的代码比较简洁,这里选择直接贴出源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_GPUMULTIBUFFERINGPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

// todo hal: 一个比较重要,且简单的pass
// 这个Pass是完成软件流水线的重要一环。
// 多缓冲模式中应使用的缓冲数量取决于流水线深度(也称流水线阶段数量)
// 其核心思想是通过引入多缓冲,使得循环迭代尽可能并行执行,从而为最终实现软件流水线创造条件。
namespace {
struct GPUMultiBufferingPass final
: impl::GPUMultiBufferingPassBase<GPUMultiBufferingPass> {
using GPUMultiBufferingPassBase::GPUMultiBufferingPassBase;

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

// First hoist all shared memory allocations to the entry block of the
// function. We can see memref.alloc in loops after bufferizing scf.forall
// with promoted shared memory usage inside.

// 收集所有allocOp,并且是shared memory中开辟的
// 提前到func头部
SmallVector<memref::AllocOp> allocs;
// Collect all the alloc operations.
// 收集所有在shared memory中的内存空间开辟
funcOp.walk([&](memref::AllocOp allocOp) {
// 如果是shared memory的分配,那么就将其移动到函数的entry block中
// 即multi-buffering优化是只针对shared memory的分配
if (hasSharedMemoryAddressSpace(allocOp.getType()))
allocs.push_back(allocOp);
});

assert(funcOp.getBlocks().size() == 1);
// 将所有allocOp放在funcOp的最前面
for (memref::AllocOp allocOp : allocs) {
if (allocOp->getParentOp() != funcOp)
allocOp->moveBefore(&*funcOp.begin()->begin());
}

// Then perform multibuffering transformations.

allocs.clear();
// Collect all the alloc operations.
funcOp.walk([&](memref::AllocOp allocOp) {
// Skip allocations not used in a loop.
// 学习这里的写法:获取user list,以及判断user的所属operation
for (Operation *user : allocOp->getUsers()) {
auto loop = user->getParentOfType<scf::ForOp>();
if (!loop)
return WalkResult::advance();
}
allocs.push_back(allocOp);

// Interrupt: the walk will be interrupted and no more operations, regions or blocks will be visited.
// Advance: the walk will continue.
// Skip: the walk of the current operation, region or block and their nested elements that haven't been visited already will be skipped and will continue with the next operation, region or block.
return WalkResult::advance();
});
// Apply multi-buffering to all of them.
for (memref::AllocOp alloc : allocs) {
// 这个pass的真正核心步骤,完成多缓冲的引入
// 辅助gpu软件流水线的设计
// 其核心是:
// 1. alloc op的使用者必须在for loop里
// 2. alloc op的使用者不能有loop-carried dependence

if (failed(memref::multiBuffer(alloc, numBuffers))) {
// Error out and stop if any buffer cannot be multi buffered, as future
// software pipelining transformations will assume this happened.
alloc.emitOpError("cannot be multi-buffered");
return signalPassFailure();
}
}
}
};

} // namespace
} // namespace mlir::iree_compiler

这个pass只有当GPU设定的流水线大于1的时候,才会启用,核心目的是通过multibuffer技术来overlap计算和数据传输的延迟开销。该pass的具体流程如下:

  • 收集所有shared memory的开辟操作(allocOp),并提前到入口函数处,旨在方便后续优化

  • 判断每个开辟的空间,是否在后续的循环中有使用,并且不存在loop carried依赖。

    Loop Carried依赖 是指循环中不同迭代之间的数据依赖关系,即后续迭代的执行依赖于前面迭代的结果。这种依赖限制了循环的并行化能力,因为迭代无法独立执行,必须按顺序进行。

    例如:

    1
    2
    3
    for (int i = 1; i < N; i++) {
    A[i] = A[i-1] + B[i]; // 第i次迭代依赖第i-1次的结果
    }
  • 调用memref dialect的方法:multiBuffer()完成多缓冲技术。

大体读完pass后,我们来思考一下multiBuffer()技术是如何优化GPU流水线的。首先通过debug手段直观感受一下这个pass给IR带来的变化。在pass执行之前,我们的IR代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
func.func @matmul_accumulate_512x128xf16_times_128x512xf16_into_512x512xf16_for_LLVMGPUMatmulTensorCore_32_32_16_64_2_1_dispatch_0_matmul_512x512x128_f16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulTensorCore workgroup_size = [64, 2, 1], {pipeline_depth = 3 : i64, store_stage = 1 : i64}>} {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c16 = arith.constant 16 : index
%alloc = memref.alloc() : memref<16x32xf16, #gpu.address_space<workgroup>>
%alloc_0 = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
%alloc_1 = memref.alloc() : memref<32x32xf16, #gpu.address_space<workgroup>>
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<512x128xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<128x512xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<512x512xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_y]
%subview = memref.subview %0[%3, 0] [32, 128] [1, 1] : memref<512x128xf16, #hal.descriptor_type<storage_buffer>> to memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%4 = affine.apply affine_map<()[s0] -> (s0 * 32)>()[%workgroup_id_x]
%subview_2 = memref.subview %1[0, %4] [128, 32] [1, 1] : memref<128x512xf16, #hal.descriptor_type<storage_buffer>> to memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_3 = memref.subview %2[%3, %4] [32, 32] [1, 1] : memref<512x512xf16, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
memref.copy %subview_3, %alloc_1 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x32xf16, #gpu.address_space<workgroup>>
gpu.barrier
scf.for %arg0 = %c0 to %c128 step %c16 {
%subview_4 = memref.subview %subview[0, %arg0] [32, 16] [1, 1] : memref<32x128xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_5 = memref.subview %subview_2[%arg0, 0] [16, 32] [1, 1] : memref<128x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
memref.copy %subview_4, %alloc_0 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x16xf16, strided<[128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x16xf16, #gpu.address_space<workgroup>>
memref.copy %subview_5, %alloc {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<16x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<16x32xf16, #gpu.address_space<workgroup>>
gpu.barrier
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%thread_id_y]
%6 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%thread_id_x)
%subview_6 = memref.subview %alloc_0[%5, 0] [16, 16] [1, 1] : memref<32x16xf16, #gpu.address_space<workgroup>> to memref<16x16xf16, strided<[16, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_7 = memref.subview %alloc[0, %6] [16, 16] [1, 1] : memref<16x32xf16, #gpu.address_space<workgroup>> to memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>
%subview_8 = memref.subview %alloc_1[%5, %6] [16, 16] [1, 1] : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>
linalg.matmul {__internal_linalg_transform__ = "vectorize", lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 32, 16]]>} ins(%subview_6, %subview_7 : memref<16x16xf16, strided<[16, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%subview_8 : memref<16x16xf16, strided<[32, 1], offset: ?>, #gpu.address_space<workgroup>>)
}
gpu.barrier
memref.copy %alloc_1, %subview_3 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x32xf16, #gpu.address_space<workgroup>> to memref<32x32xf16, strided<[512, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
return
}

经过该pass之后,中间代码变动如下:

image-20250516213503860

如上图所示将A和B矩阵buffer成三份。在tileAndDistributePass中,C矩阵的allo在循环外,并且没有在循环使用,所以不参与MultiBufferPass。而A和B矩阵,当workgroup比warp大,会针对warp做tiling,A和B矩阵的shared memory的alloc因此在tile循环中,可以参与multiBuffer优化。由于我们设置的 pipeline depth = 3,因此 Matrix A ,B对应的 shared memory 会进行 factor = 3 的 multi-buffering。

image-20250516214558122

上图是内层k维度循环中的计算逻辑,也是MultiBuffer主要变换的代码逻辑。在矩阵乘法中,由于为了适配warp大小,做了warp分块(c0 to c128 step c16)。针对每一分块,先将work group mem拷贝入shared mem中,然后计算再拷贝出去。通过引入多缓冲机制,可以将每个循环的各个阶段overlap,以减少延迟,其可视化状态图如下:

image-20250516221417987

可以看到每个单元同一时间都不空闲,因此IREE该pass的多缓冲机制是make sense的。完成多缓冲人物的核心是MLIR提供的memref::multiBuffer(),由于篇幅原因不深入解读,仅仅提供一个流程图概括整个过程。感兴趣的可以结合源码以及IREE GPUMultiBufferingPass解读加深理解。该函数中很多可以复用的基础组件写法在此也简单罗列一下,以备以后自行实现组件时参考。

流程图

image-20250520160243480

代码组件

1
2
3
4
5
6
7
8
9
10
// ==================================== 获取loop信息 ========================================
// loop四要素:1. induction var 2. lower bound 3. step 4. region
std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
if (!inductionVar || !lowerBound || !singleStep ||
!llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
return failure();
}
1
2
3
4
5
6
7
8
9
// 1. Construct the multi-buffered memref type.
ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
llvm::append_range(multiBufferedShape, originalShape); // 这里的multiBufferedShape是一个一维数组,第一维是multiBufferingFactor,后面是原来的shape
LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
.setShape(multiBufferedShape)
.setLayout(MemRefLayoutAttrInterface());
LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
1
2
3
4
5
6
7
// 2. Create the multi-buffered alloc.
Location loc = allocOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(allocOp);
auto mbAlloc = rewriter.create<memref::AllocOp>(
loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 3. Within the loop, build the modular leading index (i.e. each loop
// iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
rewriter.setInsertionPointToStart(
&candidateLoop.getLoopRegions().front()->front());
// 这个affine expr的实现值得学习
Value ivVal = *inductionVar;
Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
AffineExpr iv, lb, step;
bindDims(rewriter.getContext(), iv, lb, step);
// trickey here:affine expr可以由其他affine expr组合而成
Value bufferIndex = affine::makeComposedAffineApply(
rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
{ivVal, lbVal, stepVal});
LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 4. Build the subview accessing the particular slice, taking modular
// rotation into account.
int64_t mbMemRefTypeRank = mbMemRefType.getRank();
IntegerAttr zero = rewriter.getIndexAttr(0);
IntegerAttr one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);

// 这里做的事情十分简单:(1)通过bufferIndex计算出偏移量offsets(2)size为[1, 实际size],第一维度为multibuffer
// Offset is [bufferIndex, 0 ... 0 ].
offsets.front() = bufferIndex;
// Sizes is [1, original_size_0 ... original_size_n ].
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
// 目的存储可以根据offset,size和stride推导出来大小
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
originalShape, mbMemRefType, offsets, sizes, strides);
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);

LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
// handle dealloc uses separately..
for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
if (!deallocOp)
continue;
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(deallocOp);

// 主要作用是将deallocOp的buffer替换为mbAlloc
auto newDeallocOp =
rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
(void)newDeallocOp;
LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
rewriter.eraseOp(deallocOp);
}

LLVMGPUTensorCoreVectorizationPass

image-20250520161347725

这个pass的主要作用,是将linalg方言下降到vector方言中。MLIR针对linalg.matmul算子提供vector.contract表示。

GPUDistributeSharedMemoryCopyPass

这个pass的作用是将共享内存的拷贝操作分配到各个线程中。

GPUReduceBankConflictsPass

这个pass是比较有趣的一个优化点,通过对memref.alloc添加padding的方式,消除GPU中的bank conflict。这个padding逻辑的一大问题是,均是通过用户手动来指定padding size。因此本pass仅仅做个流程图解读一下:

未命名文件 (6)

LLVMGPUVectorToGPU

GPUPipeliningPass

LLVMGPUPackSharedMemoryAlloc

参考资料

  1. 论文:MLIR GPU代码生成工作
  2. 博客:IREE GPU后端解读