1use super::{
4 BayesianOptimization, Direction, EarlyStoppingConfig, ParameterValue, PruningConfig,
5 PruningStrategy, RandomSearch, SearchSpace, SearchStrategy, Trial, TrialHistory, TrialResult,
6};
7use crate::TrainingArguments;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::time::{Duration, Instant};
12use trustformers_core::errors::{file_not_found, invalid_format, Result};
13
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub enum OptimizationDirection {
17 Minimize,
19 Maximize,
21}
22
23impl From<OptimizationDirection> for Direction {
24 fn from(dir: OptimizationDirection) -> Self {
25 match dir {
26 OptimizationDirection::Minimize => Direction::Minimize,
27 OptimizationDirection::Maximize => Direction::Maximize,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct TunerConfig {
35 pub study_name: String,
37 pub direction: OptimizationDirection,
39 pub objective_metric: String,
41 pub max_trials: Option<usize>,
43 pub max_duration: Option<Duration>,
45 pub early_stopping: Option<EarlyStoppingConfig>,
47 pub pruning: Option<PruningConfig>,
49 pub output_dir: PathBuf,
51 pub save_checkpoints: bool,
53 pub min_trials_for_pruning: usize,
55 pub seed: Option<u64>,
57}
58
59impl Default for TunerConfig {
60 fn default() -> Self {
61 Self {
62 study_name: "hyperparameter_study".to_string(),
63 direction: OptimizationDirection::Maximize,
64 objective_metric: "eval_accuracy".to_string(),
65 max_trials: Some(100),
66 max_duration: None,
67 early_stopping: None,
68 pruning: None,
69 output_dir: PathBuf::from("./hyperopt_results"),
70 save_checkpoints: true,
71 min_trials_for_pruning: 10,
72 seed: None,
73 }
74 }
75}
76
77impl TunerConfig {
78 pub fn new(study_name: impl Into<String>) -> Self {
80 Self {
81 study_name: study_name.into(),
82 ..Default::default()
83 }
84 }
85
86 pub fn direction(mut self, direction: OptimizationDirection) -> Self {
88 self.direction = direction;
89 self
90 }
91
92 pub fn objective_metric(mut self, metric: impl Into<String>) -> Self {
94 self.objective_metric = metric.into();
95 self
96 }
97
98 pub fn max_trials(mut self, max_trials: usize) -> Self {
100 self.max_trials = Some(max_trials);
101 self
102 }
103
104 pub fn max_duration(mut self, duration: Duration) -> Self {
106 self.max_duration = Some(duration);
107 self
108 }
109
110 pub fn early_stopping(mut self, config: EarlyStoppingConfig) -> Self {
112 self.early_stopping = Some(config);
113 self
114 }
115
116 pub fn pruning(mut self, config: PruningConfig) -> Self {
118 self.pruning = Some(config);
119 self
120 }
121
122 pub fn output_dir(mut self, dir: impl Into<PathBuf>) -> Self {
124 self.output_dir = dir.into();
125 self
126 }
127
128 pub fn seed(mut self, seed: u64) -> Self {
130 self.seed = Some(seed);
131 self
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct StudyStatistics {
138 pub total_trials: usize,
140 pub completed_trials: usize,
142 pub failed_trials: usize,
144 pub pruned_trials: usize,
146 pub best_value: Option<f64>,
148 pub best_trial_number: Option<usize>,
150 pub total_duration: Duration,
152 pub average_trial_duration: Duration,
154 pub success_rate: f64,
156 pub pruning_rate: f64,
158}
159
160pub trait HyperparameterCallback: Send + Sync {
162 fn on_study_start(&mut self, _config: &TunerConfig) {}
164
165 fn on_study_end(&mut self, _config: &TunerConfig, _statistics: &StudyStatistics) {}
167
168 fn on_trial_start(&mut self, _trial: &Trial) {}
170
171 fn on_trial_complete(&mut self, _trial: &Trial) {}
173
174 fn on_trial_pruned(&mut self, _trial: &Trial, _reason: &str) {}
176
177 fn on_new_best(&mut self, _trial: &Trial, _improvement: f64) {}
179}
180
181pub struct LoggingCallback;
183
184impl HyperparameterCallback for LoggingCallback {
185 fn on_study_start(&mut self, config: &TunerConfig) {
186 println!("Starting hyperparameter study: {}", config.study_name);
187 println!("Direction: {:?}", config.direction);
188 println!("Objective metric: {}", config.objective_metric);
189 if let Some(max_trials) = config.max_trials {
190 println!("Max trials: {}", max_trials);
191 }
192 }
193
194 fn on_study_end(&mut self, _config: &TunerConfig, statistics: &StudyStatistics) {
195 println!("\nHyperparameter study completed!");
196 println!("Total trials: {}", statistics.total_trials);
197 println!("Completed trials: {}", statistics.completed_trials);
198 println!("Success rate: {:.2}%", statistics.success_rate);
199 if let Some(best_value) = statistics.best_value {
200 println!("Best value: {:.6}", best_value);
201 }
202 println!("Total duration: {:?}", statistics.total_duration);
203 }
204
205 fn on_trial_start(&mut self, trial: &Trial) {
206 println!("Starting trial {}: {}", trial.number, trial.summary());
207 }
208
209 fn on_trial_complete(&mut self, trial: &Trial) {
210 println!("Completed trial {}: {}", trial.number, trial.summary());
211 }
212
213 fn on_trial_pruned(&mut self, trial: &Trial, reason: &str) {
214 println!(
215 "Pruned trial {} ({}): {}",
216 trial.number,
217 reason,
218 trial.summary()
219 );
220 }
221
222 fn on_new_best(&mut self, trial: &Trial, improvement: f64) {
223 println!(
224 "New best trial {}: improvement={:.6}, {}",
225 trial.number,
226 improvement,
227 trial.summary()
228 );
229 }
230}
231
232pub struct HyperparameterTuner {
234 config: TunerConfig,
236 search_space: SearchSpace,
238 strategy: Box<dyn SearchStrategy>,
240 history: TrialHistory,
242 start_time: Option<Instant>,
244 callbacks: Vec<Box<dyn HyperparameterCallback>>,
246 current_trial_number: usize,
248}
249
250impl HyperparameterTuner {
251 pub fn new(
253 config: TunerConfig,
254 search_space: SearchSpace,
255 strategy: Box<dyn SearchStrategy>,
256 ) -> Self {
257 let direction = config.direction.clone().into();
258
259 Self {
260 config,
261 search_space,
262 strategy,
263 history: TrialHistory::new(direction),
264 start_time: None,
265 callbacks: vec![Box::new(LoggingCallback)],
266 current_trial_number: 0,
267 }
268 }
269
270 pub fn with_random_search(config: TunerConfig, search_space: SearchSpace) -> Self {
272 let max_trials = config.max_trials.unwrap_or(100);
273 let strategy = if let Some(seed) = config.seed {
274 Box::new(RandomSearch::with_seed(max_trials, seed))
275 } else {
276 Box::new(RandomSearch::new(max_trials))
277 };
278
279 Self::new(config, search_space, strategy)
280 }
281
282 pub fn with_bayesian_optimization(config: TunerConfig, search_space: SearchSpace) -> Self {
284 let max_trials = config.max_trials.unwrap_or(100);
285 let strategy = Box::new(BayesianOptimization::new(max_trials));
286
287 Self::new(config, search_space, strategy)
288 }
289
290 pub fn add_callback(mut self, callback: Box<dyn HyperparameterCallback>) -> Self {
292 self.callbacks.push(callback);
293 self
294 }
295
296 pub fn best_trial(&self) -> Option<&Trial> {
298 self.history.best_trial()
299 }
300
301 pub fn best_value(&self) -> Option<f64> {
303 self.history.best_value()
304 }
305
306 pub fn trials(&self) -> &[Trial] {
308 &self.history.trials
309 }
310
311 pub fn statistics(&self) -> StudyStatistics {
313 let trial_stats = self.history.statistics();
314 let total_duration =
315 self.start_time.map(|start| start.elapsed()).unwrap_or(Duration::from_secs(0));
316
317 StudyStatistics {
318 total_trials: trial_stats.total_trials,
319 completed_trials: trial_stats.completed_trials,
320 failed_trials: trial_stats.failed_trials,
321 pruned_trials: trial_stats.pruned_trials,
322 best_value: trial_stats.best_value,
323 best_trial_number: self.best_trial().map(|t| t.number),
324 total_duration,
325 average_trial_duration: trial_stats.average_trial_duration,
326 success_rate: trial_stats.success_rate(),
327 pruning_rate: trial_stats.pruning_rate(),
328 }
329 }
330
331 pub fn optimize<F>(&mut self, mut objective_fn: F) -> Result<super::OptimizationResult>
333 where
334 F: FnMut(HashMap<String, ParameterValue>) -> Result<TrialResult>,
335 {
336 self.start_time = Some(Instant::now());
337
338 std::fs::create_dir_all(&self.config.output_dir)
340 .map_err(|e| file_not_found(e.to_string()))?;
341
342 for callback in &mut self.callbacks {
344 callback.on_study_start(&self.config);
345 }
346
347 let mut last_best_value = None;
348
349 while !self.should_terminate() {
351 if let Some(params) = self.strategy.suggest(&self.search_space, &self.history) {
353 if let Err(e) = self.search_space.validate(¶ms) {
355 eprintln!("Warning: Invalid parameters suggested: {}", e);
356 continue;
357 }
358
359 let mut trial = Trial::new(self.current_trial_number, params);
361 self.current_trial_number += 1;
362
363 for callback in &mut self.callbacks {
365 callback.on_trial_start(&trial);
366 }
367
368 trial.start();
370
371 match objective_fn(trial.params.clone()) {
373 Ok(result) => {
374 if self.should_prune_trial(&trial, &result) {
376 trial.prune("Poor performance");
377 for callback in &mut self.callbacks {
378 callback.on_trial_pruned(&trial, "Poor performance");
379 }
380 } else {
381 trial.complete(result);
383
384 if let Some(objective_value) = trial.objective_value() {
386 let is_new_best = match last_best_value {
387 None => true,
388 Some(prev_best) => match self.config.direction {
389 OptimizationDirection::Maximize => {
390 objective_value > prev_best
391 },
392 OptimizationDirection::Minimize => {
393 objective_value < prev_best
394 },
395 },
396 };
397
398 if is_new_best {
399 let improvement = match last_best_value {
400 None => 0.0,
401 Some(prev) => (objective_value - prev).abs(),
402 };
403 last_best_value = Some(objective_value);
404
405 for callback in &mut self.callbacks {
406 callback.on_new_best(&trial, improvement);
407 }
408 }
409 }
410
411 for callback in &mut self.callbacks {
412 callback.on_trial_complete(&trial);
413 }
414 }
415 },
416 Err(e) => {
417 let result = TrialResult::failure(e.to_string());
419 trial.complete(result);
420
421 for callback in &mut self.callbacks {
422 callback.on_trial_complete(&trial);
423 }
424 },
425 }
426
427 self.strategy.update(&trial);
429
430 self.history.add_trial(trial);
432
433 if self.config.save_checkpoints {
435 if let Err(e) = self.save_checkpoint() {
436 eprintln!("Warning: Failed to save checkpoint: {}", e);
437 }
438 }
439 } else {
440 break;
442 }
443 }
444
445 let statistics = self.statistics();
447
448 for callback in &mut self.callbacks {
450 callback.on_study_end(&self.config, &statistics);
451 }
452
453 self.save_results()?;
455
456 Ok(super::OptimizationResult {
458 best_trial: self.best_trial().unwrap_or(&Trial::new(0, HashMap::new())).clone(),
459 trials: self.history.trials.clone(),
460 completed_trials: statistics.completed_trials,
461 failed_trials: statistics.failed_trials,
462 total_duration: statistics.total_duration,
463 statistics,
464 })
465 }
466
467 fn should_terminate(&self) -> bool {
468 if self.strategy.should_terminate(&self.history) {
470 return true;
471 }
472
473 if let Some(max_trials) = self.config.max_trials {
475 if self.history.trials.len() >= max_trials {
476 return true;
477 }
478 }
479
480 if let Some(max_duration) = self.config.max_duration {
482 if let Some(start_time) = self.start_time {
483 if start_time.elapsed() >= max_duration {
484 return true;
485 }
486 }
487 }
488
489 false
490 }
491
492 fn should_prune_trial(&self, trial: &Trial, result: &TrialResult) -> bool {
493 if let Some(pruning_config) = &self.config.pruning {
494 if self.history.completed_trials().len() < self.config.min_trials_for_pruning {
496 return false;
497 }
498
499 if result.metrics.intermediate_values.is_empty() {
501 return false;
502 }
503
504 match &pruning_config.strategy {
505 PruningStrategy::None => false,
506 PruningStrategy::Median => self.is_below_median(trial, result, pruning_config),
507 PruningStrategy::Percentile(percentile) => {
508 self.is_below_percentile(trial, result, *percentile, pruning_config)
509 },
510 PruningStrategy::SuccessiveHalving => {
511 false },
514 }
515 } else {
516 false
517 }
518 }
519
520 fn is_below_median(
521 &self,
522 _trial: &Trial,
523 result: &TrialResult,
524 config: &PruningConfig,
525 ) -> bool {
526 self.is_below_percentile(_trial, result, 0.5, config)
527 }
528
529 fn is_below_percentile(
530 &self,
531 _trial: &Trial,
532 result: &TrialResult,
533 percentile: f64,
534 config: &PruningConfig,
535 ) -> bool {
536 if let Some((latest_step, latest_value)) = result.metrics.intermediate_values.last() {
537 if *latest_step < config.min_steps {
538 return false;
539 }
540
541 let mut values_at_step = Vec::new();
543 for historical_trial in self.history.completed_trials() {
544 if let Some(trial_result) = &historical_trial.result {
545 if let Some(value) =
546 trial_result.metrics.intermediate_value_at_step(*latest_step)
547 {
548 values_at_step.push(value);
549 }
550 }
551 }
552
553 if values_at_step.is_empty() {
554 return false;
555 }
556
557 values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
559 let percentile_index = (percentile * (values_at_step.len() - 1) as f64) as usize;
560 let percentile_value = values_at_step[percentile_index];
561
562 match self.config.direction {
564 OptimizationDirection::Maximize => *latest_value < percentile_value,
565 OptimizationDirection::Minimize => *latest_value > percentile_value,
566 }
567 } else {
568 false
569 }
570 }
571
572 fn save_checkpoint(&self) -> Result<()> {
573 let checkpoint_path = self.config.output_dir.join("checkpoint.json");
574 let checkpoint_data = serde_json::to_string_pretty(&self.history)
575 .map_err(|e| invalid_format("json", e.to_string()))?;
576 std::fs::write(checkpoint_path, checkpoint_data)
577 .map_err(|e| file_not_found(e.to_string()))?;
578 Ok(())
579 }
580
581 fn save_results(&self) -> Result<()> {
582 let history_path = self.config.output_dir.join("trial_history.json");
584 let history_data = serde_json::to_string_pretty(&self.history)
585 .map_err(|e| invalid_format("json", e.to_string()))?;
586 std::fs::write(history_path, history_data).map_err(|e| file_not_found(e.to_string()))?;
587
588 let stats_path = self.config.output_dir.join("statistics.json");
590 let statistics = self.statistics();
591 let stats_data = serde_json::to_string_pretty(&statistics)
592 .map_err(|e| invalid_format("json", e.to_string()))?;
593 std::fs::write(stats_path, stats_data).map_err(|e| file_not_found(e.to_string()))?;
594
595 if let Some(best_trial) = self.best_trial() {
597 let best_params_path = self.config.output_dir.join("best_parameters.json");
598 let params_data = serde_json::to_string_pretty(&best_trial.params)
599 .map_err(|e| invalid_format("json", e.to_string()))?;
600 std::fs::write(best_params_path, params_data)
601 .map_err(|e| file_not_found(e.to_string()))?;
602 }
603
604 Ok(())
605 }
606
607 pub fn load_checkpoint(&mut self, checkpoint_path: &Path) -> Result<()> {
609 let checkpoint_data =
610 std::fs::read_to_string(checkpoint_path).map_err(|e| file_not_found(e.to_string()))?;
611 self.history = serde_json::from_str(&checkpoint_data)
612 .map_err(|e| invalid_format("json", e.to_string()))?;
613
614 self.current_trial_number = self.history.trials.len();
616
617 Ok(())
618 }
619}
620
621pub fn hyperparams_to_training_args(
623 base_args: &TrainingArguments,
624 hyperparams: &HashMap<String, ParameterValue>,
625) -> TrainingArguments {
626 let mut args = base_args.clone();
627
628 for (name, value) in hyperparams {
630 match name.as_str() {
631 "learning_rate" => {
632 if let Some(lr) = value.as_float() {
633 args.learning_rate = lr as f32;
634 }
635 },
636 "weight_decay" => {
637 if let Some(wd) = value.as_float() {
638 args.weight_decay = wd as f32;
639 }
640 },
641 "per_device_train_batch_size" | "batch_size" => {
642 if let Some(bs) = value.as_int() {
643 args.per_device_train_batch_size = bs as usize;
644 }
645 },
646 "num_train_epochs" => {
647 if let Some(epochs) = value.as_float() {
648 args.num_train_epochs = epochs as f32;
649 }
650 },
651 "warmup_ratio" => {
652 if let Some(ratio) = value.as_float() {
653 args.warmup_ratio = ratio as f32;
654 }
655 },
656 "adam_beta1" => {
657 if let Some(beta1) = value.as_float() {
658 args.adam_beta1 = beta1 as f32;
659 }
660 },
661 "adam_beta2" => {
662 if let Some(beta2) = value.as_float() {
663 args.adam_beta2 = beta2 as f32;
664 }
665 },
666 "max_grad_norm" => {
667 if let Some(norm) = value.as_float() {
668 args.max_grad_norm = norm as f32;
669 }
670 },
671 "gradient_accumulation_steps" => {
672 if let Some(steps) = value.as_int() {
673 args.gradient_accumulation_steps = steps as usize;
674 }
675 },
676 _ => {
677 eprintln!("Warning: Unknown hyperparameter: {}", name);
679 },
680 }
681 }
682
683 args
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::hyperopt::search_space::SearchSpaceBuilder;
690 use std::time::Duration;
691
692 #[test]
693 fn test_tuner_config() {
694 let config = TunerConfig::new("test_study")
695 .direction(OptimizationDirection::Minimize)
696 .objective_metric("loss")
697 .max_trials(50)
698 .max_duration(Duration::from_secs(3600))
699 .seed(42);
700
701 assert_eq!(config.study_name, "test_study");
702 assert_eq!(config.direction, OptimizationDirection::Minimize);
703 assert_eq!(config.objective_metric, "loss");
704 assert_eq!(config.max_trials, Some(50));
705 assert_eq!(config.max_duration, Some(Duration::from_secs(3600)));
706 assert_eq!(config.seed, Some(42));
707 }
708
709 #[test]
710 fn test_hyperparameter_tuner_creation() {
711 let config = TunerConfig::new("test");
712 let search_space = SearchSpaceBuilder::new()
713 .continuous("learning_rate", 1e-5, 1e-1)
714 .discrete("batch_size", 8, 64, 8)
715 .build();
716
717 let tuner = HyperparameterTuner::with_random_search(config, search_space);
718
719 assert_eq!(tuner.config.study_name, "test");
720 assert_eq!(tuner.current_trial_number, 0);
721 assert!(tuner.history.trials.is_empty());
722 }
723
724 #[test]
725 fn test_hyperparams_to_training_args() {
726 let base_args = TrainingArguments::default();
727 let mut hyperparams = HashMap::new();
728 hyperparams.insert("learning_rate".to_string(), ParameterValue::Float(0.001));
729 hyperparams.insert("batch_size".to_string(), ParameterValue::Int(32));
730 hyperparams.insert("num_train_epochs".to_string(), ParameterValue::Float(5.0));
731
732 let updated_args = hyperparams_to_training_args(&base_args, &hyperparams);
733
734 assert_eq!(updated_args.learning_rate, 0.001);
735 assert_eq!(updated_args.per_device_train_batch_size, 32);
736 assert_eq!(updated_args.num_train_epochs, 5.0);
737 }
738
739 #[test]
740 fn test_optimization_direction_conversion() {
741 let max_dir: Direction = OptimizationDirection::Maximize.into();
742 let min_dir: Direction = OptimizationDirection::Minimize.into();
743
744 assert_eq!(max_dir, Direction::Maximize);
745 assert_eq!(min_dir, Direction::Minimize);
746 }
747}