

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# 对脚本应用 SageMaker 智能筛选 PyTorch
<a name="train-smart-sifting-apply-to-pytorch-script"></a>

这些说明演示了如何使用训练脚本启用 SageMaker 智能筛选。

1. 配置 SageMaker 智能筛选界面。

    SageMaker 智能筛选库实现了一种基于相对阈值损耗的采样技术，该技术有助于筛选出对降低损耗值影响较小的样本。 SageMaker 智能筛选算法使用正向传递计算每个输入数据样本的损失值，并根据先前数据的损失值计算其相对百分位数。

   以下两个参数是创建筛选配置对象时需要为 `RelativeProbabilisticSiftConfig` 类指定的参数。
   + 指定用于 `beta_value` 参数训练的数据比例。
   + 使用 `loss_history_length` 参数指定用于比较的样本数。

   以下代码示例演示了如何设置 `RelativeProbabilisticSiftConfig` 类的对象。

   ```
   from smart_sifting.sift_config.sift_configs import (
       RelativeProbabilisticSiftConfig
       LossConfig
       SiftingBaseConfig
   )
   
   sift_config=RelativeProbabilisticSiftConfig(
       beta_value=0.5,
       loss_history_length=500,
       loss_based_sift_config=LossConfig(
            sift_config=SiftingBaseConfig(sift_delay=0)
       )
   )
   ```

   有关`loss_based_sift_config`参数和相关类的更多信息，请参阅[SageMaker 智能筛选配置模块](train-smart-sifting-pysdk-reference.md#train-smart-sifting-pysdk-base-config-modules) SageMaker 智能筛选 Python SDK 参考部分中的。

   前面代码示例中的 `sift_config` 对象在第 4 步中用于设置 `SiftingDataloader` 类。

1. （可选）配置 SageMaker 智能筛选批量转换类。

   不同的训练使用场景需要不同的训练数据格式。鉴于数据格式多种多样， SageMaker 智能筛选算法需要确定如何对特定批次进行筛选。为了解决这个问题， SageMaker 智能筛选提供了一个批量转换模块，可以帮助将批次转换为可以高效筛选的标准化格式。

   1. SageMaker 智能筛选处理以下格式的训练数据的批量转换：Python 列表、字典、元组和张量。对于这些数据格式， SageMaker 智能筛选会自动处理批量数据格式转换，您可以跳过此步骤的其余部分。如果您跳过此步骤，在配置 `SiftingDataloader` 的第 4 步中，请将 `SiftingDataloader` 的 `batch_transforms` 参数保留为默认值 `None`。

   1. 如果您的数据集不是这些格式，则您应继续本步骤的其余部分，使用 `SiftingBatchTransform` 创建自定义批量转换。

      如果您的数据集不是 SageMaker 智能筛选支持的格式之一，则可能会遇到错误。此类数据格式错误可以通过在 `SiftingDataloader` 类中添加 `batch_format_index` 或 `batch_transforms` 参数来解决，您可以在第 4 步中进行设置。下面显示了由于数据格式不兼容而导致的错误示例以及解决方法。    
[\[See the AWS documentation website for more details\]](http://docs.aws.amazon.com/zh_cn/sagemaker/latest/dg/train-smart-sifting-apply-to-pytorch-script.html)

      要解决上述问题，您需要使用 `SiftingBatchTransform` 模块创建自定义批处理转换类。批次转换类应由一对转换和反向转换函数组成。函数对将您的数据格式转换为 SageMaker 智能筛选算法可以处理的格式。创建批次转换类后，此类会返回一个 `SiftingBatch` 对象，您将在第 4 步中把此对象传递给 `SiftingDataloader` 类。

      以下是 `SiftingBatchTransform` 模块中自定义批次转换类的示例。
      + 使用 SageMaker 智能筛选实现自定义列表批量转换的示例，适用于数据加载器块包含输入、掩码和标签的情况。

        ```
        from typing import Any
        
        import torch
        
        from smart_sifting.data_model.data_model_interface import SiftingBatchTransform
        from smart_sifting.data_model.list_batch import ListBatch
        
        class ListBatchTransform(SiftingBatchTransform):
            def transform(self, batch: Any):
                inputs = batch[0].tolist()
                labels = batch[-1].tolist()  # assume the last one is the list of labels
                return ListBatch(inputs, labels)
        
            def reverse_transform(self, list_batch: ListBatch):
                a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)]
                return a_batch
        ```
      + 使用 SageMaker 智能筛选实现自定义列表批量转换的示例，适用于不需要标签进行反向转换的情况。

        ```
        class ListBatchTransformNoLabels(SiftingBatchTransform):
            def transform(self, batch: Any):
                return ListBatch(batch[0].tolist())
        
            def reverse_transform(self, list_batch: ListBatch):
                a_batch = [torch.tensor(list_batch.inputs)]
                return a_batch
        ```
      + 在数据加载器块有输入、掩码和标签的情况下，使用 SageMaker 智能筛选的自定义张量批处理实现示例。

        ```
        from typing import Any
        
        from smart_sifting.data_model.data_model_interface import SiftingBatchTransform
        from smart_sifting.data_model.tensor_batch import TensorBatch
        
        class TensorBatchTransform(SiftingBatchTransform):
            def transform(self, batch: Any):
                a_tensor_batch = TensorBatch(
                    batch[0], batch[-1]
                )  # assume the last one is the list of labels
                return a_tensor_batch
        
            def reverse_transform(self, tensor_batch: TensorBatch):
                a_batch = [tensor_batch.inputs, tensor_batch.labels]
                return a_batch
        ```

      在您创建已执行 `SiftingBatchTransform` 批次转换类后，可在第 4 步中使用 `SiftingDataloader` 类进行设置。本指南的其余部分假设已创建了一个 `ListBatchTransform` 类。在第 4 步中，此类将传递给 `batch_transforms`。

1. 创建用于实现 SageMaker 智能筛选`Loss`接口的类。本教程假定此类名为 `SiftingImplementedLoss`。在设置此类时，我们建议您在模型训练循环中使用相同的损失函数。按照以下子步骤创建 SageMaker 智能筛选`Loss`实现的类。

   1. SageMaker 智能筛选计算每个训练数据样本的损失值，而不是计算批次的单个损失值。为确保 SageMaker 智能筛选使用相同的损失计算逻辑，请使用 SageMaker 智能筛选`Loss`模块创建 smart-sifting-implemented损失函数，该模块使用您的损失函数并计算每个训练样本的损失。
**提示**  
SageMaker 智能筛选算法在每个数据样本上运行，而不是在整个批次上运行，因此您应该添加一个初始化函数来设置 PyTorch 损失函数，而无需任何还原策略。  

      ```
      class SiftingImplementedLoss(Loss):  
          def __init__(self):
              self.loss = torch.nn.CrossEntropyLoss(reduction='none')
      ```
以下代码示例也说明了这一点。

   1. 定义一个接受`original_batch`（或者`transformed_batch`如果您在步骤 2 中设置了批量变换）和 PyTorch模型的损失函数。 SageMaker 智能筛选使用不减值的指定损失函数，对每个数据样本进行正向传递，以评估其损失值。

   以下代码是一个名为的 smart-sifting-implemented`Loss`接口的示例`SiftingImplementedLoss`。

   ```
   from typing import Any
   
   import torch
   import torch.nn as nn
   from torch import Tensor
   
   from smart_sifting.data_model.data_model_interface import SiftingBatch
   from smart_sifting.loss.abstract_sift_loss_module import Loss
   
   model=... # a PyTorch model based on torch.nn.Module
   
   class SiftingImplementedLoss(Loss):   
       # You should add the following initializaztion function 
       # to calculate loss per sample, not per batch.
       def __init__(self):
           self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none')
   
       def loss(
           self,
           model: torch.nn.Module,
           transformed_batch: SiftingBatch,
           original_batch: Any = None,
       ) -> torch.Tensor:
           device = next(model.parameters()).device
           batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2
           # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2
   
           # compute loss
           outputs = model(batch)
           return self.loss_no_reduction(outputs.logits, batch[2])
   ```

   在训练循环进入实际前向传递之前，每次迭代获取批次数据的数据加载阶段都会进行筛选损失计算。然后将单个损失值与之前的损失值进行比较，并根据步骤 1 中设置的 `RelativeProbabilisticSiftConfig` 对象估算出其相对百分位数。

1. 按 SageMaker AI `SiftingDataloader` 类封装 PyTroch 数据加载器。

   最后，将您在前面步骤中配置的所有 SageMaker 智能筛选实现的类用于 SageMaker AI `SiftingDataloder` 配置类。这个类是的封装器。 PyTorch [https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)通过封装 PyTorch`DataLoader`， SageMaker 智能筛选被注册为在 PyTorch 训练作业的每次迭代中作为数据加载的一部分运行。以下代码示例演示如何实现 SageMaker AI 数据筛选到. PyTorch `DataLoader` 

   ```
   from smart_sifting.dataloader.sift_dataloader import SiftingDataloader
   from torch.utils.data import DataLoader
   
   train_dataloader = DataLoader(...) # PyTorch data loader
   
   # Wrap the PyTorch data loader by SiftingDataloder
   train_dataloader = SiftingDataloader(
       sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig
       orig_dataloader=train_dataloader,
       batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2
       loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface
       model=model,
       log_batch_data=False
   )
   ```