1use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::types::Float;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
17pub enum FidelityLevel {
18 Low {
20 sample_fraction: Float,
21
22 epochs_fraction: Float,
23
24 cv_folds: usize,
25 },
26 Medium {
28 sample_fraction: Float,
29
30 epochs_fraction: Float,
31 cv_folds: usize,
32 },
33 High {
35 sample_fraction: Float,
36 epochs_fraction: Float,
37 cv_folds: usize,
38 },
39 Custom {
41 parameters: HashMap<String, Float>,
42 relative_cost: Float,
43 accuracy_estimate: Float,
44 },
45}
46
47#[derive(Debug, Clone)]
49pub enum MultiFidelityStrategy {
50 SuccessiveHalving {
52 eta: Float,
53
54 min_fidelity: FidelityLevel,
55
56 max_fidelity: FidelityLevel,
57 },
58 BayesianOptimization {
60 acquisition_function: AcquisitionFunction,
61
62 fidelity_selection: FidelitySelectionMethod,
63 correlation_model: CorrelationModel,
64 },
65 Hyperband {
67 max_budget: Float,
68 eta: Float,
69 fidelities: Vec<FidelityLevel>,
70 },
71 BOHB {
73 min_budget: Float,
74 max_budget: Float,
75 eta: Float,
76 bandwidth_factor: Float,
77 },
78 Fabolas {
80 min_dataset_fraction: Float,
81 max_dataset_fraction: Float,
82 cost_model: CostModel,
83 },
84 MultiTaskGP {
86 task_similarity: Float,
87 shared_hyperparameters: Vec<String>,
88 },
89}
90
91#[derive(Debug, Clone)]
93pub enum AcquisitionFunction {
94 ExpectedImprovement,
96 UpperConfidenceBound { beta: Float },
98 ProbabilityOfImprovement,
100 KnowledgeGradient,
102 EntropySearch,
104 MultiFidelityEI { fidelity_weight: Float },
106}
107
108#[derive(Debug, Clone)]
110pub enum FidelitySelectionMethod {
111 LowestFirst,
113 UncertaintyBased { threshold: Float },
115 CostAware { budget_fraction: Float },
117 PerformanceBased { improvement_threshold: Float },
119 InformationTheoretic,
121}
122
123#[derive(Debug, Clone)]
125pub enum CorrelationModel {
126 Linear { correlation_strength: Float },
128 Exponential { decay_rate: Float },
130 GaussianProcess { kernel_type: String },
132 RankCorrelation,
134}
135
136#[derive(Debug, Clone)]
138pub enum CostModel {
139 Polynomial {
141 degree: usize,
142
143 coefficients: Vec<Float>,
144 },
145 Exponential { base: Float, scale: Float },
147 Linear { slope: Float, intercept: Float },
149 Custom { cost_function: String },
151}
152
153#[derive(Debug, Clone)]
155pub struct MultiFidelityConfig {
156 pub strategy: MultiFidelityStrategy,
157 pub max_evaluations: usize,
158 pub max_budget: Float,
159 pub early_stopping_patience: usize,
160 pub fidelity_progression: FidelityProgression,
161 pub random_state: Option<u64>,
162 pub parallel_evaluations: usize,
163}
164
165#[derive(Debug, Clone)]
167pub enum FidelityProgression {
168 Linear,
170 Exponential { growth_rate: Float },
172 Adaptive { adaptation_rate: Float },
174 Conservative,
176 Aggressive,
178}
179
180#[derive(Debug, Clone)]
182pub struct FidelityEvaluation {
183 pub hyperparameters: HashMap<String, Float>,
184 pub fidelity: FidelityLevel,
185 pub score: Float,
186 pub cost: Float,
187 pub evaluation_time: Float,
188 pub uncertainty: Option<Float>,
189 pub additional_metrics: HashMap<String, Float>,
190}
191
192#[derive(Debug, Clone)]
194pub struct MultiFidelityResult {
195 pub best_hyperparameters: HashMap<String, Float>,
196 pub best_score: Float,
197 pub best_fidelity: FidelityLevel,
198 pub optimization_history: Vec<FidelityEvaluation>,
199 pub total_cost: Float,
200 pub total_time: Float,
201 pub convergence_curve: Vec<Float>,
202 pub fidelity_usage: HashMap<String, usize>,
203 pub cost_efficiency: Float,
204}
205
206#[derive(Debug)]
208pub struct MultiFidelityOptimizer {
209 config: MultiFidelityConfig,
210 gaussian_process: MultiFidelityGP,
211 evaluation_history: Vec<FidelityEvaluation>,
212 current_best: Option<FidelityEvaluation>,
213 rng: StdRng,
214}
215
216#[derive(Debug, Clone)]
218pub struct MultiFidelityGP {
219 observations: Vec<(Array1<Float>, Float, Float)>, hyperparameters: GPHyperparameters,
221 trained: bool,
222}
223
224#[derive(Debug, Clone)]
226pub struct GPHyperparameters {
227 pub length_scales: Array1<Float>,
228 pub signal_variance: Float,
229 pub noise_variance: Float,
230 pub fidelity_correlation: Float,
231}
232
233impl Default for MultiFidelityConfig {
234 fn default() -> Self {
235 Self {
236 strategy: MultiFidelityStrategy::BayesianOptimization {
237 acquisition_function: AcquisitionFunction::ExpectedImprovement,
238 fidelity_selection: FidelitySelectionMethod::UncertaintyBased { threshold: 0.1 },
239 correlation_model: CorrelationModel::Linear {
240 correlation_strength: 0.8,
241 },
242 },
243 max_evaluations: 100,
244 max_budget: 1000.0,
245 early_stopping_patience: 10,
246 fidelity_progression: FidelityProgression::Adaptive {
247 adaptation_rate: 0.1,
248 },
249 random_state: None,
250 parallel_evaluations: 1,
251 }
252 }
253}
254
255impl MultiFidelityOptimizer {
256 pub fn new(config: MultiFidelityConfig) -> Self {
258 let rng = match config.random_state {
259 Some(seed) => StdRng::seed_from_u64(seed),
260 None => {
261 use scirs2_core::random::thread_rng;
262 StdRng::from_rng(&mut thread_rng())
263 }
264 };
265
266 let gaussian_process = MultiFidelityGP::new();
267
268 Self {
269 config,
270 gaussian_process,
271 evaluation_history: Vec::new(),
272 current_best: None,
273 rng,
274 }
275 }
276
277 pub fn optimize<F>(
279 &mut self,
280 evaluation_fn: F,
281 parameter_bounds: &[(Float, Float)],
282 ) -> Result<MultiFidelityResult, Box<dyn std::error::Error>>
283 where
284 F: Fn(
285 &HashMap<String, Float>,
286 &FidelityLevel,
287 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
288 {
289 let start_time = std::time::Instant::now();
290 let mut total_cost = 0.0;
291 let mut convergence_curve = Vec::new();
292 let mut fidelity_usage = HashMap::new();
293
294 match &self.config.strategy {
295 MultiFidelityStrategy::SuccessiveHalving { .. } => {
296 self.successive_halving_optimize(
297 &evaluation_fn,
298 parameter_bounds,
299 &mut total_cost,
300 &mut convergence_curve,
301 &mut fidelity_usage,
302 )?;
303 }
304 MultiFidelityStrategy::BayesianOptimization { .. } => {
305 self.bayesian_optimize(
306 &evaluation_fn,
307 parameter_bounds,
308 &mut total_cost,
309 &mut convergence_curve,
310 &mut fidelity_usage,
311 )?;
312 }
313 MultiFidelityStrategy::Hyperband { .. } => {
314 self.hyperband_optimize(
315 &evaluation_fn,
316 parameter_bounds,
317 &mut total_cost,
318 &mut convergence_curve,
319 &mut fidelity_usage,
320 )?;
321 }
322 MultiFidelityStrategy::BOHB { .. } => {
323 self.bohb_optimize(
324 &evaluation_fn,
325 parameter_bounds,
326 &mut total_cost,
327 &mut convergence_curve,
328 &mut fidelity_usage,
329 )?;
330 }
331 MultiFidelityStrategy::Fabolas { .. } => {
332 self.fabolas_optimize(
333 &evaluation_fn,
334 parameter_bounds,
335 &mut total_cost,
336 &mut convergence_curve,
337 &mut fidelity_usage,
338 )?;
339 }
340 MultiFidelityStrategy::MultiTaskGP { .. } => {
341 self.multi_task_gp_optimize(
342 &evaluation_fn,
343 parameter_bounds,
344 &mut total_cost,
345 &mut convergence_curve,
346 &mut fidelity_usage,
347 )?;
348 }
349 }
350
351 let total_time = start_time.elapsed().as_secs_f64() as Float;
352 let cost_efficiency = if total_cost > 0.0 {
353 self.current_best.as_ref().map_or(0.0, |best| best.score) / total_cost
354 } else {
355 0.0
356 };
357
358 Ok(MultiFidelityResult {
359 best_hyperparameters: self
360 .current_best
361 .as_ref()
362 .map(|best| best.hyperparameters.clone())
363 .unwrap_or_default(),
364 best_score: self.current_best.as_ref().map_or(0.0, |best| best.score),
365 best_fidelity: self
366 .current_best
367 .as_ref()
368 .map(|best| best.fidelity.clone())
369 .unwrap_or(self.get_default_fidelity()),
370 optimization_history: self.evaluation_history.clone(),
371 total_cost,
372 total_time,
373 convergence_curve,
374 fidelity_usage,
375 cost_efficiency,
376 })
377 }
378
379 fn successive_halving_optimize<F>(
381 &mut self,
382 evaluation_fn: &F,
383 parameter_bounds: &[(Float, Float)],
384 total_cost: &mut Float,
385 convergence_curve: &mut Vec<Float>,
386 fidelity_usage: &mut HashMap<String, usize>,
387 ) -> Result<(), Box<dyn std::error::Error>>
388 where
389 F: Fn(
390 &HashMap<String, Float>,
391 &FidelityLevel,
392 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
393 {
394 let (eta, min_fidelity, max_fidelity) = match &self.config.strategy {
395 MultiFidelityStrategy::SuccessiveHalving {
396 eta,
397 min_fidelity,
398 max_fidelity,
399 } => (*eta, min_fidelity.clone(), max_fidelity.clone()),
400 _ => unreachable!(),
401 };
402
403 let mut configurations = self.generate_initial_configurations(parameter_bounds, 50)?;
404 let mut current_fidelity = min_fidelity;
405
406 while configurations.len() > 1 && !self.should_stop() {
407 let mut evaluations = Vec::new();
408
409 for config in &configurations {
411 let evaluation = evaluation_fn(config, ¤t_fidelity)?;
412 *total_cost += evaluation.cost;
413 *fidelity_usage
414 .entry(self.fidelity_to_string(¤t_fidelity))
415 .or_insert(0) += 1;
416
417 self.evaluation_history.push(evaluation.clone());
418 evaluations.push(evaluation.clone());
419
420 if self.update_best(&evaluation) {
421 convergence_curve.push(self.current_best.as_ref().unwrap().score);
422 } else if let Some(best) = &self.current_best {
423 convergence_curve.push(best.score);
424 }
425 }
426
427 evaluations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
429 let keep_count = (configurations.len() as Float / eta).max(1.0) as usize;
430
431 configurations = evaluations
432 .iter()
433 .take(keep_count)
434 .map(|eval| eval.hyperparameters.clone())
435 .collect();
436
437 current_fidelity = self.increase_fidelity(¤t_fidelity, &max_fidelity);
439 }
440
441 Ok(())
442 }
443
444 fn bayesian_optimize<F>(
446 &mut self,
447 evaluation_fn: &F,
448 parameter_bounds: &[(Float, Float)],
449 total_cost: &mut Float,
450 convergence_curve: &mut Vec<Float>,
451 fidelity_usage: &mut HashMap<String, usize>,
452 ) -> Result<(), Box<dyn std::error::Error>>
453 where
454 F: Fn(
455 &HashMap<String, Float>,
456 &FidelityLevel,
457 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
458 {
459 let (acquisition_function, fidelity_selection, _correlation_model) =
460 match &self.config.strategy {
461 MultiFidelityStrategy::BayesianOptimization {
462 acquisition_function,
463 fidelity_selection,
464 correlation_model,
465 } => (
466 acquisition_function.clone(),
467 fidelity_selection.clone(),
468 correlation_model.clone(),
469 ),
470 _ => unreachable!(),
471 };
472
473 let init_evaluations = 5;
475 for _ in 0..init_evaluations {
476 let config = self.sample_random_configuration(parameter_bounds)?;
477 let fidelity = self.select_fidelity(&fidelity_selection, None)?;
478
479 let evaluation = evaluation_fn(&config, &fidelity)?;
480 *total_cost += evaluation.cost;
481 *fidelity_usage
482 .entry(self.fidelity_to_string(&fidelity))
483 .or_insert(0) += 1;
484
485 self.evaluation_history.push(evaluation.clone());
486 if self.update_best(&evaluation) {
487 convergence_curve.push(self.current_best.as_ref().unwrap().score);
488 } else if let Some(best) = &self.current_best {
489 convergence_curve.push(best.score);
490 }
491 }
492
493 self.gaussian_process.update(&self.evaluation_history)?;
495
496 while self.evaluation_history.len() < self.config.max_evaluations && !self.should_stop() {
498 let next_config = self.optimize_acquisition(&acquisition_function, parameter_bounds)?;
500 let next_fidelity = self.select_fidelity(&fidelity_selection, Some(&next_config))?;
501
502 let evaluation = evaluation_fn(&next_config, &next_fidelity)?;
503 *total_cost += evaluation.cost;
504 *fidelity_usage
505 .entry(self.fidelity_to_string(&next_fidelity))
506 .or_insert(0) += 1;
507
508 self.evaluation_history.push(evaluation.clone());
509 if self.update_best(&evaluation) {
510 convergence_curve.push(self.current_best.as_ref().unwrap().score);
511 } else if let Some(best) = &self.current_best {
512 convergence_curve.push(best.score);
513 }
514
515 if self.evaluation_history.len() % 5 == 0 {
517 self.gaussian_process.update(&self.evaluation_history)?;
518 }
519 }
520
521 Ok(())
522 }
523
524 fn hyperband_optimize<F>(
526 &mut self,
527 evaluation_fn: &F,
528 parameter_bounds: &[(Float, Float)],
529 total_cost: &mut Float,
530 convergence_curve: &mut Vec<Float>,
531 fidelity_usage: &mut HashMap<String, usize>,
532 ) -> Result<(), Box<dyn std::error::Error>>
533 where
534 F: Fn(
535 &HashMap<String, Float>,
536 &FidelityLevel,
537 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
538 {
539 let (max_budget, eta, fidelities) = match &self.config.strategy {
540 MultiFidelityStrategy::Hyperband {
541 max_budget,
542 eta,
543 fidelities,
544 } => (*max_budget, *eta, fidelities.clone()),
545 _ => unreachable!(),
546 };
547
548 let log_eta = eta.ln();
549 let s_max = (max_budget.ln() / log_eta).floor() as usize;
550
551 for s in 0..=s_max {
552 let n = ((s_max + 1) as Float * eta.powi(s as i32) / (s + 1) as Float).ceil() as usize;
553 let r = max_budget * eta.powi(-(s as i32));
554
555 let mut configurations = self.generate_initial_configurations(parameter_bounds, n)?;
556 let current_budget = r;
557
558 for i in 0..=s {
559 let n_i = (n as Float * eta.powi(-(i as i32))).floor() as usize;
560 let r_i = current_budget * eta.powi(i as i32);
561
562 if configurations.len() > n_i {
563 configurations.truncate(n_i);
564 }
565
566 let fidelity = self.budget_to_fidelity(r_i, &fidelities);
567 let mut evaluations = Vec::new();
568
569 for config in &configurations {
570 let evaluation = evaluation_fn(config, &fidelity)?;
571 *total_cost += evaluation.cost;
572 *fidelity_usage
573 .entry(self.fidelity_to_string(&fidelity))
574 .or_insert(0) += 1;
575
576 self.evaluation_history.push(evaluation.clone());
577 evaluations.push(evaluation.clone());
578
579 if self.update_best(&evaluation) {
580 convergence_curve.push(self.current_best.as_ref().unwrap().score);
581 } else if let Some(best) = &self.current_best {
582 convergence_curve.push(best.score);
583 }
584 }
585
586 evaluations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
588 configurations = evaluations
589 .iter()
590 .take(n_i)
591 .map(|eval| eval.hyperparameters.clone())
592 .collect();
593 }
594 }
595
596 Ok(())
597 }
598
599 fn bohb_optimize<F>(
601 &mut self,
602 evaluation_fn: &F,
603 parameter_bounds: &[(Float, Float)],
604 total_cost: &mut Float,
605 convergence_curve: &mut Vec<Float>,
606 fidelity_usage: &mut HashMap<String, usize>,
607 ) -> Result<(), Box<dyn std::error::Error>>
608 where
609 F: Fn(
610 &HashMap<String, Float>,
611 &FidelityLevel,
612 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
613 {
614 self.hyperband_optimize(
617 evaluation_fn,
618 parameter_bounds,
619 total_cost,
620 convergence_curve,
621 fidelity_usage,
622 )?;
623
624 let remaining_budget = self.config.max_budget - *total_cost;
626 if remaining_budget > 0.0 {
627 self.bayesian_optimize(
628 evaluation_fn,
629 parameter_bounds,
630 total_cost,
631 convergence_curve,
632 fidelity_usage,
633 )?;
634 }
635
636 Ok(())
637 }
638
639 fn fabolas_optimize<F>(
641 &mut self,
642 evaluation_fn: &F,
643 parameter_bounds: &[(Float, Float)],
644 total_cost: &mut Float,
645 convergence_curve: &mut Vec<Float>,
646 fidelity_usage: &mut HashMap<String, usize>,
647 ) -> Result<(), Box<dyn std::error::Error>>
648 where
649 F: Fn(
650 &HashMap<String, Float>,
651 &FidelityLevel,
652 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
653 {
654 let (min_fraction, max_fraction, _cost_model) = match &self.config.strategy {
656 MultiFidelityStrategy::Fabolas {
657 min_dataset_fraction,
658 max_dataset_fraction,
659 cost_model,
660 } => (*min_dataset_fraction, *max_dataset_fraction, cost_model),
661 _ => unreachable!(),
662 };
663
664 let mut current_fraction = min_fraction;
665 let fraction_step = (max_fraction - min_fraction) / 10.0;
666
667 while current_fraction <= max_fraction && !self.should_stop() {
668 let fidelity = FidelityLevel::Custom {
669 parameters: {
670 let mut params = HashMap::new();
671 params.insert("dataset_fraction".to_string(), current_fraction);
672 params
673 },
674 relative_cost: current_fraction,
675 accuracy_estimate: current_fraction.sqrt(),
676 };
677
678 let config = self.sample_random_configuration(parameter_bounds)?;
679 let evaluation = evaluation_fn(&config, &fidelity)?;
680
681 *total_cost += evaluation.cost;
682 *fidelity_usage
683 .entry(self.fidelity_to_string(&fidelity))
684 .or_insert(0) += 1;
685
686 self.evaluation_history.push(evaluation.clone());
687 if self.update_best(&evaluation) {
688 convergence_curve.push(self.current_best.as_ref().unwrap().score);
689 } else if let Some(best) = &self.current_best {
690 convergence_curve.push(best.score);
691 }
692
693 current_fraction += fraction_step;
694 }
695
696 Ok(())
697 }
698
699 fn multi_task_gp_optimize<F>(
701 &mut self,
702 evaluation_fn: &F,
703 parameter_bounds: &[(Float, Float)],
704 total_cost: &mut Float,
705 convergence_curve: &mut Vec<Float>,
706 fidelity_usage: &mut HashMap<String, usize>,
707 ) -> Result<(), Box<dyn std::error::Error>>
708 where
709 F: Fn(
710 &HashMap<String, Float>,
711 &FidelityLevel,
712 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
713 {
714 let fidelities = vec![
717 FidelityLevel::Low {
718 sample_fraction: 0.1,
719 epochs_fraction: 0.1,
720 cv_folds: 3,
721 },
722 FidelityLevel::Medium {
723 sample_fraction: 0.5,
724 epochs_fraction: 0.5,
725 cv_folds: 5,
726 },
727 FidelityLevel::High {
728 sample_fraction: 1.0,
729 epochs_fraction: 1.0,
730 cv_folds: 10,
731 },
732 ];
733
734 while self.evaluation_history.len() < self.config.max_evaluations && !self.should_stop() {
735 for fidelity in &fidelities {
736 let config = self.sample_random_configuration(parameter_bounds)?;
737 let evaluation = evaluation_fn(&config, fidelity)?;
738
739 *total_cost += evaluation.cost;
740 *fidelity_usage
741 .entry(self.fidelity_to_string(fidelity))
742 .or_insert(0) += 1;
743
744 self.evaluation_history.push(evaluation.clone());
745 if self.update_best(&evaluation) {
746 convergence_curve.push(self.current_best.as_ref().unwrap().score);
747 } else if let Some(best) = &self.current_best {
748 convergence_curve.push(best.score);
749 }
750
751 if self.evaluation_history.len() >= self.config.max_evaluations {
752 break;
753 }
754 }
755 }
756
757 Ok(())
758 }
759
760 fn generate_initial_configurations(
762 &mut self,
763 parameter_bounds: &[(Float, Float)],
764 n: usize,
765 ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
766 let mut configurations = Vec::new();
767
768 for _ in 0..n {
769 configurations.push(self.sample_random_configuration(parameter_bounds)?);
770 }
771
772 Ok(configurations)
773 }
774
775 fn sample_random_configuration(
777 &mut self,
778 parameter_bounds: &[(Float, Float)],
779 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
780 let mut config = HashMap::new();
781
782 for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
783 let value = self.rng.gen_range(low..high + 1.0);
784 config.insert(format!("param_{}", i), value);
785 }
786
787 Ok(config)
788 }
789
790 fn select_fidelity(
792 &mut self,
793 method: &FidelitySelectionMethod,
794 _config: Option<&HashMap<String, Float>>,
795 ) -> Result<FidelityLevel, Box<dyn std::error::Error>> {
796 match method {
797 FidelitySelectionMethod::LowestFirst => Ok(FidelityLevel::Low {
798 sample_fraction: 0.1,
799 epochs_fraction: 0.1,
800 cv_folds: 3,
801 }),
802 FidelitySelectionMethod::UncertaintyBased { threshold } => {
803 if self.evaluation_history.len() < 5 {
805 Ok(FidelityLevel::Low {
806 sample_fraction: 0.1,
807 epochs_fraction: 0.1,
808 cv_folds: 3,
809 })
810 } else {
811 let avg_uncertainty = self
812 .evaluation_history
813 .iter()
814 .filter_map(|eval| eval.uncertainty)
815 .sum::<Float>()
816 / self.evaluation_history.len() as Float;
817
818 if avg_uncertainty > *threshold {
819 Ok(FidelityLevel::High {
820 sample_fraction: 1.0,
821 epochs_fraction: 1.0,
822 cv_folds: 10,
823 })
824 } else {
825 Ok(FidelityLevel::Medium {
826 sample_fraction: 0.5,
827 epochs_fraction: 0.5,
828 cv_folds: 5,
829 })
830 }
831 }
832 }
833 FidelitySelectionMethod::CostAware { budget_fraction } => {
834 let used_budget_fraction = self
835 .evaluation_history
836 .iter()
837 .map(|e| e.cost)
838 .sum::<Float>()
839 / self.config.max_budget;
840
841 if used_budget_fraction < *budget_fraction {
842 Ok(FidelityLevel::Low {
843 sample_fraction: 0.1,
844 epochs_fraction: 0.1,
845 cv_folds: 3,
846 })
847 } else {
848 Ok(FidelityLevel::High {
849 sample_fraction: 1.0,
850 epochs_fraction: 1.0,
851 cv_folds: 10,
852 })
853 }
854 }
855 _ => Ok(FidelityLevel::Medium {
856 sample_fraction: 0.5,
857 epochs_fraction: 0.5,
858 cv_folds: 5,
859 }),
860 }
861 }
862
863 fn optimize_acquisition(
865 &mut self,
866 acquisition_function: &AcquisitionFunction,
867 parameter_bounds: &[(Float, Float)],
868 ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
869 let n_candidates = 100;
871 let mut best_config = self.sample_random_configuration(parameter_bounds)?;
872 let mut best_acquisition_value = Float::NEG_INFINITY;
873
874 for _ in 0..n_candidates {
875 let candidate = self.sample_random_configuration(parameter_bounds)?;
876 let acquisition_value = self.evaluate_acquisition(&candidate, acquisition_function)?;
877
878 if acquisition_value > best_acquisition_value {
879 best_acquisition_value = acquisition_value;
880 best_config = candidate;
881 }
882 }
883
884 Ok(best_config)
885 }
886
887 fn evaluate_acquisition(
889 &mut self,
890 config: &HashMap<String, Float>,
891 acquisition_function: &AcquisitionFunction,
892 ) -> Result<Float, Box<dyn std::error::Error>> {
893 match acquisition_function {
895 AcquisitionFunction::ExpectedImprovement => {
896 let config_vec: Vec<Float> = config.values().cloned().collect();
898 let config_sum = config_vec.iter().sum::<Float>();
899 Ok(config_sum + self.rng.random::<Float>() * 0.1)
900 }
901 AcquisitionFunction::UpperConfidenceBound { beta } => {
902 let config_vec: Vec<Float> = config.values().cloned().collect();
904 let config_sum = config_vec.iter().sum::<Float>();
905 Ok(config_sum + beta * self.rng.random::<Float>())
906 }
907 _ => {
908 Ok(self.rng.random::<Float>())
910 }
911 }
912 }
913
914 fn increase_fidelity(&self, current: &FidelityLevel, max: &FidelityLevel) -> FidelityLevel {
916 match (current, max) {
917 (FidelityLevel::Low { .. }, _) => FidelityLevel::Medium {
918 sample_fraction: 0.5,
919 epochs_fraction: 0.5,
920 cv_folds: 5,
921 },
922 (FidelityLevel::Medium { .. }, _) => FidelityLevel::High {
923 sample_fraction: 1.0,
924 epochs_fraction: 1.0,
925 cv_folds: 10,
926 },
927 _ => current.clone(),
928 }
929 }
930
931 fn budget_to_fidelity(&self, budget: Float, fidelities: &[FidelityLevel]) -> FidelityLevel {
933 if budget < 0.3 {
934 fidelities
935 .first()
936 .unwrap_or(&FidelityLevel::Low {
937 sample_fraction: 0.1,
938 epochs_fraction: 0.1,
939 cv_folds: 3,
940 })
941 .clone()
942 } else if budget < 0.7 {
943 fidelities
944 .get(1)
945 .unwrap_or(&FidelityLevel::Medium {
946 sample_fraction: 0.5,
947 epochs_fraction: 0.5,
948 cv_folds: 5,
949 })
950 .clone()
951 } else {
952 fidelities
953 .get(2)
954 .unwrap_or(&FidelityLevel::High {
955 sample_fraction: 1.0,
956 epochs_fraction: 1.0,
957 cv_folds: 10,
958 })
959 .clone()
960 }
961 }
962
963 fn fidelity_to_string(&self, fidelity: &FidelityLevel) -> String {
965 match fidelity {
966 FidelityLevel::Low { .. } => "Low".to_string(),
967 FidelityLevel::Medium { .. } => "Medium".to_string(),
968 FidelityLevel::High { .. } => "High".to_string(),
969 FidelityLevel::Custom { .. } => "Custom".to_string(),
970 }
971 }
972
973 fn update_best(&mut self, evaluation: &FidelityEvaluation) -> bool {
975 match &self.current_best {
976 Some(current) => {
977 if evaluation.score > current.score {
978 self.current_best = Some(evaluation.clone());
979 true
980 } else {
981 false
982 }
983 }
984 None => {
985 self.current_best = Some(evaluation.clone());
986 true
987 }
988 }
989 }
990
991 fn should_stop(&self) -> bool {
993 self.evaluation_history.len() >= self.config.max_evaluations
994 }
995
996 fn get_default_fidelity(&self) -> FidelityLevel {
998 FidelityLevel::Medium {
999 sample_fraction: 0.5,
1000 epochs_fraction: 0.5,
1001 cv_folds: 5,
1002 }
1003 }
1004}
1005
1006impl MultiFidelityGP {
1007 fn new() -> Self {
1009 Self {
1010 observations: Vec::new(),
1011 hyperparameters: GPHyperparameters {
1012 length_scales: Array1::from_elem(1, 1.0),
1013 signal_variance: 1.0,
1014 noise_variance: 0.1,
1015 fidelity_correlation: 0.8,
1016 },
1017 trained: false,
1018 }
1019 }
1020
1021 fn update(
1023 &mut self,
1024 evaluations: &[FidelityEvaluation],
1025 ) -> Result<(), Box<dyn std::error::Error>> {
1026 self.observations.clear();
1027
1028 for eval in evaluations {
1029 let params: Vec<Float> = eval.hyperparameters.values().cloned().collect();
1030 let fidelity_value = self.fidelity_to_value(&eval.fidelity);
1031 self.observations
1032 .push((Array1::from_vec(params), fidelity_value, eval.score));
1033 }
1034
1035 self.trained = true;
1037 Ok(())
1038 }
1039
1040 fn fidelity_to_value(&self, fidelity: &FidelityLevel) -> Float {
1042 match fidelity {
1043 FidelityLevel::Low { .. } => 0.1,
1044 FidelityLevel::Medium { .. } => 0.5,
1045 FidelityLevel::High { .. } => 1.0,
1046 FidelityLevel::Custom { relative_cost, .. } => *relative_cost,
1047 }
1048 }
1049}
1050
1051pub fn multi_fidelity_optimize<F>(
1053 evaluation_fn: F,
1054 parameter_bounds: &[(Float, Float)],
1055 config: Option<MultiFidelityConfig>,
1056) -> Result<MultiFidelityResult, Box<dyn std::error::Error>>
1057where
1058 F: Fn(
1059 &HashMap<String, Float>,
1060 &FidelityLevel,
1061 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
1062{
1063 let config = config.unwrap_or_default();
1064 let mut optimizer = MultiFidelityOptimizer::new(config);
1065 optimizer.optimize(evaluation_fn, parameter_bounds)
1066}
1067
1068#[allow(non_snake_case)]
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072
1073 fn mock_evaluation_function(
1074 hyperparameters: &HashMap<String, Float>,
1075 fidelity: &FidelityLevel,
1076 ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>> {
1077 let score = hyperparameters.values().sum::<Float>() * 0.1;
1078 let cost = match fidelity {
1079 FidelityLevel::Low { .. } => 1.0,
1080 FidelityLevel::Medium { .. } => 5.0,
1081 FidelityLevel::High { .. } => 10.0,
1082 FidelityLevel::Custom { relative_cost, .. } => *relative_cost * 10.0,
1083 };
1084
1085 Ok(FidelityEvaluation {
1086 hyperparameters: hyperparameters.clone(),
1087 fidelity: fidelity.clone(),
1088 score,
1089 cost,
1090 evaluation_time: cost,
1091 uncertainty: Some(0.1),
1092 additional_metrics: HashMap::new(),
1093 })
1094 }
1095
1096 #[test]
1097 fn test_multi_fidelity_optimizer_creation() {
1098 let config = MultiFidelityConfig::default();
1099 let optimizer = MultiFidelityOptimizer::new(config);
1100 assert_eq!(optimizer.evaluation_history.len(), 0);
1101 }
1102
1103 #[test]
1104 fn test_multi_fidelity_optimization() {
1105 let config = MultiFidelityConfig {
1106 max_evaluations: 10,
1107 max_budget: 100.0,
1108 ..Default::default()
1109 };
1110
1111 let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
1112
1113 let result =
1114 multi_fidelity_optimize(mock_evaluation_function, ¶meter_bounds, Some(config))
1115 .unwrap();
1116
1117 assert!(result.best_score >= 0.0);
1118 assert!(result.total_cost > 0.0);
1119 assert!(!result.optimization_history.is_empty());
1120 }
1121
1122 #[test]
1123 fn test_fidelity_levels() {
1124 let low_fidelity = FidelityLevel::Low {
1125 sample_fraction: 0.1,
1126 epochs_fraction: 0.1,
1127 cv_folds: 3,
1128 };
1129
1130 let evaluation = mock_evaluation_function(
1131 &HashMap::from([("param_0".to_string(), 0.5)]),
1132 &low_fidelity,
1133 )
1134 .unwrap();
1135
1136 assert_eq!(evaluation.cost, 1.0);
1137 }
1138
1139 #[test]
1140 fn test_successive_halving_strategy() {
1141 let config = MultiFidelityConfig {
1142 strategy: MultiFidelityStrategy::SuccessiveHalving {
1143 eta: 2.0,
1144 min_fidelity: FidelityLevel::Low {
1145 sample_fraction: 0.1,
1146 epochs_fraction: 0.1,
1147 cv_folds: 3,
1148 },
1149 max_fidelity: FidelityLevel::High {
1150 sample_fraction: 1.0,
1151 epochs_fraction: 1.0,
1152 cv_folds: 10,
1153 },
1154 },
1155 max_evaluations: 20,
1156 max_budget: 200.0,
1157 ..Default::default()
1158 };
1159
1160 let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
1161
1162 let result =
1163 multi_fidelity_optimize(mock_evaluation_function, ¶meter_bounds, Some(config))
1164 .unwrap();
1165
1166 assert!(result.best_score >= 0.0);
1167 assert!(!result.fidelity_usage.is_empty());
1168 }
1169}