

本文為英文版的機器翻譯版本，如內容有任何歧義或不一致之處，概以英文版為準。

# FlashAttention


SMP v2 支援 [FlashAttention](https://github.com/HazyResearch/flash-attention) 核心，並可輕鬆將其套用至 Hugging Face Transformer 模型的各種案例。請注意，如果您使用 FlashAttention 套件 v2.0 或更新版本，SMP 會使用 FlashAttention v2；不過，Triton Flash Attention 預設為 FlashAttention v1.x 中的 Flash Attention 核心，使其僅在 FlashAttention v1 中受支援。

模組 (`nn.Module`) 是低階 API，可定義模型的注意力層。它應該在建立模型後立即套用，例如從 `AutoModelForCausalLM.from_config()` API 套用，以及在轉換模型或使用 FSDP 包裝模型之前套用。

## 使用 FlashAttention 核心進行自我關注


下列程式碼片段示範如何使用 SMP v2 提供的 [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) API。

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## 使用 FlashAttention 核心進行分組查詢注意力


SMP v2 也支援 [FlashAttention](https://github.com/HazyResearch/flash-attention) 核心進行分組查詢注意力 (GQA)，並可輕鬆將其套用至 Hugging Face Transformer 模型的各種案例。與原始注意力架構不同，GQA 會將查詢標頭平均分割為群組，而相同群組中的查詢標頭會共用相同的索引鍵和值標頭。因此，q 和 kv 前端會分別傳入轉接呼叫。注意：q 前端數目需要除以 kv 前端數目。

**使用 FlashGroupedQueryAttention 的範例**

下列程式碼片段示範如何使用 SMP v2 提供的 [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API。

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

SMP 程式庫也提供 [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn)，它使用低階 [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API。Hugging Face Transformer 從 v4.36.0 開始有一個類似的實作，稱為 [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)。下列程式碼片段示範如何使用 SMP v2 `LlamaFlashAttention` API 或轉換器 `LlamaFlashAttention2` API 取代現有 Llama 模型的注意力層。

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```