1use crate::cache::{CacheKey, CacheTier};
7use crate::filter::FilterKind;
8use crate::graph::NodeId;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13pub type RunId = String;
15
16pub type StudyId = String;
18
19pub type TrialId = String;
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24pub struct MetricRecord {
25 pub name: String,
26 pub value: f64,
27 pub step: usize,
28 pub timestamp: DateTime<Utc>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PlanSummary {
34 pub total_nodes: usize,
35 pub cached_nodes: usize,
36 pub parallel_branches: usize,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(tag = "event_type")]
42#[non_exhaustive]
43pub enum Event {
44 RunStarted {
47 run_id: RunId,
48 plan_summary: PlanSummary,
49 },
50
51 NodeStarted {
53 run_id: RunId,
54 node_id: NodeId,
55 kind: FilterKind,
56 },
57
58 NodeProgress {
60 run_id: RunId,
61 node_id: NodeId,
62 progress: f32,
63 },
64
65 NodeCacheHit {
67 run_id: RunId,
68 node_id: NodeId,
69 key: CacheKey,
70 tier: CacheTier,
71 #[serde(with = "duration_millis")]
72 load_time: Duration,
73 },
74
75 NodeCompleted {
77 run_id: RunId,
78 node_id: NodeId,
79 #[serde(with = "duration_millis")]
80 duration: Duration,
81 output_summary: String,
82 },
83
84 NodeFailed {
86 run_id: RunId,
87 node_id: NodeId,
88 error: String,
89 },
90
91 RunCompleted {
93 run_id: RunId,
94 #[serde(with = "duration_millis")]
95 duration: Duration,
96 },
97
98 RunFailed { run_id: RunId, error: String },
100
101 TrialStarted {
104 study_id: StudyId,
105 trial_id: TrialId,
106 params: serde_json::Value,
107 },
108
109 TrialMetric {
111 study_id: StudyId,
112 trial_id: TrialId,
113 metric: MetricRecord,
114 },
115
116 TrialPruned {
118 study_id: StudyId,
119 trial_id: TrialId,
120 step: usize,
121 reason: String,
122 },
123
124 TrialCompleted {
126 study_id: StudyId,
127 trial_id: TrialId,
128 final_metrics: Vec<MetricRecord>,
129 },
130
131 TrialFailed {
133 study_id: StudyId,
134 trial_id: TrialId,
135 error: String,
136 },
137
138 StudyStarted {
141 study_id: StudyId,
142 name: String,
143 total_trials: usize,
144 },
145
146 StudyProgress {
148 study_id: StudyId,
149 completed: usize,
150 total: usize,
151 best_value: f64,
152 },
153
154 BestUpdated {
156 study_id: StudyId,
157 trial_id: TrialId,
158 value: f64,
159 params: serde_json::Value,
160 },
161
162 ParetoUpdated {
164 study_id: StudyId,
165 front_size: usize,
166 },
167
168 StudyCompleted {
170 study_id: StudyId,
171 best_trial_id: TrialId,
172 best_value: f64,
173 },
174
175 GenerationStarted {
178 study_id: StudyId,
179 generation: usize,
180 population_size: usize,
181 },
182
183 GenerationCompleted {
185 study_id: StudyId,
186 generation: usize,
187 best_fitness: f64,
188 mean_fitness: f64,
189 },
190
191 MemberExploited {
193 study_id: StudyId,
194 generation: usize,
195 replaced_id: String,
196 donor_id: String,
197 },
198}
199
200mod duration_millis {
202 use serde::{self, Deserialize, Deserializer, Serializer};
203 use std::time::Duration;
204
205 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: Serializer,
208 {
209 serializer.serialize_u64(duration.as_millis() as u64)
210 }
211
212 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
213 where
214 D: Deserializer<'de>,
215 {
216 let millis = u64::deserialize(deserializer)?;
217 Ok(Duration::from_millis(millis))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn event_serde_run_started() {
227 let event = Event::RunStarted {
228 run_id: "run_001".into(),
229 plan_summary: PlanSummary {
230 total_nodes: 5,
231 cached_nodes: 2,
232 parallel_branches: 1,
233 },
234 };
235 let json = serde_json::to_string(&event).unwrap();
236 assert!(json.contains("RunStarted"));
237 let deserialized: Event = serde_json::from_str(&json).unwrap();
238 if let Event::RunStarted {
239 run_id,
240 plan_summary,
241 } = deserialized
242 {
243 assert_eq!(run_id, "run_001");
244 assert_eq!(plan_summary.total_nodes, 5);
245 } else {
246 panic!("wrong variant");
247 }
248 }
249
250 #[test]
251 fn event_serde_node_cache_hit() {
252 let event = Event::NodeCacheHit {
253 run_id: "run_001".into(),
254 node_id: "scaler".into(),
255 key: CacheKey::hash_data(b"test"),
256 tier: CacheTier::Memory,
257 load_time: Duration::from_micros(200),
258 };
259 let json = serde_json::to_string(&event).unwrap();
260 let deserialized: Event = serde_json::from_str(&json).unwrap();
261 if let Event::NodeCacheHit { tier, .. } = deserialized {
262 assert_eq!(tier, CacheTier::Memory);
263 } else {
264 panic!("wrong variant");
265 }
266 }
267
268 #[test]
269 fn event_serde_trial_metric() {
270 let event = Event::TrialMetric {
271 study_id: "study_001".into(),
272 trial_id: "trial_042".into(),
273 metric: MetricRecord {
274 name: "f1".into(),
275 value: 0.847,
276 step: 15,
277 timestamp: Utc::now(),
278 },
279 };
280 let json = serde_json::to_string(&event).unwrap();
281 assert!(json.contains("TrialMetric"));
282 assert!(json.contains("0.847"));
283 }
284
285 #[test]
286 fn event_serde_study_completed() {
287 let event = Event::StudyCompleted {
288 study_id: "study_001".into(),
289 best_trial_id: "trial_042".into(),
290 best_value: 0.91,
291 };
292 let json = serde_json::to_string(&event).unwrap();
293 let deserialized: Event = serde_json::from_str(&json).unwrap();
294 if let Event::StudyCompleted { best_value, .. } = deserialized {
295 assert!((best_value - 0.91).abs() < f64::EPSILON);
296 } else {
297 panic!("wrong variant");
298 }
299 }
300
301 #[test]
302 fn duration_serialized_as_millis() {
303 let event = Event::NodeCompleted {
304 run_id: "r".into(),
305 node_id: "n".into(),
306 duration: Duration::from_millis(1234),
307 output_summary: "ok".into(),
308 };
309 let json = serde_json::to_string(&event).unwrap();
310 assert!(json.contains("1234"));
311 }
312
313 #[test]
314 fn all_three_event_levels_serialize() {
315 let events: Vec<Event> = vec![
316 Event::RunStarted {
318 run_id: "r".into(),
319 plan_summary: PlanSummary {
320 total_nodes: 1,
321 cached_nodes: 0,
322 parallel_branches: 0,
323 },
324 },
325 Event::RunCompleted {
326 run_id: "r".into(),
327 duration: Duration::from_secs(1),
328 },
329 Event::TrialStarted {
331 study_id: "s".into(),
332 trial_id: "t".into(),
333 params: serde_json::json!({"lr": 0.01}),
334 },
335 Event::TrialPruned {
336 study_id: "s".into(),
337 trial_id: "t".into(),
338 step: 5,
339 reason: "below median".into(),
340 },
341 Event::StudyStarted {
343 study_id: "s".into(),
344 name: "test".into(),
345 total_trials: 100,
346 },
347 Event::BestUpdated {
348 study_id: "s".into(),
349 trial_id: "t".into(),
350 value: 0.95,
351 params: serde_json::json!({"C": 1.0}),
352 },
353 ];
354
355 for event in events {
356 let json = serde_json::to_string(&event).unwrap();
357 let _: Event = serde_json::from_str(&json).unwrap();
358 }
359 }
360}