Skip to main content

somatize_runtime/
study_runner.rs

1//! Study runner — orchestrates hyperparameter optimization.
2//!
3//! Iterates over trials: samples parameters, calls the executor,
4//! records metrics, and tracks the best result. Supports Grid,
5//! Random, and Bayesian sampling strategies.
6
7use crate::event_bus::EventBus;
8use crate::sampler::Sampler;
9use somatize_core::error::Result;
10use somatize_core::event::{Event, MetricRecord};
11use somatize_core::study::{Study, Trial, TrialState};
12use std::sync::Arc;
13use std::time::Instant;
14
15/// Result of executing a trial. Separates control flow (pruning) from errors.
16#[derive(Debug, Clone)]
17pub enum TrialOutcome {
18    /// Trial completed successfully with final metrics.
19    Completed(Vec<MetricRecord>),
20    /// Trial was pruned (stopped early) at the given step.
21    Pruned { step: usize, reason: String },
22}
23
24/// Callback that executes a trial given sampled parameters.
25///
26/// Returns `Ok(TrialOutcome)` for normal completion or pruning,
27/// `Err(SomaError)` only for unexpected failures.
28pub trait TrialExecutor: Send + Sync {
29    fn execute_trial(
30        &self,
31        params: &std::collections::HashMap<String, serde_json::Value>,
32    ) -> Result<TrialOutcome>;
33}
34
35/// Function-based trial executor for convenience.
36pub struct FnTrialExecutor<F>(pub F);
37
38impl<F> TrialExecutor for FnTrialExecutor<F>
39where
40    F: Fn(&std::collections::HashMap<String, serde_json::Value>) -> Result<TrialOutcome>
41        + Send
42        + Sync,
43{
44    fn execute_trial(
45        &self,
46        params: &std::collections::HashMap<String, serde_json::Value>,
47    ) -> Result<TrialOutcome> {
48        (self.0)(params)
49    }
50}
51
52/// Runs a Study: samples parameters, executes trials, records results.
53pub struct StudyRunner {
54    event_bus: Arc<EventBus>,
55}
56
57impl StudyRunner {
58    pub fn new(event_bus: Arc<EventBus>) -> Self {
59        Self { event_bus }
60    }
61
62    /// Run the study to completion.
63    pub fn run(
64        &self,
65        study: &mut Study,
66        sampler: &mut dyn Sampler,
67        executor: &dyn TrialExecutor,
68    ) -> Result<()> {
69        let total = sampler.n_trials().unwrap_or(0);
70
71        self.event_bus.emit(Event::StudyStarted {
72            study_id: study.id.clone(),
73            name: study.name.clone(),
74            total_trials: total,
75        });
76
77        let mut trial_index = 0;
78
79        while let Some(params) = sampler.sample(&study.search_space, trial_index)? {
80            let trial_id = format!("trial_{trial_index:04}");
81            let mut trial = Trial::new(trial_id.clone(), params.clone());
82            trial.state = TrialState::Running;
83
84            self.event_bus.emit(Event::TrialStarted {
85                study_id: study.id.clone(),
86                trial_id: trial_id.clone(),
87                params: serde_json::json!(params),
88            });
89
90            let start = Instant::now();
91
92            match executor.execute_trial(&params) {
93                Ok(TrialOutcome::Completed(metrics)) => {
94                    trial.duration_ms = Some(start.elapsed().as_millis() as u64);
95                    trial.metrics = metrics.clone();
96                    trial.state = TrialState::Completed;
97
98                    for metric in &metrics {
99                        self.event_bus.emit(Event::TrialMetric {
100                            study_id: study.id.clone(),
101                            trial_id: trial_id.clone(),
102                            metric: metric.clone(),
103                        });
104                    }
105
106                    self.event_bus.emit(Event::TrialCompleted {
107                        study_id: study.id.clone(),
108                        trial_id: trial_id.clone(),
109                        final_metrics: metrics,
110                    });
111                }
112                Ok(TrialOutcome::Pruned { step, reason }) => {
113                    trial.duration_ms = Some(start.elapsed().as_millis() as u64);
114                    trial.state = TrialState::Pruned {
115                        step,
116                        reason: reason.clone(),
117                    };
118
119                    self.event_bus.emit(Event::TrialPruned {
120                        study_id: study.id.clone(),
121                        trial_id: trial_id.clone(),
122                        step,
123                        reason,
124                    });
125                }
126                Err(e) => {
127                    trial.duration_ms = Some(start.elapsed().as_millis() as u64);
128                    trial.state = TrialState::Failed {
129                        error: e.to_string(),
130                    };
131
132                    self.event_bus.emit(Event::TrialFailed {
133                        study_id: study.id.clone(),
134                        trial_id: trial_id.clone(),
135                        error: e.to_string(),
136                    });
137                }
138            }
139
140            study.trials.push(trial);
141
142            // Check if we have a new best
143            if let Some(best) = study.best_trial()
144                && best.id == trial_id
145                && let Some(obj) = study.objectives.first()
146                && let Some(val) = best.best_metric(&obj.metric, obj.direction)
147            {
148                self.event_bus.emit(Event::BestUpdated {
149                    study_id: study.id.clone(),
150                    trial_id: trial_id.clone(),
151                    value: val,
152                    params: serde_json::json!(params),
153                });
154            }
155
156            let completed = study.trials.iter().filter(|t| t.is_terminal()).count();
157            self.event_bus.emit(Event::StudyProgress {
158                study_id: study.id.clone(),
159                completed,
160                total,
161                best_value: study
162                    .best_trial()
163                    .and_then(|t| {
164                        study
165                            .objectives
166                            .first()
167                            .and_then(|o| t.best_metric(&o.metric, o.direction))
168                    })
169                    .unwrap_or(f64::NAN),
170            });
171
172            trial_index += 1;
173        }
174
175        let best_trial_id = study.best_trial().map(|t| t.id.clone()).unwrap_or_default();
176        let best_value = study
177            .best_trial()
178            .and_then(|t| {
179                study
180                    .objectives
181                    .first()
182                    .and_then(|o| t.best_metric(&o.metric, o.direction))
183            })
184            .unwrap_or(f64::NAN);
185
186        self.event_bus.emit(Event::StudyCompleted {
187            study_id: study.id.clone(),
188            best_trial_id,
189            best_value,
190        });
191
192        Ok(())
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::sampler::{GridSampler, RandomSampler};
200    use chrono::Utc;
201    use somatize_core::error::SomaError;
202    use somatize_core::search::{Scale, SearchDimension, SearchSpace};
203    use somatize_core::study::{Direction, Objective, SearchStrategy};
204
205    fn sample_space() -> SearchSpace {
206        let mut space = SearchSpace::new();
207        space.add(SearchDimension::Float {
208            name: "lr".into(),
209            low: 0.001,
210            high: 0.1,
211            scale: Scale::Log,
212            default: None,
213        });
214        space.add(SearchDimension::Categorical {
215            name: "activation".into(),
216            choices: vec![serde_json::json!("relu"), serde_json::json!("tanh")],
217        });
218        space
219    }
220
221    /// Simple executor: f1 = 1.0 - |lr - 0.01| * 10
222    fn make_executor() -> FnTrialExecutor<
223        impl Fn(&std::collections::HashMap<String, serde_json::Value>) -> Result<TrialOutcome>,
224    > {
225        FnTrialExecutor(
226            |params: &std::collections::HashMap<String, serde_json::Value>| {
227                let lr = params["lr"].as_f64().unwrap();
228                let f1 = (1.0 - (lr - 0.01).abs() * 10.0).max(0.0);
229                Ok(TrialOutcome::Completed(vec![MetricRecord {
230                    name: "f1".into(),
231                    value: f1,
232                    step: 0,
233                    timestamp: Utc::now(),
234                }]))
235            },
236        )
237    }
238
239    #[test]
240    fn study_runner_grid_search() {
241        let bus = Arc::new(EventBus::new(256));
242        let mut rx = bus.subscribe();
243        let runner = StudyRunner::new(bus);
244
245        let space = sample_space();
246        let mut study = Study::new(
247            "grid_test",
248            space,
249            SearchStrategy::Grid { points_per_dim: 3 },
250            vec![Objective {
251                metric: "f1".into(),
252                direction: Direction::Maximize,
253            }],
254        );
255
256        let mut sampler = GridSampler::new(3);
257        let executor = make_executor();
258
259        runner.run(&mut study, &mut sampler, &executor).unwrap();
260
261        // 3 lr points * 2 activations = 6 trials
262        assert_eq!(study.trials.len(), 6);
263        assert!(study.trials.iter().all(|t| t.is_complete()));
264
265        // Best trial should have lr closest to 0.01
266        let best = study.best_trial().unwrap();
267        let best_lr = best.params["lr"].as_f64().unwrap();
268        assert!(
269            (best_lr - 0.01).abs() < 0.05,
270            "best lr should be near 0.01, got {best_lr}"
271        );
272
273        // Check events were emitted
274        let mut events = Vec::new();
275        while let Ok(e) = rx.try_recv() {
276            events.push(e);
277        }
278        assert!(
279            events
280                .iter()
281                .any(|e| matches!(e, Event::StudyStarted { .. }))
282        );
283        assert!(
284            events
285                .iter()
286                .any(|e| matches!(e, Event::TrialStarted { .. }))
287        );
288        assert!(
289            events
290                .iter()
291                .any(|e| matches!(e, Event::TrialCompleted { .. }))
292        );
293        assert!(
294            events
295                .iter()
296                .any(|e| matches!(e, Event::BestUpdated { .. }))
297        );
298        assert!(
299            events
300                .iter()
301                .any(|e| matches!(e, Event::StudyCompleted { .. }))
302        );
303    }
304
305    #[test]
306    fn study_runner_random_search() {
307        let bus = Arc::new(EventBus::new(256));
308        let runner = StudyRunner::new(bus);
309
310        let space = sample_space();
311        let mut study = Study::new(
312            "random_test",
313            space,
314            SearchStrategy::Random {
315                n_trials: 20,
316                seed: Some(42),
317            },
318            vec![Objective {
319                metric: "f1".into(),
320                direction: Direction::Maximize,
321            }],
322        );
323
324        let mut sampler = RandomSampler::new(20, Some(42));
325        let executor = make_executor();
326
327        runner.run(&mut study, &mut sampler, &executor).unwrap();
328
329        assert_eq!(study.trials.len(), 20);
330        assert!(study.best_trial().is_some());
331    }
332
333    #[test]
334    fn study_runner_handles_failed_trials() {
335        let bus = Arc::new(EventBus::new(256));
336        let runner = StudyRunner::new(bus);
337
338        let mut space = SearchSpace::new();
339        space.add(SearchDimension::Float {
340            name: "x".into(),
341            low: 0.0,
342            high: 1.0,
343            scale: Scale::Linear,
344            default: None,
345        });
346
347        let mut study = Study::new(
348            "fail_test",
349            space,
350            SearchStrategy::Random {
351                n_trials: 5,
352                seed: None,
353            },
354            vec![Objective {
355                metric: "f1".into(),
356                direction: Direction::Maximize,
357            }],
358        );
359
360        // Executor that fails on even trials
361        let executor = FnTrialExecutor(
362            |params: &std::collections::HashMap<String, serde_json::Value>| {
363                let x = params["x"].as_f64().unwrap();
364                if x > 0.5 {
365                    Err(SomaError::Other("too high".into()))
366                } else {
367                    Ok(TrialOutcome::Completed(vec![MetricRecord {
368                        name: "f1".into(),
369                        value: x,
370                        step: 0,
371                        timestamp: Utc::now(),
372                    }]))
373                }
374            },
375        );
376
377        let mut sampler = RandomSampler::new(5, Some(42));
378        runner.run(&mut study, &mut sampler, &executor).unwrap();
379
380        assert_eq!(study.trials.len(), 5);
381        // Some should be Failed
382        let failed = study
383            .trials
384            .iter()
385            .filter(|t| matches!(t.state, TrialState::Failed { .. }))
386            .count();
387        assert!(failed > 0, "should have some failed trials");
388    }
389
390    #[test]
391    fn study_runner_handles_pruned_trials() {
392        let bus = Arc::new(EventBus::new(256));
393        let runner = StudyRunner::new(bus);
394
395        let mut space = SearchSpace::new();
396        space.add(SearchDimension::Float {
397            name: "x".into(),
398            low: 0.0,
399            high: 1.0,
400            scale: Scale::Linear,
401            default: None,
402        });
403
404        let mut study = Study::new(
405            "prune_test",
406            space,
407            SearchStrategy::Random {
408                n_trials: 3,
409                seed: None,
410            },
411            vec![Objective {
412                metric: "f1".into(),
413                direction: Direction::Maximize,
414            }],
415        );
416
417        // Executor that prunes every trial
418        let executor = FnTrialExecutor(
419            |_params: &std::collections::HashMap<String, serde_json::Value>| {
420                Ok(TrialOutcome::Pruned {
421                    step: 5,
422                    reason: "below median".into(),
423                })
424            },
425        );
426
427        let mut sampler = RandomSampler::new(3, Some(42));
428        runner.run(&mut study, &mut sampler, &executor).unwrap();
429
430        assert!(
431            study
432                .trials
433                .iter()
434                .all(|t| matches!(t.state, TrialState::Pruned { .. }))
435        );
436    }
437
438    #[test]
439    fn study_progress_tracking() {
440        let bus = Arc::new(EventBus::new(256));
441        let mut rx = bus.subscribe();
442        let runner = StudyRunner::new(bus);
443
444        let mut space = SearchSpace::new();
445        space.add(SearchDimension::Float {
446            name: "x".into(),
447            low: 0.0,
448            high: 1.0,
449            scale: Scale::Linear,
450            default: None,
451        });
452
453        let mut study = Study::new(
454            "progress_test",
455            space,
456            SearchStrategy::Random {
457                n_trials: 3,
458                seed: None,
459            },
460            vec![Objective {
461                metric: "f1".into(),
462                direction: Direction::Maximize,
463            }],
464        );
465
466        let executor = FnTrialExecutor(
467            |_params: &std::collections::HashMap<String, serde_json::Value>| {
468                Ok(TrialOutcome::Completed(vec![MetricRecord {
469                    name: "f1".into(),
470                    value: 0.5,
471                    step: 0,
472                    timestamp: Utc::now(),
473                }]))
474            },
475        );
476
477        let mut sampler = RandomSampler::new(3, Some(42));
478        runner.run(&mut study, &mut sampler, &executor).unwrap();
479
480        // Collect progress events
481        let mut progress_events = Vec::new();
482        while let Ok(e) = rx.try_recv() {
483            if let Event::StudyProgress {
484                completed, total, ..
485            } = e
486            {
487                progress_events.push((completed, total));
488            }
489        }
490
491        assert_eq!(progress_events.len(), 3);
492        assert_eq!(progress_events[0], (1, 3));
493        assert_eq!(progress_events[1], (2, 3));
494        assert_eq!(progress_events[2], (3, 3));
495    }
496}