

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# 激活检查点
<a name="model-parallel-extended-features-pytorch-activation-checkpointing"></a>

*激活检查点*（或*梯度检查点*）技术通过清除某些层的激活并在向后传递期间重新计算它们，来减少内存使用量。实际上，这是用额外的计算时间来换取内存使用量的减少。如果对模块执行了检查点操作，则在向前传递结束时，该模块的输入和输出将保留在内存中。在向前传递期间，任何本应是模块内部计算一部分的中间张量都会被释放。在有检查点的模块的向后传递过程中，会重新计算这些张量。此时，有检查点的模块之外的层已经完成其向后传递，因此检查点操作的峰值内存使用量可能会更低。

**注意**  
此功能可在 SageMaker 模型并行度库 v1.6.0 及更高版本 PyTorch 中使用。

## 如何使用激活检查点
<a name="model-parallel-extended-for-pytorch-activation-checkpointing-how-to-use"></a>

使用 `smdistributed.modelparallel`，您可以按模块使用激活检查点。对于除 `torch.nn.Sequential` 之外的所有 `torch.nn` 模块，只有当从管道并行性的角度来看，模块树位于一个分区内时，才能对模块树执行检查点操作。对于 `torch.nn.Sequential` 模块，顺序模块内的每个模块树必须完全位于一个分区内，激活检查点才能起作用。使用手动分区时，请注意这些限制。

使用[自动模型分区](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features.html#model-parallel-automated-model-splitting)时，您可在训练作业日志中找到以 `Partition assignments:` 开头的分区分配日志。如果一个模块在多个秩（例如，一个后代属于一个秩，另一个后代处于不同的秩）上分区，则库会忽略对模块执行检查点的尝试，并发出一条警告消息，说明该模块没有检查点。

**注意**  
 SageMaker 模型并行度库支持重叠和非重叠操作以及检查点`allreduce`操作。

**注意**  
PyTorch的本机检查点 API 与不兼容。`smdistributed.modelparallel`

**示例 1：**以下示例代码演示了当脚本中有模型定义时，如何使用激活检查点操作。

```
import torch.nn as nn
import torch.nn.functional as F

from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        # This call of fc1 will be checkpointed
        x = checkpoint(self.fc1, x)
        x = self.fc2(x)
        return F.log_softmax(x, 1)
```

**示例 2：**以下示例代码演示了当脚本中有顺序模型时，如何使用激活检查点操作。

```
import torch.nn as nn
from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(1,20,5),
            nn.ReLU(),
            nn.Conv2d(20,64,5),
            nn.ReLU()
        )

    def forward(self, x):
        # This call of self.seq will be checkpointed
        x = checkpoint_sequential(self.seq, x)
        return F.log_softmax(x, 1)
```

**示例 3：**以下示例代码显示了在从库（例如和 Hugging Face Transformers PyTorch ）导入预建模型时如何使用激活检查点。无论您是否对顺序模型执行检查点操作，请完成以下过程：

1. 使用 `smp.DistributedModel()` 包装模型。

1. 为顺序层定义一个对象。

1. 使用 `smp.set_activation_checkpointig()` 包装顺序层对象。

```
import smdistributed.modelparallel.torch as smp
from transformers import AutoModelForCausalLM

smp.init()
model = AutoModelForCausalLM(*args, **kwargs)
model = smp.DistributedModel(model)

# Call set_activation_checkpointing API
transformer_layers = model.module.module.module.transformer.seq_layers
smp.set_activation_checkpointing(
    transformer_layers, pack_args_as_tuple=True, strategy='each')
```