

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# 使用 A SageMaker I 估算器来运行训练作业
<a name="docker-containers-adapt-your-own-private-registry-estimator"></a>

您还可以使用 Pyth SageMaker on SDK 中的[估算器](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html)来处理 SageMaker 训练作业的配置和运行。以下代码示例显示如何使用私有 Docker 注册表中的映像配置和运行估算器。

1. 导入所需的库和依赖项，如以下代码示例中所示。

   ```
   import boto3
   import sagemaker
   from sagemaker.estimator import Estimator
   
   session = sagemaker.Session()
   
   role = sagemaker.get_execution_role()
   ```

1. 向您的训练映像、安全组和子网提供统一资源标识符 (URI)，用于您的训练作业 VPC 配置，如以下代码示例所示。

   ```
   image_uri = "myteam.myorg.com/docker-local/my-training-image:<IMAGE-TAG>"
   
   security_groups = ["sg-0123456789abcdef0"]
   subnets = ["subnet-0123456789abcdef0", "subnet-0123456789abcdef0"]
   ```

   有关`security_group_ids`和的更多信息`subnets`，请参阅 Pyth SageMaker on SDK 的 “[估算器](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html)” 部分中的相应参数描述。
**注意**  
SageMaker AI 使用您的 VPC 内的网络连接来访问您的 Docker 注册表中的镜像。要将您 Docker 注册表中的映像用于训练，注册表必须可以从您账户中的 Amazon VPC 访问。

1. 或者，如果您的 Docker 注册表需要身份验证，则还必须指定向 AI 提供访问凭证 SageMaker 的函数的 AWS Lambda 亚马逊资源名称 (ARN)。以下示例演示了如何指定 ARN。

   ```
   training_repository_credentials_provider_arn = "arn:aws:lambda:us-west-2:1234567890:function:test"
   ```

   有关使用需要身份验证的 Docker 注册表中的映像的更多信息，请参阅下文中的**使用需要身份验证的 Docker 注册表进行训练**。

1. 使用前面步骤中的代码示例来配置估算器，如以下代码示例所示。

   ```
   # The training repository access mode must be 'Vpc' for private docker registry jobs 
   training_repository_access_mode = "Vpc"
   
   # Specify the instance type, instance count you want to use
   instance_type="ml.m5.xlarge"
   instance_count=1
   
   # Specify the maximum number of seconds that a model training job can run
   max_run_time = 1800
   
   # Specify the output path for the model artifacts
   output_path = "s3://your-output-bucket/your-output-path"
   
   estimator = Estimator(
       image_uri=image_uri,
       role=role,
       subnets=subnets,
       security_group_ids=security_groups,
       training_repository_access_mode=training_repository_access_mode,
       training_repository_credentials_provider_arn=training_repository_credentials_provider_arn,  # remove this line if auth is not needed
       instance_type=instance_type,
       instance_count=instance_count,
       output_path=output_path,
       max_run=max_run_time
   )
   ```

1. 使用您的作业名称和输入路径作为参数来调用 `estimator.fit`，以启动训练作业，如以下代码示例所示。

   ```
   input_path = "s3://your-input-bucket/your-input-path"
   job_name = "your-job-name"
   
   estimator.fit(
       inputs=input_path,
       job_name=job_name
   )
   ```