Skip to main content

synth_ai_core/data/
objectives.rs

1//! Objective specifications and reward observations.
2//!
3//! Types for defining optimization objectives and recording reward signals.
4
5use super::enums::{ObjectiveDirection, ObjectiveKey, RewardScope, RewardSource, RewardType};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10/// Specification for an optimization objective.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ObjectiveSpec {
13    /// Key identifying the objective metric.
14    pub key: ObjectiveKey,
15    /// Whether to maximize or minimize this objective.
16    pub direction: ObjectiveDirection,
17    /// Optional units (e.g., "ms", "usd", "tokens").
18    #[serde(default)]
19    pub units: Option<String>,
20    /// Human-readable description.
21    #[serde(default)]
22    pub description: Option<String>,
23    /// Target value (for constrained optimization).
24    #[serde(default)]
25    pub target: Option<f64>,
26    /// Minimum acceptable value.
27    #[serde(default)]
28    pub min_value: Option<f64>,
29    /// Maximum acceptable value.
30    #[serde(default)]
31    pub max_value: Option<f64>,
32}
33
34impl ObjectiveSpec {
35    /// Create a new objective spec.
36    pub fn new(key: ObjectiveKey, direction: ObjectiveDirection) -> Self {
37        Self {
38            key,
39            direction,
40            units: None,
41            description: None,
42            target: None,
43            min_value: None,
44            max_value: None,
45        }
46    }
47
48    /// Create a reward maximization objective.
49    pub fn maximize_reward() -> Self {
50        Self::new(ObjectiveKey::Reward, ObjectiveDirection::Maximize)
51    }
52
53    /// Create a cost minimization objective.
54    pub fn minimize_cost() -> Self {
55        Self::new(ObjectiveKey::CostUsd, ObjectiveDirection::Minimize)
56            .with_units("usd")
57    }
58
59    /// Create a latency minimization objective.
60    pub fn minimize_latency() -> Self {
61        Self::new(ObjectiveKey::LatencyMs, ObjectiveDirection::Minimize)
62            .with_units("ms")
63    }
64
65    /// Set units for this objective.
66    pub fn with_units(mut self, units: impl Into<String>) -> Self {
67        self.units = Some(units.into());
68        self
69    }
70
71    /// Set description for this objective.
72    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
73        self.description = Some(desc.into());
74        self
75    }
76
77    /// Check if a value satisfies the objective's constraints.
78    pub fn satisfies_constraints(&self, value: f64) -> bool {
79        if let Some(min) = self.min_value {
80            if value < min {
81                return false;
82            }
83        }
84        if let Some(max) = self.max_value {
85            if value > max {
86                return false;
87            }
88        }
89        true
90    }
91}
92
93/// A single reward observation.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct RewardObservation {
96    /// The reward value.
97    pub value: f64,
98    /// Type of reward signal.
99    #[serde(default)]
100    pub reward_type: RewardType,
101    /// Whether this is event-level or outcome-level.
102    #[serde(default)]
103    pub scope: RewardScope,
104    /// Source of the reward signal.
105    #[serde(default)]
106    pub source: RewardSource,
107    /// Which objective this reward corresponds to.
108    #[serde(default)]
109    pub objective_key: ObjectiveKey,
110    /// Optional event ID (for event-level rewards).
111    #[serde(default)]
112    pub event_id: Option<String>,
113    /// Optional turn number.
114    #[serde(default)]
115    pub turn_number: Option<i32>,
116    /// Additional metadata.
117    #[serde(default)]
118    pub metadata: HashMap<String, Value>,
119}
120
121impl RewardObservation {
122    /// Create a new reward observation.
123    pub fn new(value: f64) -> Self {
124        Self {
125            value,
126            reward_type: RewardType::default(),
127            scope: RewardScope::default(),
128            source: RewardSource::default(),
129            objective_key: ObjectiveKey::default(),
130            event_id: None,
131            turn_number: None,
132            metadata: HashMap::new(),
133        }
134    }
135
136    /// Create an outcome-level reward.
137    pub fn outcome(value: f64) -> Self {
138        Self::new(value).with_scope(RewardScope::Outcome)
139    }
140
141    /// Create an event-level reward.
142    pub fn event(value: f64, event_id: impl Into<String>) -> Self {
143        let mut obs = Self::new(value).with_scope(RewardScope::Event);
144        obs.event_id = Some(event_id.into());
145        obs
146    }
147
148    /// Set the reward type.
149    pub fn with_type(mut self, reward_type: RewardType) -> Self {
150        self.reward_type = reward_type;
151        self
152    }
153
154    /// Set the scope.
155    pub fn with_scope(mut self, scope: RewardScope) -> Self {
156        self.scope = scope;
157        self
158    }
159
160    /// Set the source.
161    pub fn with_source(mut self, source: RewardSource) -> Self {
162        self.source = source;
163        self
164    }
165}
166
167/// Assignment of objectives to an outcome (session-level).
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct OutcomeObjectiveAssignment {
170    /// Map of objective key to value.
171    pub objectives: HashMap<String, f64>,
172    /// Session ID this assignment belongs to.
173    #[serde(default)]
174    pub session_id: Option<String>,
175    /// Trace correlation ID.
176    #[serde(default)]
177    pub trace_id: Option<String>,
178    /// Additional metadata.
179    #[serde(default)]
180    pub metadata: HashMap<String, Value>,
181}
182
183impl OutcomeObjectiveAssignment {
184    /// Create a new outcome objective assignment.
185    pub fn new() -> Self {
186        Self {
187            objectives: HashMap::new(),
188            session_id: None,
189            trace_id: None,
190            metadata: HashMap::new(),
191        }
192    }
193
194    /// Add an objective value.
195    pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
196        self.objectives.insert(key.into(), value);
197        self
198    }
199
200    /// Set the session ID.
201    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
202        self.session_id = Some(session_id.into());
203        self
204    }
205}
206
207impl Default for OutcomeObjectiveAssignment {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213/// Assignment of objectives to an event.
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct EventObjectiveAssignment {
216    /// Event ID this assignment belongs to.
217    pub event_id: String,
218    /// Map of objective key to value.
219    pub objectives: HashMap<String, f64>,
220    /// Turn number.
221    #[serde(default)]
222    pub turn_number: Option<i32>,
223    /// Additional metadata.
224    #[serde(default)]
225    pub metadata: HashMap<String, Value>,
226}
227
228impl EventObjectiveAssignment {
229    /// Create a new event objective assignment.
230    pub fn new(event_id: impl Into<String>) -> Self {
231        Self {
232            event_id: event_id.into(),
233            objectives: HashMap::new(),
234            turn_number: None,
235            metadata: HashMap::new(),
236        }
237    }
238
239    /// Add an objective value.
240    pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
241        self.objectives.insert(key.into(), value);
242        self
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_objective_spec() {
252        let spec = ObjectiveSpec::maximize_reward();
253        assert_eq!(spec.key, ObjectiveKey::Reward);
254        assert_eq!(spec.direction, ObjectiveDirection::Maximize);
255    }
256
257    #[test]
258    fn test_objective_constraints() {
259        let mut spec = ObjectiveSpec::minimize_latency();
260        spec.max_value = Some(1000.0);
261
262        assert!(spec.satisfies_constraints(500.0));
263        assert!(!spec.satisfies_constraints(1500.0));
264    }
265
266    #[test]
267    fn test_reward_observation() {
268        let obs = RewardObservation::outcome(0.95)
269            .with_type(RewardType::Sparse)
270            .with_source(RewardSource::Verifier);
271
272        assert_eq!(obs.value, 0.95);
273        assert_eq!(obs.scope, RewardScope::Outcome);
274        assert_eq!(obs.source, RewardSource::Verifier);
275    }
276
277    #[test]
278    fn test_outcome_assignment() {
279        let assignment = OutcomeObjectiveAssignment::new()
280            .with_objective("reward", 0.85)
281            .with_objective("cost_usd", 0.002)
282            .with_session("session-123");
283
284        assert_eq!(assignment.objectives.get("reward"), Some(&0.85));
285        assert_eq!(assignment.session_id, Some("session-123".to_string()));
286    }
287
288    #[test]
289    fn test_serde() {
290        let obs = RewardObservation::outcome(1.0);
291        let json = serde_json::to_string(&obs).unwrap();
292        let parsed: RewardObservation = serde_json::from_str(&json).unwrap();
293        assert_eq!(parsed.value, 1.0);
294    }
295}