View a markdown version of this page

实现奖励函数 - Amazon Nova

实现奖励函数

概述

奖励函数(也称为评分器或打分器)是评测模型响应并为训练提供反馈信号的核心组件。其必须作为 Lambda 函数实现,接受模型响应并返回奖励分数。

接口规范

奖励函数必须按以下格式接收与返回数据:

训练样本输入示例

{ "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

奖励 Lambda 的有效载荷示例

容器在将数据发送到 Lambda 函数之前会自动进行转换,具体方式如下:

  1. 为每个提示生成模型响应

  2. 将 assistant 轮次(生成的响应)追加到 messages 数组中

  3. 添加唯一的 id 字段用于跟踪

Lambda 函数将接收此转换格式的数据:

{ "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don not have a dedicated security team..." } ], # Following section will be same as your training dataset sample "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

奖励 Lambda 约定

def lambda_handler(event, context): return lambda_grader(event) def lambda_grader(samples: list[dict]) -> list[dict]: """ Args: samples: List of dictionaries in OpenAI format Example input: { "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Company, I don nott have a dedicated security team..." } ], # This section will be same as your training dataset "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } } Returns: List of dictionaries with reward scores: { "id": str, # Same id as input sample "aggregate_reward_score": float, # Overall score for the sample "metrics_list": [ # OPTIONAL: Component scores { "name": str, # Name of the component score "value": float, # Value of the component score "type": str # "Reward" or "Metric" } ] } """

输入和输出字段

输入字段

字段 说明 附加说明
id 样本的唯一标识符 在输出中原样返回。字符串格式
消息 OpenAI 格式的有序聊天记录 消息对象数组
messages[].role 消息发送方 常用值:user、assistant、system
messages[].content 消息文本内容 纯字符串
**metadata 用于辅助评分的自定义信息 对象类型;由训练数据传入的可选字段

输出字段

字段 说明 附加说明
id 与输入样本一致的标识符 必须与输入匹配
aggregate_reward_score 样本综合分数 浮点数(如 0.0 – 1.0 或任务自定义区间)
metrics_list 构成综合评分的各单项评分 指标对象数组

技术约束

  • 超时限制:每次 Lambda 调用最长执行时间为 15 分钟

  • 并发能力:必须能够处理 rollout_worker_replicas * 64 个并发请求

  • 可靠性:必须实现完善的错误处理机制,并稳定返回有效评分

  • 性能:优化执行速度(秒级而非分钟级),保障训练高效运行

最佳实践

  • 尽量减少外部 API 调用

  • 使用高效的算法与数据结构

  • 为瞬时故障实现重试逻辑

  • 缓存可重复使用的计算结果

  • 训练前充分测试,确保执行无异常

使用自定义奖励函数

当存在针对特定任务的评测标准时,可实现自定义奖励函数:

  • 定义评测标准:明确在当前任务中优质响应的判定依据

  • 实现 Lambda 函数:按照接口规范创建 Lambda 函数

  • 本地测试:验证函数对样本输入能返回正确评分

  • 部署到 AWS:部署 Lambda 并记录其 ARN

  • 配置配方:将 Lambda ARN 填入配方的 reward_lambda_arn 字段

  • 小规模数据集测试: 使用少量数据运行 RFT,验证集成效果

IAM 权限

所需的权限

SageMaker 执行角色必须具备调用 Lambda 函数的权限。将此策略添加到 SageMaker 执行角色:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "lambda:InvokeFunction" ], "Resource": "arn:aws:lambda:region:account-id:function:function-name" } ] }

Lambda 执行角色

Lambda 函数的执行角色需要基本的 Lambda 执行权限:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents" ], "Resource": "arn:aws:logs:*:*:*" } ] }

其他权限:如果 Lambda 函数访问其他 AWS 服务(例如用于参考数据的 S3、用于日志记录的 DynamoDB),请将这些权限添加到 Lambda 执行角色。

示例:LLM-as-a-judge 奖励函数

此示例演示了如何使用 Amazon Bedrock 模型作为评判工具,通过将模型响应与参考答案进行比较来评测模型响应。此 Lambda 模板为用户提供框架,支持调用 Amazon Bedrock 发起推理请求,以完成评判评测处理。此 Lambda 函数与其他奖励函数采用一致的输入输出约定。

实施

此 Lambda 函数采用两阶段评测流程:lambda_handler 模块从输入样本中提取模型响应与参考答案,随后 lambda_graded 函数调用 Amazon Bedrock 对二者的语义相似度进行评分。该实现具备完善的错误处理机制,可对瞬时故障自动重试,并支持多种灵活的参考答案格式(字符串格式与结构化字典格式)。

实现细节:

  • 重试逻辑:针对限流异常实现指数回退(1 秒、2 秒、4 秒),以应对 Bedrock API 速率节流

  • 错误处理:评测失败时返回 0.0 分,而非抛出异常

  • 确定性评分:使用 temperature=0.0 确保多次评测结果一致

  • 灵活参考格式:自动支持字符串与字典两种格式的参考答案

  • 分数裁剪:确保所有分数落在有效的 [0.0, 1.0] 区间内

  • 模型无关:更改 JUDGE_MODEL_ID 即可使用任何 Amazon Bedrock 模型(Nova、Llama、Mistral 等)

""" LLM Judge Lambda POC - Working implementation using Amazon Bedrock """ import json import time import boto3 bedrock_runtime = boto3.client('bedrock-runtime', region_name='us-east-1') JUDGE_MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0" SYSTEM_PROMPT = "You must output ONLY a number between 0.0 and 1.0. No explanations, no text, just the number." JUDGE_PROMPT_TEMPLATE = """Compare the following two responses and rate how similar they are on a scale of 0.0 to 1.0, where: - 1.0 means the responses are semantically equivalent (same meaning, even if worded differently) - 0.5 means the responses are partially similar - 0.0 means the responses are completely different or contradictory Response A: {response_a} Response B: {response_b} Output ONLY a number between 0.0 and 1.0. No explanations.""" def lambda_graded(response_a: str, response_b: str, max_retries: int = 3) -> float: """Call Bedrock to compare responses and return similarity score.""" prompt = JUDGE_PROMPT_TEMPLATE.format(response_a=response_a, response_b=response_b) for attempt in range(max_retries): try: response = bedrock_runtime.converse( modelId=JUDGE_MODEL_ID, messages=[{"role": "user", "content": [{"text": prompt}]}], system=[{"text": SYSTEM_PROMPT}], inferenceConfig={"temperature": 0.0, "maxTokens": 10} ) print(f"Bedrock call successful: {response}") output = response['output']['message']['content'][0]['text'].strip() score = float(output) print(f"Score parsed: {score}") return max(0.0, min(1.0, score)) except Exception as e: if "ThrottlingException" in str(e) and attempt < max_retries - 1: time.sleep(2 ** attempt) else: print(f"Bedrock call failed: {e}") return None return None def lambda_handler(event, context): """AWS Lambda handler - processes samples from RFTEvalInvoker.""" try: samples = event if isinstance(event, list) else [event] results = [] for sample in samples: sample_id = sample.get("id", "unknown") messages = sample.get("messages", []) # Extract assistant response (response A) response_a = "" for msg in messages: if msg.get("role") in ["assistant", "nova_assistant"]: response_a = msg.get("content", "") break # Extract reference answer from root level (no longer in metadata) reference_answer = sample.get("reference_answer", "") # Handle both string and dict reference_answer formats if isinstance(reference_answer, dict): # If reference_answer is a dict, extract the explanation or compliant field response_b = reference_answer.get("explanation", reference_answer.get("compliant", "")) else: response_b = reference_answer if not response_a or not response_b: results.append({ "id": sample_id, "aggregate_reward_score": 0.0, "metrics_list": [{"name": "similarity_score", "value": 0.0, "type": "Metric"}] }) continue # Get similarity score score = lambda_graded(response_a, response_b) results.append({ "id": sample_id, "aggregate_reward_score": score, "metrics_list": [ { "name": "similarity_score", "value": score, "type": "Metric" } ] }) return {"statusCode": 200, "body": json.dumps(results)} except Exception as e: print(f"Error: {e}") return {"statusCode": 500, "body": json.dumps({"error": str(e)})}

输入格式

Lambda 采用与其他奖励函数相同的输入格式:

{ "id": "sample-001", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don't have a dedicated security team..." } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." }, "my_custom_field": "custom_value" }

输出格式

{ "id": "sample-001", "aggregate_reward_score": 0.85, "metrics_list": [ { "name": "similarity_score", "value": 0.85, "type": "Metric" } ] }

部署注意事项

您可能还需要根据所选模型的能力和 API 格式调整提示模板和推理参数。

  • IAM 权限:Lambda 执行角色必须具备对所选模型的 bedrock:InvokeModel 调用权限

  • 超时:将 Lambda 超时设置为至少 60 秒,以适配 Bedrock API 延迟和重试需求

  • 区域:在所选 Bedrock 模型可用的区域部署

  • 成本:监控 Bedrock API 使用量,每次评测中每个样本均会发起一次 API 调用

  • 吞吐量:大规模评测时,申请提升 Bedrock 配额以避免节流

提高 Bedrock 吞吐量

如果在评测过程中遇到节流,可提升 Bedrock 模型配额:

  • 进入 AWS 服务配额控制台

  • 搜索 Bedrock 并选择对应区域

  • 找到所选模型的配额项(例如“Claude 3.5 Sonnet 每分钟调用次数”)

  • 点击 “请求增加配额”并指定所需的吞吐量

  • 提供增加配额的理由(例如“RFT 评测工作负载”)

Lambda 内置的重试逻辑可处理偶发节流,而持续高吞吐量评测需相应提升配额。

所需的 IAM 策略:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "bedrock:InvokeModel" ], "Resource": "arn:aws:bedrock:*::foundation-model/*" } ] }