View a markdown version of this page

將 SageMaker Smart Sifting 套用至 PyTorch 指令碼 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

將 SageMaker Smart Sifting 套用至 PyTorch 指令碼

這些指示說明如何以訓練指令碼啟動 SageMaker Smart Sifting。

  1. 設定 SageMaker Smart Sifting 介面。

    SageMaker Smart Sifting 程式庫會實作相對閾值損失型取樣技術,有助於篩選掉對降低損失值較無影響的範例。SageMaker Smart Sifting 演算法會透過向前傳遞,來計算每個輸入資料範例的損失值,並根據先前資料的損失值計算其相對百分位數。

    下列兩個參數是您建立篩選組態物件時,需要指定的 RelativeProbabilisticSiftConfig 類別的參數。

    • 指定用於訓練 beta_value 參數的資料比例。

    • 指定相較於 loss_history_length 參數的使用範例數量。

    下列程式碼範例示範如何設定 RelativeProbabilisticSiftConfig 類別的物件。

    from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )

    如需 loss_based_sift_config 參數和相關類別的詳細資訊,請參閱 SageMaker Smart Sifting Python SDK 參考一節中的 SageMaker Smart Sifting 組態模組

    上述程式碼範例中的 sift_config 物件在步驟 4 中用於設定 SiftingDataloader 類別。

  2. (選用) 設定 SageMaker Smart Sifting 批次轉換類別。

    不同的訓練使用案例需要不同的訓練資料格式。考慮到各種資料格式,SageMaker Smart Sifting 演算法需要知道如何在特定批次上進行篩選。為了解決此問題,SageMaker Smart Sifting 會提供批次轉換模組,協助將批次轉換為可有效篩選的標準化格式。

    1. SageMaker Smart Sifting 會以下列格式處理訓練資料的批次轉換:Python 清單、字典、元組和張量。針對這些資料格式,SageMaker Smart Sifting 會自動轉換批次資料格式,您可以略過此步驟的其餘部分。如果您略過此步驟,請在步驟 4 中設定 SiftingDataloader,將 SiftingDataloaderbatch_transforms 參數保留為預設值,也就是 None

    2. 如果您的資料集不是這些格式,您仍應繼續進行此步驟的其餘部分,以使用 SiftingBatchTransform 建立自訂批次轉換。

      如果您的資料集不是 SageMaker Smart Sifting 支援的格式,您可能會遇到錯誤。您可以將 batch_format_indexbatch_transforms 參數新增至您在步驟 4 中設定的 SiftingDataloader 類別,以解決此類資料格式錯誤。以下顯示因資料格式和解析度不相容而造成的範例錯誤。

      錯誤訊息 Resolution

      根據預設,系統不支援 {type(batch)} 類型的批次。

      此錯誤表示系統預設不支援該批次格式。您應該實作自訂批次轉換類別,並將該類別指定至 SiftingDataloader 類別的 batch_transforms 參數來使用此類別。

      無法為 {type(batch)} 類型的批次編製索引

      此錯誤表示無法正常編製批次物件的索引。使用者必須實作自訂批次轉換,並使用 batch_transforms 參數傳遞此轉換。

      批次大小 {batch_size} 不符合維度 0 或維度 1 的大小

      您提供的批次大小不符合批次的第 0 個或第 1 個維度時,就會發生此錯誤。使用者必須實作自訂批次轉換,並使用 batch_transforms 參數傳遞此轉換。

      維度 0 和維度 1 皆符合批次大小

      此錯誤表示,由於多個維度符合您提供的批次大小,需要更多資訊才能篩選批次。使用者可以提供 batch_format_index 參數,指明批次是否可依範例或功能編製索引。使用者也可以實作自訂批次轉換,但這種方式較不必要。

      若要解決上述問題,您需要使用 SiftingBatchTransform 模組建立自訂批次轉換類別。批次轉換類別應包含一對轉換和反向轉換函式。函式對會將您的資料格式轉換為 SageMaker Smart Sifting 演算法可以處理的格式。建立批次轉換類別之後,類別會傳回您將在步驟 4 中傳遞給 SiftingDataloader 類別的 SiftingBatch 物件。

      以下是 SiftingBatchTransform 模組的自訂批次轉換類別範例。

      • 資料載入器區塊具有輸入、遮罩和標籤時,使用 SageMaker Smart Sifting 的自訂清單批次轉換實作範例。

        from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class ListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch
      • 不需要標籤進行反向轉換時,使用 SageMaker Smart Sifting 的自訂清單批次轉換實作範例。

        class ListBatchTransformNoLabels(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch
      • 資料載入器區塊具有輸入、遮罩和標籤時,使用 SageMaker Smart Sifting 的自訂張量批次實作範例。

        from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class TensorBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch

      建立 SiftingBatchTransform 實作的批次轉換類別之後,您可以在步驟 4 中使用此類別來設定 SiftingDataloader 類別。本指南的其餘部分假設您已建立 ListBatchTransform 類別。在步驟 4 中,此類別會傳遞至 batch_transforms

  3. 建立用來實作 SageMaker Smart Sifting Loss 介面的類別。本教學假設該類別名為 SiftingImplementedLoss。設定此類別時,建議您在模型訓練迴路中使用相同的損失函式。請完成下列子步驟,以建立 SageMaker Smart Sifting Loss 實作類別。

    1. SageMaker Smart Sifting 會計算每個訓練資料範例的損失值,而不是計算批次的單一損失值。為了確保 SageMaker Smart Sifting 使用相同的損失計算邏輯,請使用 SageMaker Smart Sifting Loss 模組來建立 Smart Sifting 實作損失函式;該模組使用您的損失函式並計算每個訓練範例的損失。

      提示

      SageMaker Smart Sifting 演算法會在每個資料範例上執行,而不是在整個批次上執行,因此您應該新增初始化函式來設定 PyTorch 遺失函式,而不需要任何減少策略。

      class SiftingImplementedLoss(Loss): def __init__(self): self.loss = torch.nn.CrossEntropyLoss(reduction='none')

      如以下程式碼範例所示。

    2. 定義接受 original_batch (如果您已在步驟 2 中設定批次轉換,則是接受 transformed_batch) 和 PyTorch 模型的損失函式。SageMaker Smart Sifting 在不減少的情況下使用指定的損失函式,為每個資料範例執行向前傳遞,以評估其損失值。

    下列程式碼是名為 SiftingImplementedLoss 的 Smart Sifting 實作 Loss 介面範例。

    from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class SiftingImplementedLoss(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction(outputs.logits, batch[2])

    在訓練迴路實際向前傳遞之前,此篩選損失計算會在每次反覆運算中擷取批次的資料載入階段完成。然後比較個別損失值與先前的損失值,並根據您在步驟 1 中設定的 RelativeProbabilisticSiftConfig 物件來估計其相對百分位數。

  4. 根據 SageMaker AI SiftingDataloader 類別包裝 PyTroch 資料載入器。

    最後,將您在先前步驟中設定的所有 SageMaker Smart Sifting 實作類別都使用到 SageMaker AI SiftingDataloder 設定類別中。此類別是 PyTorch DataLoader 的包裝函式。透過包裝 PyTorch DataLoader,SageMaker Smart Sifting 會經註冊,在 PyTorch 訓練任務的每個反覆運算中作為資料載入的一部分執行。下列程式碼範例示範實作 SageMaker AI 資料篩選到 PyTorch DataLoader

    from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader, batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface model=model, log_batch_data=False )