

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

O SMP v2 suporta [FlashAttention](https://github.com/HazyResearch/flash-attention)kernels e facilita sua aplicação em vários cenários para modelos Hugging Face Transformer. Observe que, se você usa o FlashAttention pacote v2.0 ou posterior, o SMP usa a FlashAttention v2; no entanto, o padrão da atenção flash do Triton é o kernel de atenção flash na FlashAttention v1.x, tornando-o suportado exclusivamente na v1. FlashAttention 

O módulo (`nn.Module`) é uma API de baixo nível que define as camadas de atenção de um modelo. Ele deve ser aplicado logo após a criação do modelo, por exemplo, a partir da API `AutoModelForCausalLM.from_config()`, e antes de o modelo ser transformado ou envolvido ao FSDP.

## Use FlashAttention grãos para autoatenção
<a name="model-parallel-core-features-v2-flashattention-self"></a>

O trecho de código a seguir apresenta como usar a API [`torch.sagemaker.nn.attn. FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) fornecida pelo SMP v2.

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## Use FlashAttention kernels para atenção de consultas agrupadas
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

O SMP v2 também suporta [FlashAttention](https://github.com/HazyResearch/flash-attention)kernels para atenção de consultas agrupadas (GQA) e facilita sua aplicação em vários cenários para modelos Hugging Face Transformer. Diferentemente da arquitetura de atenção original, a GQA divide de forma igualitária os cabeçalhos de consulta em grupos, e os cabeçalhos de consulta no mesmo grupo compartilham os mesmos cabeçalhos de chave e valor. Portanto, os cabeçalhos q e kv são passados para a chamada direta separadamente. Nota: o número de cabeçalhos q precisa ser divisível pelo número de cabeçalhos kv.

**Exemplo de uso FlashGroupedQueryAttention**

O trecho de código a seguir apresenta como usar a API [`torch.sagemaker.nn.attn. FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) fornecida pelo SMP v2.

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

A biblioteca de SMP também fornece [`torch.sagemaker.nn.huggingface.llama_flashattn. LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn), que usa a API [`torch.sagemaker.nn.attn. FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) em baixo nível. O Hugging Face Transformers tem uma implementação semelhante chamada [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) a partir da v4.36.0. O trecho de código a seguir mostra como usar a API SMP v2 ou a API `LlamaFlashAttention`Transformers `LlamaFlashAttention2` para substituir as camadas de atenção de um modelo Llama existente.

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```