Skip to main content

somatize_core/
event.rs

1//! Runtime lifecycle events — emitted during plan execution.
2//!
3//! Events track run/node/study/trial state transitions and are
4//! broadcast via the [`EventBus`] for observability and debugging.
5
6use crate::cache::{CacheKey, CacheTier};
7use crate::filter::FilterKind;
8use crate::graph::NodeId;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13/// Unique identifier for a pipeline run.
14pub type RunId = String;
15
16/// Unique identifier for an optimization study.
17pub type StudyId = String;
18
19/// Unique identifier for a trial within a study.
20pub type TrialId = String;
21
22/// A metric measurement reported during training.
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24pub struct MetricRecord {
25    pub name: String,
26    pub value: f64,
27    pub step: usize,
28    pub timestamp: DateTime<Utc>,
29}
30
31/// Summary of a compiled plan (for event payloads without the full plan).
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PlanSummary {
34    pub total_nodes: usize,
35    pub cached_nodes: usize,
36    pub parallel_branches: usize,
37}
38
39/// Structured events emitted during execution at three levels.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(tag = "event_type")]
42#[non_exhaustive]
43pub enum Event {
44    // ── Level 1: Pipeline execution (per run) ──
45    /// A pipeline run has started.
46    RunStarted {
47        run_id: RunId,
48        plan_summary: PlanSummary,
49    },
50
51    /// A filter node has started execution.
52    NodeStarted {
53        run_id: RunId,
54        node_id: NodeId,
55        kind: FilterKind,
56    },
57
58    /// A filter node reports progress (0.0 to 1.0).
59    NodeProgress {
60        run_id: RunId,
61        node_id: NodeId,
62        progress: f32,
63    },
64
65    /// A filter node's result was loaded from cache.
66    NodeCacheHit {
67        run_id: RunId,
68        node_id: NodeId,
69        key: CacheKey,
70        tier: CacheTier,
71        #[serde(with = "duration_millis")]
72        load_time: Duration,
73    },
74
75    /// A filter node completed successfully.
76    NodeCompleted {
77        run_id: RunId,
78        node_id: NodeId,
79        #[serde(with = "duration_millis")]
80        duration: Duration,
81        output_summary: String,
82    },
83
84    /// A filter node failed.
85    NodeFailed {
86        run_id: RunId,
87        node_id: NodeId,
88        error: String,
89    },
90
91    /// The pipeline run completed.
92    RunCompleted {
93        run_id: RunId,
94        #[serde(with = "duration_millis")]
95        duration: Duration,
96    },
97
98    /// The pipeline run failed.
99    RunFailed { run_id: RunId, error: String },
100
101    // ── Level 2: Trial execution (per hyperparameter set) ──
102    /// A new trial has started.
103    TrialStarted {
104        study_id: StudyId,
105        trial_id: TrialId,
106        params: serde_json::Value,
107    },
108
109    /// A trial reports an intermediate metric.
110    TrialMetric {
111        study_id: StudyId,
112        trial_id: TrialId,
113        metric: MetricRecord,
114    },
115
116    /// A trial was pruned (stopped early).
117    TrialPruned {
118        study_id: StudyId,
119        trial_id: TrialId,
120        step: usize,
121        reason: String,
122    },
123
124    /// A trial completed successfully.
125    TrialCompleted {
126        study_id: StudyId,
127        trial_id: TrialId,
128        final_metrics: Vec<MetricRecord>,
129    },
130
131    /// A trial failed.
132    TrialFailed {
133        study_id: StudyId,
134        trial_id: TrialId,
135        error: String,
136    },
137
138    // ── Level 3: Study execution (optimization session) ──
139    /// An optimization study has started.
140    StudyStarted {
141        study_id: StudyId,
142        name: String,
143        total_trials: usize,
144    },
145
146    /// Study progress update.
147    StudyProgress {
148        study_id: StudyId,
149        completed: usize,
150        total: usize,
151        best_value: f64,
152    },
153
154    /// The best trial has been updated.
155    BestUpdated {
156        study_id: StudyId,
157        trial_id: TrialId,
158        value: f64,
159        params: serde_json::Value,
160    },
161
162    /// The Pareto front has changed (multi-objective).
163    ParetoUpdated {
164        study_id: StudyId,
165        front_size: usize,
166    },
167
168    /// The study completed.
169    StudyCompleted {
170        study_id: StudyId,
171        best_trial_id: TrialId,
172        best_value: f64,
173    },
174
175    // ── Level 4: Population-Based Training ──
176    /// A PBT generation started (train → evaluate → exploit/explore).
177    GenerationStarted {
178        study_id: StudyId,
179        generation: usize,
180        population_size: usize,
181    },
182
183    /// A PBT generation completed.
184    GenerationCompleted {
185        study_id: StudyId,
186        generation: usize,
187        best_fitness: f64,
188        mean_fitness: f64,
189    },
190
191    /// A population member was replaced during exploit step.
192    MemberExploited {
193        study_id: StudyId,
194        generation: usize,
195        replaced_id: String,
196        donor_id: String,
197    },
198}
199
200/// Serde helper: Duration as milliseconds (u64).
201mod duration_millis {
202    use serde::{self, Deserialize, Deserializer, Serializer};
203    use std::time::Duration;
204
205    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
206    where
207        S: Serializer,
208    {
209        serializer.serialize_u64(duration.as_millis() as u64)
210    }
211
212    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
213    where
214        D: Deserializer<'de>,
215    {
216        let millis = u64::deserialize(deserializer)?;
217        Ok(Duration::from_millis(millis))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn event_serde_run_started() {
227        let event = Event::RunStarted {
228            run_id: "run_001".into(),
229            plan_summary: PlanSummary {
230                total_nodes: 5,
231                cached_nodes: 2,
232                parallel_branches: 1,
233            },
234        };
235        let json = serde_json::to_string(&event).unwrap();
236        assert!(json.contains("RunStarted"));
237        let deserialized: Event = serde_json::from_str(&json).unwrap();
238        if let Event::RunStarted {
239            run_id,
240            plan_summary,
241        } = deserialized
242        {
243            assert_eq!(run_id, "run_001");
244            assert_eq!(plan_summary.total_nodes, 5);
245        } else {
246            panic!("wrong variant");
247        }
248    }
249
250    #[test]
251    fn event_serde_node_cache_hit() {
252        let event = Event::NodeCacheHit {
253            run_id: "run_001".into(),
254            node_id: "scaler".into(),
255            key: CacheKey::hash_data(b"test"),
256            tier: CacheTier::Memory,
257            load_time: Duration::from_micros(200),
258        };
259        let json = serde_json::to_string(&event).unwrap();
260        let deserialized: Event = serde_json::from_str(&json).unwrap();
261        if let Event::NodeCacheHit { tier, .. } = deserialized {
262            assert_eq!(tier, CacheTier::Memory);
263        } else {
264            panic!("wrong variant");
265        }
266    }
267
268    #[test]
269    fn event_serde_trial_metric() {
270        let event = Event::TrialMetric {
271            study_id: "study_001".into(),
272            trial_id: "trial_042".into(),
273            metric: MetricRecord {
274                name: "f1".into(),
275                value: 0.847,
276                step: 15,
277                timestamp: Utc::now(),
278            },
279        };
280        let json = serde_json::to_string(&event).unwrap();
281        assert!(json.contains("TrialMetric"));
282        assert!(json.contains("0.847"));
283    }
284
285    #[test]
286    fn event_serde_study_completed() {
287        let event = Event::StudyCompleted {
288            study_id: "study_001".into(),
289            best_trial_id: "trial_042".into(),
290            best_value: 0.91,
291        };
292        let json = serde_json::to_string(&event).unwrap();
293        let deserialized: Event = serde_json::from_str(&json).unwrap();
294        if let Event::StudyCompleted { best_value, .. } = deserialized {
295            assert!((best_value - 0.91).abs() < f64::EPSILON);
296        } else {
297            panic!("wrong variant");
298        }
299    }
300
301    #[test]
302    fn duration_serialized_as_millis() {
303        let event = Event::NodeCompleted {
304            run_id: "r".into(),
305            node_id: "n".into(),
306            duration: Duration::from_millis(1234),
307            output_summary: "ok".into(),
308        };
309        let json = serde_json::to_string(&event).unwrap();
310        assert!(json.contains("1234"));
311    }
312
313    #[test]
314    fn all_three_event_levels_serialize() {
315        let events: Vec<Event> = vec![
316            // Level 1
317            Event::RunStarted {
318                run_id: "r".into(),
319                plan_summary: PlanSummary {
320                    total_nodes: 1,
321                    cached_nodes: 0,
322                    parallel_branches: 0,
323                },
324            },
325            Event::RunCompleted {
326                run_id: "r".into(),
327                duration: Duration::from_secs(1),
328            },
329            // Level 2
330            Event::TrialStarted {
331                study_id: "s".into(),
332                trial_id: "t".into(),
333                params: serde_json::json!({"lr": 0.01}),
334            },
335            Event::TrialPruned {
336                study_id: "s".into(),
337                trial_id: "t".into(),
338                step: 5,
339                reason: "below median".into(),
340            },
341            // Level 3
342            Event::StudyStarted {
343                study_id: "s".into(),
344                name: "test".into(),
345                total_trials: 100,
346            },
347            Event::BestUpdated {
348                study_id: "s".into(),
349                trial_id: "t".into(),
350                value: 0.95,
351                params: serde_json::json!({"C": 1.0}),
352            },
353        ];
354
355        for event in events {
356            let json = serde_json::to_string(&event).unwrap();
357            let _: Event = serde_json::from_str(&json).unwrap();
358        }
359    }
360}