

本文為英文版的機器翻譯版本，如內容有任何歧義或不一致之處，概以英文版為準。

# 啟用檢查點
<a name="model-parallel-core-features-v2-pytorch-activation-checkpointing"></a>

*啟用檢查點*是減少記憶體使用量的技術，方法是清除某些圖層的啟用，並在向後傳遞期間重新加以運算。實際上，這是以額外運算時間換取減少記憶體使用量。如果對模組進行了檢查點作業，則在向前傳遞結束時，只有模組的初始輸入和模組的最終輸出會保留在記憶體中。PyTorch 在向前傳遞期間，會釋放屬於該模組內部運算一部分的任何中級張量。在檢查點模組的向後傳遞期間，PyTorch 會重新運算這些張量。此時，超出此檢查點模組的圖層已完成其向後傳遞，因此可降低運用檢查點的最高記憶體使用量。

SMP v2 支援 PyTorch 啟用檢查點模組 [https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing](https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing)。以下是 Hugging Face GPT-NeoX 模型啟用檢查點的範例。

**Hugging Face GPT-NeoX 模型的檢查點轉換器層**

```
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing
)
    
# check_fn receives a module as the arg, 
# and it needs to return whether the module is to be checkpointed
def is_transformer_layer(module):
    from transformers.models.gpt_neox import GPTNeoXLayer
    return isinstance(submodule, GPTNeoXLayer)
    
apply_activation_checkpointing(model, check_fn=is_transformer_layer)
```

**對 Hugging Face GPT-NeoX 模型的其他每個轉換器層進行檢查點**

```
# check_fn receives a module as arg, 
# and it needs to return whether the module is to be checkpointed
# here we define that function based on global variable (transformer_layers)
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing
)

transformer_layers = [
    m for m model.modules() if isinstance(m, GPTNeoXLayer)
]

def is_odd_transformer_layer(module):
    return transformer_layers.index(module) % 2 == 0
    
apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
```

或者，PyTorch 也有用於檢查點的 `torch.utils.checkpoint` 模組，由 Hugging Face Transformer 模型的子集使用。此模組也適用於 SMP v2。不過，這需要您存取模型定義以新增檢查點包裝函式。因此，建議您使用 `apply_activation_checkpointing` 方法。