1use crate::{OptimizerError, OptimizerResult};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct HyperparameterSpace {
15 pub continuous: HashMap<String, (f32, f32)>,
17 pub discrete: HashMap<String, (i32, i32)>,
19 pub categorical: HashMap<String, Vec<String>>,
21 pub log_scale: Vec<String>,
23}
24
25impl HyperparameterSpace {
26 pub fn new() -> Self {
27 Self {
28 continuous: HashMap::new(),
29 discrete: HashMap::new(),
30 categorical: HashMap::new(),
31 log_scale: Vec::new(),
32 }
33 }
34
35 pub fn add_continuous(mut self, name: &str, min: f32, max: f32) -> Self {
36 self.continuous.insert(name.to_string(), (min, max));
37 self
38 }
39
40 pub fn add_discrete(mut self, name: &str, min: i32, max: i32) -> Self {
41 self.discrete.insert(name.to_string(), (min, max));
42 self
43 }
44
45 pub fn add_categorical(mut self, name: &str, choices: Vec<&str>) -> Self {
46 self.categorical.insert(
47 name.to_string(),
48 choices.iter().map(|s| s.to_string()).collect(),
49 );
50 self
51 }
52
53 pub fn set_log_scale(mut self, names: Vec<&str>) -> Self {
54 self.log_scale = names.iter().map(|s| s.to_string()).collect();
55 self
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct HyperparameterConfig {
62 pub parameters: HashMap<String, HyperparameterValue>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum HyperparameterValue {
67 Float(f32),
68 Int(i32),
69 String(String),
70}
71
72impl HyperparameterConfig {
73 pub fn new() -> Self {
74 Self {
75 parameters: HashMap::new(),
76 }
77 }
78
79 pub fn set_float(&mut self, name: &str, value: f32) {
80 self.parameters
81 .insert(name.to_string(), HyperparameterValue::Float(value));
82 }
83
84 pub fn set_int(&mut self, name: &str, value: i32) {
85 self.parameters
86 .insert(name.to_string(), HyperparameterValue::Int(value));
87 }
88
89 pub fn set_string(&mut self, name: &str, value: &str) {
90 self.parameters.insert(
91 name.to_string(),
92 HyperparameterValue::String(value.to_string()),
93 );
94 }
95
96 pub fn get_float(&self, name: &str) -> OptimizerResult<f32> {
97 match self.parameters.get(name) {
98 Some(HyperparameterValue::Float(v)) => Ok(*v),
99 _ => Err(OptimizerError::InvalidParameter(format!(
100 "Parameter {} not found or not float",
101 name
102 ))),
103 }
104 }
105
106 pub fn get_int(&self, name: &str) -> OptimizerResult<i32> {
107 match self.parameters.get(name) {
108 Some(HyperparameterValue::Int(v)) => Ok(*v),
109 _ => Err(OptimizerError::InvalidParameter(format!(
110 "Parameter {} not found or not int",
111 name
112 ))),
113 }
114 }
115
116 pub fn get_string(&self, name: &str) -> OptimizerResult<&str> {
117 match self.parameters.get(name) {
118 Some(HyperparameterValue::String(v)) => Ok(v),
119 _ => Err(OptimizerError::InvalidParameter(format!(
120 "Parameter {} not found or not string",
121 name
122 ))),
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub enum SearchStrategy {
130 Random { n_trials: usize },
132 Grid { n_points_per_dim: usize },
134 Bayesian {
136 n_trials: usize,
137 acquisition_function: AcquisitionFunction,
138 },
139 Evolutionary {
141 population_size: usize,
142 n_generations: usize,
143 mutation_rate: f32,
144 },
145}
146
147#[derive(Debug, Clone)]
148pub enum AcquisitionFunction {
149 ExpectedImprovement,
150 ProbabilityOfImprovement,
151 UpperConfidenceBound { kappa: f32 },
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct Trial {
157 pub config: HyperparameterConfig,
158 pub objective_value: f32,
159 pub duration: Duration,
160 pub metadata: HashMap<String, String>,
161}
162
163#[derive(Debug, Clone)]
165pub struct TuningConfig {
166 pub search_strategy: SearchStrategy,
167 pub objective: ObjectiveFunction,
168 pub search_space: HyperparameterSpace,
169 pub max_duration: Option<Duration>,
170 pub early_stopping: Option<EarlyStoppingConfig>,
171 pub parallel_trials: usize,
172}
173
174#[derive(Debug, Clone)]
175pub enum ObjectiveFunction {
176 MinimizeValidationLoss,
178 MaximizeValidationAccuracy,
180 Custom(fn(&HyperparameterConfig) -> f32),
182}
183
184#[derive(Debug, Clone)]
185pub struct EarlyStoppingConfig {
186 pub patience: usize,
187 pub min_delta: f32,
188}
189
190pub struct HyperparameterTuner {
192 config: TuningConfig,
193 trials: Vec<Trial>,
194 best_trial: Option<Trial>,
195 start_time: Option<Instant>,
196}
197
198impl HyperparameterTuner {
199 pub fn new(config: TuningConfig) -> Self {
200 Self {
201 config,
202 trials: Vec::new(),
203 best_trial: None,
204 start_time: None,
205 }
206 }
207
208 pub fn optimize<F>(&mut self, objective_fn: F) -> OptimizerResult<HyperparameterConfig>
210 where
211 F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
212 {
213 self.start_time = Some(Instant::now());
214
215 match &self.config.search_strategy {
216 SearchStrategy::Random { n_trials } => self.random_search(*n_trials, objective_fn),
217 SearchStrategy::Grid { n_points_per_dim } => {
218 self.grid_search(*n_points_per_dim, objective_fn)
219 }
220 SearchStrategy::Bayesian { n_trials, .. } => {
221 self.bayesian_optimization(*n_trials, objective_fn)
222 }
223 SearchStrategy::Evolutionary {
224 population_size,
225 n_generations,
226 ..
227 } => self.evolutionary_search(*population_size, *n_generations, objective_fn),
228 }
229 }
230
231 fn random_search<F>(
232 &mut self,
233 n_trials: usize,
234 objective_fn: F,
235 ) -> OptimizerResult<HyperparameterConfig>
236 where
237 F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
238 {
239 for _ in 0..n_trials {
240 if self.should_stop() {
241 break;
242 }
243
244 let config = self.sample_random_config()?;
245 let trial_start = Instant::now();
246
247 match objective_fn(&config) {
248 Ok(objective_value) => {
249 let trial = Trial {
250 config: config.clone(),
251 objective_value,
252 duration: trial_start.elapsed(),
253 metadata: HashMap::new(),
254 };
255
256 self.update_best_trial(&trial);
257 self.trials.push(trial);
258 }
259 Err(e) => {
260 log::warn!("Trial failed: {:?}", e);
261 }
262 }
263 }
264
265 self.get_best_config()
266 }
267
268 fn grid_search<F>(
269 &mut self,
270 n_points_per_dim: usize,
271 objective_fn: F,
272 ) -> OptimizerResult<HyperparameterConfig>
273 where
274 F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
275 {
276 let grid_points = self.generate_grid_points(n_points_per_dim)?;
277
278 for config in grid_points {
279 if self.should_stop() {
280 break;
281 }
282
283 let trial_start = Instant::now();
284
285 match objective_fn(&config) {
286 Ok(objective_value) => {
287 let trial = Trial {
288 config: config.clone(),
289 objective_value,
290 duration: trial_start.elapsed(),
291 metadata: HashMap::new(),
292 };
293
294 self.update_best_trial(&trial);
295 self.trials.push(trial);
296 }
297 Err(e) => {
298 log::warn!("Trial failed: {:?}", e);
299 }
300 }
301 }
302
303 self.get_best_config()
304 }
305
306 fn bayesian_optimization<F>(
307 &mut self,
308 n_trials: usize,
309 objective_fn: F,
310 ) -> OptimizerResult<HyperparameterConfig>
311 where
312 F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
313 {
314 let n_random = (n_trials as f32 * 0.3) as usize;
319 for _ in 0..n_random {
320 if self.should_stop() {
321 break;
322 }
323
324 let config = self.sample_random_config()?;
325 let trial_start = Instant::now();
326
327 match objective_fn(&config) {
328 Ok(objective_value) => {
329 let trial = Trial {
330 config: config.clone(),
331 objective_value,
332 duration: trial_start.elapsed(),
333 metadata: HashMap::new(),
334 };
335
336 self.update_best_trial(&trial);
337 self.trials.push(trial);
338 }
339 Err(e) => {
340 log::warn!("Trial failed: {:?}", e);
341 }
342 }
343 }
344
345 for _ in n_random..n_trials {
347 if self.should_stop() {
348 break;
349 }
350
351 let config = if let Some(best) = &self.best_trial {
352 self.sample_around_config(&best.config, 0.1)?
353 } else {
354 self.sample_random_config()?
355 };
356
357 let trial_start = Instant::now();
358
359 match objective_fn(&config) {
360 Ok(objective_value) => {
361 let trial = Trial {
362 config: config.clone(),
363 objective_value,
364 duration: trial_start.elapsed(),
365 metadata: HashMap::new(),
366 };
367
368 self.update_best_trial(&trial);
369 self.trials.push(trial);
370 }
371 Err(e) => {
372 log::warn!("Trial failed: {:?}", e);
373 }
374 }
375 }
376
377 self.get_best_config()
378 }
379
380 fn evolutionary_search<F>(
381 &mut self,
382 population_size: usize,
383 n_generations: usize,
384 objective_fn: F,
385 ) -> OptimizerResult<HyperparameterConfig>
386 where
387 F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
388 {
389 let mut population = Vec::new();
391 for _ in 0..population_size {
392 let config = self.sample_random_config()?;
393 if let Ok(objective_value) = objective_fn(&config) {
394 population.push((config, objective_value));
395 }
396 }
397
398 for _generation in 0..n_generations {
400 if self.should_stop() {
401 break;
402 }
403
404 let mut new_population = Vec::new();
406 for _ in 0..population_size {
407 let parent1 = self.tournament_selection(&population, 3);
408 let parent2 = self.tournament_selection(&population, 3);
409
410 if let Ok(child) = self.crossover(&parent1.0, &parent2.0) {
411 let mutated = self.mutate(&child, 0.1)?;
412
413 if let Ok(objective_value) = objective_fn(&mutated) {
414 new_population.push((mutated, objective_value));
415 }
416 }
417 }
418
419 population = new_population;
420
421 if let Some((best_config, best_value)) = population
423 .iter()
424 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
425 {
426 let trial = Trial {
427 config: best_config.clone(),
428 objective_value: *best_value,
429 duration: Duration::from_secs(0), metadata: HashMap::new(),
431 };
432 self.update_best_trial(&trial);
433 }
434 }
435
436 self.get_best_config()
437 }
438
439 fn sample_random_config(&self) -> OptimizerResult<HyperparameterConfig> {
440 let mut config = HyperparameterConfig::new();
441
442 for (name, (min, max)) in &self.config.search_space.continuous {
444 let value = if self.config.search_space.log_scale.contains(name) {
445 let log_min = min.ln();
446 let log_max = max.ln();
447 let log_value = log_min + 0.5 * (log_max - log_min);
448 log_value.exp()
449 } else {
450 min + 0.5 * (max - min)
451 };
452 config.set_float(name, value);
453 }
454
455 for (name, (min, max)) in &self.config.search_space.discrete {
457 let value = (*min + *max) / 2;
458 config.set_int(name, value);
459 }
460
461 for (name, choices) in &self.config.search_space.categorical {
463 if !choices.is_empty() {
464 let idx = 0;
465 config.set_string(name, &choices[idx]);
466 }
467 }
468
469 Ok(config)
470 }
471
472 fn sample_around_config(
473 &self,
474 base_config: &HyperparameterConfig,
475 noise_scale: f32,
476 ) -> OptimizerResult<HyperparameterConfig> {
477 let mut config = base_config.clone();
478
479 for (name, (min, max)) in &self.config.search_space.continuous {
481 if let Ok(base_value) = base_config.get_float(name) {
482 let range = max - min;
483 let noise = 0.0;
484 let new_value = (base_value + noise).clamp(*min, *max);
485 config.set_float(name, new_value);
486 }
487 }
488
489 Ok(config)
490 }
491
492 fn generate_grid_points(
493 &self,
494 n_points_per_dim: usize,
495 ) -> OptimizerResult<Vec<HyperparameterConfig>> {
496 let mut configs = Vec::new();
497
498 if self.config.search_space.continuous.is_empty() {
500 return Ok(vec![self.sample_random_config()?]);
501 }
502
503 let param_names: Vec<_> = self
504 .config
505 .search_space
506 .continuous
507 .keys()
508 .cloned()
509 .collect();
510 let n_params = param_names.len();
511
512 let total_points = n_points_per_dim.pow(n_params as u32);
514
515 for i in 0..total_points {
516 let mut config = HyperparameterConfig::new();
517 let mut remaining = i;
518
519 for (_param_idx, param_name) in param_names.iter().enumerate() {
520 let (min, max) = self.config.search_space.continuous[param_name];
521 let grid_idx = remaining % n_points_per_dim;
522 remaining /= n_points_per_dim;
523
524 let value = if n_points_per_dim == 1 {
525 (min + max) / 2.0
526 } else {
527 min + (grid_idx as f32) * (max - min) / ((n_points_per_dim - 1) as f32)
528 };
529
530 config.set_float(param_name, value);
531 }
532
533 configs.push(config);
534 }
535
536 Ok(configs)
537 }
538
539 fn tournament_selection<'a>(
540 &self,
541 population: &'a [(HyperparameterConfig, f32)],
542 tournament_size: usize,
543 ) -> &'a (HyperparameterConfig, f32) {
544 let mut best = &population[0];
545
546 for _ in 1..tournament_size {
547 let candidate = &population[0];
548 if candidate.1 < best.1 {
549 best = candidate;
551 }
552 }
553
554 best
555 }
556
557 fn crossover(
558 &self,
559 parent1: &HyperparameterConfig,
560 parent2: &HyperparameterConfig,
561 ) -> OptimizerResult<HyperparameterConfig> {
562 let mut child = HyperparameterConfig::new();
563
564 for name in self.config.search_space.continuous.keys() {
566 if let (Ok(v1), Ok(v2)) = (parent1.get_float(name), parent2.get_float(name)) {
567 let alpha = 0.5;
568 let value = (1.0 - alpha) * v1 + alpha * v2;
569 child.set_float(name, value);
570 }
571 }
572
573 for name in self.config.search_space.discrete.keys() {
575 let value = if true {
576 parent1.get_int(name).unwrap_or(0)
577 } else {
578 parent2.get_int(name).unwrap_or(0)
579 };
580 child.set_int(name, value);
581 }
582
583 Ok(child)
584 }
585
586 fn mutate(
587 &self,
588 config: &HyperparameterConfig,
589 mutation_rate: f32,
590 ) -> OptimizerResult<HyperparameterConfig> {
591 let mut mutated = config.clone();
592
593 for (name, (min, max)) in &self.config.search_space.continuous {
594 if 0.1 < mutation_rate {
595 if let Ok(current_value) = config.get_float(name) {
596 let range = max - min;
597 let noise = 0.0;
598 let new_value = (current_value + noise).clamp(*min, *max);
599 mutated.set_float(name, new_value);
600 }
601 }
602 }
603
604 Ok(mutated)
605 }
606
607 fn should_stop(&self) -> bool {
608 if let (Some(start_time), Some(max_duration)) = (self.start_time, &self.config.max_duration)
609 {
610 if start_time.elapsed() > *max_duration {
611 return true;
612 }
613 }
614
615 if let Some(early_stopping) = &self.config.early_stopping {
617 if self.trials.len() >= early_stopping.patience {
618 let recent_trials = &self.trials[self.trials.len() - early_stopping.patience..];
619 let values: Vec<f32> = recent_trials.iter().map(|t| t.objective_value).collect();
620
621 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
622 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
623
624 if (max_val - min_val) < early_stopping.min_delta {
625 return true;
626 }
627 }
628 }
629
630 false
631 }
632
633 fn update_best_trial(&mut self, trial: &Trial) {
634 let is_better = match &self.best_trial {
635 None => true,
636 Some(best) => match &self.config.objective {
637 ObjectiveFunction::MinimizeValidationLoss => {
638 trial.objective_value < best.objective_value
639 }
640 ObjectiveFunction::MaximizeValidationAccuracy => {
641 trial.objective_value > best.objective_value
642 }
643 ObjectiveFunction::Custom(_) => trial.objective_value < best.objective_value, },
645 };
646
647 if is_better {
648 self.best_trial = Some(trial.clone());
649 }
650 }
651
652 fn get_best_config(&self) -> OptimizerResult<HyperparameterConfig> {
653 match &self.best_trial {
654 Some(trial) => Ok(trial.config.clone()),
655 None => Err(OptimizerError::StateError(
656 "No successful trials found".to_string(),
657 )),
658 }
659 }
660
661 pub fn get_results(&self) -> TuningResults {
663 TuningResults {
664 best_config: self.best_trial.as_ref().map(|t| t.config.clone()),
665 best_value: self.best_trial.as_ref().map(|t| t.objective_value),
666 trials: self.trials.clone(),
667 total_duration: self.start_time.map(|t| t.elapsed()),
668 }
669 }
670}
671
672#[derive(Debug, Clone)]
674pub struct TuningResults {
675 pub best_config: Option<HyperparameterConfig>,
676 pub best_value: Option<f32>,
677 pub trials: Vec<Trial>,
678 pub total_duration: Option<Duration>,
679}
680
681impl TuningResults {
682 pub fn convergence_history(&self) -> Vec<f32> {
684 let mut best_so_far = f32::INFINITY;
685 let mut history = Vec::new();
686
687 for trial in &self.trials {
688 if trial.objective_value < best_so_far {
689 best_so_far = trial.objective_value;
690 }
691 history.push(best_so_far);
692 }
693
694 history
695 }
696
697 pub fn parameter_importance(&self) -> HashMap<String, f32> {
699 let mut param_values: HashMap<String, Vec<f32>> = HashMap::new();
700
701 if self.trials.len() < 2 {
702 return HashMap::new();
703 }
704
705 for trial in &self.trials {
707 for (param_name, param_value) in &trial.config.parameters {
708 if let HyperparameterValue::Float(value) = param_value {
709 param_values
710 .entry(param_name.clone())
711 .or_insert(Vec::new())
712 .push(*value);
713 }
714 }
715 }
716
717 let mut variance_scores = HashMap::new();
719 for (param_name, values) in ¶m_values {
720 if values.len() > 1 {
721 let mean = values.iter().sum::<f32>() / values.len() as f32;
722 let variance =
723 values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
724 variance_scores.insert(param_name.clone(), variance);
725 }
726 }
727
728 let mut importance = HashMap::new();
730 let max_variance = variance_scores.values().fold(0.0f32, |a, &b| a.max(b));
731 if max_variance > 0.0 {
732 for (param_name, variance) in variance_scores {
733 importance.insert(param_name, variance / max_variance);
734 }
735 }
736
737 importance
738 }
739}
740
741pub mod presets {
743 use super::*;
744
745 pub fn adam_space() -> HyperparameterSpace {
747 HyperparameterSpace::new()
748 .add_continuous("lr", 1e-5, 1e-1)
749 .add_continuous("beta1", 0.8, 0.99)
750 .add_continuous("beta2", 0.9, 0.999)
751 .add_continuous("eps", 1e-10, 1e-6)
752 .add_continuous("weight_decay", 0.0, 1e-2)
753 .set_log_scale(vec!["lr", "eps"])
754 }
755
756 pub fn sgd_space() -> HyperparameterSpace {
758 HyperparameterSpace::new()
759 .add_continuous("lr", 1e-4, 1.0)
760 .add_continuous("momentum", 0.0, 0.99)
761 .add_continuous("weight_decay", 0.0, 1e-2)
762 .add_categorical("nesterov", vec!["true", "false"])
763 .set_log_scale(vec!["lr"])
764 }
765
766 pub fn rmsprop_space() -> HyperparameterSpace {
768 HyperparameterSpace::new()
769 .add_continuous("lr", 1e-5, 1e-1)
770 .add_continuous("alpha", 0.9, 0.999)
771 .add_continuous("eps", 1e-10, 1e-6)
772 .add_continuous("weight_decay", 0.0, 1e-2)
773 .add_continuous("momentum", 0.0, 0.1)
774 .set_log_scale(vec!["lr", "eps"])
775 }
776}