Leon's Blog

分享一点有趣的技术

0%

BladeDISC: RAL overview

image-20250429153010114

在先前的关于编译器后端&运行时博客中,简单介绍了各种成熟的机器学习编译器的runtime system。本文结合BladeDISC RAL文档详细解读一下BladeDISC的runtime设计。由于BladeDISC一般作为python/tensorflow的pulgin使用,所以其一部分runtime依托现成机器学习框架的runtime,本文也会简单补充一下pytorch的runtime系统。

BladeDISC RAL

RAL原理

在开始深入BladeDISC的具体源码实现之前,我们先要讲明Compiler,Runtime之间的中间界面:硬件抽象层(RAL)。

image-20250429154132773

如上图所示是各个部分的层级关系。对于RAL,BladeDISC编译器是如下解释的:

Runtime Abstraction Layer (RAL) 是 BladeDISC 编译器的核心组件,旨在连接编译器与多样化的运行时环境,解决跨平台兼容性和资源管理问题。RAL 的核心功能体现在两个方面:首先,它通过抽象统一的接口屏蔽不同运行时环境(如 TensorFlow、PyTorch 或独立二进制)的底层差异,使编译器只需针对 RAL 生成代码,而无需为每个平台单独适配。这种设计不仅简化了编译器的开发逻辑,还支持“一次编译,多处运行”,降低了用户在不同框架间迁移的成本。其次,RAL 通过上下文对象(Context)集中管理有状态资源(如 GPU 内核、内存等),采用懒初始化(Lazy Initialization)策略优化性能。例如,GPU 内核的加载仅在首次使用时触发,后续调用直接复用,避免了重复初始化的开销。通过将资源状态(如“内核是否已加载”)隐藏在上下文背后,RAL 对外提供初始化无关的接口(如 launch_kernel),使得编译器生成的代码无需关注资源初始化细节,只需调用简单接口即可完成任务。

BladeDISC RAL总体设计

这一块主要参考BladeDISC RAL设计文档。在设计文档中,从编译器角度和运行时角度两个角度来讲解RAL的设计理念。我个人的浅薄理解,从编译角度主要是解读如何将编译出的binary code,封装成Runtime API,而运行时角度则是如何调度运行封装好的Runtime API,从这两个角度均能体现RAL的设计的价值。

从compiler角度

RAL 在编译器侧通过一系列转换流程(Transformation Passes)将代码适配到 RAL 运行时环境,其核心设计围绕上下文注入、输入输出绑定和统一类型擦除的ABI展开。

  • 上下文注入

    将状态操作(如 GPU 内核加载、内存分配)与编译器逻辑解耦。具体实现方式上

    1. 自定义MLIR方言:通过 MLIR 框架定义 disc_ral 方言,引入 disc_ral.RalExecutionContextType 类型表示运行时上下文(Context)。该类型在 LLVM IR 中转换为指针类型,隐藏底层资源状态。
    2. 入口函数重写:译器入口函数(如 main)的首个参数强制注入 Context 对象,所有 RAL API 调用均以 Context 为第一个参数。例如:
    1
    2
    3
    4
    5
    6
    7
    8
    // 转换前:普通函数参数
    func @main(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) -> memref<?x?xf32>

    // 转换后:首参数注入 Context
    func @main(!disc_ral.context %ctx) {
    %arg0 = disc_ral.recv_input(%ctx, 0) // 通过 Context 获取输入
    ...
    }

    这种设计使得编译器无需感知资源状态(如“内核是否已加载”),专注优化逻辑(如算子融合、内存分配优化)。

  • 输入输出绑定

    标准化编译模块与宿主环境(如 TensorFlow/PyTorch)的数据交互接口。具体实现方式上:

    1. 动态输入输出接收:入口函数的输入/输出不再直接传递内存对象,而是通过 disc_ral.recv_inputdisc_ral.send_output API 从 Context 中按需获取或发送。例如:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      // 原始 IR:显式传递内存引用
      func @main(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) -> memref<?x?xf32>

      // 转换后:通过 Context 动态绑定
      func @main(!disc_ral.context %ctx) {
      %arg0 = disc_ral.recv_input(%ctx, 0) // 接收第 0 号输入
      %ret = alloc(...)
      disc_ral.send_output(%ctx, 0, %ret) // 发送第 0 号输出
      }

      这种动态输入输出绑定的好处是:

      • 允许部分输入未就绪时提前执行部分计算(如流水线并行),或提前发送部分输出,减少端到端延迟(runtime调度层面可以做一定的优化过程)。
  • 类型擦除call function api

    跨语言(C/C++/Python)兼容调用 RAL 函数,同时保持 ABI 稳定。具体实现方式如下:

    1. 类型擦除接口,负责将所有 RAL 函数在编译后统一转换为 C 语言风格的泛型接口 ral_api_call,通过 api_name 动态分发。例如:

      1
      2
      3
      4
      5
      // C++ 原函数
      void gemm(RalContext*, gpu_stream_handle*, MemRef<float,2> lhs, ...);

      // 统一类型擦除接口
      void ral_api_call(void* ctx, const char* api_name, void** args);
    2. 统一函数签名编码,api_name 按规则生成唯一标识,包含设备类型(如 CPU/GPU)、输入输出类型编码等。例如:

      1
      api_name = "ral_gemm___cpu___float_2_float_2_float_2___bool_bool"
    3. 自动注册机制,通过宏 TAO_RAL_API 将模板化的 C++ 函数注册为类型擦除接口,例如:

      1
      2
      3
      4
      // 注册 float 类型的 GEMM 到 CPU
      TAO_RAL_API("ral_gemm", "cpu", gemm<float>)
      // 注册 half 类型的 GEMM 到 GPU
      TAO_RAL_API("ral_gemm", "gpu", gemm<half>)

    这套类型擦除机制,允许开发者可直接使用 C++ 接口,编译器自动处理跨语言调用细节,避免手动维护类型转换代码。

从runtime角度

image-20250429153010114

上述流程图很好地概括了runtime系统的架构。具体地可以分为三个方面来解读:

Runtime Abstraction Layer(RAL)在运行时通过三层机制实现与宿主环境(如TensorFlow、PyTorch)的高效交互,平衡跨平台兼容性与性能优化。

  • 面向不同宿主平台(TensorFlow/PyTorch)的定制化适配层

    处理输入输出绑定和元数据同步。例如,TensorFlow通过其特有的Tensor对象与RAL内存结构交互,而PyTorch需适配torch.Tensor的传输逻辑。这种差异化实现确保了数据在宿主环境与RAL间的高效传递,同时复用宿主原生的设备资源访问方式(如直接操作CUDA流),避免冗余数据拷贝或格式转换,最小化跨环境调用的性能损耗。

  • 面向不同底层硬件的设备驱动API

    涵盖内存管理、任务启动和同步操作。由于不同宿主环境对设备的抽象差异显著(如TensorFlow使用DeviceContext管理GPU,而PyTorch依赖c10::Stream),RAL需为同一设备在不同环境中实现独立的驱动接口。例如,GPU内存分配在TensorFlow中通过GPUDevice::Allocate完成,而在PyTorch中则需调用CUDACachingAllocator,RAL通过封装统一的ral_malloc接口隐藏这些差异,使编译器生成的代码无需感知底层环境细节。

  • 跨平台共享的自定义内核(如第三方库加速的矩阵乘法或排序算法)。

    这些内核直接调用RAL驱动API访问设备资源,天然兼容不同宿主环境。例如,基于cuBLAS实现的GPU矩阵乘法内核通过ral_gpu_memcpyral_gpu_stream_sync操作内存与计算流,无论部署到TensorFlow还是PyTorch均无需修改代码。

通过上述三个层次的职责划分,可以看出整个BladeDISC的底层runtime设计,其实是大量依赖于成熟的pytorch和tensorflow的运行时系统的,BladeDISC做的事情是设计实现RAL层,高效完成跨宿主(指tensorflow/pytorch系统)适配。

一些个人体会:

这种分层设计(宿主适配层-设备驱动层-内核层)使得开发者只需维护单一版本的核心逻辑,显著降低多框架支持成本,同时通过复用宿主原生接口和懒初始化策略(如首次调用时加载GPU内核),确保运行时性能接近原生框架水平。此外,新增宿主环境(如ONNX Runtime)时,仅需扩展适配层与驱动层,无需改动上层内核,为BladeDISC的跨平台部署提供了高度可扩展的技术基础。

上述三个核心组件讲完之后,来介绍一下BladeDISC的运行时管理组件:

RAL Context 与 Execution Context 的架构解析

RAL Context 是 BladeDISC 运行时的核心管理单元,其设计紧密关联宿主适配层、设备驱动层和内核层,通过分层抽象实现跨平台兼容与高效执行。针对不同宿主环境(如 TensorFlow、PyTorch 或独立二进制),RAL 提供差异化的 Context 实现。例如,TensorFlow Context 直接操作 tf::Tensor 和 GPU 设备句柄,而 PyTorch Context 适配 torch::Tensorc10::Stream,二者均通过宿主适配层的 I/O 绑定接口完成数据转换,确保输入输出与宿主环境原生数据结构无缝对接。Context 的生命周期与编译后的二进制模块一致,负责全局资源管理(如 GPU 内核预加载、内存池初始化),其内部封装设备驱动层的多环境实现(如内存分配在 TensorFlow 中调用 GPUDevice::Allocate,在 PyTorch 中则使用 CUDACachingAllocator),使得编译器生成的代码无需感知底层差异。

RAL Execution Context 作为 Context 的轻量化派生实例,专为单次执行设计,负责运行时动态状态的簿记管理。每次调用编译后的二进制时,从 Context 创建独立的 Execution Context,记录此次执行的临时资源(如输入张量指针、异步计算流标识符)。例如,在 TensorFlow 多线程场景中,多个线程可并发创建各自的 Execution Context,共享 Context 的全局资源(如预加载的 GPU 内核),但独立维护执行状态(如当前 CUDA Stream),避免资源竞争。Execution Context 的生命周期仅限于单次调用,执行结束后立即释放临时状态,从而减少内存占用并支持高并发。这一机制通过设备驱动层的环境适配接口(如 ral_gpu_stream_sync)确保内核执行与宿主环境的原生异步模型兼容,例如 PyTorch 的 JIT 执行通过 Execution Context 传递 c10::cuda::CUDAStream,使自定义内核可直接利用框架管理的计算流。

PyTorch运行时系统

重点参考Pytorch CUDA runtimePytorch Internal

BladeDISC RAL源码解读

这一部分的复现源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch_blade
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.c = torch.randn(10, 3)

def forward(self, x, y):
t1 = x + y
t2 = torch.matmul(t1, self.c)
t3 = torch.sum(t2)
return t3 * t3

my_cell = MyCell()
x = torch.rand(10, 10)
y = torch.rand(10, 10)

with torch.no_grad():
blade_cell = torch_blade.optimize(my_cell, allow_tracing=True, model_inputs=(x, y))

print(blade_cell(x, y))

注意开启

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// -----// IR Dump After DiscLowerToLibraryCallPass (disc-lower-to-library-call) //----- //
func.func @main(%arg0: !disc_ral.context) attributes {tf.entry_function = {input_placements = "gpu", inputs = "input.1_", output_placements = "gpu", outputs = "8"}} {
%0 = llvm.mlir.constant(0 : i32) : i32
%false = arith.constant false
%true = arith.constant true
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%1 = "disc_ral.dispatch"(%arg0, %c0) {backend_config = "", call_target_name = "ral_recv_input", device = "cpu", has_side_effect = false} : (!disc_ral.context, index) -> memref<?x10xf32, #gpu.address_space<global>>
%dim = memref.dim %1, %c0 : memref<?x10xf32, #gpu.address_space<global>>
%alloc = memref.alloc() : memref<10x10xf32, #gpu.address_space<global>>
"lmhlo.constant"(%alloc) {disc.device = "gpu", value = dense_resource<__elided__> : tensor<10x10xf32>} : (memref<10x10xf32, #gpu.address_space<global>>) -> ()
%alloc_0 = memref.alloc() : memref<10x10xf32, #gpu.address_space<global>>
"lmhlo.constant"(%alloc_0) {disc.device = "gpu", value = dense_resource<__elided__> : tensor<10x10xf32>} : (memref<10x10xf32, #gpu.address_space<global>>) -> ()
%alloc_1 = memref.alloc() : memref<10xf32, #gpu.address_space<global>>
"lmhlo.constant"(%alloc_1) {disc.device = "gpu", value = dense<[0.186997086, 0.235856801, 0.217500299, 0.25940907, 0.109970599, -0.152944937, 0.137896746, -0.189537019, 0.256005555, 0.235299528]> : tensor<10xf32>} : (memref<10xf32, #gpu.address_space<global>>) -> ()
%alloc_2 = memref.alloc() : memref<10xf32, #gpu.address_space<global>>
"lmhlo.constant"(%alloc_2) {disc.device = "gpu", value = dense<[0.281509697, -0.0671350583, -0.291665494, 0.300998032, -0.304899603, 0.23629041, -0.111676671, 0.304613203, 0.107744612, -0.118951075]> : tensor<10xf32>} : (memref<10xf32, #gpu.address_space<global>>) -> ()
%reinterpret_cast = memref.reinterpret_cast %1 to offset: [0], sizes: [%dim, 10], strides: [10, 1] {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>> to memref<?x10xf32, #gpu.address_space<global>>
%alloc_3 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%2 = llvm.inttoptr %0 : i32 to !llvm.ptr<i8>
"disc_ral.dispatch"(%arg0, %2, %reinterpret_cast, %alloc_0, %alloc_3, %false, %false, %true) {backend_config = "", call_target_name = "ral_gemm", device = "gpu", has_side_effect = false} : (!disc_ral.context, !llvm.ptr<i8>, memref<?x10xf32, #gpu.address_space<global>>, memref<10x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, i1, i1, i1) -> ()
memref.dealloc %alloc_0 : memref<10x10xf32, #gpu.address_space<global>>
%alloca = memref.alloca() {alignment = 64 : i64} : memref<2xindex>
memref.store %dim, %alloca[%c0] : memref<2xindex>
memref.store %c10, %alloca[%c1] : memref<2xindex>
%3 = arith.muli %dim, %c10 : index
%4 = arith.remui %3, %c4 : index
%5 = arith.cmpi eq, %4, %c0 : index
%alloc_4 = memref.alloc() : memref<f32, #gpu.address_space<global>>
%alloc_5 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%alloc_6 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%alloc_7 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%alloc_8 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
scf.if %5 {
"lmhlo.fusion"() ({
"lmhlo.constant"(%alloc_4) {disc.device = "gpu", value = dense<0.000000e+00> : tensor<f32>} : (memref<f32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_1, %alloca, %alloc_5) {broadcast_dimensions = dense<1> : tensor<1xi64>, disc.device = "gpu"} : (memref<10xf32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.add"(%alloc_3, %alloc_5, %alloc_6) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_4, %alloca, %alloc_8) {broadcast_dimensions = dense<> : tensor<0xi64>, disc.device = "gpu"} : (memref<f32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.maximum"(%alloc_6, %alloc_8, %alloc_7) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.terminator"() : () -> ()
}) {disc.device = "gpu", disc.fusion.name = "main_kLoop_maximum__5_1_0", disc.fusion.tag = "Vec4", disc.fusion_type = "kLoop", disc_vectorize_or_tile_hint = 4 : i32} : () -> ()
} else {
"lmhlo.fusion"() ({
"lmhlo.constant"(%alloc_4) {disc.device = "gpu", value = dense<0.000000e+00> : tensor<f32>} : (memref<f32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_1, %alloca, %alloc_5) {broadcast_dimensions = dense<1> : tensor<1xi64>, disc.device = "gpu"} : (memref<10xf32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.add"(%alloc_3, %alloc_5, %alloc_6) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_4, %alloca, %alloc_8) {broadcast_dimensions = dense<> : tensor<0xi64>, disc.device = "gpu"} : (memref<f32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.maximum"(%alloc_6, %alloc_8, %alloc_7) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.terminator"() : () -> ()
}) {disc.device = "gpu", disc.fusion.name = "main_kLoop_maximum__5_1_0", disc.fusion_type = "kLoop", disc_vectorize_or_tile_hint = 1 : i32} : () -> ()
}
memref.dealloc %alloc_8 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_6 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_5 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_4 : memref<f32, #gpu.address_space<global>>
memref.dealloc %alloc_3 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_1 : memref<10xf32, #gpu.address_space<global>>
%alloc_9 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%6 = llvm.inttoptr %0 : i32 to !llvm.ptr<i8>
"disc_ral.dispatch"(%arg0, %6, %alloc_7, %alloc, %alloc_9, %false, %false, %true) {backend_config = "", call_target_name = "ral_gemm", device = "gpu", has_side_effect = false} : (!disc_ral.context, !llvm.ptr<i8>, memref<?x10xf32, #gpu.address_space<global>>, memref<10x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, i1, i1, i1) -> ()
memref.dealloc %alloc_7 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc : memref<10x10xf32, #gpu.address_space<global>>
%alloc_10 = memref.alloc() : memref<f32, #gpu.address_space<global>>
%alloc_11 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%alloc_12 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%alloc_13 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
%alloc_14 = memref.alloc(%dim) {kDiscSymbolicDimAttr = [@S0, @C10]} : memref<?x10xf32, #gpu.address_space<global>>
scf.if %5 {
"lmhlo.fusion"() ({
"lmhlo.constant"(%alloc_10) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_10, %alloca, %alloc_11) {broadcast_dimensions = dense<> : tensor<0xi64>, disc.device = "gpu"} : (memref<f32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_2, %alloca, %alloc_12) {broadcast_dimensions = dense<1> : tensor<1xi64>, disc.device = "gpu"} : (memref<10xf32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.add"(%alloc_9, %alloc_12, %alloc_13) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.maximum"(%alloc_13, %alloc_11, %alloc_14) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.terminator"() : () -> ()
}) {disc.device = "gpu", disc.fusion.name = "main_kLoop_maximum__5_1_1", disc.fusion.tag = "Vec4", disc.fusion_type = "kLoop", disc_vectorize_or_tile_hint = 4 : i32} : () -> ()
} else {
"lmhlo.fusion"() ({
"lmhlo.constant"(%alloc_10) {value = dense<0.000000e+00> : tensor<f32>} : (memref<f32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_10, %alloca, %alloc_11) {broadcast_dimensions = dense<> : tensor<0xi64>, disc.device = "gpu"} : (memref<f32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.dynamic_broadcast_in_dim"(%alloc_2, %alloca, %alloc_12) {broadcast_dimensions = dense<1> : tensor<1xi64>, disc.device = "gpu"} : (memref<10xf32, #gpu.address_space<global>>, memref<2xindex>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.add"(%alloc_9, %alloc_12, %alloc_13) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.maximum"(%alloc_13, %alloc_11, %alloc_14) {disc.device = "gpu"} : (memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>, memref<?x10xf32, #gpu.address_space<global>>) -> ()
"lmhlo.terminator"() : () -> ()
}) {disc.device = "gpu", disc.fusion.name = "main_kLoop_maximum__5_1_1", disc.fusion_type = "kLoop", disc_vectorize_or_tile_hint = 1 : i32} : () -> ()
}
memref.dealloc %alloc_13 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_12 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_11 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_10 : memref<f32, #gpu.address_space<global>>
memref.dealloc %alloc_9 : memref<?x10xf32, #gpu.address_space<global>>
memref.dealloc %alloc_2 : memref<10xf32, #gpu.address_space<global>>
"disc_ral.dispatch"(%arg0, %c0, %alloc_14) {backend_config = "", call_target_name = "ral_send_output", device = "cpu", has_side_effect = false} : (!disc_ral.context, index, memref<?x10xf32, #gpu.address_space<global>>) -> ()
return
}