Skip to main content

synth_ai_core/data/
rewards.rs

1//! Reward data structures.
2//!
3//! These mirror the Python `synth_ai.data.rewards` module and provide
4//! pure data types for reward annotations and aggregates.
5
6use super::enums::{RewardSource, RewardType};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11/// Episode-level reward summary.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct OutcomeRewardRecord {
14    /// Session ID this reward corresponds to.
15    pub session_id: String,
16    /// Aggregate reward value.
17    pub total_reward: f64,
18    /// Objective key (default "reward").
19    #[serde(default = "default_objective_key")]
20    pub objective_key: String,
21    /// Number of achievements.
22    #[serde(default)]
23    pub achievements_count: i32,
24    /// Total steps in the episode.
25    #[serde(default)]
26    pub total_steps: i32,
27    /// Optional metadata.
28    #[serde(default)]
29    pub metadata: HashMap<String, Value>,
30    /// Optional annotation.
31    #[serde(default)]
32    pub annotation: HashMap<String, Value>,
33    /// Creation timestamp (ISO string).
34    #[serde(default)]
35    pub created_at: Option<String>,
36}
37
38fn default_objective_key() -> String {
39    "reward".to_string()
40}
41
42/// Event-level reward annotation.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct EventRewardRecord {
45    /// Event ID being rewarded.
46    pub event_id: String,
47    /// Session ID containing the event.
48    pub session_id: String,
49    /// Reward value.
50    pub reward_value: f64,
51    /// Objective key (default "reward").
52    #[serde(default = "default_objective_key")]
53    pub objective_key: String,
54    /// Optional reward type.
55    #[serde(default)]
56    pub reward_type: Option<RewardType>,
57    /// Optional rubric criterion key.
58    #[serde(default)]
59    pub key: Option<String>,
60    /// Optional turn number.
61    #[serde(default)]
62    pub turn_number: Option<i32>,
63    /// Optional reward source.
64    #[serde(default)]
65    pub source: Option<RewardSource>,
66    /// Optional annotation.
67    #[serde(default)]
68    pub annotation: HashMap<String, Value>,
69    /// Creation timestamp (ISO string).
70    #[serde(default)]
71    pub created_at: Option<String>,
72}
73
74/// Aggregated statistics for rewards.
75#[derive(Debug, Clone, Default, Serialize, Deserialize)]
76pub struct RewardAggregates {
77    pub mean: f64,
78    #[serde(default)]
79    pub median: f64,
80    #[serde(default)]
81    pub std: f64,
82    #[serde(default)]
83    pub n: i32,
84    #[serde(default)]
85    pub min_value: Option<f64>,
86    #[serde(default)]
87    pub max_value: Option<f64>,
88}
89
90/// Calibration example for verifier evaluation.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct CalibrationExample {
93    /// Full session trace (V3/V4 format).
94    pub session_trace: Value,
95    /// Rewards per event.
96    pub event_rewards: Vec<f64>,
97    /// Overall outcome reward.
98    pub outcome_reward: f64,
99    /// Optional metadata.
100    #[serde(default)]
101    pub metadata: HashMap<String, Value>,
102}
103
104/// Gold-standard example for contrastive evaluation.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct GoldExample {
107    /// Summary of the trace.
108    pub summary: String,
109    /// Gold score.
110    pub gold_score: f64,
111    /// Gold reasoning/explanation.
112    pub gold_reasoning: String,
113    /// Optional full session trace.
114    #[serde(default)]
115    pub session_trace: Option<Value>,
116    /// Optional metadata.
117    #[serde(default)]
118    pub metadata: HashMap<String, Value>,
119}