
上一章解读了Triton项目的一些基础知识和基本用法,使得读者对于Triton编译器有基本的了解。这一章节主要介绍Triton相关的开源算子库开发项目:FlagGems和Liger Kernel。同时为了深入理解Triton算子开发流程以及如何嵌入pytorch/Hugging face开发流程,本博客也会提供一个最简Triton library实现:dly-gem。
FlagGems项目
参考资料
Liger kernel项目
如何使用?
安装
1
2
3
4
5
6
7
8
9git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel
Install Default Dependencies
Setup.py will detect whether you are using AMD or NVIDIA
pip install -e .
Setup Development Dependencies
pip install -e ".[dev]"最简用法
方法一:使用
AutoLigerKernelForCausalLM,liger会辅助下载Huggingface模型自动将liger的算子替换到原有模型。可能会出现连接不上Huggingface的问题,具体参考HuggingFace官网使用教程。一般通过设定
HF_ENDPOINT=https://hf-mirror.com镜像足以解决问题。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36from liger_kernel.transformers import AutoLigerKernelForCausalLM
from transformers import AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 🎯 统一定义设备
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# ✅ 加载并迁移模型
model = AutoLigerKernelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map={"": device} # ⚠️ 关键:强制模型主设备
).eval()
# ✅ 加载Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# ✅ 迁移输入数据(覆盖所有张量)
input_text = "用通俗的中文解释一下量子纠缠"
inputs = tokenizer(input_text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()} # 🚀 字典内全量迁移
# ✅ 生成时传递设备参数
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
)
print("\n🚀 模型生成输出:\n")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))方法二:使用liger的api自定义模型。
1
2
3
4
5
6
7
8
9
10
11
12
13
14from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torch
model = nn.Linear(128, 256).cuda()
# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()
input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")
loss = loss_fn(model.weight, input, target)
loss.backward()