Leon's Blog

分享一点有趣的技术

0%

访存密集算子融合

image-20250331214150737

这是算子优化系列的第一篇,主要聚焦于存储密集型算子的融合优化。探讨这一主题的动机源于 BladeDISC 编译器,这个由阿里云开发的编译器专注于解决动态 shape 问题,并在大模型业务中实现了显著的优化。BladeDISC 的一大亮点是复用了 astitch 论文中提出的多元访存算子融合方法,因此,本文将对这一问题展开深入讨论,从两方面来分析:XLA传统融合和BladeDISC使用的stich融合。

XLA代码解析

在阿里的bladedisc的ppt中,可以看到如下总结:

image-20250325195503234

先考虑从mlir-hlo项目入手,理解xla alike的fusion是如何做的。具体代码在mhlo_fusion.cc中。上述图片其实已经总结了XLA可以做的两种kernel fusion方式:KLoopKInput。可以看出,XLA还是重点关注的算子pattern是比较简单的。

Source Code源码分析

核心原理是通过形状约束分析和图论中的边收缩算法,动态识别可融合的操作组,并生成高效的融合计划。

形状约束是阿里团队提出的动态shape约束。

该代码是典型的两阶段code,第一阶段分析出fusion plan,第二阶段应用fusion plan做相应的fusion:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
void runOnOperation() override {
FuncOp func = getOperation();
if (!IsTargetFunc(func)) {
return;
}

// process each block and do fusion within a block.
for (Block& block : func) {
// 收集一个block内部的所有operation list
SmallVector<Operation*, 4> op_list;
for (Operation& op : block) {
op_list.push_back(&op);
}

FusionPlanner planner(op_list);
llvm::Optional<FusionPlan> plan = planner.Run();
if (!plan) {
emitError(func.getLoc(), "can't find a fusion plan");
signalPassFailure();
return;
}
if (!ApplyFusionPlan(*plan)) {
emitError(func.getLoc(), "apply fusion plan failed");
signalPassFailure();
return;
}
}
}

上述代码段时整个pass的入口,也是整体流程控制。

  • FusionPlanner planner(op_list)针对一个block的oplist,初始化一个fusion plan。
  • planner.Run()根据planner构建的图,做子图划分,生成潜在的fusion pattern集和。
  • ApplyFusionPlan为第二阶段,将fusion pattern list用于代码改写,完成最终的fusion操作。

接下来分别展开这三个函数。

基础知识

下面先来介绍一些辅助函数:

  1. 首先,先明确XLA Fusion的几个基本数据结构概念

    1
    2
    using FusionPattern = std::vector<Operation*>;
    using FusionPlan = std::vector<FusionPattern>;

    FusionPattern是可融合的operation集和。FusionPlan是FusionPattern集和。

  2. XLA针对memory-intensive的算子,主要考虑如下两个算子:ReduceOpelement-wiseOp。分别有如下可能的融合方案:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    // reduceOp主要判断,其operand的define point的op是否shape相同
    // 如果相同,reduceOp考虑的模式是operand融合
    bool IsFusibleWithOperand(Operation* op) {
    // 相比IsFusibleWithConsumer,支持reduce op
    return IsMhlo(op) &&
    (op->hasTrait<::mlir::OpTrait::Elementwise>() || isa<ReduceOp>(op));
    }

    // element-wise操作或是常量操作,可以考虑是否可以和consumer融合
    bool IsFusibleWithConsumer(Operation* op) {
    // 必须是MHLO操作,并且是elementwise操作,或是常量操作
    return IsMhlo(op) && (op->hasTrait<::mlir::OpTrait::Elementwise>() ||
    matchPattern(op, m_Constant()));
    }

    // 全局的isFusible判断
    bool IsFusible(Operation* op) {
    // 只有常量操作或是能和consumer融合的操作才是可融合的,或是能和operand融合的操作
    return matchPattern(op, m_Constant()) || IsFusibleWithConsumer(op) ||
    IsFusibleWithOperand(op);
    }
  3. 针对FusionPattern和其他block的交互:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
    SmallVector<Value, 4> inputs;
    DenseSet<Value> input_set;
    DenseSet<Operation*> op_set;
    // 收集一个fusion pattern里的所有operation
    for (Operation* op : pattern) {
    bool inserted = op_set.insert(op).second;
    (void)inserted;
    assert(inserted && "FusionPattern contains duplicate operations");
    }

    for (Operation* op : pattern) {
    for (Value operand : op->getOperands()) {
    Operation* operand_op = operand.getDefiningOp();
    // 如果defining op在pattern里,则跳过
    // 否则加入到inputs中,表示是该潜在的fusion pattern的输入
    if (op_set.find(operand_op) != op_set.end()) {
    // skip if defining op is in the pattern
    continue;
    }
    if (input_set.insert(operand).second) {
    inputs.push_back(operand);
    }
    }
    }
    return inputs;
    }

    // 收集融合操作中所有被外部使用的输出,作为融合后的输出
    SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
    SmallVector<Value, 4> outputs;
    DenseSet<Operation*> op_set;
    for (Operation* op : pattern) {
    // 检查是否有重复的operation
    bool inserted = op_set.insert(op).second;
    (void)inserted;
    assert(inserted && "FusionPattern contains duplicate operations");
    }

    // 尝试做融合的operation
    for (Operation* op : pattern) {
    for (Value result : op->getResults()) {
    // 判断该operation是否result有被外部使用
    bool has_external_user = llvm::any_of(
    result.getUses(),
    [&](OpOperand& use) { return !op_set.count(use.getOwner()); });
    // 显示收集被外部使用的output
    if (has_external_user) {
    outputs.push_back(result);
    }
    }
    }
    return outputs;
    }

    上述两个函数支持获取FusionPattern的外部输入和输出。

  4. 合并两个FusionPattern,这个函数比较重要,其实就是发现可以合并融合的operation list,扩大单个fusion region:

    1
    2
    3
    4
    5
    6
    FusionPattern MergeFusionPattern(const FusionPattern& lhs,
    const FusionPattern& rhs) {
    FusionPattern pattern(lhs);
    pattern.insert(pattern.end(), rhs.begin(), rhs.end());
    return pattern;
    }
  5. 并查集用于做shape推导,注意,XLA的fusion中,shape 推导是比较简单的:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    class ShapeConstraintAnalysis {
    public:
    explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
    PropagateEquality(op_list);
    }

    // Returns true is `lhs` and `rhs` are supposed to have same shape.
    bool HasSameShape(Value lhs, Value rhs) {
    // 单纯判断两者在unionfind中的位置是否相同
    return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
    }

    private:
    // shape equality propagation based on the shape constrains of
    // elementwise ops.
    // 针对elementwise操作,做shape相等传播
    void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
    bool converged = true;
    do {
    converged = true;
    // 显示对两个value做unionfind
    auto update = [&](Value lhs, Value rhs) {
    if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
    // 有更改,说明还没有完全收敛
    converged = false;
    impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
    }
    };
    for (Operation* op : op_list) {
    // 只对有InferShapeEqualityOpInterface trait的operation做shape相等传播
    auto op_fusibility = dyn_cast<InferShapeEqualityOpInterface>(op);
    if (!op_fusibility) continue;

    int numInput = op->getNumOperands();
    int numOutput = op->getNumResults();
    // shape equality propagation between inputs.
    for (int input1 = 0; input1 < numInput; ++input1)
    for (int input2 = input1 + 1; input2 < numInput; ++input2)
    // 通过op_fusibility.inferInputsShapeEquality函数判断两个input是否shape相等
    if (op_fusibility.inferInputsShapeEquality(input1, input2))
    update(op->getOperand(input1), op->getOperand(input2));

    // shape equality propagation between outputs.
    for (int output1 = 0; output1 < numOutput; ++output1)
    for (int output2 = output1 + 1; output2 < numOutput; ++output2)
    // 同理,判断两个output是否shape相等
    if (op_fusibility.inferOutputsShapeEquality(output1, output2))
    update(op->getResult(output1), op->getResult(output2));

    // shape equality propagation between input and output.
    for (int input = 0; input < numInput; ++input)
    for (int output = 0; output < numOutput; ++output)
    // 最关键的步骤,判断input和output是否shape相等
    if (op_fusibility.inferInputOutputShapeEquality(input, output))
    // 如果相等,则调用lambda函数,将两者做unionfind
    update(op->getOperand(input), op->getResult(output));
    }
    } while (!converged);
    }

    // a UnionFind set
    // 使用LLVM提供的内置UF集和
    EquivalenceClasses<ValueWrapper> impl_;
    };

planner的初始化

FusionPlanner的初始化逻辑是比较简单的:

1
2
3
4
5
6
7
8
9
class FusionPlanner {
public:
explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
: op_list_(op_list),
shape_analysis_(op_list),
cycle_detector_(op_list.size()) {
// 构建初始的cluster图
BuildNodeMap();
}

主要干了如下事情:

  1. 初始化op_list_,用于后续的fusion
  2. 初始化shape_analysis_工具,用于知道fusion的shape compatible analysis
  3. 初始化cycle detector,XLA的fusion中,一个条件是不能引入loop

最后构建一个nodemap,构建一个cluster图,用于后续子图生成,op遍历。

成员变量如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
const SmallVectorImpl<Operation*>& op_list_;

// Shape equality checker
ShapeConstraintAnalysis shape_analysis_;

// op -> node_id
std::unordered_map<Operation*, int> op_to_node_id_;

// make sure not introduce cycle after fusion
GraphCycles cycle_detector_;
std::vector<std::unique_ptr<Cluster>> cluster_storage_;

// a UnionFind set. Each set represents a (partial) fused pattern
// and has a leader as representation.
EquivalenceClasses<int32_t> leader_for_node_;

构建图的逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
void BuildNodeMap() {
// 当前初始构建图,每个operation为一个node,为一个cluster的head,并leader_for_node_的并查集存储
int num_nodes = op_list_.size();
for (int node_id = 0; node_id < num_nodes; ++node_id) {
// 针对当前operation,构建一个cluster
// 并设定当前op为cluster的头leader
Operation* op = op_list_[node_id];
MakeCluster(node_id);
op_to_node_id_[op] = node_id;

// learder_for_node_是一个UnionFind set
leader_for_node_.insert(node_id);
for (Value operand : op->getOperands()) {
Operation* operand_op = operand.getDefiningOp();
if (operand_op == nullptr) {
// skip block argument
continue;
}

// 检查operand的def point是否属于别的融合组
// 显示构建依赖关系
auto iter = op_to_node_id_.find(operand_op);
assert(iter != op_to_node_id_.end());
cycle_detector_.InsertEdge(iter->second, node_id);
}
}
}
  • 遍历每一个op,在unionfind中,每一个op为一个cluster,后续fusion操作才会融合cluster。
  • 初始化op_to_node_id_等的映射方式,unionfind中存储的是该op的node_id,即在当前fusionplan中的次序号。
  • operand的define op和当前op,在cycle_detector中显示插入依赖链条。

上述整体逻辑是十分清楚的。

在FusionPlanner中,有一个私有类是Cluster

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// Represent a (partial) fused pattern
// 根据注释,这个cluster类并不是完整的融合模式,而是一个融合模式的一部分
// 这是unionfind的思想,每个cluster都是一个融合模式的一部分,最终通过unionfind合并
class Cluster {
public:
Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
const SmallVectorImpl<Operation*>& op_list = planner->op_list();
pattern_.push_back(op_list[node_id]);
}

// Merges `other` into this cluster, and clears `other`.
void Merge(Cluster* other) {
pattern_.insert(pattern_.end(), other->pattern_.begin(),
other->pattern_.end());
other->pattern_.clear();
}

// The number of nodes in this cluster.
int cluster_size() const { return pattern_.size(); }

// The ID of the cluster as represented in `cycle_detector_`.
int cycles_graph_node_id() const { return node_id_; }

// Sets the ID of the cluster as represented in `cycle_detector_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
node_id_ = cycles_graph_node_id;
}

// Currently the fused pattern this cluster holds.
const FusionPattern& fused_pattern() { return pattern_; }

private:
// ID of the representative node of this cluster.
int node_id_;

// the fused pattern this cluster holds.
FusionPattern pattern_;
};

上述cluster最主要的功能,就是可以对于cluster中的operation,显示记录他们可以应用的FusionPattern。同时还支持merge操作,将cluster之间进行fusion。该原理是将另一个cluster的pattern拷贝入当前cluster,并清空另一个cluster的pattern。

2

planner运行

核心部分,主要作用是构建cluster融合,并给对应cluster赋予fusion pattern。

主函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Returns a fusion plan if success, otherwise none.
// 返回一个fusion plan,如果没有找到则返回none
llvm::Optional<FusionPlan> Run() {
// Greedily search connected fusible pattern, and ops belonging to
// a same fusion pattern are grouped into a cluster.
// 每一个op有一个可融合的pattern,找寻compatible的pattern并合并对应的operation
RunEdgeContractionLoop();

// After doing edge contraction, each unique cluster having size
// more than one represents a potential fusion pattern.
// We collect all these clusters and construct a fusion plan.
// union find中,每一个有大于1的size的cluster都是一个融合模式
//
// Note that the ops in a fusion pattern are in topological ordering.
// fusion pattern中的operation是按照拓扑排序的
// 得到一个fusion plan集合后,后续apply这些fusion plan即可
FusionPlan plan;
DenseMap<int, int> pattern_ids;
for (Operation* op : op_list_) {
// 获取该op所属的cluster
Cluster* cluster = GetClusterForNode(op);
int node_id = cluster->cycles_graph_node_id();

// 不可融合或是融合模式的size小于等于1,跳过
if (!IsFusible(op_list_[node_id]) ||
EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
continue;
}

if (!pattern_ids.count(node_id)) {
int pattern_id = pattern_ids.size();
pattern_ids[node_id] = pattern_id;
plan.emplace_back();
}
// 特定的plan后添加operation
plan[pattern_ids[node_id]].push_back(op);
}
return plan;
}

这里需要理解一个重要的点:每一个cluster都有一个list,该list存储所有的fusion Pattern,一个cluster是否可以做融合,其实就是看发现了多少个fusion Pattern。

  • 如果fusion pattern > 1,则其内部可以融合
  • 如果cluster之间有兼容的fusion pattern,则intra cluster 融合是ok的。

识别出潜在的fusion后,构造fusion plan:一个fusion pattern的vector存储结构。

上述code可以分为两大阶段:RunEdgeContractionLoop()和融合pattern收集。

1

上述是一个完整的流程图。可以看出,run()函数最重要的function是RunEdgeContractionLoop()

1
2
3
4
5
6
7
8
9
10
// Greedily fuse connected node.
// 贪心算法,融合可融合的node
bool RunEdgeContractionLoop() {
using std::placeholders::_1;
using std::placeholders::_2;
// 理解std::bind操作
// 给TryToContractEdge函数绑定了两个参数,第一个参数是this指针,第二个参数是_1和_2,是占位符
return ForEachEdgeInPostOrder(
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
}

对于graph做后序遍历,并贪心引用边收缩算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 按照后序遍历的顺序,对每一个edge执行fn函数
template <typename FnTy>
bool ForEachEdgeInPostOrder(FnTy fn) {
bool changed = false;
for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
// 可能在TryToContractEdge函数中修改后续的node,所以这里需要拷贝一份
std::vector<int32_t> successors_copy =
cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());

// 对该节点的后续分别尝试做融合
for (int to : successors_copy) {
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
// 这里传入的fn是TryToContractEdge函数,传入cluster_from和cluster_to两个参数,通过std::bind绑定。
bool contracted_edge = fn(cluster_from, cluster_to);
changed |= contracted_edge;
}
}

return changed;
}

该逻辑是,后续访问每个node,并尝试对后续的每个op尝试引用fn函数,即TryToContractEdge函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// This function check if fusing `from` with `to` is valid and if so perform
// the merge. The validity is based on the operations in the clusters and
// the compatibility of the shapes of the outputs of the would-be fused
// clusters.
// Returns true is the merge was performed.
// 尝试合并两个cluster,如果合并成功则返回true
bool TryToContractEdge(Cluster* from, Cluster* to) {
int node_to = to->cycles_graph_node_id();
int node_from = from->cycles_graph_node_id();

// Both node_to and node_from should be fusible
if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
return false;
}

// 判断node from是否可以和consumer融合
// 即该node是否是const或是elementwise
if (!IsFusibleWithConsumer(op_list_[node_from])) {
// This op cannot be fused with its consumers.
return false;
}

// 判断node to是否可以和operand融合
// 即该node是否是element wise或是const或是reduce
if (!IsFusibleWithOperand(op_list_[node_to])) {
// This op cannot be fused with its operands.
return false;
}

// Output shapes of a fusion pattern should be compatible as described in
// the document of this class.
// 判断两个cluster的输出是否shape相同
SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);

Value ref = InferEffectiveWorkloadShape(results[0]);
if (!llvm::all_of(results, [&](Value result) {
Value val = InferEffectiveWorkloadShape(result);
return shape_analysis_.HasSameShape(ref, val);
})) {
return false;
}

// 实际的fusion操作,将cluster做fusion
return MergeClusters(from, to);
}

注意,如果op是reduceOp,则其只能是to,不能是from,所以reduceOp一定在sub graph的end point。

  • 判断from和to是否可fuse:isfusable()

  • GetResultsOfFusedPattern

    1
    2
    3
    4
    5
    6
    7
    // returns the outputs if two cluster were merged
    SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
    // 将两个cluster的fused_pattern合并
    FusionPattern fused_pattern =
    MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
    return GetOutputsOfFusionPattern(fused_pattern);
    }

    尝试强行融合两个cluster的pattern,并获取一个fusion厚的output。

  • 显示做shape 判断,来判断前一步的fusion是否是合法的:

    1
    2
    3
    4
    5
    6
    7
    Value ref = InferEffectiveWorkloadShape(results[0]);
    if (!llvm::all_of(results, [&](Value result) {
    Value val = InferEffectiveWorkloadShape(result);
    return shape_analysis_.HasSameShape(ref, val);
    })) {
    return false;
    }

    这个inferworkload逻辑如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    // 根据value的类型,判断workload的shape的推导方式
    // 主要针对reduce op做特殊化处理
    Value InferEffectiveWorkloadShape(Value v) {
    Operation* op = v.getDefiningOp();

    // 如果是reduce op,则返回operand的shape
    // 否则,v本身用于推导shape
    return isa_and_nonnull<ReduceOp>(op) ? op->getOperand(0) : v;
    }
    • 取融合后第一个输出的“有效工作负载形状”作为参考。

    • 对普通操作,有效形状即输出本身的形状;对 ReduceOp,有效形状是操作数的形状(因为融合需基于其输入数据的形状)。

    • 对融合后的所有输出结果,逐一检查它们的有效形状是否与参考形状一致。

    • 若存在任意一个输出形状不匹配,返回 false,拒绝合并。

    这背后的原理如下:

    kLoop 融合:要求所有输出形状一致,以放入同一并行循环。

    kInput 融合

    允许包含 ReduceOp,但其有效形状需与其他输出的有效形状一致(即 ReduceOp 的输入形状需与其他输出形状一致)。例如,若融合模式包含一个reduceOp和elementwiseOp,做如下操作:

    1. ReduceOp 的有效形状是其输入(操作数)的形状。

    2. ElementWise 的有效形状是其输出的形状。

    3. 两者必须相同才能融合。

3

ApplyFusionPlan

源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
bool ApplyFusionPlan(const FusionPlan& plan) {
// 遍历每个融合模式
for (const FusionPattern& pattern : plan) {
// 在pattern最后一个操作位置创建Builder
OpBuilder b(pattern.back());

// 收集所有操作的位置信息
SmallVector<Location, 4> locations;
locations.reserve(pattern.size());
for (Operation* op : pattern) {
locations.push_back(op->getLoc());
}
// 创建融合位置标识(用于调试)
Location fused_loc = FusedLoc::get(pattern.back()->getContext(), locations);

// 获取外部输入和输出
SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);

// 构建输出类型列表
SmallVector<Type, 4> output_types;
output_types.reserve(outputs.size());
for (Value v : outputs) {
output_types.push_back(v.getType());
}

/* 消费者调整阶段 */
// 记录融合操作集合
DenseSet<Operation*> fused_set(pattern.begin(), pattern.end());
DenseSet<Operation*> consumers_set; // 已处理的消费者
SmallVector<Operation*, 4> consumers_vec; // 待移动的消费者

// 确定原始代码范围:第一个融合操作到最后一个的迭代器范围
auto first_iter = pattern.front()->getIterator();
auto last_iter = pattern.back()->getIterator();

// 扫描区间内的所有操作
for (Operation& cur_op : llvm::make_range(first_iter, last_iter)) {
if (!fused_set.contains(&cur_op)) { // 非融合操作
// 检查是否是消费者:操作数来自融合集或已标记的消费者
bool is_consumer = llvm::any_of(cur_op.getOperands(),
[&](Value v) {
Operation* def_op = v.getDefiningOp();
return fused_set.contains(def_op) || consumers_set.contains(def_op);
});

if (is_consumer) {
consumers_set.insert(&cur_op); // 标记为已处理
consumers_vec.push_back(&cur_op); // 加入移动队列
}
}
}

// 逆序移动消费者到融合点之后(防止顺序移动导致迭代器失效)
for (auto* op : llvm::reverse(consumers_vec)) {
op->moveAfter(pattern.back()); // 重定位到融合操作末尾
}

/* 创建融合操作 */
FusionOp fusion = b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);

// 构建融合计算区域
Region& region = fusion.fused_computation();
region.push_back(new Block); // 创建基本块
Block& block = region.front();

// 将原始操作移入融合区域
for (Operation* op : pattern) {
op->moveBefore(&block, block.end()); // 保持原有顺序
}

// 在区域末尾插入ReturnOp
b.setInsertionPoint(&block, block.end());
b.create<mhlo::ReturnOp>(fused_loc, outputs);

/* 结果替换阶段 */
for (auto [output, fusion_result] : llvm::zip(outputs, fusion.getResults())) {
// 替换所有外部使用点
for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
if (use.getOwner()->getBlock() != &block) { // 外部使用
use.set(fusion_result); // 替换为融合结果
}
}
}
}
return true;
}

如下是ApplyFusionPlan的流程图

4

其中比较核心的是识别消费者操作,该consumer要求是直接或间接依赖fusion里的operation,但本身并不是fusion operation,并且其location在fusionOp的范围内,因此需要显示地移动

4

BladeDISC source code分析

BladeDISC的kernel 融合主要参考AStitchtitch fusion

4

上述很好地阐述了应用XLA和fusion stitich技术的差异。XLA编译器无法对middle reduce操作做fusion,fusion stitch划分四类memory intensive op融合方式,起到有效扩展作用:

Source code解读

一个总的框架:

上述详情参考BladeDISC slide

节点类型划分

stich的fusion pattern中,相比xla,其重点是将node划分为不同的类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// Represents a list of lmhlo ops that are going to be fused.
// Concepts for a fusion pattern:
// - Root op: the op whose output is the fusion-pattern's output.
// Sub-root op可以理解为stich fusion的缝合点,为用shared-memory缝合的op bound
// - Sub-root op: the op whose output is to be maintained on shared-memory for
// kStitch fusion. Currently, we only support row-reduction to be a sub-root
// op.
// - Regular xroot op: either a root op or a sub-root op, for whose operands
// we successfully build tile information during kStitch fusion-pattern init
// phase.
// - Irregular xroot op: an root op for whose operands we fail to build tile
// information durint kStitch fusion-pattern init phase.
// - Skeleton op: the op who will be used to build the loop skeleton when
// lowering a kStitch fusion to parallel loops. Currently, sub-root ops, and
// regular xroot ops who generate external only results, are skeleton ops.
// Other xroot ops are lowered with input-inline fusion phase.
// Note: for an regular xroot op which is not an skeleton op, the output data
// to be written should be coverred by its corresponding skeleton op.
// Otherwise, this xroot are regared as irregular.

上述注释中详细解读了node的分类:

  • Root op:fusion-pattern的output node,即fusion的边界。
  • Sub-root op:通过shared-memory fuse的op。
  • xroot op:在astich论文中,每个op的thread感知分配信息,都是通过分析sub-root或是root,然后反向传播到整个fusion区域。xroot op是分析出thread信息的root或sub-root
  • Irregular xroot:没有分析出thread信息的sub root或是root op。
  • Skeleton op:负责构建动态shape的并行循环骨架,是GPU kernel代码生成的模板基础。通过选择具有典型计算特征的子根操作(如行归约)作为骨架,能够自动推导出循环维度、分块策略等关键参数。主要负责codegen部分。

Shape analysis

GPU Stitch策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// 重点关注如何将stitch技术用在gpu上
class StitchGpuFusionStrategy : public FusionStrategy {
public:
StitchGpuFusionStrategy(const FusionOptions& options)
: FusionStrategy(options) {}
virtual bool isFusible(Operation* op) override;
virtual bool tryFuse(ShapeAnalysis& shapeAnalysis, FusionPattern& lhs,
FusionPattern& rhs, FusionPattern& target) override;
virtual bool initFusionPattern(ShapeAnalysis& shapeAnalysis,
FusionPattern& fusion_pattern) override;
virtual StringRef getName() override { return "StitchGpuFusionStrategy"; }

private:
virtual Value getEffectiveShape(FusionPattern& target, Value value);

bool tileCoverInfoPropagateO2I(
ShapeAnalysis& shapeAnalysis, DenseMap<Value, TileInfo>& tile_plan,
Operation* op, SmallVector<std::pair<Value, TileInfo>, 4>& in_info,
bool& cover);
bool findFusionPatternTypeAndSubroot(ShapeAnalysis& shapeAnalysis,
FusionPattern& fusion_pattern);
bool tileXroots(ShapeAnalysis& shapeAnalysis, FusionPattern& fusion_pattern);
bool backtraceTileAndCover(ShapeAnalysis& shapeAnalysis,
FusionPattern& fusion_pattern, Value value);
};

整体的流程框架如下:

6

上述流程图很好的表达了整个流程框架:

  • initFusionPatter负责初始化fusion的option信息

  • findFusionPatternTypeAndSubroot负责引用fusion,找寻subroot等特殊节点,其整体逻辑如下:

    7

  • tileXroots针对sub root做线程分配算法

    8

  • backtraceTileAndCover在一个fusion中,从不同sub root出发,做反向thread分配推导

    9

对应论文中的图片:

image-20250327184033247

上述source code和论文中的对应描述参考论文中的chapter4.3的steps。