

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

# Checkpointing menggunakan SMP
<a name="model-parallel-core-features-v2-checkpoints"></a>

Pustaka SageMaker model paralelisme (SMP) mendukung PyTorch API untuk pos pemeriksaan, dan menyediakan API yang membantu pos pemeriksaan dengan benar saat menggunakan pustaka SMP. 

PyTorch FSDP (Fully Sharded Data Parallelism) mendukung tiga jenis pos pemeriksaan: penuh, sharded, dan lokal, masing-masing melayani tujuan yang berbeda. Pos pemeriksaan penuh digunakan saat mengekspor model setelah pelatihan selesai, karena menghasilkan pos pemeriksaan penuh adalah proses yang mahal secara komputasi. Pos pemeriksaan sharded membantu menyimpan dan memuat status model yang dipecah untuk setiap peringkat individu. Dengan pos pemeriksaan sharded, Anda dapat melanjutkan pelatihan dengan konfigurasi perangkat keras yang berbeda, seperti jumlah GPU yang berbeda. Namun, memuat pos pemeriksaan sharded bisa lambat karena komunikasi yang terlibat di antara beberapa perangkat. Pustaka SMP menyediakan fungsionalitas pos pemeriksaan lokal, yang memungkinkan pengambilan status model lebih cepat tanpa overhead komunikasi tambahan. Perhatikan bahwa pos pemeriksaan yang dibuat oleh FSDP memerlukan penulisan ke sistem file jaringan bersama seperti Amazon FSx.

## Pos pemeriksaan lokal async
<a name="w2aac25c25c19c19c33b7"></a>

Saat melatih model pembelajaran mesin, tidak perlu iterasi berikutnya untuk menunggu file pos pemeriksaan disimpan ke disk. Dengan dirilisnya SMP v2.5, perpustakaan mendukung penyimpanan file pos pemeriksaan secara asinkron. Ini berarti bahwa iterasi pelatihan berikutnya dapat berjalan bersamaan dengan operasi input dan output (I/O) untuk membuat pos pemeriksaan, tanpa diperlambat atau ditahan oleh operasi tersebut. I/O Selain itu, proses pengambilan model sharded dan paramemeter pengoptimal PyTorch dapat memakan waktu karena komunikasi kolektif tambahan yang diperlukan untuk menukar metadata tensor terdistribusi di seluruh peringkat. Bahkan ketika menggunakan `StateDictType.LOCAL_STATE_DICT` untuk menyimpan pos pemeriksaan lokal untuk setiap peringkat, PyTorch masih memanggil kait yang melakukan komunikasi kolektif. Untuk mengurangi masalah ini dan mengurangi waktu yang diperlukan untuk pengambilan pos pemeriksaan, SMP memperkenalkan`SMStateDictType.SM_LOCAL_STATE_DICT`, yang memungkinkan pengambilan lebih cepat dari pos pemeriksaan model dan pengoptimal dengan melewati overhead komunikasi kolektif. 

**catatan**  
Menjaga konsistensi dalam FSDP `SHARD_DEGREE` adalah persyaratan untuk memanfaatkan. `SMStateDictType.SM_LOCAL_STATE_DICT` Pastikan bahwa `SHARD_DEGREE` sisa-sisa tidak berubah. Sementara jumlah replikasi model dapat bervariasi, tingkat pecahan model harus identik dengan pengaturan pelatihan sebelumnya saat melanjutkan dari pos pemeriksaan.

```
import os
import torch.distributed as dist
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
)

global_rank = dist.get_rank()
save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Wait for the previous checkpointing done
maybe_finalize_async_calls(
    blocking=True, process_group=current_replication_group
)

# 3. Get model local checkpoint
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
       "model": model.state_dict(),
       "optimizer": optimizer.state_dict(),
        # Potentially add more customized state dicts.
    }

# 4. Save a local checkpoint 
async_save(
    state_dict,
    checkpoint_id=os.path.join(save_dir, sub_dir),
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
)
```

Cuplikan kode berikut menunjukkan cara memuat pos pemeriksaan menggunakan. `SMStateDictType.SM_LOCAL_STATE_DICT`

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
    init_optim_state
)
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)

load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"
global_rank = dist.get_rank()
checkpoint_id = os.path.join(load_dir, sub_dir)
storage_reader = DistributedFileSystemReader(checkpoint_id)

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Create local state_dict
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        # Potentially add more customized state dicts.
    }
 
    # Init optimizer state_dict states by setting zero grads and step.
    init_optim_state(optimizer, skip_empty_param=True)
    state_dict["optimizer"] = optimizer.state_dict()
 
# 3. Load a checkpoint
load(
    state_dict=state_dict,
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
    storage_reader=storage_reader,
)
```

Menyimpan pos pemeriksaan untuk model bahasa besar (LLM) bisa mahal karena sering membutuhkan pembuatan volume sistem file yang besar. Untuk mengurangi biaya, Anda memiliki opsi untuk menyimpan pos pemeriksaan langsung ke Amazon S3 tanpa perlu layanan sistem file tambahan seperti Amazon FSx. Anda dapat memanfaatkan contoh sebelumnya dengan cuplikan kode berikut untuk menyimpan pos pemeriksaan ke S3 dengan menentukan URL S3 sebagai tujuan. 

```
key = os.path.join(checkpoint_dir, sub_dir)
checkpoint_id= f"{{s3://{your_s3_bucket}/{key}}}"
async_save(state_dict, checkpoint_id=checkpoint_id, **kw)
load(state_dict, checkpoint_id=checkpoint_id, **kw)
```

## Pos pemeriksaan sharded async
<a name="w2aac25c25c19c19c33b9"></a>

Mungkin ada situasi di mana Anda perlu melanjutkan pelatihan dengan konfigurasi perangkat keras yang berbeda, seperti mengubah jumlah GPU. Dalam kasus ini, proses pelatihan Anda harus memuat pos pemeriksaan saat resharding, yang berarti melanjutkan pelatihan berikutnya dengan jumlah yang berbeda. `SHARD_DEGREE` Untuk mengatasi skenario di mana Anda perlu melanjutkan pelatihan dengan jumlah yang berbeda`SHARD_DEGREE`, Anda harus menyimpan pos pemeriksaan model Anda menggunakan jenis kamus status sharded, yang diwakili oleh. `StateDictType.SHARDED_STATE_DICT` Menyimpan pos pemeriksaan dalam format ini memungkinkan Anda menangani proses resharding dengan benar saat melanjutkan pelatihan dengan konfigurasi perangkat keras yang dimodifikasi. Cuplikan kode yang disediakan menggambarkan cara menggunakan `tsm` API untuk menyimpan pos pemeriksaan sharded secara asinkron, memungkinkan proses pelatihan yang lebih efisien dan efisien.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)

save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(save_dir, sub_dir)

# To determine whether curreto take part in checkpointing.
global_rank = dist.get_rank()
action_rank = state.ranker.get_rep_rank(global_rank) == 0
process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

# 1. wait for the previous checkpointing done
maybe_finalize_async_calls(blocking=True, process_group=process_group)

# 2. retrieve model & optimizer sharded state_dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        "optimizer": FSDP.optim_state_dict(model, optimizer),
        # Potentially add more customized state dicts.
    }
 
# 3. save checkpoints asynchronously using async_save
if action_rank:
    async_save(
        state_dict,
        checkpoint_id=checkpoint_id,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
    )
```

Proses memuat pos pemeriksaan bersama mirip dengan bagian sebelumnya, tetapi melibatkan penggunaan `torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader` dan `load` metodenya. `load`Metode kelas ini memungkinkan Anda untuk memuat data pos pemeriksaan bersama, mengikuti proses analog dengan yang dijelaskan sebelumnya.

```
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)
 
 load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(load_dir, sub_dir)
reader = DistributedFileSystemReader(checkpoint_id)

process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
   # 1. Load model and everything else except the optimizer.
   state_dict = {
        "model": model.state_dict()
        # Potentially more customized state dicts.
   }
   load(
        state_dict,
        storage_reader=reader,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
   )
   model.load_state_dict(state_dict["model"])
 
   # 2. Load optimizer.
   optim_state = load_sharded_optimizer_state_dict(
        model_state_dict=state_dict["model"],
        optimizer_key="optimizer",
        storage_reader=reader,
        process_group=process_group,
    )    
   flattened_optimizer_state = FSDP.optim_state_dict_to_load(
        optim_state["optimizer"], model, optimizer,
         group=model.process_group
   )
   optimizer.load_state_dict(flattened_optimizer_state)
```

## Pos pemeriksaan model lengkap
<a name="model-parallel-core-features-v2-checkpoints-full"></a>

Di akhir pelatihan, Anda dapat menyimpan pos pemeriksaan lengkap yang menggabungkan semua pecahan model ke dalam satu file pos pemeriksaan model. Pustaka SMP sepenuhnya mendukung API pos pemeriksaan model PyTorch lengkap, jadi Anda tidak perlu melakukan perubahan apa pun.

Perhatikan bahwa jika Anda menggunakan SMP[Paralelisme tensor](model-parallel-core-features-v2-tensor-parallelism.md), perpustakaan SMP mengubah model. Saat memeriksa model lengkap dalam kasus ini, pustaka SMP menerjemahkan model kembali ke format pos pemeriksaan Hugging Face Transformers secara default.

Jika Anda berlatih dengan paralelisme tensor SMP dan mematikan proses penerjemahan SMP, Anda dapat menggunakan `translate_on_save` argumen PyTorch `FullStateDictConfig` API untuk mengaktifkan atau menonaktifkan terjemahan otomatis SMP sesuai kebutuhan. Misalnya, jika Anda berfokus pada pelatihan model, Anda tidak perlu menambahkan proses terjemahan yang menambahkan overhead. Dalam hal ini, kami sarankan Anda untuk mengatur`translate_on_save=False`. Juga, jika Anda berencana untuk tetap menggunakan terjemahan SMP model untuk pelatihan lebih lanjut di masa depan, Anda dapat mematikannya untuk menyimpan terjemahan SMP model untuk digunakan nanti. Menerjemahkan model kembali ke format pos pemeriksaan model Hugging Face Transformers diperlukan saat Anda menyelesaikan pelatihan model Anda dan menggunakannya untuk inferensi.

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig
import torch.sagemaker as tsm

# Save checkpoints.
with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT, 
    FullStateDictConfig(
        rank0_only=True, offload_to_cpu=True,
        # Default value is to translate back to Hugging Face Transformers format,
        # when saving full checkpoints for models trained with SMP tensor parallelism.
        # translate_on_save=True
    ),
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        logger.info("Processed state dict to save. Starting write to disk now.")
        os.makedirs({{save_dir}}, exist_ok=True)
        # This name is needed for HF from_pretrained API to work.
        torch.save(state_dict, os.path.join({{save_dir}}, "pytorch_model.bin"))
        hf_model_config.save_pretrained({{save_dir}})
    dist.barrier()
```

Perhatikan bahwa pilihannya `FullStateDictConfig(rank0_only=True, offload_to_cpu=True)` adalah mengumpulkan model pada CPU perangkat peringkat 0 untuk menghemat memori saat melatih model besar.

Untuk memuat kembali model untuk inferensi, Anda melakukannya seperti yang ditunjukkan pada contoh kode berikut. Perhatikan bahwa kelas `AutoModelForCausalLM` mungkin berubah ke kelas pembuat faktor lain di Hugging Face Transformers, `AutoModelForSeq2SeqLM` seperti, tergantung pada model Anda. Untuk informasi selengkapnya, lihat dokumentasi [Hugging Face Transformers](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/auto#natural-language-processing).

```
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained({{save_dir}})
```