1use super::enums::{ObjectiveDirection, ObjectiveKey, RewardScope, RewardSource, RewardType};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ObjectiveSpec {
13 pub key: ObjectiveKey,
15 pub direction: ObjectiveDirection,
17 #[serde(default)]
19 pub units: Option<String>,
20 #[serde(default)]
22 pub description: Option<String>,
23 #[serde(default)]
25 pub target: Option<f64>,
26 #[serde(default)]
28 pub min_value: Option<f64>,
29 #[serde(default)]
31 pub max_value: Option<f64>,
32}
33
34impl ObjectiveSpec {
35 pub fn new(key: ObjectiveKey, direction: ObjectiveDirection) -> Self {
37 Self {
38 key,
39 direction,
40 units: None,
41 description: None,
42 target: None,
43 min_value: None,
44 max_value: None,
45 }
46 }
47
48 pub fn maximize_reward() -> Self {
50 Self::new(ObjectiveKey::Reward, ObjectiveDirection::Maximize)
51 }
52
53 pub fn minimize_cost() -> Self {
55 Self::new(ObjectiveKey::CostUsd, ObjectiveDirection::Minimize).with_units("usd")
56 }
57
58 pub fn minimize_latency() -> Self {
60 Self::new(ObjectiveKey::LatencyMs, ObjectiveDirection::Minimize).with_units("ms")
61 }
62
63 pub fn with_units(mut self, units: impl Into<String>) -> Self {
65 self.units = Some(units.into());
66 self
67 }
68
69 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
71 self.description = Some(desc.into());
72 self
73 }
74
75 pub fn satisfies_constraints(&self, value: f64) -> bool {
77 if let Some(min) = self.min_value {
78 if value < min {
79 return false;
80 }
81 }
82 if let Some(max) = self.max_value {
83 if value > max {
84 return false;
85 }
86 }
87 true
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct RewardObservation {
94 pub value: f64,
96 #[serde(default)]
98 pub reward_type: RewardType,
99 #[serde(default)]
101 pub scope: RewardScope,
102 #[serde(default)]
104 pub source: RewardSource,
105 #[serde(default)]
107 pub objective_key: ObjectiveKey,
108 #[serde(default)]
110 pub event_id: Option<String>,
111 #[serde(default)]
113 pub turn_number: Option<i32>,
114 #[serde(default)]
116 pub metadata: HashMap<String, Value>,
117}
118
119impl RewardObservation {
120 pub fn new(value: f64) -> Self {
122 Self {
123 value,
124 reward_type: RewardType::default(),
125 scope: RewardScope::default(),
126 source: RewardSource::default(),
127 objective_key: ObjectiveKey::default(),
128 event_id: None,
129 turn_number: None,
130 metadata: HashMap::new(),
131 }
132 }
133
134 pub fn outcome(value: f64) -> Self {
136 Self::new(value).with_scope(RewardScope::Outcome)
137 }
138
139 pub fn event(value: f64, event_id: impl Into<String>) -> Self {
141 let mut obs = Self::new(value).with_scope(RewardScope::Event);
142 obs.event_id = Some(event_id.into());
143 obs
144 }
145
146 pub fn with_type(mut self, reward_type: RewardType) -> Self {
148 self.reward_type = reward_type;
149 self
150 }
151
152 pub fn with_scope(mut self, scope: RewardScope) -> Self {
154 self.scope = scope;
155 self
156 }
157
158 pub fn with_source(mut self, source: RewardSource) -> Self {
160 self.source = source;
161 self
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct OutcomeObjectiveAssignment {
168 pub objectives: HashMap<String, f64>,
170 #[serde(default)]
172 pub session_id: Option<String>,
173 #[serde(default)]
175 pub trace_id: Option<String>,
176 #[serde(default)]
178 pub metadata: HashMap<String, Value>,
179}
180
181impl OutcomeObjectiveAssignment {
182 pub fn new() -> Self {
184 Self {
185 objectives: HashMap::new(),
186 session_id: None,
187 trace_id: None,
188 metadata: HashMap::new(),
189 }
190 }
191
192 pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
194 self.objectives.insert(key.into(), value);
195 self
196 }
197
198 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
200 self.session_id = Some(session_id.into());
201 self
202 }
203}
204
205impl Default for OutcomeObjectiveAssignment {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct EventObjectiveAssignment {
214 pub event_id: String,
216 pub objectives: HashMap<String, f64>,
218 #[serde(default)]
220 pub turn_number: Option<i32>,
221 #[serde(default)]
223 pub metadata: HashMap<String, Value>,
224}
225
226impl EventObjectiveAssignment {
227 pub fn new(event_id: impl Into<String>) -> Self {
229 Self {
230 event_id: event_id.into(),
231 objectives: HashMap::new(),
232 turn_number: None,
233 metadata: HashMap::new(),
234 }
235 }
236
237 pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
239 self.objectives.insert(key.into(), value);
240 self
241 }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct InstanceObjectiveAssignment {
247 pub instance_id: String,
249 pub objectives: HashMap<String, f64>,
251 #[serde(default)]
253 pub split: Option<String>,
254 #[serde(default)]
256 pub metadata: HashMap<String, Value>,
257}
258
259impl InstanceObjectiveAssignment {
260 pub fn new(instance_id: impl Into<String>) -> Self {
262 Self {
263 instance_id: instance_id.into(),
264 objectives: HashMap::new(),
265 split: None,
266 metadata: HashMap::new(),
267 }
268 }
269
270 pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
272 self.objectives.insert(key.into(), value);
273 self
274 }
275
276 pub fn with_split(mut self, split: impl Into<String>) -> Self {
278 self.split = Some(split.into());
279 self
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_objective_spec() {
289 let spec = ObjectiveSpec::maximize_reward();
290 assert_eq!(spec.key, ObjectiveKey::Reward);
291 assert_eq!(spec.direction, ObjectiveDirection::Maximize);
292 }
293
294 #[test]
295 fn test_objective_constraints() {
296 let mut spec = ObjectiveSpec::minimize_latency();
297 spec.max_value = Some(1000.0);
298
299 assert!(spec.satisfies_constraints(500.0));
300 assert!(!spec.satisfies_constraints(1500.0));
301 }
302
303 #[test]
304 fn test_reward_observation() {
305 let obs = RewardObservation::outcome(0.95)
306 .with_type(RewardType::Sparse)
307 .with_source(RewardSource::Verifier);
308
309 assert_eq!(obs.value, 0.95);
310 assert_eq!(obs.scope, RewardScope::Outcome);
311 assert_eq!(obs.source, RewardSource::Verifier);
312 }
313
314 #[test]
315 fn test_outcome_assignment() {
316 let assignment = OutcomeObjectiveAssignment::new()
317 .with_objective("reward", 0.85)
318 .with_objective("cost_usd", 0.002)
319 .with_session("session-123");
320
321 assert_eq!(assignment.objectives.get("reward"), Some(&0.85));
322 assert_eq!(assignment.session_id, Some("session-123".to_string()));
323 }
324
325 #[test]
326 fn test_serde() {
327 let obs = RewardObservation::outcome(1.0);
328 let json = serde_json::to_string(&obs).unwrap();
329 let parsed: RewardObservation = serde_json::from_str(&json).unwrap();
330 assert_eq!(parsed.value, 1.0);
331 }
332}