1use crate::event_bus::EventBus;
8use crate::sampler::Sampler;
9use somatize_core::error::Result;
10use somatize_core::event::{Event, MetricRecord};
11use somatize_core::study::{Study, Trial, TrialState};
12use std::sync::Arc;
13use std::time::Instant;
14
15#[derive(Debug, Clone)]
17pub enum TrialOutcome {
18 Completed(Vec<MetricRecord>),
20 Pruned { step: usize, reason: String },
22}
23
24pub trait TrialExecutor: Send + Sync {
29 fn execute_trial(
30 &self,
31 params: &std::collections::HashMap<String, serde_json::Value>,
32 ) -> Result<TrialOutcome>;
33}
34
35pub struct FnTrialExecutor<F>(pub F);
37
38impl<F> TrialExecutor for FnTrialExecutor<F>
39where
40 F: Fn(&std::collections::HashMap<String, serde_json::Value>) -> Result<TrialOutcome>
41 + Send
42 + Sync,
43{
44 fn execute_trial(
45 &self,
46 params: &std::collections::HashMap<String, serde_json::Value>,
47 ) -> Result<TrialOutcome> {
48 (self.0)(params)
49 }
50}
51
52pub struct StudyRunner {
54 event_bus: Arc<EventBus>,
55}
56
57impl StudyRunner {
58 pub fn new(event_bus: Arc<EventBus>) -> Self {
59 Self { event_bus }
60 }
61
62 pub fn run(
64 &self,
65 study: &mut Study,
66 sampler: &mut dyn Sampler,
67 executor: &dyn TrialExecutor,
68 ) -> Result<()> {
69 let total = sampler.n_trials().unwrap_or(0);
70
71 self.event_bus.emit(Event::StudyStarted {
72 study_id: study.id.clone(),
73 name: study.name.clone(),
74 total_trials: total,
75 });
76
77 let mut trial_index = 0;
78
79 while let Some(params) = sampler.sample(&study.search_space, trial_index)? {
80 let trial_id = format!("trial_{trial_index:04}");
81 let mut trial = Trial::new(trial_id.clone(), params.clone());
82 trial.state = TrialState::Running;
83
84 self.event_bus.emit(Event::TrialStarted {
85 study_id: study.id.clone(),
86 trial_id: trial_id.clone(),
87 params: serde_json::json!(params),
88 });
89
90 let start = Instant::now();
91
92 match executor.execute_trial(¶ms) {
93 Ok(TrialOutcome::Completed(metrics)) => {
94 trial.duration_ms = Some(start.elapsed().as_millis() as u64);
95 trial.metrics = metrics.clone();
96 trial.state = TrialState::Completed;
97
98 for metric in &metrics {
99 self.event_bus.emit(Event::TrialMetric {
100 study_id: study.id.clone(),
101 trial_id: trial_id.clone(),
102 metric: metric.clone(),
103 });
104 }
105
106 self.event_bus.emit(Event::TrialCompleted {
107 study_id: study.id.clone(),
108 trial_id: trial_id.clone(),
109 final_metrics: metrics,
110 });
111 }
112 Ok(TrialOutcome::Pruned { step, reason }) => {
113 trial.duration_ms = Some(start.elapsed().as_millis() as u64);
114 trial.state = TrialState::Pruned {
115 step,
116 reason: reason.clone(),
117 };
118
119 self.event_bus.emit(Event::TrialPruned {
120 study_id: study.id.clone(),
121 trial_id: trial_id.clone(),
122 step,
123 reason,
124 });
125 }
126 Err(e) => {
127 trial.duration_ms = Some(start.elapsed().as_millis() as u64);
128 trial.state = TrialState::Failed {
129 error: e.to_string(),
130 };
131
132 self.event_bus.emit(Event::TrialFailed {
133 study_id: study.id.clone(),
134 trial_id: trial_id.clone(),
135 error: e.to_string(),
136 });
137 }
138 }
139
140 study.trials.push(trial);
141
142 if let Some(best) = study.best_trial()
144 && best.id == trial_id
145 && let Some(obj) = study.objectives.first()
146 && let Some(val) = best.best_metric(&obj.metric, obj.direction)
147 {
148 self.event_bus.emit(Event::BestUpdated {
149 study_id: study.id.clone(),
150 trial_id: trial_id.clone(),
151 value: val,
152 params: serde_json::json!(params),
153 });
154 }
155
156 let completed = study.trials.iter().filter(|t| t.is_terminal()).count();
157 self.event_bus.emit(Event::StudyProgress {
158 study_id: study.id.clone(),
159 completed,
160 total,
161 best_value: study
162 .best_trial()
163 .and_then(|t| {
164 study
165 .objectives
166 .first()
167 .and_then(|o| t.best_metric(&o.metric, o.direction))
168 })
169 .unwrap_or(f64::NAN),
170 });
171
172 trial_index += 1;
173 }
174
175 let best_trial_id = study.best_trial().map(|t| t.id.clone()).unwrap_or_default();
176 let best_value = study
177 .best_trial()
178 .and_then(|t| {
179 study
180 .objectives
181 .first()
182 .and_then(|o| t.best_metric(&o.metric, o.direction))
183 })
184 .unwrap_or(f64::NAN);
185
186 self.event_bus.emit(Event::StudyCompleted {
187 study_id: study.id.clone(),
188 best_trial_id,
189 best_value,
190 });
191
192 Ok(())
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::sampler::{GridSampler, RandomSampler};
200 use chrono::Utc;
201 use somatize_core::error::SomaError;
202 use somatize_core::search::{Scale, SearchDimension, SearchSpace};
203 use somatize_core::study::{Direction, Objective, SearchStrategy};
204
205 fn sample_space() -> SearchSpace {
206 let mut space = SearchSpace::new();
207 space.add(SearchDimension::Float {
208 name: "lr".into(),
209 low: 0.001,
210 high: 0.1,
211 scale: Scale::Log,
212 default: None,
213 });
214 space.add(SearchDimension::Categorical {
215 name: "activation".into(),
216 choices: vec![serde_json::json!("relu"), serde_json::json!("tanh")],
217 });
218 space
219 }
220
221 fn make_executor() -> FnTrialExecutor<
223 impl Fn(&std::collections::HashMap<String, serde_json::Value>) -> Result<TrialOutcome>,
224 > {
225 FnTrialExecutor(
226 |params: &std::collections::HashMap<String, serde_json::Value>| {
227 let lr = params["lr"].as_f64().unwrap();
228 let f1 = (1.0 - (lr - 0.01).abs() * 10.0).max(0.0);
229 Ok(TrialOutcome::Completed(vec![MetricRecord {
230 name: "f1".into(),
231 value: f1,
232 step: 0,
233 timestamp: Utc::now(),
234 }]))
235 },
236 )
237 }
238
239 #[test]
240 fn study_runner_grid_search() {
241 let bus = Arc::new(EventBus::new(256));
242 let mut rx = bus.subscribe();
243 let runner = StudyRunner::new(bus);
244
245 let space = sample_space();
246 let mut study = Study::new(
247 "grid_test",
248 space,
249 SearchStrategy::Grid { points_per_dim: 3 },
250 vec![Objective {
251 metric: "f1".into(),
252 direction: Direction::Maximize,
253 }],
254 );
255
256 let mut sampler = GridSampler::new(3);
257 let executor = make_executor();
258
259 runner.run(&mut study, &mut sampler, &executor).unwrap();
260
261 assert_eq!(study.trials.len(), 6);
263 assert!(study.trials.iter().all(|t| t.is_complete()));
264
265 let best = study.best_trial().unwrap();
267 let best_lr = best.params["lr"].as_f64().unwrap();
268 assert!(
269 (best_lr - 0.01).abs() < 0.05,
270 "best lr should be near 0.01, got {best_lr}"
271 );
272
273 let mut events = Vec::new();
275 while let Ok(e) = rx.try_recv() {
276 events.push(e);
277 }
278 assert!(
279 events
280 .iter()
281 .any(|e| matches!(e, Event::StudyStarted { .. }))
282 );
283 assert!(
284 events
285 .iter()
286 .any(|e| matches!(e, Event::TrialStarted { .. }))
287 );
288 assert!(
289 events
290 .iter()
291 .any(|e| matches!(e, Event::TrialCompleted { .. }))
292 );
293 assert!(
294 events
295 .iter()
296 .any(|e| matches!(e, Event::BestUpdated { .. }))
297 );
298 assert!(
299 events
300 .iter()
301 .any(|e| matches!(e, Event::StudyCompleted { .. }))
302 );
303 }
304
305 #[test]
306 fn study_runner_random_search() {
307 let bus = Arc::new(EventBus::new(256));
308 let runner = StudyRunner::new(bus);
309
310 let space = sample_space();
311 let mut study = Study::new(
312 "random_test",
313 space,
314 SearchStrategy::Random {
315 n_trials: 20,
316 seed: Some(42),
317 },
318 vec![Objective {
319 metric: "f1".into(),
320 direction: Direction::Maximize,
321 }],
322 );
323
324 let mut sampler = RandomSampler::new(20, Some(42));
325 let executor = make_executor();
326
327 runner.run(&mut study, &mut sampler, &executor).unwrap();
328
329 assert_eq!(study.trials.len(), 20);
330 assert!(study.best_trial().is_some());
331 }
332
333 #[test]
334 fn study_runner_handles_failed_trials() {
335 let bus = Arc::new(EventBus::new(256));
336 let runner = StudyRunner::new(bus);
337
338 let mut space = SearchSpace::new();
339 space.add(SearchDimension::Float {
340 name: "x".into(),
341 low: 0.0,
342 high: 1.0,
343 scale: Scale::Linear,
344 default: None,
345 });
346
347 let mut study = Study::new(
348 "fail_test",
349 space,
350 SearchStrategy::Random {
351 n_trials: 5,
352 seed: None,
353 },
354 vec![Objective {
355 metric: "f1".into(),
356 direction: Direction::Maximize,
357 }],
358 );
359
360 let executor = FnTrialExecutor(
362 |params: &std::collections::HashMap<String, serde_json::Value>| {
363 let x = params["x"].as_f64().unwrap();
364 if x > 0.5 {
365 Err(SomaError::Other("too high".into()))
366 } else {
367 Ok(TrialOutcome::Completed(vec![MetricRecord {
368 name: "f1".into(),
369 value: x,
370 step: 0,
371 timestamp: Utc::now(),
372 }]))
373 }
374 },
375 );
376
377 let mut sampler = RandomSampler::new(5, Some(42));
378 runner.run(&mut study, &mut sampler, &executor).unwrap();
379
380 assert_eq!(study.trials.len(), 5);
381 let failed = study
383 .trials
384 .iter()
385 .filter(|t| matches!(t.state, TrialState::Failed { .. }))
386 .count();
387 assert!(failed > 0, "should have some failed trials");
388 }
389
390 #[test]
391 fn study_runner_handles_pruned_trials() {
392 let bus = Arc::new(EventBus::new(256));
393 let runner = StudyRunner::new(bus);
394
395 let mut space = SearchSpace::new();
396 space.add(SearchDimension::Float {
397 name: "x".into(),
398 low: 0.0,
399 high: 1.0,
400 scale: Scale::Linear,
401 default: None,
402 });
403
404 let mut study = Study::new(
405 "prune_test",
406 space,
407 SearchStrategy::Random {
408 n_trials: 3,
409 seed: None,
410 },
411 vec![Objective {
412 metric: "f1".into(),
413 direction: Direction::Maximize,
414 }],
415 );
416
417 let executor = FnTrialExecutor(
419 |_params: &std::collections::HashMap<String, serde_json::Value>| {
420 Ok(TrialOutcome::Pruned {
421 step: 5,
422 reason: "below median".into(),
423 })
424 },
425 );
426
427 let mut sampler = RandomSampler::new(3, Some(42));
428 runner.run(&mut study, &mut sampler, &executor).unwrap();
429
430 assert!(
431 study
432 .trials
433 .iter()
434 .all(|t| matches!(t.state, TrialState::Pruned { .. }))
435 );
436 }
437
438 #[test]
439 fn study_progress_tracking() {
440 let bus = Arc::new(EventBus::new(256));
441 let mut rx = bus.subscribe();
442 let runner = StudyRunner::new(bus);
443
444 let mut space = SearchSpace::new();
445 space.add(SearchDimension::Float {
446 name: "x".into(),
447 low: 0.0,
448 high: 1.0,
449 scale: Scale::Linear,
450 default: None,
451 });
452
453 let mut study = Study::new(
454 "progress_test",
455 space,
456 SearchStrategy::Random {
457 n_trials: 3,
458 seed: None,
459 },
460 vec![Objective {
461 metric: "f1".into(),
462 direction: Direction::Maximize,
463 }],
464 );
465
466 let executor = FnTrialExecutor(
467 |_params: &std::collections::HashMap<String, serde_json::Value>| {
468 Ok(TrialOutcome::Completed(vec![MetricRecord {
469 name: "f1".into(),
470 value: 0.5,
471 step: 0,
472 timestamp: Utc::now(),
473 }]))
474 },
475 );
476
477 let mut sampler = RandomSampler::new(3, Some(42));
478 runner.run(&mut study, &mut sampler, &executor).unwrap();
479
480 let mut progress_events = Vec::new();
482 while let Ok(e) = rx.try_recv() {
483 if let Event::StudyProgress {
484 completed, total, ..
485 } = e
486 {
487 progress_events.push((completed, total));
488 }
489 }
490
491 assert_eq!(progress_events.len(), 3);
492 assert_eq!(progress_events[0], (1, 3));
493 assert_eq!(progress_events[1], (2, 3));
494 assert_eq!(progress_events[2], (3, 3));
495 }
496}