Skip to main content

synth_ai_core/data/
objectives.rs

1//! Objective specifications and reward observations.
2//!
3//! Types for defining optimization objectives and recording reward signals.
4
5use super::enums::{ObjectiveDirection, ObjectiveKey, RewardScope, RewardSource, RewardType};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10/// Specification for an optimization objective.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ObjectiveSpec {
13    /// Key identifying the objective metric.
14    pub key: ObjectiveKey,
15    /// Whether to maximize or minimize this objective.
16    pub direction: ObjectiveDirection,
17    /// Optional units (e.g., "ms", "usd", "tokens").
18    #[serde(default)]
19    pub units: Option<String>,
20    /// Human-readable description.
21    #[serde(default)]
22    pub description: Option<String>,
23    /// Target value (for constrained optimization).
24    #[serde(default)]
25    pub target: Option<f64>,
26    /// Minimum acceptable value.
27    #[serde(default)]
28    pub min_value: Option<f64>,
29    /// Maximum acceptable value.
30    #[serde(default)]
31    pub max_value: Option<f64>,
32}
33
34impl ObjectiveSpec {
35    /// Create a new objective spec.
36    pub fn new(key: ObjectiveKey, direction: ObjectiveDirection) -> Self {
37        Self {
38            key,
39            direction,
40            units: None,
41            description: None,
42            target: None,
43            min_value: None,
44            max_value: None,
45        }
46    }
47
48    /// Create a reward maximization objective.
49    pub fn maximize_reward() -> Self {
50        Self::new(ObjectiveKey::Reward, ObjectiveDirection::Maximize)
51    }
52
53    /// Create a cost minimization objective.
54    pub fn minimize_cost() -> Self {
55        Self::new(ObjectiveKey::CostUsd, ObjectiveDirection::Minimize).with_units("usd")
56    }
57
58    /// Create a latency minimization objective.
59    pub fn minimize_latency() -> Self {
60        Self::new(ObjectiveKey::LatencyMs, ObjectiveDirection::Minimize).with_units("ms")
61    }
62
63    /// Set units for this objective.
64    pub fn with_units(mut self, units: impl Into<String>) -> Self {
65        self.units = Some(units.into());
66        self
67    }
68
69    /// Set description for this objective.
70    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
71        self.description = Some(desc.into());
72        self
73    }
74
75    /// Check if a value satisfies the objective's constraints.
76    pub fn satisfies_constraints(&self, value: f64) -> bool {
77        if let Some(min) = self.min_value {
78            if value < min {
79                return false;
80            }
81        }
82        if let Some(max) = self.max_value {
83            if value > max {
84                return false;
85            }
86        }
87        true
88    }
89}
90
91/// A single reward observation.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RewardObservation {
94    /// The reward value.
95    pub value: f64,
96    /// Type of reward signal.
97    #[serde(default)]
98    pub reward_type: RewardType,
99    /// Whether this is event-level or outcome-level.
100    #[serde(default)]
101    pub scope: RewardScope,
102    /// Source of the reward signal.
103    #[serde(default)]
104    pub source: RewardSource,
105    /// Which objective this reward corresponds to.
106    #[serde(default)]
107    pub objective_key: ObjectiveKey,
108    /// Optional event ID (for event-level rewards).
109    #[serde(default)]
110    pub event_id: Option<String>,
111    /// Optional turn number.
112    #[serde(default)]
113    pub turn_number: Option<i32>,
114    /// Additional metadata.
115    #[serde(default)]
116    pub metadata: HashMap<String, Value>,
117}
118
119impl RewardObservation {
120    /// Create a new reward observation.
121    pub fn new(value: f64) -> Self {
122        Self {
123            value,
124            reward_type: RewardType::default(),
125            scope: RewardScope::default(),
126            source: RewardSource::default(),
127            objective_key: ObjectiveKey::default(),
128            event_id: None,
129            turn_number: None,
130            metadata: HashMap::new(),
131        }
132    }
133
134    /// Create an outcome-level reward.
135    pub fn outcome(value: f64) -> Self {
136        Self::new(value).with_scope(RewardScope::Outcome)
137    }
138
139    /// Create an event-level reward.
140    pub fn event(value: f64, event_id: impl Into<String>) -> Self {
141        let mut obs = Self::new(value).with_scope(RewardScope::Event);
142        obs.event_id = Some(event_id.into());
143        obs
144    }
145
146    /// Set the reward type.
147    pub fn with_type(mut self, reward_type: RewardType) -> Self {
148        self.reward_type = reward_type;
149        self
150    }
151
152    /// Set the scope.
153    pub fn with_scope(mut self, scope: RewardScope) -> Self {
154        self.scope = scope;
155        self
156    }
157
158    /// Set the source.
159    pub fn with_source(mut self, source: RewardSource) -> Self {
160        self.source = source;
161        self
162    }
163}
164
165/// Assignment of objectives to an outcome (session-level).
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct OutcomeObjectiveAssignment {
168    /// Map of objective key to value.
169    pub objectives: HashMap<String, f64>,
170    /// Session ID this assignment belongs to.
171    #[serde(default)]
172    pub session_id: Option<String>,
173    /// Trace correlation ID.
174    #[serde(default)]
175    pub trace_id: Option<String>,
176    /// Additional metadata.
177    #[serde(default)]
178    pub metadata: HashMap<String, Value>,
179}
180
181impl OutcomeObjectiveAssignment {
182    /// Create a new outcome objective assignment.
183    pub fn new() -> Self {
184        Self {
185            objectives: HashMap::new(),
186            session_id: None,
187            trace_id: None,
188            metadata: HashMap::new(),
189        }
190    }
191
192    /// Add an objective value.
193    pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
194        self.objectives.insert(key.into(), value);
195        self
196    }
197
198    /// Set the session ID.
199    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
200        self.session_id = Some(session_id.into());
201        self
202    }
203}
204
205impl Default for OutcomeObjectiveAssignment {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211/// Assignment of objectives to an event.
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EventObjectiveAssignment {
214    /// Event ID this assignment belongs to.
215    pub event_id: String,
216    /// Map of objective key to value.
217    pub objectives: HashMap<String, f64>,
218    /// Turn number.
219    #[serde(default)]
220    pub turn_number: Option<i32>,
221    /// Additional metadata.
222    #[serde(default)]
223    pub metadata: HashMap<String, Value>,
224}
225
226impl EventObjectiveAssignment {
227    /// Create a new event objective assignment.
228    pub fn new(event_id: impl Into<String>) -> Self {
229        Self {
230            event_id: event_id.into(),
231            objectives: HashMap::new(),
232            turn_number: None,
233            metadata: HashMap::new(),
234        }
235    }
236
237    /// Add an objective value.
238    pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
239        self.objectives.insert(key.into(), value);
240        self
241    }
242}
243
244/// Assignment of objectives to a specific instance (dataset example).
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct InstanceObjectiveAssignment {
247    /// Instance ID or seed.
248    pub instance_id: String,
249    /// Map of objective key to value.
250    pub objectives: HashMap<String, f64>,
251    /// Optional dataset split.
252    #[serde(default)]
253    pub split: Option<String>,
254    /// Additional metadata.
255    #[serde(default)]
256    pub metadata: HashMap<String, Value>,
257}
258
259impl InstanceObjectiveAssignment {
260    /// Create a new instance objective assignment.
261    pub fn new(instance_id: impl Into<String>) -> Self {
262        Self {
263            instance_id: instance_id.into(),
264            objectives: HashMap::new(),
265            split: None,
266            metadata: HashMap::new(),
267        }
268    }
269
270    /// Add an objective value.
271    pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
272        self.objectives.insert(key.into(), value);
273        self
274    }
275
276    /// Set the dataset split.
277    pub fn with_split(mut self, split: impl Into<String>) -> Self {
278        self.split = Some(split.into());
279        self
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_objective_spec() {
289        let spec = ObjectiveSpec::maximize_reward();
290        assert_eq!(spec.key, ObjectiveKey::Reward);
291        assert_eq!(spec.direction, ObjectiveDirection::Maximize);
292    }
293
294    #[test]
295    fn test_objective_constraints() {
296        let mut spec = ObjectiveSpec::minimize_latency();
297        spec.max_value = Some(1000.0);
298
299        assert!(spec.satisfies_constraints(500.0));
300        assert!(!spec.satisfies_constraints(1500.0));
301    }
302
303    #[test]
304    fn test_reward_observation() {
305        let obs = RewardObservation::outcome(0.95)
306            .with_type(RewardType::Sparse)
307            .with_source(RewardSource::Verifier);
308
309        assert_eq!(obs.value, 0.95);
310        assert_eq!(obs.scope, RewardScope::Outcome);
311        assert_eq!(obs.source, RewardSource::Verifier);
312    }
313
314    #[test]
315    fn test_outcome_assignment() {
316        let assignment = OutcomeObjectiveAssignment::new()
317            .with_objective("reward", 0.85)
318            .with_objective("cost_usd", 0.002)
319            .with_session("session-123");
320
321        assert_eq!(assignment.objectives.get("reward"), Some(&0.85));
322        assert_eq!(assignment.session_id, Some("session-123".to_string()));
323    }
324
325    #[test]
326    fn test_serde() {
327        let obs = RewardObservation::outcome(1.0);
328        let json = serde_json::to_string(&obs).unwrap();
329        let parsed: RewardObservation = serde_json::from_str(&json).unwrap();
330        assert_eq!(parsed.value, 1.0);
331    }
332}