Skip to main content

trustformers_training/hyperopt/
tuner.rs

1//! Main hyperparameter tuner implementation
2
3use super::{
4    BayesianOptimization, Direction, EarlyStoppingConfig, ParameterValue, PruningConfig,
5    PruningStrategy, RandomSearch, SearchSpace, SearchStrategy, Trial, TrialHistory, TrialResult,
6};
7use crate::TrainingArguments;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::{Duration, Instant};
12use trustformers_core::errors::{file_not_found, invalid_format, Result};
13
14/// Direction for optimization
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub enum OptimizationDirection {
17    /// Minimize the objective (e.g., loss)
18    Minimize,
19    /// Maximize the objective (e.g., accuracy)
20    Maximize,
21}
22
23impl From<OptimizationDirection> for Direction {
24    fn from(dir: OptimizationDirection) -> Self {
25        match dir {
26            OptimizationDirection::Minimize => Direction::Minimize,
27            OptimizationDirection::Maximize => Direction::Maximize,
28        }
29    }
30}
31
32/// Configuration for the hyperparameter tuner
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct TunerConfig {
35    /// Name of the study
36    pub study_name: String,
37    /// Direction of optimization
38    pub direction: OptimizationDirection,
39    /// Name of the metric to optimize
40    pub objective_metric: String,
41    /// Maximum number of trials
42    pub max_trials: Option<usize>,
43    /// Maximum time to spend on optimization
44    pub max_duration: Option<Duration>,
45    /// Early stopping configuration
46    pub early_stopping: Option<EarlyStoppingConfig>,
47    /// Pruning configuration
48    pub pruning: Option<PruningConfig>,
49    /// Directory to save study results
50    pub output_dir: PathBuf,
51    /// Whether to save intermediate checkpoints
52    pub save_checkpoints: bool,
53    /// Minimum number of trials before considering pruning
54    pub min_trials_for_pruning: usize,
55    /// Random seed for reproducibility
56    pub seed: Option<u64>,
57}
58
59impl Default for TunerConfig {
60    fn default() -> Self {
61        Self {
62            study_name: "hyperparameter_study".to_string(),
63            direction: OptimizationDirection::Maximize,
64            objective_metric: "eval_accuracy".to_string(),
65            max_trials: Some(100),
66            max_duration: None,
67            early_stopping: None,
68            pruning: None,
69            output_dir: PathBuf::from("./hyperopt_results"),
70            save_checkpoints: true,
71            min_trials_for_pruning: 10,
72            seed: None,
73        }
74    }
75}
76
77impl TunerConfig {
78    /// Create a new tuner configuration
79    pub fn new(study_name: impl Into<String>) -> Self {
80        Self {
81            study_name: study_name.into(),
82            ..Default::default()
83        }
84    }
85
86    /// Set the optimization direction
87    pub fn direction(mut self, direction: OptimizationDirection) -> Self {
88        self.direction = direction;
89        self
90    }
91
92    /// Set the objective metric name
93    pub fn objective_metric(mut self, metric: impl Into<String>) -> Self {
94        self.objective_metric = metric.into();
95        self
96    }
97
98    /// Set the maximum number of trials
99    pub fn max_trials(mut self, max_trials: usize) -> Self {
100        self.max_trials = Some(max_trials);
101        self
102    }
103
104    /// Set the maximum duration
105    pub fn max_duration(mut self, duration: Duration) -> Self {
106        self.max_duration = Some(duration);
107        self
108    }
109
110    /// Set early stopping configuration
111    pub fn early_stopping(mut self, config: EarlyStoppingConfig) -> Self {
112        self.early_stopping = Some(config);
113        self
114    }
115
116    /// Set pruning configuration
117    pub fn pruning(mut self, config: PruningConfig) -> Self {
118        self.pruning = Some(config);
119        self
120    }
121
122    /// Set the output directory
123    pub fn output_dir(mut self, dir: impl Into<PathBuf>) -> Self {
124        self.output_dir = dir.into();
125        self
126    }
127
128    /// Set random seed
129    pub fn seed(mut self, seed: u64) -> Self {
130        self.seed = Some(seed);
131        self
132    }
133}
134
135/// Statistics about an optimization study
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct StudyStatistics {
138    /// Total number of trials
139    pub total_trials: usize,
140    /// Number of completed trials
141    pub completed_trials: usize,
142    /// Number of failed trials
143    pub failed_trials: usize,
144    /// Number of pruned trials
145    pub pruned_trials: usize,
146    /// Best objective value found
147    pub best_value: Option<f64>,
148    /// Best trial
149    pub best_trial_number: Option<usize>,
150    /// Total time spent
151    pub total_duration: Duration,
152    /// Average trial duration
153    pub average_trial_duration: Duration,
154    /// Success rate (percentage)
155    pub success_rate: f64,
156    /// Pruning rate (percentage)
157    pub pruning_rate: f64,
158}
159
160/// Callback for hyperparameter tuning events
161pub trait HyperparameterCallback: Send + Sync {
162    /// Called when a study starts
163    fn on_study_start(&mut self, _config: &TunerConfig) {}
164
165    /// Called when a study ends
166    fn on_study_end(&mut self, _config: &TunerConfig, _statistics: &StudyStatistics) {}
167
168    /// Called when a trial starts
169    fn on_trial_start(&mut self, _trial: &Trial) {}
170
171    /// Called when a trial completes
172    fn on_trial_complete(&mut self, _trial: &Trial) {}
173
174    /// Called when a trial is pruned
175    fn on_trial_pruned(&mut self, _trial: &Trial, _reason: &str) {}
176
177    /// Called when a new best trial is found
178    fn on_new_best(&mut self, _trial: &Trial, _improvement: f64) {}
179}
180
181/// Default callback that logs events
182pub struct LoggingCallback;
183
184impl HyperparameterCallback for LoggingCallback {
185    fn on_study_start(&mut self, config: &TunerConfig) {
186        println!("Starting hyperparameter study: {}", config.study_name);
187        println!("Direction: {:?}", config.direction);
188        println!("Objective metric: {}", config.objective_metric);
189        if let Some(max_trials) = config.max_trials {
190            println!("Max trials: {}", max_trials);
191        }
192    }
193
194    fn on_study_end(&mut self, _config: &TunerConfig, statistics: &StudyStatistics) {
195        println!("\nHyperparameter study completed!");
196        println!("Total trials: {}", statistics.total_trials);
197        println!("Completed trials: {}", statistics.completed_trials);
198        println!("Success rate: {:.2}%", statistics.success_rate);
199        if let Some(best_value) = statistics.best_value {
200            println!("Best value: {:.6}", best_value);
201        }
202        println!("Total duration: {:?}", statistics.total_duration);
203    }
204
205    fn on_trial_start(&mut self, trial: &Trial) {
206        println!("Starting trial {}: {}", trial.number, trial.summary());
207    }
208
209    fn on_trial_complete(&mut self, trial: &Trial) {
210        println!("Completed trial {}: {}", trial.number, trial.summary());
211    }
212
213    fn on_trial_pruned(&mut self, trial: &Trial, reason: &str) {
214        println!(
215            "Pruned trial {} ({}): {}",
216            trial.number,
217            reason,
218            trial.summary()
219        );
220    }
221
222    fn on_new_best(&mut self, trial: &Trial, improvement: f64) {
223        println!(
224            "New best trial {}: improvement={:.6}, {}",
225            trial.number,
226            improvement,
227            trial.summary()
228        );
229    }
230}
231
232/// Main hyperparameter tuner
233pub struct HyperparameterTuner {
234    /// Configuration
235    config: TunerConfig,
236    /// Search space
237    search_space: SearchSpace,
238    /// Search strategy
239    strategy: Box<dyn SearchStrategy>,
240    /// Trial history
241    history: TrialHistory,
242    /// Start time of the study
243    start_time: Option<Instant>,
244    /// Callbacks
245    callbacks: Vec<Box<dyn HyperparameterCallback>>,
246    /// Current trial number
247    current_trial_number: usize,
248}
249
250impl HyperparameterTuner {
251    /// Create a new hyperparameter tuner
252    pub fn new(
253        config: TunerConfig,
254        search_space: SearchSpace,
255        strategy: Box<dyn SearchStrategy>,
256    ) -> Self {
257        let direction = config.direction.clone().into();
258
259        Self {
260            config,
261            search_space,
262            strategy,
263            history: TrialHistory::new(direction),
264            start_time: None,
265            callbacks: vec![Box::new(LoggingCallback)],
266            current_trial_number: 0,
267        }
268    }
269
270    /// Create a tuner with random search strategy
271    pub fn with_random_search(config: TunerConfig, search_space: SearchSpace) -> Self {
272        let max_trials = config.max_trials.unwrap_or(100);
273        let strategy = if let Some(seed) = config.seed {
274            Box::new(RandomSearch::with_seed(max_trials, seed))
275        } else {
276            Box::new(RandomSearch::new(max_trials))
277        };
278
279        Self::new(config, search_space, strategy)
280    }
281
282    /// Create a tuner with Bayesian optimization strategy
283    pub fn with_bayesian_optimization(config: TunerConfig, search_space: SearchSpace) -> Self {
284        let max_trials = config.max_trials.unwrap_or(100);
285        let strategy = Box::new(BayesianOptimization::new(max_trials));
286
287        Self::new(config, search_space, strategy)
288    }
289
290    /// Add a callback
291    pub fn add_callback(mut self, callback: Box<dyn HyperparameterCallback>) -> Self {
292        self.callbacks.push(callback);
293        self
294    }
295
296    /// Get the current best trial
297    pub fn best_trial(&self) -> Option<&Trial> {
298        self.history.best_trial()
299    }
300
301    /// Get the best value found so far
302    pub fn best_value(&self) -> Option<f64> {
303        self.history.best_value()
304    }
305
306    /// Get all trials
307    pub fn trials(&self) -> &[Trial] {
308        &self.history.trials
309    }
310
311    /// Get study statistics
312    pub fn statistics(&self) -> StudyStatistics {
313        let trial_stats = self.history.statistics();
314        let total_duration =
315            self.start_time.map(|start| start.elapsed()).unwrap_or(Duration::from_secs(0));
316
317        StudyStatistics {
318            total_trials: trial_stats.total_trials,
319            completed_trials: trial_stats.completed_trials,
320            failed_trials: trial_stats.failed_trials,
321            pruned_trials: trial_stats.pruned_trials,
322            best_value: trial_stats.best_value,
323            best_trial_number: self.best_trial().map(|t| t.number),
324            total_duration,
325            average_trial_duration: trial_stats.average_trial_duration,
326            success_rate: trial_stats.success_rate(),
327            pruning_rate: trial_stats.pruning_rate(),
328        }
329    }
330
331    /// Run the hyperparameter optimization study
332    pub fn optimize<F>(&mut self, mut objective_fn: F) -> Result<super::OptimizationResult>
333    where
334        F: FnMut(HashMap<String, ParameterValue>) -> Result<TrialResult>,
335    {
336        self.start_time = Some(Instant::now());
337
338        // Create output directory
339        std::fs::create_dir_all(&self.config.output_dir)
340            .map_err(|e| file_not_found(e.to_string()))?;
341
342        // Notify callbacks
343        for callback in &mut self.callbacks {
344            callback.on_study_start(&self.config);
345        }
346
347        let mut last_best_value = None;
348
349        // Main optimization loop
350        while !self.should_terminate() {
351            // Get next suggestion
352            if let Some(params) = self.strategy.suggest(&self.search_space, &self.history) {
353                // Validate parameters
354                if let Err(e) = self.search_space.validate(&params) {
355                    eprintln!("Warning: Invalid parameters suggested: {}", e);
356                    continue;
357                }
358
359                // Create new trial
360                let mut trial = Trial::new(self.current_trial_number, params);
361                self.current_trial_number += 1;
362
363                // Notify callbacks
364                for callback in &mut self.callbacks {
365                    callback.on_trial_start(&trial);
366                }
367
368                // Start the trial
369                trial.start();
370
371                // Run the objective function
372                match objective_fn(trial.params.clone()) {
373                    Ok(result) => {
374                        // Check if we should prune this trial
375                        if self.should_prune_trial(&trial, &result) {
376                            trial.prune("Poor performance");
377                            for callback in &mut self.callbacks {
378                                callback.on_trial_pruned(&trial, "Poor performance");
379                            }
380                        } else {
381                            // Complete the trial
382                            trial.complete(result);
383
384                            // Check for new best
385                            if let Some(objective_value) = trial.objective_value() {
386                                let is_new_best = match last_best_value {
387                                    None => true,
388                                    Some(prev_best) => match self.config.direction {
389                                        OptimizationDirection::Maximize => {
390                                            objective_value > prev_best
391                                        },
392                                        OptimizationDirection::Minimize => {
393                                            objective_value < prev_best
394                                        },
395                                    },
396                                };
397
398                                if is_new_best {
399                                    let improvement = match last_best_value {
400                                        None => 0.0,
401                                        Some(prev) => (objective_value - prev).abs(),
402                                    };
403                                    last_best_value = Some(objective_value);
404
405                                    for callback in &mut self.callbacks {
406                                        callback.on_new_best(&trial, improvement);
407                                    }
408                                }
409                            }
410
411                            for callback in &mut self.callbacks {
412                                callback.on_trial_complete(&trial);
413                            }
414                        }
415                    },
416                    Err(e) => {
417                        // Trial failed
418                        let result = TrialResult::failure(e.to_string());
419                        trial.complete(result);
420
421                        for callback in &mut self.callbacks {
422                            callback.on_trial_complete(&trial);
423                        }
424                    },
425                }
426
427                // Update strategy with completed trial
428                self.strategy.update(&trial);
429
430                // Add trial to history
431                self.history.add_trial(trial);
432
433                // Save checkpoint if enabled
434                if self.config.save_checkpoints {
435                    if let Err(e) = self.save_checkpoint() {
436                        eprintln!("Warning: Failed to save checkpoint: {}", e);
437                    }
438                }
439            } else {
440                // No more suggestions from strategy
441                break;
442            }
443        }
444
445        // Get final statistics
446        let statistics = self.statistics();
447
448        // Notify callbacks
449        for callback in &mut self.callbacks {
450            callback.on_study_end(&self.config, &statistics);
451        }
452
453        // Save final results
454        self.save_results()?;
455
456        // Create optimization result
457        Ok(super::OptimizationResult {
458            best_trial: self.best_trial().unwrap_or(&Trial::new(0, HashMap::new())).clone(),
459            trials: self.history.trials.clone(),
460            completed_trials: statistics.completed_trials,
461            failed_trials: statistics.failed_trials,
462            total_duration: statistics.total_duration,
463            statistics,
464        })
465    }
466
467    fn should_terminate(&self) -> bool {
468        // Check if strategy wants to terminate
469        if self.strategy.should_terminate(&self.history) {
470            return true;
471        }
472
473        // Check max trials
474        if let Some(max_trials) = self.config.max_trials {
475            if self.history.trials.len() >= max_trials {
476                return true;
477            }
478        }
479
480        // Check max duration
481        if let Some(max_duration) = self.config.max_duration {
482            if let Some(start_time) = self.start_time {
483                if start_time.elapsed() >= max_duration {
484                    return true;
485                }
486            }
487        }
488
489        false
490    }
491
492    fn should_prune_trial(&self, trial: &Trial, result: &TrialResult) -> bool {
493        if let Some(pruning_config) = &self.config.pruning {
494            // Only prune if we have enough trials for comparison
495            if self.history.completed_trials().len() < self.config.min_trials_for_pruning {
496                return false;
497            }
498
499            // Check if we have intermediate values to evaluate
500            if result.metrics.intermediate_values.is_empty() {
501                return false;
502            }
503
504            match &pruning_config.strategy {
505                PruningStrategy::None => false,
506                PruningStrategy::Median => self.is_below_median(trial, result, pruning_config),
507                PruningStrategy::Percentile(percentile) => {
508                    self.is_below_percentile(trial, result, *percentile, pruning_config)
509                },
510                PruningStrategy::SuccessiveHalving => {
511                    // Implement successive halving pruning logic
512                    false // Simplified for now
513                },
514            }
515        } else {
516            false
517        }
518    }
519
520    fn is_below_median(
521        &self,
522        _trial: &Trial,
523        result: &TrialResult,
524        config: &PruningConfig,
525    ) -> bool {
526        self.is_below_percentile(_trial, result, 0.5, config)
527    }
528
529    fn is_below_percentile(
530        &self,
531        _trial: &Trial,
532        result: &TrialResult,
533        percentile: f64,
534        config: &PruningConfig,
535    ) -> bool {
536        if let Some((latest_step, latest_value)) = result.metrics.intermediate_values.last() {
537            if *latest_step < config.min_steps {
538                return false;
539            }
540
541            // Get intermediate values from other trials at the same step
542            let mut values_at_step = Vec::new();
543            for historical_trial in self.history.completed_trials() {
544                if let Some(trial_result) = &historical_trial.result {
545                    if let Some(value) =
546                        trial_result.metrics.intermediate_value_at_step(*latest_step)
547                    {
548                        values_at_step.push(value);
549                    }
550                }
551            }
552
553            if values_at_step.is_empty() {
554                return false;
555            }
556
557            // Sort values and find percentile
558            values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
559            let percentile_index = (percentile * (values_at_step.len() - 1) as f64) as usize;
560            let percentile_value = values_at_step[percentile_index];
561
562            // Prune if current value is significantly below percentile
563            match self.config.direction {
564                OptimizationDirection::Maximize => *latest_value < percentile_value,
565                OptimizationDirection::Minimize => *latest_value > percentile_value,
566            }
567        } else {
568            false
569        }
570    }
571
572    fn save_checkpoint(&self) -> Result<()> {
573        let checkpoint_path = self.config.output_dir.join("checkpoint.json");
574        let checkpoint_data = serde_json::to_string_pretty(&self.history)
575            .map_err(|e| invalid_format("json", e.to_string()))?;
576        std::fs::write(checkpoint_path, checkpoint_data)
577            .map_err(|e| file_not_found(e.to_string()))?;
578        Ok(())
579    }
580
581    fn save_results(&self) -> Result<()> {
582        // Save trial history
583        let history_path = self.config.output_dir.join("trial_history.json");
584        let history_data = serde_json::to_string_pretty(&self.history)
585            .map_err(|e| invalid_format("json", e.to_string()))?;
586        std::fs::write(history_path, history_data).map_err(|e| file_not_found(e.to_string()))?;
587
588        // Save statistics
589        let stats_path = self.config.output_dir.join("statistics.json");
590        let statistics = self.statistics();
591        let stats_data = serde_json::to_string_pretty(&statistics)
592            .map_err(|e| invalid_format("json", e.to_string()))?;
593        std::fs::write(stats_path, stats_data).map_err(|e| file_not_found(e.to_string()))?;
594
595        // Save best parameters
596        if let Some(best_trial) = self.best_trial() {
597            let best_params_path = self.config.output_dir.join("best_parameters.json");
598            let params_data = serde_json::to_string_pretty(&best_trial.params)
599                .map_err(|e| invalid_format("json", e.to_string()))?;
600            std::fs::write(best_params_path, params_data)
601                .map_err(|e| file_not_found(e.to_string()))?;
602        }
603
604        Ok(())
605    }
606
607    /// Load a previous study from checkpoint
608    pub fn load_checkpoint(&mut self, checkpoint_path: &Path) -> Result<()> {
609        let checkpoint_data =
610            std::fs::read_to_string(checkpoint_path).map_err(|e| file_not_found(e.to_string()))?;
611        self.history = serde_json::from_str(&checkpoint_data)
612            .map_err(|e| invalid_format("json", e.to_string()))?;
613
614        // Update trial counter
615        self.current_trial_number = self.history.trials.len();
616
617        Ok(())
618    }
619}
620
621/// Helper function to create training arguments from hyperparameters
622pub fn hyperparams_to_training_args(
623    base_args: &TrainingArguments,
624    hyperparams: &HashMap<String, ParameterValue>,
625) -> TrainingArguments {
626    let mut args = base_args.clone();
627
628    // Update training arguments based on hyperparameters
629    for (name, value) in hyperparams {
630        match name.as_str() {
631            "learning_rate" => {
632                if let Some(lr) = value.as_float() {
633                    args.learning_rate = lr as f32;
634                }
635            },
636            "weight_decay" => {
637                if let Some(wd) = value.as_float() {
638                    args.weight_decay = wd as f32;
639                }
640            },
641            "per_device_train_batch_size" | "batch_size" => {
642                if let Some(bs) = value.as_int() {
643                    args.per_device_train_batch_size = bs as usize;
644                }
645            },
646            "num_train_epochs" => {
647                if let Some(epochs) = value.as_float() {
648                    args.num_train_epochs = epochs as f32;
649                }
650            },
651            "warmup_ratio" => {
652                if let Some(ratio) = value.as_float() {
653                    args.warmup_ratio = ratio as f32;
654                }
655            },
656            "adam_beta1" => {
657                if let Some(beta1) = value.as_float() {
658                    args.adam_beta1 = beta1 as f32;
659                }
660            },
661            "adam_beta2" => {
662                if let Some(beta2) = value.as_float() {
663                    args.adam_beta2 = beta2 as f32;
664                }
665            },
666            "max_grad_norm" => {
667                if let Some(norm) = value.as_float() {
668                    args.max_grad_norm = norm as f32;
669                }
670            },
671            "gradient_accumulation_steps" => {
672                if let Some(steps) = value.as_int() {
673                    args.gradient_accumulation_steps = steps as usize;
674                }
675            },
676            _ => {
677                // Unknown hyperparameter, ignore or log warning
678                eprintln!("Warning: Unknown hyperparameter: {}", name);
679            },
680        }
681    }
682
683    args
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use crate::hyperopt::search_space::SearchSpaceBuilder;
690    use std::time::Duration;
691
692    #[test]
693    fn test_tuner_config() {
694        let config = TunerConfig::new("test_study")
695            .direction(OptimizationDirection::Minimize)
696            .objective_metric("loss")
697            .max_trials(50)
698            .max_duration(Duration::from_secs(3600))
699            .seed(42);
700
701        assert_eq!(config.study_name, "test_study");
702        assert_eq!(config.direction, OptimizationDirection::Minimize);
703        assert_eq!(config.objective_metric, "loss");
704        assert_eq!(config.max_trials, Some(50));
705        assert_eq!(config.max_duration, Some(Duration::from_secs(3600)));
706        assert_eq!(config.seed, Some(42));
707    }
708
709    #[test]
710    fn test_hyperparameter_tuner_creation() {
711        let config = TunerConfig::new("test");
712        let search_space = SearchSpaceBuilder::new()
713            .continuous("learning_rate", 1e-5, 1e-1)
714            .discrete("batch_size", 8, 64, 8)
715            .build();
716
717        let tuner = HyperparameterTuner::with_random_search(config, search_space);
718
719        assert_eq!(tuner.config.study_name, "test");
720        assert_eq!(tuner.current_trial_number, 0);
721        assert!(tuner.history.trials.is_empty());
722    }
723
724    #[test]
725    fn test_hyperparams_to_training_args() {
726        let base_args = TrainingArguments::default();
727        let mut hyperparams = HashMap::new();
728        hyperparams.insert("learning_rate".to_string(), ParameterValue::Float(0.001));
729        hyperparams.insert("batch_size".to_string(), ParameterValue::Int(32));
730        hyperparams.insert("num_train_epochs".to_string(), ParameterValue::Float(5.0));
731
732        let updated_args = hyperparams_to_training_args(&base_args, &hyperparams);
733
734        assert_eq!(updated_args.learning_rate, 0.001);
735        assert_eq!(updated_args.per_device_train_batch_size, 32);
736        assert_eq!(updated_args.num_train_epochs, 5.0);
737    }
738
739    #[test]
740    fn test_optimization_direction_conversion() {
741        let max_dir: Direction = OptimizationDirection::Maximize.into();
742        let min_dir: Direction = OptimizationDirection::Minimize.into();
743
744        assert_eq!(max_dir, Direction::Maximize);
745        assert_eq!(min_dir, Direction::Minimize);
746    }
747}