1use crate::event::MetricRecord;
8use crate::search::SearchSpace;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum Direction {
15 Minimize,
16 Maximize,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Objective {
22 pub metric: String,
23 pub direction: Direction,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "strategy_type")]
29pub enum SearchStrategy {
30 Grid { points_per_dim: usize },
32
33 Random { n_trials: usize, seed: Option<u64> },
35
36 Bayesian {
38 n_trials: usize,
39 n_startup: usize,
40 seed: Option<u64>,
41 },
42
43 Hyperband {
45 max_resource: usize,
46 reduction_factor: usize,
47 },
48
49 MultiObjective {
51 n_trials: usize,
52 objectives: Vec<Objective>,
53 },
54}
55
56impl SearchStrategy {
57 pub fn n_trials(&self) -> Option<usize> {
59 match self {
60 Self::Grid { .. } => None, Self::Random { n_trials, .. } => Some(*n_trials),
62 Self::Bayesian { n_trials, .. } => Some(*n_trials),
63 Self::Hyperband { .. } => None, Self::MultiObjective { n_trials, .. } => Some(*n_trials),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71#[serde(tag = "pruning_type")]
72pub enum PruningStrategy {
73 None,
75
76 Median { n_warmup_steps: usize },
78
79 Percentile {
81 percentile: f64,
82 n_warmup_steps: usize,
83 },
84
85 Hyperband,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
91#[serde(tag = "trial_state")]
92pub enum TrialState {
93 Pending,
94 Running,
95 Completed,
96 Pruned { step: usize, reason: String },
97 Failed { error: String },
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct Trial {
103 pub id: String,
104 pub params: HashMap<String, serde_json::Value>,
105 pub state: TrialState,
106 pub metrics: Vec<MetricRecord>,
107 pub duration_ms: Option<u64>,
108}
109
110impl Trial {
111 pub fn new(id: impl Into<String>, params: HashMap<String, serde_json::Value>) -> Self {
112 Self {
113 id: id.into(),
114 params,
115 state: TrialState::Pending,
116 metrics: Vec::new(),
117 duration_ms: None,
118 }
119 }
120
121 pub fn best_metric(&self, name: &str, direction: Direction) -> Option<f64> {
123 let values: Vec<f64> = self
124 .metrics
125 .iter()
126 .filter(|m| m.name == name)
127 .map(|m| m.value)
128 .collect();
129 match direction {
130 Direction::Maximize => values.into_iter().reduce(f64::max),
131 Direction::Minimize => values.into_iter().reduce(f64::min),
132 }
133 }
134
135 pub fn is_complete(&self) -> bool {
136 matches!(self.state, TrialState::Completed)
137 }
138
139 pub fn is_terminal(&self) -> bool {
140 matches!(
141 self.state,
142 TrialState::Completed | TrialState::Pruned { .. } | TrialState::Failed { .. }
143 )
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Study {
150 pub id: String,
151 pub name: String,
152 pub search_space: SearchSpace,
153 pub strategy: SearchStrategy,
154 pub pruning: PruningStrategy,
155 pub objectives: Vec<Objective>,
156 pub trials: Vec<Trial>,
157 pub frozen: HashMap<String, serde_json::Value>,
158}
159
160impl Study {
161 pub fn new(
162 name: impl Into<String>,
163 search_space: SearchSpace,
164 strategy: SearchStrategy,
165 objectives: Vec<Objective>,
166 ) -> Self {
167 Self {
168 id: uuid_v4(),
169 name: name.into(),
170 search_space,
171 strategy,
172 pruning: PruningStrategy::None,
173 objectives,
174 trials: Vec::new(),
175 frozen: HashMap::new(),
176 }
177 }
178
179 pub fn with_pruning(mut self, pruning: PruningStrategy) -> Self {
180 self.pruning = pruning;
181 self
182 }
183
184 pub fn completed_trials(&self) -> Vec<&Trial> {
186 self.trials.iter().filter(|t| t.is_complete()).collect()
187 }
188
189 pub fn best_trial(&self) -> Option<&Trial> {
191 let obj = self.objectives.first()?;
192 self.completed_trials()
193 .into_iter()
194 .filter_map(|t| {
195 let val = t.best_metric(&obj.metric, obj.direction)?;
196 Some((t, val))
197 })
198 .reduce(|best, current| match obj.direction {
199 Direction::Maximize => {
200 if current.1 > best.1 {
201 current
202 } else {
203 best
204 }
205 }
206 Direction::Minimize => {
207 if current.1 < best.1 {
208 current
209 } else {
210 best
211 }
212 }
213 })
214 .map(|(t, _)| t)
215 }
216
217 pub fn total_trials(&self) -> Option<usize> {
219 self.strategy.n_trials()
220 }
221
222 pub fn progress(&self) -> f64 {
224 let completed = self.trials.iter().filter(|t| t.is_terminal()).count();
225 match self.total_trials() {
226 Some(total) if total > 0 => completed as f64 / total as f64,
227 _ => 0.0,
228 }
229 }
230}
231
232fn uuid_v4() -> String {
233 use std::time::{SystemTime, UNIX_EPOCH};
234 let nanos = SystemTime::now()
235 .duration_since(UNIX_EPOCH)
236 .unwrap_or_default()
237 .as_nanos();
238 format!("study_{nanos:x}")
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::search::{Scale, SearchDimension};
245 use chrono::Utc;
246 use serde_json::json;
247
248 fn sample_search_space() -> SearchSpace {
249 let mut space = SearchSpace::new();
250 space.add(SearchDimension::Float {
251 name: "lr".into(),
252 low: 0.001,
253 high: 0.1,
254 scale: Scale::Log,
255 default: None,
256 });
257 space.add(SearchDimension::Categorical {
258 name: "kernel".into(),
259 choices: vec![json!("rbf"), json!("linear")],
260 });
261 space
262 }
263
264 fn make_trial(id: &str, f1: f64) -> Trial {
265 let mut t = Trial::new(id, HashMap::from([("lr".into(), json!(0.01))]));
266 t.state = TrialState::Completed;
267 t.metrics.push(MetricRecord {
268 name: "f1".into(),
269 value: f1,
270 step: 10,
271 timestamp: Utc::now(),
272 });
273 t
274 }
275
276 #[test]
277 fn study_best_trial_maximize() {
278 let mut study = Study::new(
279 "test",
280 sample_search_space(),
281 SearchStrategy::Random {
282 n_trials: 10,
283 seed: None,
284 },
285 vec![Objective {
286 metric: "f1".into(),
287 direction: Direction::Maximize,
288 }],
289 );
290
291 study.trials.push(make_trial("t1", 0.75));
292 study.trials.push(make_trial("t2", 0.90));
293 study.trials.push(make_trial("t3", 0.82));
294
295 let best = study.best_trial().unwrap();
296 assert_eq!(best.id, "t2");
297 }
298
299 #[test]
300 fn study_best_trial_minimize() {
301 let mut study = Study::new(
302 "test",
303 sample_search_space(),
304 SearchStrategy::Random {
305 n_trials: 10,
306 seed: None,
307 },
308 vec![Objective {
309 metric: "loss".into(),
310 direction: Direction::Minimize,
311 }],
312 );
313
314 let mut t1 = Trial::new("t1", HashMap::new());
315 t1.state = TrialState::Completed;
316 t1.metrics.push(MetricRecord {
317 name: "loss".into(),
318 value: 0.5,
319 step: 10,
320 timestamp: Utc::now(),
321 });
322
323 let mut t2 = Trial::new("t2", HashMap::new());
324 t2.state = TrialState::Completed;
325 t2.metrics.push(MetricRecord {
326 name: "loss".into(),
327 value: 0.3,
328 step: 10,
329 timestamp: Utc::now(),
330 });
331
332 study.trials.push(t1);
333 study.trials.push(t2);
334
335 let best = study.best_trial().unwrap();
336 assert_eq!(best.id, "t2");
337 }
338
339 #[test]
340 fn study_progress() {
341 let mut study = Study::new(
342 "test",
343 sample_search_space(),
344 SearchStrategy::Random {
345 n_trials: 10,
346 seed: None,
347 },
348 vec![],
349 );
350
351 assert_eq!(study.progress(), 0.0);
352
353 study.trials.push(make_trial("t1", 0.5));
354 study.trials.push(make_trial("t2", 0.6));
355 assert!((study.progress() - 0.2).abs() < f64::EPSILON);
356 }
357
358 #[test]
359 fn trial_terminal_states() {
360 let mut t = Trial::new("t1", HashMap::new());
361 assert!(!t.is_terminal());
362
363 t.state = TrialState::Running;
364 assert!(!t.is_terminal());
365
366 t.state = TrialState::Completed;
367 assert!(t.is_terminal());
368
369 t.state = TrialState::Pruned {
370 step: 5,
371 reason: "bad".into(),
372 };
373 assert!(t.is_terminal());
374
375 t.state = TrialState::Failed {
376 error: "oops".into(),
377 };
378 assert!(t.is_terminal());
379 }
380
381 #[test]
382 fn study_serde_roundtrip() {
383 let mut study = Study::new(
384 "test_study",
385 sample_search_space(),
386 SearchStrategy::Bayesian {
387 n_trials: 100,
388 n_startup: 10,
389 seed: Some(42),
390 },
391 vec![Objective {
392 metric: "f1".into(),
393 direction: Direction::Maximize,
394 }],
395 );
396 study.trials.push(make_trial("t1", 0.85));
397
398 let json = serde_json::to_string(&study).unwrap();
399 let deserialized: Study = serde_json::from_str(&json).unwrap();
400 assert_eq!(deserialized.name, "test_study");
401 assert_eq!(deserialized.trials.len(), 1);
402 }
403
404 #[test]
405 fn search_strategy_n_trials() {
406 assert_eq!(
407 SearchStrategy::Random {
408 n_trials: 50,
409 seed: None
410 }
411 .n_trials(),
412 Some(50)
413 );
414 assert_eq!(SearchStrategy::Grid { points_per_dim: 5 }.n_trials(), None);
415 assert_eq!(
416 SearchStrategy::Bayesian {
417 n_trials: 100,
418 n_startup: 10,
419 seed: None
420 }
421 .n_trials(),
422 Some(100)
423 );
424 }
425
426 #[test]
427 fn no_best_trial_when_empty() {
428 let study = Study::new(
429 "empty",
430 SearchSpace::new(),
431 SearchStrategy::Random {
432 n_trials: 10,
433 seed: None,
434 },
435 vec![Objective {
436 metric: "f1".into(),
437 direction: Direction::Maximize,
438 }],
439 );
440 assert!(study.best_trial().is_none());
441 }
442}