

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# 使用 Apache Spark 在亚马逊 A SageMaker I 上使用自定义算法进行模型训练和托管
<a name="apache-spark-example1-cust-algo"></a>

在中[SageMaker 适用于 Scala 的人工智能 Spark](apache-spark-example1.md)，之所以使用，`kMeansSageMakerEstimator`是因为该示例使用 Amazon A SageMaker I 提供的 k-means 算法进行模型训练。您可以选择使用自己的自定义算法进行模型训练。假设您已经创建了 Docker 映像，则可以创建自己的 `SageMakerEstimator` 并为您的自定义映像指定 Amazon Elastic Container Registry 路径。

以下示例显示如何从 `SageMakerEstimator` 中创建 `KMeansSageMakerEstimator`。在新的评估程序中，您可以显式指定训练和推理代码图像的 Docker 注册表路径。

```
import com.amazonaws.services.sagemaker.sparksdk.IAMRole
import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator
import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer
import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer

val estimator = new SageMakerEstimator(
  trainingImage =
    "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1",
  modelImage =
    "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1",
  requestRowSerializer = new ProtobufRequestRowSerializer(),
  responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(),
  hyperParameters = Map("k" -> "10", "feature_dim" -> "784"),
  sagemakerRole = IAMRole(roleArn),
  trainingInstanceType = "ml.p2.xlarge",
  trainingInstanceCount = 1,
  endpointInstanceType = "ml.c4.xlarge",
  endpointInitialInstanceCount = 1,
  trainingSparkDataFormat = "sagemaker")
```

在该代码中，`SageMakerEstimator` 构造函数中的参数包括：
+ `trainingImage` - 标识包含自定义代码的训练映像的 Docker 注册表路径。
+ `modelImage` - 标识包含推理代码的映像的 Docker 注册表路径。
+ `requestRowSerializer` - 实施 `com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer`。

  此参数对输入中的行进行序列化，将其发送`DataFrame`到 SageMaker AI 中托管的模型进行推理。
+ `responseRowDeserializer` - 实施 

  `com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer`.

  此参数将来自模型的响应（托管在 SageMaker AI 中）反序列化为。`DataFrame`
+ `trainingSparkDataFormat` - 指定 Spark 在将训练数据从 `DataFrame` 上传到 S3 时使用的数据格式。例如，`"sagemaker"` 用于 protobuf 格式，`"csv"` 用于逗号分隔值，`"libsvm"` 用于 LibSVM 格式。

您可以实施自己的 `RequestRowSerializer` 和 `ResponseRowDeserializer` 以序列化或反序列化您的推理代码支持的数据格式中的行，例如 .libsvm 或 .csv。