Skip to main content

torsh_optim/
hyperparameter_tuning.rs

1//! Automatic hyperparameter tuning for optimizers
2//!
3//! This module provides tools for automatically tuning optimizer hyperparameters
4//! using various search strategies including grid search, random search, Bayesian optimization,
5//! and evolutionary algorithms.
6
7use crate::{OptimizerError, OptimizerResult};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::{Duration, Instant};
11
12/// Hyperparameter search space definition
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct HyperparameterSpace {
15    /// Continuous parameters with [min, max] bounds
16    pub continuous: HashMap<String, (f32, f32)>,
17    /// Discrete integer parameters with [min, max] bounds
18    pub discrete: HashMap<String, (i32, i32)>,
19    /// Categorical parameters with list of choices
20    pub categorical: HashMap<String, Vec<String>>,
21    /// Log-scale parameters (will be sampled in log space)
22    pub log_scale: Vec<String>,
23}
24
25impl HyperparameterSpace {
26    pub fn new() -> Self {
27        Self {
28            continuous: HashMap::new(),
29            discrete: HashMap::new(),
30            categorical: HashMap::new(),
31            log_scale: Vec::new(),
32        }
33    }
34
35    pub fn add_continuous(mut self, name: &str, min: f32, max: f32) -> Self {
36        self.continuous.insert(name.to_string(), (min, max));
37        self
38    }
39
40    pub fn add_discrete(mut self, name: &str, min: i32, max: i32) -> Self {
41        self.discrete.insert(name.to_string(), (min, max));
42        self
43    }
44
45    pub fn add_categorical(mut self, name: &str, choices: Vec<&str>) -> Self {
46        self.categorical.insert(
47            name.to_string(),
48            choices.iter().map(|s| s.to_string()).collect(),
49        );
50        self
51    }
52
53    pub fn set_log_scale(mut self, names: Vec<&str>) -> Self {
54        self.log_scale = names.iter().map(|s| s.to_string()).collect();
55        self
56    }
57}
58
59/// Hyperparameter configuration
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct HyperparameterConfig {
62    pub parameters: HashMap<String, HyperparameterValue>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum HyperparameterValue {
67    Float(f32),
68    Int(i32),
69    String(String),
70}
71
72impl HyperparameterConfig {
73    pub fn new() -> Self {
74        Self {
75            parameters: HashMap::new(),
76        }
77    }
78
79    pub fn set_float(&mut self, name: &str, value: f32) {
80        self.parameters
81            .insert(name.to_string(), HyperparameterValue::Float(value));
82    }
83
84    pub fn set_int(&mut self, name: &str, value: i32) {
85        self.parameters
86            .insert(name.to_string(), HyperparameterValue::Int(value));
87    }
88
89    pub fn set_string(&mut self, name: &str, value: &str) {
90        self.parameters.insert(
91            name.to_string(),
92            HyperparameterValue::String(value.to_string()),
93        );
94    }
95
96    pub fn get_float(&self, name: &str) -> OptimizerResult<f32> {
97        match self.parameters.get(name) {
98            Some(HyperparameterValue::Float(v)) => Ok(*v),
99            _ => Err(OptimizerError::InvalidParameter(format!(
100                "Parameter {} not found or not float",
101                name
102            ))),
103        }
104    }
105
106    pub fn get_int(&self, name: &str) -> OptimizerResult<i32> {
107        match self.parameters.get(name) {
108            Some(HyperparameterValue::Int(v)) => Ok(*v),
109            _ => Err(OptimizerError::InvalidParameter(format!(
110                "Parameter {} not found or not int",
111                name
112            ))),
113        }
114    }
115
116    pub fn get_string(&self, name: &str) -> OptimizerResult<&str> {
117        match self.parameters.get(name) {
118            Some(HyperparameterValue::String(v)) => Ok(v),
119            _ => Err(OptimizerError::InvalidParameter(format!(
120                "Parameter {} not found or not string",
121                name
122            ))),
123        }
124    }
125}
126
127/// Search strategy for hyperparameter optimization
128#[derive(Debug, Clone)]
129pub enum SearchStrategy {
130    /// Random search with specified number of trials
131    Random { n_trials: usize },
132    /// Grid search with specified resolution per dimension
133    Grid { n_points_per_dim: usize },
134    /// Bayesian optimization using Gaussian processes
135    Bayesian {
136        n_trials: usize,
137        acquisition_function: AcquisitionFunction,
138    },
139    /// Evolutionary algorithm
140    Evolutionary {
141        population_size: usize,
142        n_generations: usize,
143        mutation_rate: f32,
144    },
145}
146
147#[derive(Debug, Clone)]
148pub enum AcquisitionFunction {
149    ExpectedImprovement,
150    ProbabilityOfImprovement,
151    UpperConfidenceBound { kappa: f32 },
152}
153
154/// Trial result from hyperparameter optimization
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct Trial {
157    pub config: HyperparameterConfig,
158    pub objective_value: f32,
159    pub duration: Duration,
160    pub metadata: HashMap<String, String>,
161}
162
163/// Configuration for hyperparameter tuning
164#[derive(Debug, Clone)]
165pub struct TuningConfig {
166    pub search_strategy: SearchStrategy,
167    pub objective: ObjectiveFunction,
168    pub search_space: HyperparameterSpace,
169    pub max_duration: Option<Duration>,
170    pub early_stopping: Option<EarlyStoppingConfig>,
171    pub parallel_trials: usize,
172}
173
174#[derive(Debug, Clone)]
175pub enum ObjectiveFunction {
176    /// Minimize validation loss
177    MinimizeValidationLoss,
178    /// Maximize validation accuracy
179    MaximizeValidationAccuracy,
180    /// Custom objective function
181    Custom(fn(&HyperparameterConfig) -> f32),
182}
183
184#[derive(Debug, Clone)]
185pub struct EarlyStoppingConfig {
186    pub patience: usize,
187    pub min_delta: f32,
188}
189
190/// Hyperparameter tuner
191pub struct HyperparameterTuner {
192    config: TuningConfig,
193    trials: Vec<Trial>,
194    best_trial: Option<Trial>,
195    start_time: Option<Instant>,
196}
197
198impl HyperparameterTuner {
199    pub fn new(config: TuningConfig) -> Self {
200        Self {
201            config,
202            trials: Vec::new(),
203            best_trial: None,
204            start_time: None,
205        }
206    }
207
208    /// Run hyperparameter optimization
209    pub fn optimize<F>(&mut self, objective_fn: F) -> OptimizerResult<HyperparameterConfig>
210    where
211        F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
212    {
213        self.start_time = Some(Instant::now());
214
215        match &self.config.search_strategy {
216            SearchStrategy::Random { n_trials } => self.random_search(*n_trials, objective_fn),
217            SearchStrategy::Grid { n_points_per_dim } => {
218                self.grid_search(*n_points_per_dim, objective_fn)
219            }
220            SearchStrategy::Bayesian { n_trials, .. } => {
221                self.bayesian_optimization(*n_trials, objective_fn)
222            }
223            SearchStrategy::Evolutionary {
224                population_size,
225                n_generations,
226                ..
227            } => self.evolutionary_search(*population_size, *n_generations, objective_fn),
228        }
229    }
230
231    fn random_search<F>(
232        &mut self,
233        n_trials: usize,
234        objective_fn: F,
235    ) -> OptimizerResult<HyperparameterConfig>
236    where
237        F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
238    {
239        for _ in 0..n_trials {
240            if self.should_stop() {
241                break;
242            }
243
244            let config = self.sample_random_config()?;
245            let trial_start = Instant::now();
246
247            match objective_fn(&config) {
248                Ok(objective_value) => {
249                    let trial = Trial {
250                        config: config.clone(),
251                        objective_value,
252                        duration: trial_start.elapsed(),
253                        metadata: HashMap::new(),
254                    };
255
256                    self.update_best_trial(&trial);
257                    self.trials.push(trial);
258                }
259                Err(e) => {
260                    log::warn!("Trial failed: {:?}", e);
261                }
262            }
263        }
264
265        self.get_best_config()
266    }
267
268    fn grid_search<F>(
269        &mut self,
270        n_points_per_dim: usize,
271        objective_fn: F,
272    ) -> OptimizerResult<HyperparameterConfig>
273    where
274        F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
275    {
276        let grid_points = self.generate_grid_points(n_points_per_dim)?;
277
278        for config in grid_points {
279            if self.should_stop() {
280                break;
281            }
282
283            let trial_start = Instant::now();
284
285            match objective_fn(&config) {
286                Ok(objective_value) => {
287                    let trial = Trial {
288                        config: config.clone(),
289                        objective_value,
290                        duration: trial_start.elapsed(),
291                        metadata: HashMap::new(),
292                    };
293
294                    self.update_best_trial(&trial);
295                    self.trials.push(trial);
296                }
297                Err(e) => {
298                    log::warn!("Trial failed: {:?}", e);
299                }
300            }
301        }
302
303        self.get_best_config()
304    }
305
306    fn bayesian_optimization<F>(
307        &mut self,
308        n_trials: usize,
309        objective_fn: F,
310    ) -> OptimizerResult<HyperparameterConfig>
311    where
312        F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
313    {
314        // Simplified Bayesian optimization (in practice, would use GP library)
315        // For now, implement as random search with some exploitation
316
317        // Start with random exploration
318        let n_random = (n_trials as f32 * 0.3) as usize;
319        for _ in 0..n_random {
320            if self.should_stop() {
321                break;
322            }
323
324            let config = self.sample_random_config()?;
325            let trial_start = Instant::now();
326
327            match objective_fn(&config) {
328                Ok(objective_value) => {
329                    let trial = Trial {
330                        config: config.clone(),
331                        objective_value,
332                        duration: trial_start.elapsed(),
333                        metadata: HashMap::new(),
334                    };
335
336                    self.update_best_trial(&trial);
337                    self.trials.push(trial);
338                }
339                Err(e) => {
340                    log::warn!("Trial failed: {:?}", e);
341                }
342            }
343        }
344
345        // Then exploit around best configurations
346        for _ in n_random..n_trials {
347            if self.should_stop() {
348                break;
349            }
350
351            let config = if let Some(best) = &self.best_trial {
352                self.sample_around_config(&best.config, 0.1)?
353            } else {
354                self.sample_random_config()?
355            };
356
357            let trial_start = Instant::now();
358
359            match objective_fn(&config) {
360                Ok(objective_value) => {
361                    let trial = Trial {
362                        config: config.clone(),
363                        objective_value,
364                        duration: trial_start.elapsed(),
365                        metadata: HashMap::new(),
366                    };
367
368                    self.update_best_trial(&trial);
369                    self.trials.push(trial);
370                }
371                Err(e) => {
372                    log::warn!("Trial failed: {:?}", e);
373                }
374            }
375        }
376
377        self.get_best_config()
378    }
379
380    fn evolutionary_search<F>(
381        &mut self,
382        population_size: usize,
383        n_generations: usize,
384        objective_fn: F,
385    ) -> OptimizerResult<HyperparameterConfig>
386    where
387        F: Fn(&HyperparameterConfig) -> OptimizerResult<f32>,
388    {
389        // Initialize population
390        let mut population = Vec::new();
391        for _ in 0..population_size {
392            let config = self.sample_random_config()?;
393            if let Ok(objective_value) = objective_fn(&config) {
394                population.push((config, objective_value));
395            }
396        }
397
398        // Evolution loop
399        for _generation in 0..n_generations {
400            if self.should_stop() {
401                break;
402            }
403
404            // Selection (tournament selection)
405            let mut new_population = Vec::new();
406            for _ in 0..population_size {
407                let parent1 = self.tournament_selection(&population, 3);
408                let parent2 = self.tournament_selection(&population, 3);
409
410                if let Ok(child) = self.crossover(&parent1.0, &parent2.0) {
411                    let mutated = self.mutate(&child, 0.1)?;
412
413                    if let Ok(objective_value) = objective_fn(&mutated) {
414                        new_population.push((mutated, objective_value));
415                    }
416                }
417            }
418
419            population = new_population;
420
421            // Update best trial
422            if let Some((best_config, best_value)) = population
423                .iter()
424                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
425            {
426                let trial = Trial {
427                    config: best_config.clone(),
428                    objective_value: *best_value,
429                    duration: Duration::from_secs(0), // Approximate
430                    metadata: HashMap::new(),
431                };
432                self.update_best_trial(&trial);
433            }
434        }
435
436        self.get_best_config()
437    }
438
439    fn sample_random_config(&self) -> OptimizerResult<HyperparameterConfig> {
440        let mut config = HyperparameterConfig::new();
441
442        // Sample continuous parameters
443        for (name, (min, max)) in &self.config.search_space.continuous {
444            let value = if self.config.search_space.log_scale.contains(name) {
445                let log_min = min.ln();
446                let log_max = max.ln();
447                let log_value = log_min + 0.5 * (log_max - log_min);
448                log_value.exp()
449            } else {
450                min + 0.5 * (max - min)
451            };
452            config.set_float(name, value);
453        }
454
455        // Sample discrete parameters
456        for (name, (min, max)) in &self.config.search_space.discrete {
457            let value = (*min + *max) / 2;
458            config.set_int(name, value);
459        }
460
461        // Sample categorical parameters
462        for (name, choices) in &self.config.search_space.categorical {
463            if !choices.is_empty() {
464                let idx = 0;
465                config.set_string(name, &choices[idx]);
466            }
467        }
468
469        Ok(config)
470    }
471
472    fn sample_around_config(
473        &self,
474        base_config: &HyperparameterConfig,
475        noise_scale: f32,
476    ) -> OptimizerResult<HyperparameterConfig> {
477        let mut config = base_config.clone();
478
479        // Add noise to continuous parameters
480        for (name, (min, max)) in &self.config.search_space.continuous {
481            if let Ok(base_value) = base_config.get_float(name) {
482                let range = max - min;
483                let noise = 0.0;
484                let new_value = (base_value + noise).clamp(*min, *max);
485                config.set_float(name, new_value);
486            }
487        }
488
489        Ok(config)
490    }
491
492    fn generate_grid_points(
493        &self,
494        n_points_per_dim: usize,
495    ) -> OptimizerResult<Vec<HyperparameterConfig>> {
496        let mut configs = Vec::new();
497
498        // For simplicity, only handle continuous parameters in grid search
499        if self.config.search_space.continuous.is_empty() {
500            return Ok(vec![self.sample_random_config()?]);
501        }
502
503        let param_names: Vec<_> = self
504            .config
505            .search_space
506            .continuous
507            .keys()
508            .cloned()
509            .collect();
510        let n_params = param_names.len();
511
512        // Generate all combinations
513        let total_points = n_points_per_dim.pow(n_params as u32);
514
515        for i in 0..total_points {
516            let mut config = HyperparameterConfig::new();
517            let mut remaining = i;
518
519            for (_param_idx, param_name) in param_names.iter().enumerate() {
520                let (min, max) = self.config.search_space.continuous[param_name];
521                let grid_idx = remaining % n_points_per_dim;
522                remaining /= n_points_per_dim;
523
524                let value = if n_points_per_dim == 1 {
525                    (min + max) / 2.0
526                } else {
527                    min + (grid_idx as f32) * (max - min) / ((n_points_per_dim - 1) as f32)
528                };
529
530                config.set_float(param_name, value);
531            }
532
533            configs.push(config);
534        }
535
536        Ok(configs)
537    }
538
539    fn tournament_selection<'a>(
540        &self,
541        population: &'a [(HyperparameterConfig, f32)],
542        tournament_size: usize,
543    ) -> &'a (HyperparameterConfig, f32) {
544        let mut best = &population[0];
545
546        for _ in 1..tournament_size {
547            let candidate = &population[0];
548            if candidate.1 < best.1 {
549                // Assuming minimization
550                best = candidate;
551            }
552        }
553
554        best
555    }
556
557    fn crossover(
558        &self,
559        parent1: &HyperparameterConfig,
560        parent2: &HyperparameterConfig,
561    ) -> OptimizerResult<HyperparameterConfig> {
562        let mut child = HyperparameterConfig::new();
563
564        // Blend crossover for continuous parameters
565        for name in self.config.search_space.continuous.keys() {
566            if let (Ok(v1), Ok(v2)) = (parent1.get_float(name), parent2.get_float(name)) {
567                let alpha = 0.5;
568                let value = (1.0 - alpha) * v1 + alpha * v2;
569                child.set_float(name, value);
570            }
571        }
572
573        // Random selection for categorical/discrete
574        for name in self.config.search_space.discrete.keys() {
575            let value = if true {
576                parent1.get_int(name).unwrap_or(0)
577            } else {
578                parent2.get_int(name).unwrap_or(0)
579            };
580            child.set_int(name, value);
581        }
582
583        Ok(child)
584    }
585
586    fn mutate(
587        &self,
588        config: &HyperparameterConfig,
589        mutation_rate: f32,
590    ) -> OptimizerResult<HyperparameterConfig> {
591        let mut mutated = config.clone();
592
593        for (name, (min, max)) in &self.config.search_space.continuous {
594            if 0.1 < mutation_rate {
595                if let Ok(current_value) = config.get_float(name) {
596                    let range = max - min;
597                    let noise = 0.0;
598                    let new_value = (current_value + noise).clamp(*min, *max);
599                    mutated.set_float(name, new_value);
600                }
601            }
602        }
603
604        Ok(mutated)
605    }
606
607    fn should_stop(&self) -> bool {
608        if let (Some(start_time), Some(max_duration)) = (self.start_time, &self.config.max_duration)
609        {
610            if start_time.elapsed() > *max_duration {
611                return true;
612            }
613        }
614
615        // Early stopping based on convergence
616        if let Some(early_stopping) = &self.config.early_stopping {
617            if self.trials.len() >= early_stopping.patience {
618                let recent_trials = &self.trials[self.trials.len() - early_stopping.patience..];
619                let values: Vec<f32> = recent_trials.iter().map(|t| t.objective_value).collect();
620
621                let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
622                let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
623
624                if (max_val - min_val) < early_stopping.min_delta {
625                    return true;
626                }
627            }
628        }
629
630        false
631    }
632
633    fn update_best_trial(&mut self, trial: &Trial) {
634        let is_better = match &self.best_trial {
635            None => true,
636            Some(best) => match &self.config.objective {
637                ObjectiveFunction::MinimizeValidationLoss => {
638                    trial.objective_value < best.objective_value
639                }
640                ObjectiveFunction::MaximizeValidationAccuracy => {
641                    trial.objective_value > best.objective_value
642                }
643                ObjectiveFunction::Custom(_) => trial.objective_value < best.objective_value, // Assume minimization
644            },
645        };
646
647        if is_better {
648            self.best_trial = Some(trial.clone());
649        }
650    }
651
652    fn get_best_config(&self) -> OptimizerResult<HyperparameterConfig> {
653        match &self.best_trial {
654            Some(trial) => Ok(trial.config.clone()),
655            None => Err(OptimizerError::StateError(
656                "No successful trials found".to_string(),
657            )),
658        }
659    }
660
661    /// Get optimization results and statistics
662    pub fn get_results(&self) -> TuningResults {
663        TuningResults {
664            best_config: self.best_trial.as_ref().map(|t| t.config.clone()),
665            best_value: self.best_trial.as_ref().map(|t| t.objective_value),
666            trials: self.trials.clone(),
667            total_duration: self.start_time.map(|t| t.elapsed()),
668        }
669    }
670}
671
672/// Results from hyperparameter tuning
673#[derive(Debug, Clone)]
674pub struct TuningResults {
675    pub best_config: Option<HyperparameterConfig>,
676    pub best_value: Option<f32>,
677    pub trials: Vec<Trial>,
678    pub total_duration: Option<Duration>,
679}
680
681impl TuningResults {
682    /// Get convergence history
683    pub fn convergence_history(&self) -> Vec<f32> {
684        let mut best_so_far = f32::INFINITY;
685        let mut history = Vec::new();
686
687        for trial in &self.trials {
688            if trial.objective_value < best_so_far {
689                best_so_far = trial.objective_value;
690            }
691            history.push(best_so_far);
692        }
693
694        history
695    }
696
697    /// Get parameter importance analysis
698    pub fn parameter_importance(&self) -> HashMap<String, f32> {
699        let mut param_values: HashMap<String, Vec<f32>> = HashMap::new();
700
701        if self.trials.len() < 2 {
702            return HashMap::new();
703        }
704
705        // Simple variance-based importance
706        for trial in &self.trials {
707            for (param_name, param_value) in &trial.config.parameters {
708                if let HyperparameterValue::Float(value) = param_value {
709                    param_values
710                        .entry(param_name.clone())
711                        .or_insert(Vec::new())
712                        .push(*value);
713                }
714            }
715        }
716
717        // Calculate normalized variance for each parameter
718        let mut variance_scores = HashMap::new();
719        for (param_name, values) in &param_values {
720            if values.len() > 1 {
721                let mean = values.iter().sum::<f32>() / values.len() as f32;
722                let variance =
723                    values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
724                variance_scores.insert(param_name.clone(), variance);
725            }
726        }
727
728        // Normalize scores
729        let mut importance = HashMap::new();
730        let max_variance = variance_scores.values().fold(0.0f32, |a, &b| a.max(b));
731        if max_variance > 0.0 {
732            for (param_name, variance) in variance_scores {
733                importance.insert(param_name, variance / max_variance);
734            }
735        }
736
737        importance
738    }
739}
740
741/// Utility functions for common optimizer hyperparameter spaces
742pub mod presets {
743    use super::*;
744
745    /// Adam optimizer hyperparameter space
746    pub fn adam_space() -> HyperparameterSpace {
747        HyperparameterSpace::new()
748            .add_continuous("lr", 1e-5, 1e-1)
749            .add_continuous("beta1", 0.8, 0.99)
750            .add_continuous("beta2", 0.9, 0.999)
751            .add_continuous("eps", 1e-10, 1e-6)
752            .add_continuous("weight_decay", 0.0, 1e-2)
753            .set_log_scale(vec!["lr", "eps"])
754    }
755
756    /// SGD optimizer hyperparameter space
757    pub fn sgd_space() -> HyperparameterSpace {
758        HyperparameterSpace::new()
759            .add_continuous("lr", 1e-4, 1.0)
760            .add_continuous("momentum", 0.0, 0.99)
761            .add_continuous("weight_decay", 0.0, 1e-2)
762            .add_categorical("nesterov", vec!["true", "false"])
763            .set_log_scale(vec!["lr"])
764    }
765
766    /// RMSprop optimizer hyperparameter space
767    pub fn rmsprop_space() -> HyperparameterSpace {
768        HyperparameterSpace::new()
769            .add_continuous("lr", 1e-5, 1e-1)
770            .add_continuous("alpha", 0.9, 0.999)
771            .add_continuous("eps", 1e-10, 1e-6)
772            .add_continuous("weight_decay", 0.0, 1e-2)
773            .add_continuous("momentum", 0.0, 0.1)
774            .set_log_scale(vec!["lr", "eps"])
775    }
776}