

本文為英文版的機器翻譯版本，如內容有任何歧義或不一致之處，概以英文版為準。

# 啟用檢查點
<a name="model-parallel-extended-features-pytorch-activation-checkpointing"></a>

*啟用檢查點* (或*漸層檢查點*) 是減少記憶體使用量的技術，方法是清除某些圖層的啟用，並在向後傳遞期間重新加以運算。實際上，這是以額外運算時間換取減少記憶體使用量。如果模組進行檢查點作業，則在向前傳遞結束時，模組的輸入與輸出都將保留在記憶體。在向前傳遞期間，原本會成為該模組內部運算一部分的任何中級張量都將被釋放。在檢查點模組的向後傳遞期間，會重新運算這些張量。此時，超出此檢查點模組的圖層已完成其向後傳遞，因此可降低運用檢查點的最高記憶體使用量。

**注意**  
此功能適用 SageMaker 模型平行化程式庫 v1.6.0 及更高版本的 PyTorch。

## 如何使用啟用檢查點
<a name="model-parallel-extended-for-pytorch-activation-checkpointing-how-to-use"></a>

當使用 `smdistributed.modelparallel` 時，您可以在模組的精細程度使用啟用檢查點。對於除 `torch.nn.Sequential` 外的所有 `torch.nn` 模組，僅當從管道平行處理的角度而言，模組樹狀目錄位於單一分割內時，您才能對其進行檢查點作業。對於 `torch.nn.Sequential` 模組，循序模組內部的每個模組樹狀目錄必須完全位於單一分割內，以便啟用檢查點作業。當您使用手動分割時，請注意這些限制。

當您使用[自動化模型分割](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features.html#model-parallel-automated-model-splitting)時，您可以在訓練任務日誌找到開頭為 `Partition assignments:` 的分割指派日誌。如果跨多個等級分割模組 (例如，其中一個子代位於某一等級，另一子代位於不同等級)，程式庫會忽略而不嘗試對模組進行檢查點作業，並提出警告訊息，指出不會檢查該模組。

**注意**  
SageMaker 模型平行化程式庫支援重疊與非重疊 `allreduce` 作業，並結合檢查點。

**注意**  
PyTorch 的原生檢查點 API 不相容 `smdistributed.modelparallel`。

**範例 1：**下列範例程式碼示範當指令碼具模型定義時，如何使用啟用檢查點。

```
import torch.nn as nn
import torch.nn.functional as F

from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        # This call of fc1 will be checkpointed
        x = checkpoint(self.fc1, x)
        x = self.fc2(x)
        return F.log_softmax(x, 1)
```

**範例 2：**下列範例程式碼示範當指令碼具循序模型時，如何使用啟用檢查點。

```
import torch.nn as nn
from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(1,20,5),
            nn.ReLU(),
            nn.Conv2d(20,64,5),
            nn.ReLU()
        )

    def forward(self, x):
        # This call of self.seq will be checkpointed
        x = checkpoint_sequential(self.seq, x)
        return F.log_softmax(x, 1)
```

**範例 3：**下列範例程式碼示範當從程式庫 (例如 PyTorch 與 Hugging Face 轉換器) 匯入預先建置的模型時，如何使用啟用檢查點。無論您是否針對循序模組進行檢查點作業，請執行以下操作：

1. 以 `smp.DistributedModel()` 包裝模型。

1. 定義循序圖層物件。

1. 以 `smp.set_activation_checkpointig()` 包裝循序圖層物件。

```
import smdistributed.modelparallel.torch as smp
from transformers import AutoModelForCausalLM

smp.init()
model = AutoModelForCausalLM(*args, **kwargs)
model = smp.DistributedModel(model)

# Call set_activation_checkpointing API
transformer_layers = model.module.module.module.transformer.seq_layers
smp.set_activation_checkpointing(
    transformer_layers, pack_args_as_tuple=True, strategy='each')
```