Skip to main content

somatize_core/
study.rs

1//! Study — defines an optimization experiment with objectives and strategy.
2//!
3//! A [`Study`] holds the search space, strategy (Grid/Random/Bayesian),
4//! objectives, and tracks trials. The [`StudyRunner`] in soma-runtime
5//! orchestrates execution.
6
7use crate::event::MetricRecord;
8use crate::search::SearchSpace;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Direction of optimization for an objective.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum Direction {
15    Minimize,
16    Maximize,
17}
18
19/// An optimization objective (metric + direction).
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Objective {
22    pub metric: String,
23    pub direction: Direction,
24}
25
26/// Search strategy for hyperparameter optimization.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "strategy_type")]
29pub enum SearchStrategy {
30    /// Exhaustive grid search.
31    Grid { points_per_dim: usize },
32
33    /// Random sampling.
34    Random { n_trials: usize, seed: Option<u64> },
35
36    /// Bayesian optimization (TPE).
37    Bayesian {
38        n_trials: usize,
39        n_startup: usize,
40        seed: Option<u64>,
41    },
42
43    /// Successive halving with early stopping.
44    Hyperband {
45        max_resource: usize,
46        reduction_factor: usize,
47    },
48
49    /// Multi-objective optimization.
50    MultiObjective {
51        n_trials: usize,
52        objectives: Vec<Objective>,
53    },
54}
55
56impl SearchStrategy {
57    /// Planned number of trials (if known).
58    pub fn n_trials(&self) -> Option<usize> {
59        match self {
60            Self::Grid { .. } => None, // depends on search space
61            Self::Random { n_trials, .. } => Some(*n_trials),
62            Self::Bayesian { n_trials, .. } => Some(*n_trials),
63            Self::Hyperband { .. } => None, // depends on brackets
64            Self::MultiObjective { n_trials, .. } => Some(*n_trials),
65        }
66    }
67}
68
69/// Pruning strategy for early stopping of unpromising trials.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(tag = "pruning_type")]
72pub enum PruningStrategy {
73    /// No pruning.
74    None,
75
76    /// Prune if metric is below median of completed trials at same step.
77    Median { n_warmup_steps: usize },
78
79    /// Prune if metric is below given percentile.
80    Percentile {
81        percentile: f64,
82        n_warmup_steps: usize,
83    },
84
85    /// Bracket-based pruning (used with Hyperband).
86    Hyperband,
87}
88
89/// State of a single trial.
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
91#[serde(tag = "trial_state")]
92pub enum TrialState {
93    Pending,
94    Running,
95    Completed,
96    Pruned { step: usize, reason: String },
97    Failed { error: String },
98}
99
100/// A single hyperparameter evaluation.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct Trial {
103    pub id: String,
104    pub params: HashMap<String, serde_json::Value>,
105    pub state: TrialState,
106    pub metrics: Vec<MetricRecord>,
107    pub duration_ms: Option<u64>,
108}
109
110impl Trial {
111    pub fn new(id: impl Into<String>, params: HashMap<String, serde_json::Value>) -> Self {
112        Self {
113            id: id.into(),
114            params,
115            state: TrialState::Pending,
116            metrics: Vec::new(),
117            duration_ms: None,
118        }
119    }
120
121    /// Get the last recorded value for a specific metric.
122    pub fn best_metric(&self, name: &str, direction: Direction) -> Option<f64> {
123        let values: Vec<f64> = self
124            .metrics
125            .iter()
126            .filter(|m| m.name == name)
127            .map(|m| m.value)
128            .collect();
129        match direction {
130            Direction::Maximize => values.into_iter().reduce(f64::max),
131            Direction::Minimize => values.into_iter().reduce(f64::min),
132        }
133    }
134
135    pub fn is_complete(&self) -> bool {
136        matches!(self.state, TrialState::Completed)
137    }
138
139    pub fn is_terminal(&self) -> bool {
140        matches!(
141            self.state,
142            TrialState::Completed | TrialState::Pruned { .. } | TrialState::Failed { .. }
143        )
144    }
145}
146
147/// An optimization study: orchestrates multiple trials.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Study {
150    pub id: String,
151    pub name: String,
152    pub search_space: SearchSpace,
153    pub strategy: SearchStrategy,
154    pub pruning: PruningStrategy,
155    pub objectives: Vec<Objective>,
156    pub trials: Vec<Trial>,
157    pub frozen: HashMap<String, serde_json::Value>,
158}
159
160impl Study {
161    pub fn new(
162        name: impl Into<String>,
163        search_space: SearchSpace,
164        strategy: SearchStrategy,
165        objectives: Vec<Objective>,
166    ) -> Self {
167        Self {
168            id: uuid_v4(),
169            name: name.into(),
170            search_space,
171            strategy,
172            pruning: PruningStrategy::None,
173            objectives,
174            trials: Vec::new(),
175            frozen: HashMap::new(),
176        }
177    }
178
179    pub fn with_pruning(mut self, pruning: PruningStrategy) -> Self {
180        self.pruning = pruning;
181        self
182    }
183
184    /// Get completed trials.
185    pub fn completed_trials(&self) -> Vec<&Trial> {
186        self.trials.iter().filter(|t| t.is_complete()).collect()
187    }
188
189    /// Get the best trial for the primary objective.
190    pub fn best_trial(&self) -> Option<&Trial> {
191        let obj = self.objectives.first()?;
192        self.completed_trials()
193            .into_iter()
194            .filter_map(|t| {
195                let val = t.best_metric(&obj.metric, obj.direction)?;
196                Some((t, val))
197            })
198            .reduce(|best, current| match obj.direction {
199                Direction::Maximize => {
200                    if current.1 > best.1 {
201                        current
202                    } else {
203                        best
204                    }
205                }
206                Direction::Minimize => {
207                    if current.1 < best.1 {
208                        current
209                    } else {
210                        best
211                    }
212                }
213            })
214            .map(|(t, _)| t)
215    }
216
217    /// Number of total planned trials (if known).
218    pub fn total_trials(&self) -> Option<usize> {
219        self.strategy.n_trials()
220    }
221
222    /// Fraction of trials completed.
223    pub fn progress(&self) -> f64 {
224        let completed = self.trials.iter().filter(|t| t.is_terminal()).count();
225        match self.total_trials() {
226            Some(total) if total > 0 => completed as f64 / total as f64,
227            _ => 0.0,
228        }
229    }
230}
231
232fn uuid_v4() -> String {
233    use std::time::{SystemTime, UNIX_EPOCH};
234    let nanos = SystemTime::now()
235        .duration_since(UNIX_EPOCH)
236        .unwrap_or_default()
237        .as_nanos();
238    format!("study_{nanos:x}")
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::search::{Scale, SearchDimension};
245    use chrono::Utc;
246    use serde_json::json;
247
248    fn sample_search_space() -> SearchSpace {
249        let mut space = SearchSpace::new();
250        space.add(SearchDimension::Float {
251            name: "lr".into(),
252            low: 0.001,
253            high: 0.1,
254            scale: Scale::Log,
255            default: None,
256        });
257        space.add(SearchDimension::Categorical {
258            name: "kernel".into(),
259            choices: vec![json!("rbf"), json!("linear")],
260        });
261        space
262    }
263
264    fn make_trial(id: &str, f1: f64) -> Trial {
265        let mut t = Trial::new(id, HashMap::from([("lr".into(), json!(0.01))]));
266        t.state = TrialState::Completed;
267        t.metrics.push(MetricRecord {
268            name: "f1".into(),
269            value: f1,
270            step: 10,
271            timestamp: Utc::now(),
272        });
273        t
274    }
275
276    #[test]
277    fn study_best_trial_maximize() {
278        let mut study = Study::new(
279            "test",
280            sample_search_space(),
281            SearchStrategy::Random {
282                n_trials: 10,
283                seed: None,
284            },
285            vec![Objective {
286                metric: "f1".into(),
287                direction: Direction::Maximize,
288            }],
289        );
290
291        study.trials.push(make_trial("t1", 0.75));
292        study.trials.push(make_trial("t2", 0.90));
293        study.trials.push(make_trial("t3", 0.82));
294
295        let best = study.best_trial().unwrap();
296        assert_eq!(best.id, "t2");
297    }
298
299    #[test]
300    fn study_best_trial_minimize() {
301        let mut study = Study::new(
302            "test",
303            sample_search_space(),
304            SearchStrategy::Random {
305                n_trials: 10,
306                seed: None,
307            },
308            vec![Objective {
309                metric: "loss".into(),
310                direction: Direction::Minimize,
311            }],
312        );
313
314        let mut t1 = Trial::new("t1", HashMap::new());
315        t1.state = TrialState::Completed;
316        t1.metrics.push(MetricRecord {
317            name: "loss".into(),
318            value: 0.5,
319            step: 10,
320            timestamp: Utc::now(),
321        });
322
323        let mut t2 = Trial::new("t2", HashMap::new());
324        t2.state = TrialState::Completed;
325        t2.metrics.push(MetricRecord {
326            name: "loss".into(),
327            value: 0.3,
328            step: 10,
329            timestamp: Utc::now(),
330        });
331
332        study.trials.push(t1);
333        study.trials.push(t2);
334
335        let best = study.best_trial().unwrap();
336        assert_eq!(best.id, "t2");
337    }
338
339    #[test]
340    fn study_progress() {
341        let mut study = Study::new(
342            "test",
343            sample_search_space(),
344            SearchStrategy::Random {
345                n_trials: 10,
346                seed: None,
347            },
348            vec![],
349        );
350
351        assert_eq!(study.progress(), 0.0);
352
353        study.trials.push(make_trial("t1", 0.5));
354        study.trials.push(make_trial("t2", 0.6));
355        assert!((study.progress() - 0.2).abs() < f64::EPSILON);
356    }
357
358    #[test]
359    fn trial_terminal_states() {
360        let mut t = Trial::new("t1", HashMap::new());
361        assert!(!t.is_terminal());
362
363        t.state = TrialState::Running;
364        assert!(!t.is_terminal());
365
366        t.state = TrialState::Completed;
367        assert!(t.is_terminal());
368
369        t.state = TrialState::Pruned {
370            step: 5,
371            reason: "bad".into(),
372        };
373        assert!(t.is_terminal());
374
375        t.state = TrialState::Failed {
376            error: "oops".into(),
377        };
378        assert!(t.is_terminal());
379    }
380
381    #[test]
382    fn study_serde_roundtrip() {
383        let mut study = Study::new(
384            "test_study",
385            sample_search_space(),
386            SearchStrategy::Bayesian {
387                n_trials: 100,
388                n_startup: 10,
389                seed: Some(42),
390            },
391            vec![Objective {
392                metric: "f1".into(),
393                direction: Direction::Maximize,
394            }],
395        );
396        study.trials.push(make_trial("t1", 0.85));
397
398        let json = serde_json::to_string(&study).unwrap();
399        let deserialized: Study = serde_json::from_str(&json).unwrap();
400        assert_eq!(deserialized.name, "test_study");
401        assert_eq!(deserialized.trials.len(), 1);
402    }
403
404    #[test]
405    fn search_strategy_n_trials() {
406        assert_eq!(
407            SearchStrategy::Random {
408                n_trials: 50,
409                seed: None
410            }
411            .n_trials(),
412            Some(50)
413        );
414        assert_eq!(SearchStrategy::Grid { points_per_dim: 5 }.n_trials(), None);
415        assert_eq!(
416            SearchStrategy::Bayesian {
417                n_trials: 100,
418                n_startup: 10,
419                seed: None
420            }
421            .n_trials(),
422            Some(100)
423        );
424    }
425
426    #[test]
427    fn no_best_trial_when_empty() {
428        let study = Study::new(
429            "empty",
430            SearchSpace::new(),
431            SearchStrategy::Random {
432                n_trials: 10,
433                seed: None,
434            },
435            vec![Objective {
436                metric: "f1".into(),
437                direction: Direction::Maximize,
438            }],
439        );
440        assert!(study.best_trial().is_none());
441    }
442}