

本文為英文版的機器翻譯版本，如內容有任何歧義或不一致之處，概以英文版為準。

# 將 SageMaker Smart Sifting 套用至 PyTorch 指令碼
<a name="train-smart-sifting-apply-to-pytorch-script"></a>

這些指示說明如何以訓練指令碼啟動 SageMaker Smart Sifting。

1. 設定 SageMaker Smart Sifting 介面。

   SageMaker Smart Sifting 程式庫會實作相對閾值損失型取樣技術，有助於篩選掉對降低損失值較無影響的範例。SageMaker Smart Sifting 演算法會透過向前傳遞，來計算每個輸入資料範例的損失值，並根據先前資料的損失值計算其相對百分位數。

   下列兩個參數是您建立篩選組態物件時，需要指定的 `RelativeProbabilisticSiftConfig` 類別的參數。
   + 指定用於訓練 `beta_value` 參數的資料比例。
   + 指定相較於 `loss_history_length` 參數的使用範例數量。

   下列程式碼範例示範如何設定 `RelativeProbabilisticSiftConfig` 類別的物件。

   ```
   from smart_sifting.sift_config.sift_configs import (
       RelativeProbabilisticSiftConfig
       LossConfig
       SiftingBaseConfig
   )
   
   sift_config=RelativeProbabilisticSiftConfig(
       beta_value=0.5,
       loss_history_length=500,
       loss_based_sift_config=LossConfig(
            sift_config=SiftingBaseConfig(sift_delay=0)
       )
   )
   ```

   如需 `loss_based_sift_config` 參數和相關類別的詳細資訊，請參閱 SageMaker Smart Sifting Python SDK 參考一節中的 [SageMaker Smart Sifting 組態模組](train-smart-sifting-pysdk-reference.md#train-smart-sifting-pysdk-base-config-modules)。

   上述程式碼範例中的 `sift_config` 物件在步驟 4 中用於設定 `SiftingDataloader` 類別。

1. (選用) 設定 SageMaker Smart Sifting 批次轉換類別。

   不同的訓練使用案例需要不同的訓練資料格式。考慮到各種資料格式，SageMaker Smart Sifting 演算法需要知道如何在特定批次上進行篩選。為了解決此問題，SageMaker Smart Sifting 會提供批次轉換模組，協助將批次轉換為可有效篩選的標準化格式。

   1. SageMaker Smart Sifting 會以下列格式處理訓練資料的批次轉換：Python 清單、字典、元組和張量。針對這些資料格式，SageMaker Smart Sifting 會自動轉換批次資料格式，您可以略過此步驟的其餘部分。如果您略過此步驟，請在步驟 4 中設定 `SiftingDataloader`，將 `SiftingDataloader` 的 `batch_transforms` 參數保留為預設值，也就是 `None`。

   1. 如果您的資料集不是這些格式，您仍應繼續進行此步驟的其餘部分，以使用 `SiftingBatchTransform` 建立自訂批次轉換。

      如果您的資料集不是 SageMaker Smart Sifting 支援的格式，您可能會遇到錯誤。您可以將 `batch_format_index` 或 `batch_transforms` 參數新增至您在步驟 4 中設定的 `SiftingDataloader` 類別，以解決此類資料格式錯誤。以下顯示因資料格式和解析度不相容而造成的範例錯誤。    
[\[See the AWS documentation website for more details\]](http://docs.aws.amazon.com/zh_tw/sagemaker/latest/dg/train-smart-sifting-apply-to-pytorch-script.html)

      若要解決上述問題，您需要使用 `SiftingBatchTransform` 模組建立自訂批次轉換類別。批次轉換類別應包含一對轉換和反向轉換函式。函式對會將您的資料格式轉換為 SageMaker Smart Sifting 演算法可以處理的格式。建立批次轉換類別之後，類別會傳回您將在步驟 4 中傳遞給 `SiftingDataloader` 類別的 `SiftingBatch` 物件。

      以下是 `SiftingBatchTransform` 模組的自訂批次轉換類別範例。
      + 資料載入器區塊具有輸入、遮罩和標籤時，使用 SageMaker Smart Sifting 的自訂清單批次轉換實作範例。

        ```
        from typing import Any
        
        import torch
        
        from smart_sifting.data_model.data_model_interface import SiftingBatchTransform
        from smart_sifting.data_model.list_batch import ListBatch
        
        class ListBatchTransform(SiftingBatchTransform):
            def transform(self, batch: Any):
                inputs = batch[0].tolist()
                labels = batch[-1].tolist()  # assume the last one is the list of labels
                return ListBatch(inputs, labels)
        
            def reverse_transform(self, list_batch: ListBatch):
                a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)]
                return a_batch
        ```
      + 不需要標籤進行反向轉換時，使用 SageMaker Smart Sifting 的自訂清單批次轉換實作範例。

        ```
        class ListBatchTransformNoLabels(SiftingBatchTransform):
            def transform(self, batch: Any):
                return ListBatch(batch[0].tolist())
        
            def reverse_transform(self, list_batch: ListBatch):
                a_batch = [torch.tensor(list_batch.inputs)]
                return a_batch
        ```
      + 資料載入器區塊具有輸入、遮罩和標籤時，使用 SageMaker Smart Sifting 的自訂張量批次實作範例。

        ```
        from typing import Any
        
        from smart_sifting.data_model.data_model_interface import SiftingBatchTransform
        from smart_sifting.data_model.tensor_batch import TensorBatch
        
        class TensorBatchTransform(SiftingBatchTransform):
            def transform(self, batch: Any):
                a_tensor_batch = TensorBatch(
                    batch[0], batch[-1]
                )  # assume the last one is the list of labels
                return a_tensor_batch
        
            def reverse_transform(self, tensor_batch: TensorBatch):
                a_batch = [tensor_batch.inputs, tensor_batch.labels]
                return a_batch
        ```

      建立 `SiftingBatchTransform` 實作的批次轉換類別之後，您可以在步驟 4 中使用此類別來設定 `SiftingDataloader` 類別。本指南的其餘部分假設您已建立 `ListBatchTransform` 類別。在步驟 4 中，此類別會傳遞至 `batch_transforms`。

1. 建立用來實作 SageMaker Smart Sifting `Loss` 介面的類別。本教學假設該類別名為 `SiftingImplementedLoss`。設定此類別時，建議您在模型訓練迴路中使用相同的損失函式。請完成下列子步驟，以建立 SageMaker Smart Sifting `Loss` 實作類別。

   1. SageMaker Smart Sifting 會計算每個訓練資料範例的損失值，而不是計算批次的單一損失值。為了確保 SageMaker Smart Sifting 使用相同的損失計算邏輯，請使用 SageMaker Smart Sifting `Loss` 模組來建立 Smart Sifting 實作損失函式；該模組使用您的損失函式並計算每個訓練範例的損失。
**提示**  
SageMaker Smart Sifting 演算法會在每個資料範例上執行，而不是在整個批次上執行，因此您應該新增初始化函式來設定 PyTorch 遺失函式，而不需要任何減少策略。  

      ```
      class SiftingImplementedLoss(Loss):  
          def __init__(self):
              self.loss = torch.nn.CrossEntropyLoss(reduction='none')
      ```
如以下程式碼範例所示。

   1. 定義接受 `original_batch` (如果您已在步驟 2 中設定批次轉換，則是接受 `transformed_batch`) 和 PyTorch 模型的損失函式。SageMaker Smart Sifting 在不減少的情況下使用指定的損失函式，為每個資料範例執行向前傳遞，以評估其損失值。

   下列程式碼是名為 `SiftingImplementedLoss` 的 Smart Sifting 實作 `Loss` 介面範例。

   ```
   from typing import Any
   
   import torch
   import torch.nn as nn
   from torch import Tensor
   
   from smart_sifting.data_model.data_model_interface import SiftingBatch
   from smart_sifting.loss.abstract_sift_loss_module import Loss
   
   model=... # a PyTorch model based on torch.nn.Module
   
   class SiftingImplementedLoss(Loss):   
       # You should add the following initializaztion function 
       # to calculate loss per sample, not per batch.
       def __init__(self):
           self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none')
   
       def loss(
           self,
           model: torch.nn.Module,
           transformed_batch: SiftingBatch,
           original_batch: Any = None,
       ) -> torch.Tensor:
           device = next(model.parameters()).device
           batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2
           # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2
   
           # compute loss
           outputs = model(batch)
           return self.loss_no_reduction(outputs.logits, batch[2])
   ```

   在訓練迴路實際向前傳遞之前，此篩選損失計算會在每次反覆運算中擷取批次的資料載入階段完成。然後比較個別損失值與先前的損失值，並根據您在步驟 1 中設定的 `RelativeProbabilisticSiftConfig` 物件來估計其相對百分位數。

1. 根據 SageMaker AI `SiftingDataloader` 類別包裝 PyTroch 資料載入器。

   最後，將您在先前步驟中設定的所有 SageMaker Smart Sifting 實作類別都使用到 SageMaker AI `SiftingDataloder` 設定類別中。此類別是 PyTorch [https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 的包裝函式。透過包裝 PyTorch `DataLoader`，SageMaker Smart Sifting 會經註冊，在 PyTorch 訓練任務的每個反覆運算中作為資料載入的一部分執行。下列程式碼範例示範實作 SageMaker AI 資料篩選到 PyTorch `DataLoader`。

   ```
   from smart_sifting.dataloader.sift_dataloader import SiftingDataloader
   from torch.utils.data import DataLoader
   
   train_dataloader = DataLoader(...) # PyTorch data loader
   
   # Wrap the PyTorch data loader by SiftingDataloder
   train_dataloader = SiftingDataloader(
       sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig
       orig_dataloader=train_dataloader,
       batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2
       loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface
       model=model,
       log_batch_data=False
   )
   ```