

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

# Support pour FlashAttention
<a name="model-parallel-attention-head-size-for-flash-attention"></a>

Support de FlashAttention est une fonctionnalité de la bibliothèque applicable uniquement au modèle de *transformateur distribué*, qui est un modèle de transformateur intégré [https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/latest/smd_model_parallel_pytorch.html#smdistributed-modelparallel-torch-distributedmodel](https://sagemaker.readthedocs.io/en/v2.199.0/api/training/smp_versions/latest/smd_model_parallel_pytorch.html#smdistributed-modelparallel-torch-distributedmodel)pour l'apprentissage parallèle entre modèles. Cette fonctionnalité est également compatible avec [Parallélisme de tenseur](model-parallel-extended-features-pytorch-tensor-parallelism.md). 

La [FlashAttention](https://github.com/HazyResearch/flash-attention)bibliothèque ne prend en charge les modèles que lorsqu'elle `attention_head_size` est définie sur une valeur multiple de 8 et inférieure à 128. Par conséquent, lorsque vous entraînez un transformateur distribué et que vous vous assurez qu'il FlashAttention fonctionne correctement, vous devez ajuster les paramètres pour que la taille de la tête d'attention soit conforme aux exigences. Pour plus d'informations, voir également [Installation et fonctionnalités](https://github.com/HazyResearch/flash-attention#installation-and-features) du *FlashAttention GitHubréférentiel*.

Supposons, par exemple, que vous configurez un modèle Transformer avec `hidden_width=864` et `num_heads=48`. La taille de la tête de FlashAttention est calculée comme suit`attention_head_size = hidden_width / num_heads = 864 / 48 = 18`. Pour l'activer FlashAttention, vous devez ajuster le `num_heads` paramètre à`54`, de sorte que`attention_head_size = hidden_width / num_heads = 864 / 54 = 16`, soit un multiple de 8.