Triton & Triton Ascend
核心概念
Grid
xxx
基础语法
编译控制
Memory & Control
- tl.cdiv(K, BK) 向上取整除法 (Ceiling Division)。相当于数学上的 $\lceil K / BK \rceil$,或者 Python 里的 (K + BK - 1) // BK。
- tl.make_block_ptr(…) 数据的基地址(base),整个大张量的形状(shape),大张量在内存里是怎么跨步的(strides),你要切的块有多大(block_shape),以及当前切块的起始坐标(offsets)。
- tl.load(ptr, boundary_check=…) 从全局显存(VRAM)加载数据到 GPU 的极速寄存器(SRAM)中。把上面定义好的指针 p_k 对应的数据块加载进来,变成一个二维张量 b_k。boundary_check=(0, 1) 是它的神仙功能:如果你的数据大小不是块大小的整数倍(比如最后一块超出了边界),它会自动帮你把越界的部分填充为 0,再也不用手动写 mask 了!
Math & Matrix
- tl.trans(b_k) 矩阵转置 (Transpose)。
- tl.dot(b_k, tl.trans(b_k)) 矩阵乘法
Indexing & Masking
- tl.arange(start, end) 的语法作用与 NumPy (np.arange) 或 PyTorch (torch.arange) 非常相似:它用于生成一个包含连续整数的 一维 Tensor(张量)。
- tl.where(condition, x, y) 如果 condition 为真,就取 x;如果为假,就取 y。
- [:, None] 和 [None, :]
- row_indices = tl.arange(0, BT)[:, None] 会把 [0,1,2] 变成一个列向量(3行1列)。
- col_indices = tl.arange(0, BT)[None, :] 会把它变成一个行向量(1行3列)。
常见错误
@triton.jit
这个 ValueError('Did you forget to add @triton.jit ? ...') 是 Triton 编译器中一个非常经典的“误导性”报错。它的出现通常并不是因为你真的忘记了加 @triton.jit,而是因为在 Triton 的内核代码中,你对一个局部 Tensor(也就是驻留在寄存器/SRAM 中的变量 masked_dot)使用了不支持的动态索引/高级索引。
在 Triton 中,masked_dot 是一个局部 Tensor(块/矩阵)。而 row_idx 是通过 tl.arange 计算出来的一个 Triton Tensor。
Triton 不支持使用一个 Tensor 作为索引去切片另一个局部 Tensor(类似于 PyTorch 中的高级索引 tensor[tensor_idx] 在 Triton 的局部变量上是不合法的)。
当 Triton 的 AST 解析器遇到 masked_dot[row_idx] 时,它无法将其编译为合法的寄存器操作,导致底层构建器(builder)在分发时崩溃,从而抛出这个令人迷惑的 _builder argument must be provided outside of JIT 错误。
Slice不支持
Triton 根本不支持对 Local Tensor(驻留在寄存器/SRAM中的局部矩阵)进行切片来提取子块 (Sub-block)。
哪怕你使用了 tl.static_range,哪怕 start_row 和 end_row 都是百分之百的静态常量,masked_dot[start:end, :] 这种语法在 Triton 中也完全不被支持!报错中的 slice(<tensor…>) 是 Triton 内部 AST 解析器拒绝该操作的表现。
为什么 Triton 不能切片局部变量?
在 PyTorch 中,切片是一个极其普通的操作。但在 Triton 中,像 masked_dot 这样的局部变量不是存储在连续显存里的,而是被分散存储在流式多处理器(SM)的无数个寄存器(Registers)中。
Triton 的编译器没有实现“从这些零散的寄存器里提取一部分并重新组装成一个新的小形状 Tensor”的逻辑。在 Triton 里,切片语法只能用来增加维度(比如 b_g[:, None]),绝不能用来切割局部矩阵!
性能优化
- 看飞书 浦江 GDN kkt的 triton chunk size 128 实现优化。
- autotune : https://github.com/hongziqi/fa-test/blob/main/npu/test_fa_fwd_npu_prof_tune.py
jx文档
- 充分利用核心数:使用NPU的24个cube核和48个vector核心;
- 充分利用缓存:分块来避免UB overflow。使用autotune优化。
- 连续访存:make_block_ptr 使用时 stride 需为1
- 避免退化情况:
实践
flash_attention_npu
mojo_opset
CV版实现
Ascend的硬件复杂,除了vector还有cube,如何利用好两者,写好CV版代码是进阶的要点。