实现奖励函数
概述
奖励函数(也称为评分器或打分器)是评测模型响应并为训练提供反馈信号的核心组件。其必须作为 Lambda 函数实现,接受模型响应并返回奖励分数。
接口规范
奖励函数必须按以下格式接收与返回数据:
训练样本输入示例
{ "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }
奖励 Lambda 的有效载荷示例
容器在将数据发送到 Lambda 函数之前会自动进行转换,具体方式如下:
-
为每个提示生成模型响应
-
将 assistant 轮次(生成的响应)追加到 messages 数组中
-
添加唯一的
id字段用于跟踪
Lambda 函数将接收此转换格式的数据:
{ "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don not have a dedicated security team..." } ], # Following section will be same as your training dataset sample "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }
奖励 Lambda 约定
def lambda_handler(event, context): return lambda_grader(event) def lambda_grader(samples: list[dict]) -> list[dict]: """ Args: samples: List of dictionaries in OpenAI format Example input: { "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Company, I don nott have a dedicated security team..." } ], # This section will be same as your training dataset "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } } Returns: List of dictionaries with reward scores: { "id": str, # Same id as input sample "aggregate_reward_score": float, # Overall score for the sample "metrics_list": [ # OPTIONAL: Component scores { "name": str, # Name of the component score "value": float, # Value of the component score "type": str # "Reward" or "Metric" } ] } """
输入和输出字段
输入字段
| 字段 | 说明 | 附加说明 |
|---|---|---|
| id | 样本的唯一标识符 | 在输出中原样返回。字符串格式 |
| 消息 | OpenAI 格式的有序聊天记录 | 消息对象数组 |
| messages[].role | 消息发送方 | 常用值:user、assistant、system |
| messages[].content | 消息文本内容 | 纯字符串 |
| **metadata | 用于辅助评分的自定义信息 | 对象类型;由训练数据传入的可选字段 |
输出字段
| 字段 | 说明 | 附加说明 |
|---|---|---|
| id | 与输入样本一致的标识符 | 必须与输入匹配 |
| aggregate_reward_score | 样本综合分数 | 浮点数(如 0.0 – 1.0 或任务自定义区间) |
| metrics_list | 构成综合评分的各单项评分 | 指标对象数组 |
技术约束
-
超时限制:每次 Lambda 调用最长执行时间为 15 分钟
-
并发能力:必须能够处理
rollout_worker_replicas * 64个并发请求 -
可靠性:必须实现完善的错误处理机制,并稳定返回有效评分
-
性能:优化执行速度(秒级而非分钟级),保障训练高效运行
最佳实践
-
尽量减少外部 API 调用
-
使用高效的算法与数据结构
-
为瞬时故障实现重试逻辑
-
缓存可重复使用的计算结果
-
训练前充分测试,确保执行无异常
使用自定义奖励函数
当存在针对特定任务的评测标准时,可实现自定义奖励函数:
-
定义评测标准:明确在当前任务中优质响应的判定依据
-
实现 Lambda 函数:按照接口规范创建 Lambda 函数
-
本地测试:验证函数对样本输入能返回正确评分
-
部署到 AWS:部署 Lambda 并记录其 ARN
-
配置配方:将 Lambda ARN 填入配方的
reward_lambda_arn字段 -
小规模数据集测试: 使用少量数据运行 RFT,验证集成效果
IAM 权限
所需的权限
SageMaker 执行角色必须具备调用 Lambda 函数的权限。将此策略添加到 SageMaker 执行角色:
{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "lambda:InvokeFunction" ], "Resource": "arn:aws:lambda:region:account-id:function:function-name" } ] }
Lambda 执行角色
Lambda 函数的执行角色需要基本的 Lambda 执行权限:
{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents" ], "Resource": "arn:aws:logs:*:*:*" } ] }
其他权限:如果 Lambda 函数访问其他 AWS 服务(例如用于参考数据的 S3、用于日志记录的 DynamoDB),请将这些权限添加到 Lambda 执行角色。
示例:LLM-as-a-judge 奖励函数
此示例演示了如何使用 Amazon Bedrock 模型作为评判工具,通过将模型响应与参考答案进行比较来评测模型响应。此 Lambda 模板为用户提供框架,支持调用 Amazon Bedrock 发起推理请求,以完成评判评测处理。此 Lambda 函数与其他奖励函数采用一致的输入输出约定。
实施
此 Lambda 函数采用两阶段评测流程:lambda_handler 模块从输入样本中提取模型响应与参考答案,随后 lambda_graded 函数调用 Amazon Bedrock 对二者的语义相似度进行评分。该实现具备完善的错误处理机制,可对瞬时故障自动重试,并支持多种灵活的参考答案格式(字符串格式与结构化字典格式)。
实现细节:
-
重试逻辑:针对限流异常实现指数回退(1 秒、2 秒、4 秒),以应对 Bedrock API 速率节流
-
错误处理:评测失败时返回 0.0 分,而非抛出异常
-
确定性评分:使用 temperature=0.0 确保多次评测结果一致
-
灵活参考格式:自动支持字符串与字典两种格式的参考答案
-
分数裁剪:确保所有分数落在有效的 [0.0, 1.0] 区间内
-
模型无关:更改 JUDGE_MODEL_ID 即可使用任何 Amazon Bedrock 模型(Nova、Llama、Mistral 等)
""" LLM Judge Lambda POC - Working implementation using Amazon Bedrock """ import json import time import boto3 bedrock_runtime = boto3.client('bedrock-runtime', region_name='us-east-1') JUDGE_MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0" SYSTEM_PROMPT = "You must output ONLY a number between 0.0 and 1.0. No explanations, no text, just the number." JUDGE_PROMPT_TEMPLATE = """Compare the following two responses and rate how similar they are on a scale of 0.0 to 1.0, where: - 1.0 means the responses are semantically equivalent (same meaning, even if worded differently) - 0.5 means the responses are partially similar - 0.0 means the responses are completely different or contradictory Response A: {response_a} Response B: {response_b} Output ONLY a number between 0.0 and 1.0. No explanations.""" def lambda_graded(response_a: str, response_b: str, max_retries: int = 3) -> float: """Call Bedrock to compare responses and return similarity score.""" prompt = JUDGE_PROMPT_TEMPLATE.format(response_a=response_a, response_b=response_b) for attempt in range(max_retries): try: response = bedrock_runtime.converse( modelId=JUDGE_MODEL_ID, messages=[{"role": "user", "content": [{"text": prompt}]}], system=[{"text": SYSTEM_PROMPT}], inferenceConfig={"temperature": 0.0, "maxTokens": 10} ) print(f"Bedrock call successful: {response}") output = response['output']['message']['content'][0]['text'].strip() score = float(output) print(f"Score parsed: {score}") return max(0.0, min(1.0, score)) except Exception as e: if "ThrottlingException" in str(e) and attempt < max_retries - 1: time.sleep(2 ** attempt) else: print(f"Bedrock call failed: {e}") return None return None def lambda_handler(event, context): """AWS Lambda handler - processes samples from RFTEvalInvoker.""" try: samples = event if isinstance(event, list) else [event] results = [] for sample in samples: sample_id = sample.get("id", "unknown") messages = sample.get("messages", []) # Extract assistant response (response A) response_a = "" for msg in messages: if msg.get("role") in ["assistant", "nova_assistant"]: response_a = msg.get("content", "") break # Extract reference answer from root level (no longer in metadata) reference_answer = sample.get("reference_answer", "") # Handle both string and dict reference_answer formats if isinstance(reference_answer, dict): # If reference_answer is a dict, extract the explanation or compliant field response_b = reference_answer.get("explanation", reference_answer.get("compliant", "")) else: response_b = reference_answer if not response_a or not response_b: results.append({ "id": sample_id, "aggregate_reward_score": 0.0, "metrics_list": [{"name": "similarity_score", "value": 0.0, "type": "Metric"}] }) continue # Get similarity score score = lambda_graded(response_a, response_b) results.append({ "id": sample_id, "aggregate_reward_score": score, "metrics_list": [ { "name": "similarity_score", "value": score, "type": "Metric" } ] }) return {"statusCode": 200, "body": json.dumps(results)} except Exception as e: print(f"Error: {e}") return {"statusCode": 500, "body": json.dumps({"error": str(e)})}
输入格式
Lambda 采用与其他奖励函数相同的输入格式:
{ "id": "sample-001", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don't have a dedicated security team..." } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." }, "my_custom_field": "custom_value" }
输出格式
{ "id": "sample-001", "aggregate_reward_score": 0.85, "metrics_list": [ { "name": "similarity_score", "value": 0.85, "type": "Metric" } ] }
部署注意事项
您可能还需要根据所选模型的能力和 API 格式调整提示模板和推理参数。
-
IAM 权限:Lambda 执行角色必须具备对所选模型的
bedrock:InvokeModel调用权限 -
超时:将 Lambda 超时设置为至少 60 秒,以适配 Bedrock API 延迟和重试需求
-
区域:在所选 Bedrock 模型可用的区域部署
-
成本:监控 Bedrock API 使用量,每次评测中每个样本均会发起一次 API 调用
-
吞吐量:大规模评测时,申请提升 Bedrock 配额以避免节流
提高 Bedrock 吞吐量
如果在评测过程中遇到节流,可提升 Bedrock 模型配额:
-
进入 AWS 服务配额控制台
-
搜索 Bedrock 并选择对应区域
-
找到所选模型的配额项(例如“Claude 3.5 Sonnet 每分钟调用次数”)
-
点击 “请求增加配额”并指定所需的吞吐量
-
提供增加配额的理由(例如“RFT 评测工作负载”)
Lambda 内置的重试逻辑可处理偶发节流,而持续高吞吐量评测需相应提升配额。
所需的 IAM 策略:
{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "bedrock:InvokeModel" ], "Resource": "arn:aws:bedrock:*::foundation-model/*" } ] }