View a markdown version of this page

使用演算法來執行超參數調校工作 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用演算法來執行超參數調校工作

下一節說明如何使用演算法資源在 Amazon SageMaker AI 中執行超參數調校任務。超參數調校工作會透過在您的資料集上,使用您指定的演算法和超參數範圍執行許多訓練工作,來尋找最佳版本的模型。它接著會根據您選擇的指標,選擇可讓模型取得最佳執行結果的超參數值。如需詳細資訊,請參閱使用 SageMaker AI 執行自動模型調校

您可以使用 Amazon SageMaker AI 主控台、低層級 Amazon SageMaker API 或 Amazon SageMaker Python SDK,使用演算法資源來建立超參數調校工作。

使用演算法來執行超參數調校工作 (主控台)

使用演算法來執行超參數調校工作 (主控台)
  1. 開啟位在 https://console.aws.amazon.com/sagemaker/ 的 SageMaker AI 主控台。

  2. 選擇演算法

  3. 我的演算法索引標籤的清單上選擇您建立的演算法,或在AWS Marketplace 訂閱索引標籤上選擇您訂閱的演算法。

  4. 選擇建立超參數調校工作

    會自動選取您選擇的演算法。

  5. 建立超參數調校工作頁面上,提供以下資訊:

    1. 針對暖啟動,選擇啟用暖啟動來使用先前超參數調校工作的資訊做為此超參數調校工作的起點。如需詳細資訊,請參閱執行超參數調校任務的暖啟動

      1. 若您的輸入資料與此超參數調校工作的父系工作相同,請選擇相同資料及演算法,或是選擇傳輸學習來針對此超參數調校工作使用額外或不同的輸入資料。

      2. 針對父系超參數調校工作,選擇最多 5 個超參數調校工作,做為此超參數調校工作的父系。

    2. 針對超參數調校工作名稱,輸入調校工作的名稱。

    3. 針對 IAM 角色,請選擇擁有必要許可,能在 SageMaker AI 中執行超參數調校工作的 IAM 角色,或是選擇建立新角色來允許 SageMaker AI 建立已連接 AmazonSageMakerFullAccess 受管政策的角色。如需相關資訊,請參閱如何使用 SageMaker AI 執行角色

    4. 針對 VPC,選擇您想要允許調校工作啟動以進行存取的訓練工作的 Amazon VPC。如需詳細資訊,請參閱讓 SageMaker AI 訓練任務可以存取 Amazon VPC 中的資源

    5. 選擇下一步

    6. 針對目標指標,選擇超參數調校工作用來判斷最佳超參數組合的指標,然後選擇是否要最小或最大化此指標。如需詳細資訊,請參閱檢視最佳訓練任務

    7. 針對超參數組態,選擇您希望調校工作搜尋之可調校的超參數範圍,並設定您希望在所有超參數調校工作所啟動訓練工作中維持一致的超參數值。如需詳細資訊,請參閱定義超參數範圍

    8. 選擇下一步

    9. 針對輸入資料組態,針對每個用於超參數調校工作的輸入資料通道,指定下列值。您可以在該演算法的演算法摘要頁面的通道規格區段下,查看您用於超參數調校支援的演算法通道,以及每個通道的內容類型、支援的壓縮類型和支援的輸入模式。

      1. 針對通道名稱,輸入輸入通道的名稱。

      2. 針對內容類型,輸入演算法針對通道所預期的資料內容類型。

      3. 針對壓縮類型,選擇要使用的資料壓縮類型 (若有的話)。

      4. 針對記錄包裝函式,若演算法預期 RecordIO 格式的資料,請選擇 RecordIO

      5. 針對 S3 資料類型S3 資料分佈類型S3 位置,請指定適當的值。如需這些值所代表意義的資訊,請參閱 S3DataSource

      6. 針對輸入模式,選擇檔案來從所佈建的機器學習 (ML) 儲存磁碟區下載資料,並將目錄掛載到 Docker 磁碟區。選擇管道來直接從 Amazon S3 串流資料到容器。

      7. 若要新增另一個輸入通道,請選擇新增通道。若您已完成新增輸入通道,請選擇完成

    10. 針對輸出位置,請指定下列值:

      1. 針對 S3 輸出路徑,選擇此超參數調校工作所啟動的訓練工作用來存放輸出 (例如模型成品) 的 S3 位置。

        注意

        您可以使用存放在此位置的模型成品,從超參數調校工作建立模型或模型套件。

      2. 對於加密金鑰,如果您希望 SageMaker AI 使用 AWS KMS 金鑰來加密 S3 位置的靜態輸出資料。

    11. 針對資源組態,提供下列資訊:

      1. 針對執行個體類型,選擇要針對每個超參數調校工作所啟動訓練工作使用的執行個體類型。

      2. 針對執行個體計數,輸入要針對每個超參數調校工作所啟動訓練工作使用的機器學習 (ML) 執行個體數量。

      3. 針對每個執行個體的額外磁碟區 (GB),輸入您希望佈建超參數調校工作所啟動每個訓練工作的機器學習 (ML) 儲存磁碟區大小。機器學習 (ML) 儲存磁碟區會存放模型成品及累加狀態。

      4. 對於加密金鑰,如果您希望 Amazon SageMaker AI 使用 AWS Key Management Service 金鑰來加密連接到訓練執行個體的 ML 儲存磁碟區中的資料,請指定 金鑰。

    12. 針對資源限制,提供下列資訊:

      1. 針對訓練工作數量上限,指定您希望超參數調校工作啟動的訓練工作數量上限。超參數調校工作最多能啟動 500 個訓練任務。

      2. 針對平行訓練工作數量上限,指定您希望超參數調校工作啟動的同時訓練工作數量上限。超參數調校工作最多能啟動 10 個同時訓練工作。

      3. 針對停止條件,指定您希望超參數調校工作所啟動的每個訓練工作執行時間上限 (秒、分鐘、小時或天數)。

    13. 針對標籤,請指定一或多個標籤來管理超參數調校工作。每個標籤皆包含索引鍵與選用值。每個資源的標籤鍵必須是唯一的。

    14. 選擇建立任務來執行超參數調校工作。

使用演算法來執行超參數調校工作 (API)

若要使用 SageMaker API,利用演算法來執行超參數調校工作,請在您傳遞給 CreateHyperParameterTuningJobAlgorithmSpecification 物件的 AlgorithmName 欄位中,指定演算法的名稱或 Amazon Resource Name (ARN)。如需 SageMaker AI 中超參數調校的資訊,請參閱 使用 SageMaker AI 執行自動模型調校

使用演算法來執行超參數調校工作 (Amazon SageMaker Python SDK)

使用您在 上建立或訂閱的演算法 AWS Marketplace 來建立超參數調校任務、建立 AlgorithmEstimator 物件,並將 Amazon Resource Name (ARN) 或演算法名稱指定為algorithm_arn引數的值。然後,使用您建立的 AlgorithmEstimator 做為 estimator 引數的值,初始化 HyperparameterTuner 物件。最後,呼叫 AlgorithmEstimatorfit 方法。例如:

from sagemaker import AlgorithmEstimator from sagemaker.tuner import HyperparameterTuner data_path = os.path.join(DATA_DIR, 'marketplace', 'training') algo = AlgorithmEstimator( algorithm_arn='arn:aws:sagemaker:us-east-2:764419575721:algorithm/scikit-decision-trees-1542410022', role='SageMakerRole', instance_count=1, instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, base_job_name='test-marketplace') train_input = algo.sagemaker_session.upload_data( path=data_path, key_prefix='integ-test-data/marketplace/train') algo.set_hyperparameters(max_leaf_nodes=10) tuner = HyperparameterTuner(estimator=algo, base_tuning_job_name='some-name', objective_metric_name='validation:accuracy', hyperparameter_ranges=hyperparameter_ranges, max_jobs=2, max_parallel_jobs=2) tuner.fit({'training': train_input}, include_cls_metadata=False) tuner.wait()