將 SageMaker Smart Sifting 套用至 Hugging Face 轉換器指令碼 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

將 SageMaker Smart Sifting 套用至 Hugging Face 轉換器指令碼

有兩種方法可將 SageMaker Smart Sifting 實作至轉換器 Trainer 類別。

注意

如果您使用其中一個適用於 PyTorch 的 DLC 並安裝 SageMaker Smart Sifting 套件,請記得安裝 transformers 程式庫。您可以擴充 DLC 或傳遞 requirements.txt 至 SageMaker AI Python SDK 中 PyTorch (sagemaker.pytorch.PyTorch) 的訓練任務啟動器類別,來安裝其他套件。

簡易設定

將 SageMaker Smart Sifting 實作到轉換器 Trainer 類別的最簡單方法是使用 enable_sifting 函式。此函式接受現有 Trainer 物件,並使用 SiftingDataloader 包裝現有 DataLoader 物件。您可以繼續使用相同的訓練物件。請參閱以下使用範例。

from smart_sifting.integrations.trainer import enable_sifting from smart_sifting.loss.abstract_sift_loss_module import Loss from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) class SiftingImplementedLoss(Loss): def loss(self, model, transformed_batch, original_batch): loss_fct = MSELoss(reduction="none") # make sure to set reduction to "none" logits = model.bert(**original_batch) return loss_fct(logits, original_batch.get("labels")) sift_config = RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) ) trainer = Trainer(...) enable_sifting(trainer, sift_config, loss=SiftingImplementedLoss()) # updates the trainer with Sifting Loss and config trainer.train()

SiftingDataloader 類別是可重複運算的資料載入器。由於是在篩選期間隨機取樣,無法事先知道資料集的確切大小。因此,Hugging Face Trainer 會預期 max_steps 訓練引數。請留意,此引數會覆寫 epoch 組態參數 num_train_epochs。如果您的原始資料載入器也可以重複運算,或您的訓練使用 max_steps 和單一 epoch,則 SiftingDataloader 與現有資料載入器的表現相同。如果原始資料載入器無法重複運算或未提供 max_steps,Hugging Face 訓練器可能會擲回類似以下的錯誤訊息。

args.max_steps must be set to a positive value if dataloader does not have a length, was -1

為了解決此問題,enable_sifting 函式會提供選用 set_epochs 參數。這樣就可使用 Trainer 類別的 num_train_epochs argument 提供的 epoch 數量來啟用 epoch 訓練,並將 max_steps 設為最大系統整數,允許訓練進行,直到指定的 epoch 完成為止。

自訂設定

針對 SageMaker Smart Sifting 資料載入器的自訂整合,您可以使用自訂 Hugging Face Trainer 類別。在任何 Trainer 子類別中,get_train_dataloader() 函式都可以覆寫,以改為傳回 SiftingDataloader 類別的物件。如果有現有自訂訓練器,此方法可能較不具侵入性,但需要變更程式碼,而非只簡單的設定選項。以下是將 SageMaker Smart Sifting 實作到自訂 Hugging Face Trainer 類別的範例。

from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from smart_sifting.loss.abstract_sift_loss_module import Loss from smart_sifting.data_model.data_model_interface import SiftingBatch, SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class SiftingListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch class SiftingImplementedLoss(): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.celoss = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # compute loss outputs = model(batch) return self.celoss(outputs.logits, batch[2]) class SiftingImplementedTrainer(Trainer): def get_train_dataloader(self): dl = super().get_train_dataloader() sift_config = RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) ) return SiftingDataloader( sift_config=sift_config, orig_dataloader=dl, batch_transforms=SiftingListBatchTransform(), loss_impl=SiftingImplementedLoss(), model=self.model )

使用包裝的 Trainer 類別建立物件,如下所示。

trainer = SiftingImplementedTrainer( model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset ) trainer.train()