Leon's Blog

分享一点有趣的技术

0%

BladeDISC初探

image-20250401195537142

在大模型训练推理场景中,一个十分大的瓶颈是动态shape问题。比如nlp领域,处理的句子长短不一,tensor的shape是动态变化的,到runtime才能确定。这给机器学习编译器带来很大的困扰,以XLA为首的sota编译器均是静态shape的,在性能上会有一定损失。BladeDISC是阿里提出的针对动态shape的机器学习编译器,并且经过大量实验和实际生产检验。本文重点关注BladeDISC的构建,pytorch使用方式以及基础架构解读。后续文章会讲解优化流程和论文解读。

源码构建

Build from source

  • 下载BladeDisc镜像

    1
    docker pull bladedisc/bladedisc:latest-devel-cu118
    • 使用cu118版本
  • 运行该镜像

    1
    docker run --gpus all -it -v $PWD:/disc bladedisc/bladedisc:latest-devel-cu118 bash
  • 修改一下pytorch_blade/scripts/build_pytorch_blade.sh里面的TORCH_BLADE_CI_BUILD_TORCH_VERSION。修改为存在的requirements.txt即可。

    构建过程中,onnx由于带宽等问题,可能会报error,添加-i https://pypi.tuna.tsinghua.edu.cn/simple指定pypi镜像即可。

  • pytorch版本构建

    1
    2
    3
    cd pytorch_blade && bash ./scripts/build_pytorch_blade.sh
    python setup.py bdist_wheel
    pip install ./pytorch_blade/dist/torch_blade-0.2.0+2.0.1.cu118-cp38-cp38-linux_x86_64.whl

错误处理

如果报错没有安全git,在docker中用:

1
git config --global --add safe.directory /disc

quick install

参考docker install

Pytorch部署BERT模型

Hugging Face模型下载

  • 手动下载模型(适合服务器联网不稳定的情况使用)

    找到Bert sentiment inference 模型,主要手动下载如下几个文件:

    image-20250306203839463

    在python代码中使用离线下载的模型:

    1
    2
    3
    model_path = "./model"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path).cuda().eval()
  • 直接通过transformers 包下载,该下载方式通过huggingface对应模型网页的use this model获取

    1
    2
    3
    4
    5
    # Load model directly
    from transformers import AutoTokenizer, AutoModelForSequenceClassification

    tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
    model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")

    确保环境有transformers包即可

  • 通过huggingface-cli下载

    1
    huggingface-cli download nlptown/bert-base-multilingual-uncased-sentiment

做BERT Inference的testbench

我的测试codes如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch_blade
import time

from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
TextClassificationPipeline,
)

############################################# download model from huggingface #############################################
model_path = "./model"

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSequenceClassification.from_pretrained(model_path).cuda().eval()

def plain_tokenizer(inputs_str, return_tensors):
inputs = tokenizer(inputs_str, return_tensors=return_tensors, padding=True)
inputs = {key: value.cuda() for key, value in inputs.items()}

# torch_blade.optimize 不支持 None 作为输入
if "token_type_ids" in inputs and inputs["token_type_ids"] is None:
del inputs["token_type_ids"]

return (inputs['input_ids'], inputs['attention_mask'], inputs.get('token_type_ids', None))

class PlainTextClassificationPipeline(TextClassificationPipeline):
def _forward(self, model_inputs):
return self.model(*model_inputs)

classifier = pipeline(
'sentiment-analysis',
model=model,
tokenizer=plain_tokenizer,
pipeline_class=PlainTextClassificationPipeline,
device=0
)

input_strs = [
"We are very happy to show you the story.",
"We hope you don't hate it."
]

results = classifier(input_strs)

for inp_str, result in zip(input_strs, results):
print(inp_str)
print(f" label: {result['label']}, with a score: {round(result['score'], 4)}")

############################################# Use BladeDISC for optimization #############################################
inputs_str = "Hey, the cat is cute."
inputs = plain_tokenizer(inputs_str, return_tensors="pt")

torch_config = torch_blade.config.Config()
torch_config.enable_mlir_amp = False # disable mix-precision

# Ensure inputs are properly formatted for optimization
model_inputs = tuple(i for i in inputs if i is not None)

with torch.no_grad(), torch_config:
optimized_ts = torch_blade.optimize(model, allow_tracing=True, model_inputs=model_inputs)

# Move optimized model to CUDA
optimized_ts = optimized_ts.cuda()

# Save the optimized TorchScript model
torch.jit.save(optimized_ts, "opt.disc.pt")

############################################# testbench #############################################
@torch.no_grad()
def benchmark(model, inputs, num_iters=1000):
for _ in range(10):
model(*inputs)
torch.cuda.synchronize()

start = time.time()
for _ in range(num_iters):
model(*inputs)
torch.cuda.synchronize()
end = time.time()
return (end - start) / num_iters * 1000.0

def bench_and_report(input_strs):
inputs = plain_tokenizer(input_strs, return_tensors="pt")
model_inputs = tuple(i for i in inputs if i is not None)

avg_latency_baseline = benchmark(model, model_inputs)
avg_latency_bladedisc = benchmark(optimized_ts, model_inputs)

print(f"Seqlen: {[len(s) for s in input_strs]}")
print(f"Baseline: {avg_latency_baseline:.4f} ms")
print(f"BladeDISC: {avg_latency_bladedisc:.4f} ms")
print(f"BladeDISC speedup: {avg_latency_baseline / avg_latency_bladedisc:.4f}")

input_strs = [
"We are very happy to show you the story.",
"We hope you don't hate it."
]

bench_and_report(input_strs)

上述codes中,BladeDISC的核心如下:

1
2
with torch.no_grad(), torch_config:
optimized_ts = torch_blade.optimize(model, allow_tracing=True, model_inputs=model_inputs)

通过编译手段,生成优化后的pytorch script,注意:目前pytorch仅仅支持inference,尚不支持train。对于HuggingFace模型的pipeline有更深层兴趣的,参考HuggingFace quick tour

image-20250427225534818

可以看到对比pytorch,有1.7倍左右的加速。

Pytorch WorkFlow

image-20250311211223452

image-20250311220823276

参考Torch-Blade教程即可。

参考资料

  1. BladeDisc github