

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

# Pelatihan presisi campuran
<a name="model-parallel-core-features-v2-mixed-precision"></a>

Pustaka SageMaker model paralelisme (SMP) v2 mendukung pelatihan presisi campuran di luar kotak dengan mengintegrasikan dengan kerangka kerja sumber terbuka seperti PyTorch FSDP dan Transformer Engine. Untuk mempelajari lebih lanjut, lihat topik berikut.

**Topics**
+ [Pelatihan presisi campuran dengan FP8 instans P5 menggunakan Transformer Engine](#model-parallel-core-features-v2-mixed-precision-fp8-training-on-p5)
+ [Pelatihan presisi campuran dengan tipe data setengah presisi menggunakan PyTorch FSDP](#model-parallel-core-features-v2-mixed-precision-half-precision)

## Pelatihan presisi campuran dengan FP8 instans P5 menggunakan Transformer Engine
<a name="model-parallel-core-features-v2-mixed-precision-fp8-training-on-p5"></a>

[Mulai dari perpustakaan SageMaker model paralelisme (SMP) v2.2.0, perpustakaan SMP terintegrasi dengan [Transformer Engine dan mendukung [pelatihan presisi FP8 campuran](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) di luar kotak, menjaga kompatibilitas dengan FSDP](https://docs.nvidia.com/deeplearning/transformer-engine/index.html). PyTorch `MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision) Ini berarti Anda dapat menggunakan PyTorch FSDP untuk pelatihan presisi campuran dan Mesin Transformer untuk FP8 pelatihan. Untuk lapisan model yang tidak didukung oleh fitur FP8 pelatihan Transformer Engine, lapisan tersebut kembali ke presisi campuran PyTorch FSDP.

**catatan**  
SMP v2 menawarkan FP8 dukungan untuk model Hugging Face Transformer berikut:  
GPT-Neox (tersedia di SMP v2.2.0 dan yang lebih baru)
Llama 2 (tersedia di SMP v2.2.0 dan yang lebih baru)
Mixtral 8x7b dan Mixtral 8x22b (tersedia dalam SMP v2.5.0 dan yang lebih baru)

**catatan**  
 FP8 Pelatihan tentang fitur P5 ini tersedia dalam kombinasi perpustakaan SageMaker dan perpustakaan berikut: PyTorch   
 SageMaker Python SDK v2.212.0 dan yang lebih baru
PyTorch v2.2.0 dan yang lebih baru

*FP8*(Presisi floating point 8-bit) adalah tipe data yang telah muncul sebagai paradigma lain untuk mempercepat pelatihan pembelajaran mendalam model LLM. Dengan dirilisnya tipe FP8 data GPUs pendukung NVIDIA H100, Anda bisa mendapatkan keuntungan dari keuntungan dari peningkatan kinerja pada instans P5 yang dilengkapi dengan H100 GPUs, sekaligus mempercepat pelatihan terdistribusi dengan pelatihan presisi campuran. FP8 

Tipe FP8 data selanjutnya bercabang ke format E4M3 dan E5M2. *E4M3* menawarkan presisi yang lebih baik, memiliki rentang dinamis terbatas, dan sangat ideal untuk forward pass dalam pelatihan model. *E5M2* memiliki rentang dinamis yang lebih luas, tetapi presisi berkurang, dan lebih cocok untuk lintasan mundur, di mana presisi kurang kritis dan rentang dinamis yang lebih luas menjadi bermanfaat. Oleh karena itu, kami menyarankan Anda menggunakan [resep FP8 strategi hibrida](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-recipe) untuk memanfaatkan karakteristik ini secara efektif.

Untuk tipe data setengah presisi (FP16 dan BF16), teknik penskalaan kerugian global seperti penskalaan kerugian statis atau penskalaan kerugian dinamis menangani masalah konvergensi yang timbul dari kehilangan informasi karena gradien pembulatan dalam setengah presisi. Namun, rentang dinamis bahkan FP8 lebih sempit, dan teknik penskalaan kerugian global tidak cukup. Pada titik ini, kita membutuhkan teknik penskalaan per-tensor berbutir halus. *Penskalaan tertunda* adalah strategi yang memilih faktor penskalaan berdasarkan nilai absolut maksimum yang diamati dalam sejumlah tensor dari iterasi sebelumnya. Ada trade-off dalam strategi ini; ia menggunakan manfaat kinerja penuh dari FP8 komputasi tetapi membutuhkan memori untuk menjaga riwayat nilai maksimum tensor. Untuk mempelajari lebih lanjut tentang strategi penskalaan tertunda secara umum, lihat paper [https://arxiv.org/pdf/2209.05433.pdf](https://arxiv.org/pdf/2209.05433.pdf).

Dalam praktiknya, penggunaan FP8 sangat membantu dalam semua skenario pelatihan pada instance P5. Kami sangat menyarankan untuk mengaktifkan FP8 bila memungkinkan untuk meningkatkan kinerja pelatihan.

SMP v2 mendukung Transformer Engine di luar kotak. Oleh karena itu, saat menjalankan FP8 pelatihan dengan SMP v2 pada instance P5 SageMaker AI (`ml.p5.48xlarge`), satu-satunya hal yang perlu Anda lakukan adalah mengimpor `torch.sagemaker` skrip pelatihan Anda dan tetap menggunakan paket Transformer Engine Python asli. Untuk mempelajari lebih lanjut tentang menggunakan Mesin Transformer untuk FP8 pelatihan secara umum, lihat [Menggunakan FP8 dengan Mesin Transformer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) dalam *dokumentasi NVIDIA Transformer Engine*. Cuplikan kode berikut menunjukkan bagaimana baris kode untuk mengimpor pustaka SMP dan pengaturan FP8 dalam skrip pelatihan Anda akan terlihat.

```
import torch.sagemaker as tsm
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# Initialize the SMP torch.sagemaker API.
tsm.init()

# Define a transformer model and wrap it with the torch.sagemaker.transform API.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_config(ModelConfig)
model = tsm.transform(model)

# Enable E4M3 during forward pass, E5M2 during backward pass.
fp8_format = Format.HYBRID

# Create an FP8 recipe.
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

# Enable FP8 autocasting.
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group):
    out = model(inp)

loss = out.sum()
loss.backward()
```

Untuk menemukan contoh praktis FP8 pelatihan dengan SMP v2 pada instance P5, lihat contoh notebook di [Accelerate SageMaker PyTorch FSDP Training of LLAMA-v2 (](https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel_v2/llama_v2/smp-train-llama-fsdp-tp-fp8.ipynb)atau GPT-Neox) dengan instance P5. FP8 

## Pelatihan presisi campuran dengan tipe data setengah presisi menggunakan PyTorch FSDP
<a name="model-parallel-core-features-v2-mixed-precision-half-precision"></a>

SMP v2 mendukung [PyTorch FSDP `MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision) untuk pekerjaan pelatihan pada instans P4 dan P5. PyTorch FSDP menyediakan berbagai konfigurasi untuk presisi campuran untuk peningkatan kinerja dan pengurangan memori. 

**catatan**  
Pelatihan presisi campuran dengan fitur PyTorch FSDP ini tersedia dalam kombinasi perpustakaan SageMaker dan perpustakaan berikut. PyTorch   
SMP v2.0.0 dan yang lebih baru
 SageMaker Python SDK v2.200.0 dan yang lebih baru
PyTorch v2.0.1 dan yang lebih baru

Cara standar untuk mengonfigurasi model untuk presisi campuran adalah dengan membuat model`float32`, dan kemudian mengizinkan FSDP untuk mentransmisikan parameter ke `float16` atau dengan cepat dengan meneruskan `MixedPrecision` kebijakan, seperti yang ditunjukkan `bfloat16` pada cuplikan kode berikut. *Untuk informasi selengkapnya tentang opsi untuk mengubah parameter, reduksi, atau buffer untuk presisi campuran PyTorch, lihat [PyTorch FSDP `MixedPrecision` API dalam dokumentasi](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision). `dtype` PyTorch*

```
# Native PyTorch API
from torch.distributed.fsdp import MixedPrecision

dtype = torch.bfloat16
mixed_precision_policy = MixedPrecision(
    param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype
)

model = FSDP(
    model,
    ...,
    mixed_precision=mixed_precision_policy
)
```

Perhatikan bahwa model tertentu (seperti model Hugging Face Transformers Llama) mengharapkan buffer sebagai. `float32` Untuk menggunakan`float32`, ganti `torch.bfloat16` dengan `torch.float32` di baris yang mendefinisikan `dtype` objek.