更新时间:2024-09-14 GMT+08:00
分享

NPU_Flash_Attn融合算子约束

  1. query、key、value都需要梯度。默认开启重计算,则前向时qkv没有梯度,如果需要关闭重计算,可以在yaml配置 `disable_gradient_checkpointing: true` 关闭,但显存占用会直线上升。
  2. attn_mask 只支持布尔(bool)数据类型,或者为None。
  3. query的shape仅支持 [B, N1, S1, D],其中N1≤ 2048,D≤ 512并且dim== 4。
  4. 对于GQA,key的shape是 [B, N2, S2, D],其中 N2 ≤ 2048,并且N1是N2的正整数倍。

不满足以上场景,则不能实现NPU_Flash_Attn功能。

相关文档