1use super::enums::{ObjectiveDirection, ObjectiveKey, RewardScope, RewardSource, RewardType};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ObjectiveSpec {
13 pub key: ObjectiveKey,
15 pub direction: ObjectiveDirection,
17 #[serde(default)]
19 pub units: Option<String>,
20 #[serde(default)]
22 pub description: Option<String>,
23 #[serde(default)]
25 pub target: Option<f64>,
26 #[serde(default)]
28 pub min_value: Option<f64>,
29 #[serde(default)]
31 pub max_value: Option<f64>,
32}
33
34impl ObjectiveSpec {
35 pub fn new(key: ObjectiveKey, direction: ObjectiveDirection) -> Self {
37 Self {
38 key,
39 direction,
40 units: None,
41 description: None,
42 target: None,
43 min_value: None,
44 max_value: None,
45 }
46 }
47
48 pub fn maximize_reward() -> Self {
50 Self::new(ObjectiveKey::Reward, ObjectiveDirection::Maximize)
51 }
52
53 pub fn minimize_cost() -> Self {
55 Self::new(ObjectiveKey::CostUsd, ObjectiveDirection::Minimize)
56 .with_units("usd")
57 }
58
59 pub fn minimize_latency() -> Self {
61 Self::new(ObjectiveKey::LatencyMs, ObjectiveDirection::Minimize)
62 .with_units("ms")
63 }
64
65 pub fn with_units(mut self, units: impl Into<String>) -> Self {
67 self.units = Some(units.into());
68 self
69 }
70
71 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
73 self.description = Some(desc.into());
74 self
75 }
76
77 pub fn satisfies_constraints(&self, value: f64) -> bool {
79 if let Some(min) = self.min_value {
80 if value < min {
81 return false;
82 }
83 }
84 if let Some(max) = self.max_value {
85 if value > max {
86 return false;
87 }
88 }
89 true
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct RewardObservation {
96 pub value: f64,
98 #[serde(default)]
100 pub reward_type: RewardType,
101 #[serde(default)]
103 pub scope: RewardScope,
104 #[serde(default)]
106 pub source: RewardSource,
107 #[serde(default)]
109 pub objective_key: ObjectiveKey,
110 #[serde(default)]
112 pub event_id: Option<String>,
113 #[serde(default)]
115 pub turn_number: Option<i32>,
116 #[serde(default)]
118 pub metadata: HashMap<String, Value>,
119}
120
121impl RewardObservation {
122 pub fn new(value: f64) -> Self {
124 Self {
125 value,
126 reward_type: RewardType::default(),
127 scope: RewardScope::default(),
128 source: RewardSource::default(),
129 objective_key: ObjectiveKey::default(),
130 event_id: None,
131 turn_number: None,
132 metadata: HashMap::new(),
133 }
134 }
135
136 pub fn outcome(value: f64) -> Self {
138 Self::new(value).with_scope(RewardScope::Outcome)
139 }
140
141 pub fn event(value: f64, event_id: impl Into<String>) -> Self {
143 let mut obs = Self::new(value).with_scope(RewardScope::Event);
144 obs.event_id = Some(event_id.into());
145 obs
146 }
147
148 pub fn with_type(mut self, reward_type: RewardType) -> Self {
150 self.reward_type = reward_type;
151 self
152 }
153
154 pub fn with_scope(mut self, scope: RewardScope) -> Self {
156 self.scope = scope;
157 self
158 }
159
160 pub fn with_source(mut self, source: RewardSource) -> Self {
162 self.source = source;
163 self
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct OutcomeObjectiveAssignment {
170 pub objectives: HashMap<String, f64>,
172 #[serde(default)]
174 pub session_id: Option<String>,
175 #[serde(default)]
177 pub trace_id: Option<String>,
178 #[serde(default)]
180 pub metadata: HashMap<String, Value>,
181}
182
183impl OutcomeObjectiveAssignment {
184 pub fn new() -> Self {
186 Self {
187 objectives: HashMap::new(),
188 session_id: None,
189 trace_id: None,
190 metadata: HashMap::new(),
191 }
192 }
193
194 pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
196 self.objectives.insert(key.into(), value);
197 self
198 }
199
200 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
202 self.session_id = Some(session_id.into());
203 self
204 }
205}
206
207impl Default for OutcomeObjectiveAssignment {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct EventObjectiveAssignment {
216 pub event_id: String,
218 pub objectives: HashMap<String, f64>,
220 #[serde(default)]
222 pub turn_number: Option<i32>,
223 #[serde(default)]
225 pub metadata: HashMap<String, Value>,
226}
227
228impl EventObjectiveAssignment {
229 pub fn new(event_id: impl Into<String>) -> Self {
231 Self {
232 event_id: event_id.into(),
233 objectives: HashMap::new(),
234 turn_number: None,
235 metadata: HashMap::new(),
236 }
237 }
238
239 pub fn with_objective(mut self, key: impl Into<String>, value: f64) -> Self {
241 self.objectives.insert(key.into(), value);
242 self
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_objective_spec() {
252 let spec = ObjectiveSpec::maximize_reward();
253 assert_eq!(spec.key, ObjectiveKey::Reward);
254 assert_eq!(spec.direction, ObjectiveDirection::Maximize);
255 }
256
257 #[test]
258 fn test_objective_constraints() {
259 let mut spec = ObjectiveSpec::minimize_latency();
260 spec.max_value = Some(1000.0);
261
262 assert!(spec.satisfies_constraints(500.0));
263 assert!(!spec.satisfies_constraints(1500.0));
264 }
265
266 #[test]
267 fn test_reward_observation() {
268 let obs = RewardObservation::outcome(0.95)
269 .with_type(RewardType::Sparse)
270 .with_source(RewardSource::Verifier);
271
272 assert_eq!(obs.value, 0.95);
273 assert_eq!(obs.scope, RewardScope::Outcome);
274 assert_eq!(obs.source, RewardSource::Verifier);
275 }
276
277 #[test]
278 fn test_outcome_assignment() {
279 let assignment = OutcomeObjectiveAssignment::new()
280 .with_objective("reward", 0.85)
281 .with_objective("cost_usd", 0.002)
282 .with_session("session-123");
283
284 assert_eq!(assignment.objectives.get("reward"), Some(&0.85));
285 assert_eq!(assignment.session_id, Some("session-123".to_string()));
286 }
287
288 #[test]
289 fn test_serde() {
290 let obs = RewardObservation::outcome(1.0);
291 let json = serde_json::to_string(&obs).unwrap();
292 let parsed: RewardObservation = serde_json::from_str(&json).unwrap();
293 assert_eq!(parsed.value, 1.0);
294 }
295}