sklears_model_selection/
multi_fidelity_optimization.rs

1//! Multi-Fidelity Bayesian Optimization
2//!
3//! This module provides multi-fidelity Bayesian optimization for efficient hyperparameter tuning
4//! by leveraging multiple approximation levels (fidelities) of the objective function.
5//! Lower fidelity evaluations are cheaper but less accurate, while higher fidelity evaluations
6//! are more expensive but more accurate.
7
8use scirs2_core::ndarray::Array1;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::random::Rng;
11use scirs2_core::random::SeedableRng;
12use sklears_core::types::Float;
13use std::collections::HashMap;
14
15/// Fidelity levels for multi-fidelity optimization
16#[derive(Debug, Clone)]
17pub enum FidelityLevel {
18    /// Low fidelity (fast, less accurate)
19    Low {
20        sample_fraction: Float,
21
22        epochs_fraction: Float,
23
24        cv_folds: usize,
25    },
26    /// Medium fidelity (moderate speed and accuracy)
27    Medium {
28        sample_fraction: Float,
29
30        epochs_fraction: Float,
31        cv_folds: usize,
32    },
33    /// High fidelity (slow, most accurate)
34    High {
35        sample_fraction: Float,
36        epochs_fraction: Float,
37        cv_folds: usize,
38    },
39    /// Custom fidelity with user-defined parameters
40    Custom {
41        parameters: HashMap<String, Float>,
42        relative_cost: Float,
43        accuracy_estimate: Float,
44    },
45}
46
47/// Multi-fidelity optimization strategies
48#[derive(Debug, Clone)]
49pub enum MultiFidelityStrategy {
50    /// Successive Halving with multiple fidelities
51    SuccessiveHalving {
52        eta: Float,
53
54        min_fidelity: FidelityLevel,
55
56        max_fidelity: FidelityLevel,
57    },
58    /// Multi-fidelity Bayesian Optimization (MFBO)
59    BayesianOptimization {
60        acquisition_function: AcquisitionFunction,
61
62        fidelity_selection: FidelitySelectionMethod,
63        correlation_model: CorrelationModel,
64    },
65    /// Hyperband with multi-fidelity
66    Hyperband {
67        max_budget: Float,
68        eta: Float,
69        fidelities: Vec<FidelityLevel>,
70    },
71    /// BOHB (Bayesian Optimization and Hyperband)
72    BOHB {
73        min_budget: Float,
74        max_budget: Float,
75        eta: Float,
76        bandwidth_factor: Float,
77    },
78    /// Fabolas (Fast Bayesian Optimization on Large Datasets)
79    Fabolas {
80        min_dataset_fraction: Float,
81        max_dataset_fraction: Float,
82        cost_model: CostModel,
83    },
84    /// Multi-Task Gaussian Process
85    MultiTaskGP {
86        task_similarity: Float,
87        shared_hyperparameters: Vec<String>,
88    },
89}
90
91/// Acquisition functions for multi-fidelity optimization
92#[derive(Debug, Clone)]
93pub enum AcquisitionFunction {
94    /// Expected Improvement with fidelity consideration
95    ExpectedImprovement,
96    /// Upper Confidence Bound with fidelity adjustment
97    UpperConfidenceBound { beta: Float },
98    /// Probability of Improvement
99    ProbabilityOfImprovement,
100    /// Knowledge Gradient
101    KnowledgeGradient,
102    /// Entropy Search
103    EntropySearch,
104    /// Multi-fidelity Expected Improvement
105    MultiFidelityEI { fidelity_weight: Float },
106}
107
108/// Methods for selecting fidelity levels
109#[derive(Debug, Clone)]
110pub enum FidelitySelectionMethod {
111    /// Always start with lowest fidelity
112    LowestFirst,
113    /// Dynamic selection based on uncertainty
114    UncertaintyBased { threshold: Float },
115    /// Cost-aware selection
116    CostAware { budget_fraction: Float },
117    /// Performance-based selection
118    PerformanceBased { improvement_threshold: Float },
119    /// Information-theoretic selection
120    InformationTheoretic,
121}
122
123/// Models for correlation between fidelities
124#[derive(Debug, Clone)]
125pub enum CorrelationModel {
126    /// Linear correlation between fidelities
127    Linear { correlation_strength: Float },
128    /// Exponential correlation
129    Exponential { decay_rate: Float },
130    /// Learned correlation using Gaussian Process
131    GaussianProcess { kernel_type: String },
132    /// Rank correlation
133    RankCorrelation,
134}
135
136/// Cost models for different fidelity levels
137#[derive(Debug, Clone)]
138pub enum CostModel {
139    /// Polynomial cost model
140    Polynomial {
141        degree: usize,
142
143        coefficients: Vec<Float>,
144    },
145    /// Exponential cost model
146    Exponential { base: Float, scale: Float },
147    /// Linear cost model
148    Linear { slope: Float, intercept: Float },
149    /// Custom cost function
150    Custom { cost_function: String },
151}
152
153/// Multi-fidelity optimization configuration
154#[derive(Debug, Clone)]
155pub struct MultiFidelityConfig {
156    pub strategy: MultiFidelityStrategy,
157    pub max_evaluations: usize,
158    pub max_budget: Float,
159    pub early_stopping_patience: usize,
160    pub fidelity_progression: FidelityProgression,
161    pub random_state: Option<u64>,
162    pub parallel_evaluations: usize,
163}
164
165/// Fidelity progression strategies
166#[derive(Debug, Clone)]
167pub enum FidelityProgression {
168    /// Linear progression from low to high fidelity
169    Linear,
170    /// Exponential progression
171    Exponential { growth_rate: Float },
172    /// Adaptive progression based on performance
173    Adaptive { adaptation_rate: Float },
174    /// Conservative progression (slow increase)
175    Conservative,
176    /// Aggressive progression (fast increase)
177    Aggressive,
178}
179
180/// Evaluation result at a specific fidelity
181#[derive(Debug, Clone)]
182pub struct FidelityEvaluation {
183    pub hyperparameters: HashMap<String, Float>,
184    pub fidelity: FidelityLevel,
185    pub score: Float,
186    pub cost: Float,
187    pub evaluation_time: Float,
188    pub uncertainty: Option<Float>,
189    pub additional_metrics: HashMap<String, Float>,
190}
191
192/// Multi-fidelity optimization result
193#[derive(Debug, Clone)]
194pub struct MultiFidelityResult {
195    pub best_hyperparameters: HashMap<String, Float>,
196    pub best_score: Float,
197    pub best_fidelity: FidelityLevel,
198    pub optimization_history: Vec<FidelityEvaluation>,
199    pub total_cost: Float,
200    pub total_time: Float,
201    pub convergence_curve: Vec<Float>,
202    pub fidelity_usage: HashMap<String, usize>,
203    pub cost_efficiency: Float,
204}
205
206/// Multi-fidelity Bayesian optimizer
207#[derive(Debug)]
208pub struct MultiFidelityOptimizer {
209    config: MultiFidelityConfig,
210    gaussian_process: MultiFidelityGP,
211    evaluation_history: Vec<FidelityEvaluation>,
212    current_best: Option<FidelityEvaluation>,
213    rng: StdRng,
214}
215
216/// Multi-fidelity Gaussian Process
217#[derive(Debug, Clone)]
218pub struct MultiFidelityGP {
219    observations: Vec<(Array1<Float>, Float, Float)>, // (hyperparams, fidelity, score)
220    hyperparameters: GPHyperparameters,
221    trained: bool,
222}
223
224/// Gaussian Process hyperparameters
225#[derive(Debug, Clone)]
226pub struct GPHyperparameters {
227    pub length_scales: Array1<Float>,
228    pub signal_variance: Float,
229    pub noise_variance: Float,
230    pub fidelity_correlation: Float,
231}
232
233impl Default for MultiFidelityConfig {
234    fn default() -> Self {
235        Self {
236            strategy: MultiFidelityStrategy::BayesianOptimization {
237                acquisition_function: AcquisitionFunction::ExpectedImprovement,
238                fidelity_selection: FidelitySelectionMethod::UncertaintyBased { threshold: 0.1 },
239                correlation_model: CorrelationModel::Linear {
240                    correlation_strength: 0.8,
241                },
242            },
243            max_evaluations: 100,
244            max_budget: 1000.0,
245            early_stopping_patience: 10,
246            fidelity_progression: FidelityProgression::Adaptive {
247                adaptation_rate: 0.1,
248            },
249            random_state: None,
250            parallel_evaluations: 1,
251        }
252    }
253}
254
255impl MultiFidelityOptimizer {
256    /// Create a new multi-fidelity optimizer
257    pub fn new(config: MultiFidelityConfig) -> Self {
258        let rng = match config.random_state {
259            Some(seed) => StdRng::seed_from_u64(seed),
260            None => {
261                use scirs2_core::random::thread_rng;
262                StdRng::from_rng(&mut thread_rng())
263            }
264        };
265
266        let gaussian_process = MultiFidelityGP::new();
267
268        Self {
269            config,
270            gaussian_process,
271            evaluation_history: Vec::new(),
272            current_best: None,
273            rng,
274        }
275    }
276
277    /// Optimize hyperparameters using multi-fidelity approach
278    pub fn optimize<F>(
279        &mut self,
280        evaluation_fn: F,
281        parameter_bounds: &[(Float, Float)],
282    ) -> Result<MultiFidelityResult, Box<dyn std::error::Error>>
283    where
284        F: Fn(
285            &HashMap<String, Float>,
286            &FidelityLevel,
287        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
288    {
289        let start_time = std::time::Instant::now();
290        let mut total_cost = 0.0;
291        let mut convergence_curve = Vec::new();
292        let mut fidelity_usage = HashMap::new();
293
294        match &self.config.strategy {
295            MultiFidelityStrategy::SuccessiveHalving { .. } => {
296                self.successive_halving_optimize(
297                    &evaluation_fn,
298                    parameter_bounds,
299                    &mut total_cost,
300                    &mut convergence_curve,
301                    &mut fidelity_usage,
302                )?;
303            }
304            MultiFidelityStrategy::BayesianOptimization { .. } => {
305                self.bayesian_optimize(
306                    &evaluation_fn,
307                    parameter_bounds,
308                    &mut total_cost,
309                    &mut convergence_curve,
310                    &mut fidelity_usage,
311                )?;
312            }
313            MultiFidelityStrategy::Hyperband { .. } => {
314                self.hyperband_optimize(
315                    &evaluation_fn,
316                    parameter_bounds,
317                    &mut total_cost,
318                    &mut convergence_curve,
319                    &mut fidelity_usage,
320                )?;
321            }
322            MultiFidelityStrategy::BOHB { .. } => {
323                self.bohb_optimize(
324                    &evaluation_fn,
325                    parameter_bounds,
326                    &mut total_cost,
327                    &mut convergence_curve,
328                    &mut fidelity_usage,
329                )?;
330            }
331            MultiFidelityStrategy::Fabolas { .. } => {
332                self.fabolas_optimize(
333                    &evaluation_fn,
334                    parameter_bounds,
335                    &mut total_cost,
336                    &mut convergence_curve,
337                    &mut fidelity_usage,
338                )?;
339            }
340            MultiFidelityStrategy::MultiTaskGP { .. } => {
341                self.multi_task_gp_optimize(
342                    &evaluation_fn,
343                    parameter_bounds,
344                    &mut total_cost,
345                    &mut convergence_curve,
346                    &mut fidelity_usage,
347                )?;
348            }
349        }
350
351        let total_time = start_time.elapsed().as_secs_f64() as Float;
352        let cost_efficiency = if total_cost > 0.0 {
353            self.current_best.as_ref().map_or(0.0, |best| best.score) / total_cost
354        } else {
355            0.0
356        };
357
358        Ok(MultiFidelityResult {
359            best_hyperparameters: self
360                .current_best
361                .as_ref()
362                .map(|best| best.hyperparameters.clone())
363                .unwrap_or_default(),
364            best_score: self.current_best.as_ref().map_or(0.0, |best| best.score),
365            best_fidelity: self
366                .current_best
367                .as_ref()
368                .map(|best| best.fidelity.clone())
369                .unwrap_or(self.get_default_fidelity()),
370            optimization_history: self.evaluation_history.clone(),
371            total_cost,
372            total_time,
373            convergence_curve,
374            fidelity_usage,
375            cost_efficiency,
376        })
377    }
378
379    /// Successive halving with multi-fidelity
380    fn successive_halving_optimize<F>(
381        &mut self,
382        evaluation_fn: &F,
383        parameter_bounds: &[(Float, Float)],
384        total_cost: &mut Float,
385        convergence_curve: &mut Vec<Float>,
386        fidelity_usage: &mut HashMap<String, usize>,
387    ) -> Result<(), Box<dyn std::error::Error>>
388    where
389        F: Fn(
390            &HashMap<String, Float>,
391            &FidelityLevel,
392        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
393    {
394        let (eta, min_fidelity, max_fidelity) = match &self.config.strategy {
395            MultiFidelityStrategy::SuccessiveHalving {
396                eta,
397                min_fidelity,
398                max_fidelity,
399            } => (*eta, min_fidelity.clone(), max_fidelity.clone()),
400            _ => unreachable!(),
401        };
402
403        let mut configurations = self.generate_initial_configurations(parameter_bounds, 50)?;
404        let mut current_fidelity = min_fidelity;
405
406        while configurations.len() > 1 && !self.should_stop() {
407            let mut evaluations = Vec::new();
408
409            // Evaluate all configurations at current fidelity
410            for config in &configurations {
411                let evaluation = evaluation_fn(config, &current_fidelity)?;
412                *total_cost += evaluation.cost;
413                *fidelity_usage
414                    .entry(self.fidelity_to_string(&current_fidelity))
415                    .or_insert(0) += 1;
416
417                self.evaluation_history.push(evaluation.clone());
418                evaluations.push(evaluation.clone());
419
420                if self.update_best(&evaluation) {
421                    convergence_curve.push(self.current_best.as_ref().unwrap().score);
422                } else if let Some(best) = &self.current_best {
423                    convergence_curve.push(best.score);
424                }
425            }
426
427            // Keep top 1/eta configurations
428            evaluations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
429            let keep_count = (configurations.len() as Float / eta).max(1.0) as usize;
430
431            configurations = evaluations
432                .iter()
433                .take(keep_count)
434                .map(|eval| eval.hyperparameters.clone())
435                .collect();
436
437            // Increase fidelity
438            current_fidelity = self.increase_fidelity(&current_fidelity, &max_fidelity);
439        }
440
441        Ok(())
442    }
443
444    /// Bayesian optimization with multi-fidelity
445    fn bayesian_optimize<F>(
446        &mut self,
447        evaluation_fn: &F,
448        parameter_bounds: &[(Float, Float)],
449        total_cost: &mut Float,
450        convergence_curve: &mut Vec<Float>,
451        fidelity_usage: &mut HashMap<String, usize>,
452    ) -> Result<(), Box<dyn std::error::Error>>
453    where
454        F: Fn(
455            &HashMap<String, Float>,
456            &FidelityLevel,
457        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
458    {
459        let (acquisition_function, fidelity_selection, _correlation_model) =
460            match &self.config.strategy {
461                MultiFidelityStrategy::BayesianOptimization {
462                    acquisition_function,
463                    fidelity_selection,
464                    correlation_model,
465                } => (
466                    acquisition_function.clone(),
467                    fidelity_selection.clone(),
468                    correlation_model.clone(),
469                ),
470                _ => unreachable!(),
471            };
472
473        // Initialize with random evaluations
474        let init_evaluations = 5;
475        for _ in 0..init_evaluations {
476            let config = self.sample_random_configuration(parameter_bounds)?;
477            let fidelity = self.select_fidelity(&fidelity_selection, None)?;
478
479            let evaluation = evaluation_fn(&config, &fidelity)?;
480            *total_cost += evaluation.cost;
481            *fidelity_usage
482                .entry(self.fidelity_to_string(&fidelity))
483                .or_insert(0) += 1;
484
485            self.evaluation_history.push(evaluation.clone());
486            if self.update_best(&evaluation) {
487                convergence_curve.push(self.current_best.as_ref().unwrap().score);
488            } else if let Some(best) = &self.current_best {
489                convergence_curve.push(best.score);
490            }
491        }
492
493        // Update Gaussian Process
494        self.gaussian_process.update(&self.evaluation_history)?;
495
496        // Bayesian optimization loop
497        while self.evaluation_history.len() < self.config.max_evaluations && !self.should_stop() {
498            // Select next configuration and fidelity
499            let next_config = self.optimize_acquisition(&acquisition_function, parameter_bounds)?;
500            let next_fidelity = self.select_fidelity(&fidelity_selection, Some(&next_config))?;
501
502            let evaluation = evaluation_fn(&next_config, &next_fidelity)?;
503            *total_cost += evaluation.cost;
504            *fidelity_usage
505                .entry(self.fidelity_to_string(&next_fidelity))
506                .or_insert(0) += 1;
507
508            self.evaluation_history.push(evaluation.clone());
509            if self.update_best(&evaluation) {
510                convergence_curve.push(self.current_best.as_ref().unwrap().score);
511            } else if let Some(best) = &self.current_best {
512                convergence_curve.push(best.score);
513            }
514
515            // Update Gaussian Process periodically
516            if self.evaluation_history.len() % 5 == 0 {
517                self.gaussian_process.update(&self.evaluation_history)?;
518            }
519        }
520
521        Ok(())
522    }
523
524    /// Hyperband optimization
525    fn hyperband_optimize<F>(
526        &mut self,
527        evaluation_fn: &F,
528        parameter_bounds: &[(Float, Float)],
529        total_cost: &mut Float,
530        convergence_curve: &mut Vec<Float>,
531        fidelity_usage: &mut HashMap<String, usize>,
532    ) -> Result<(), Box<dyn std::error::Error>>
533    where
534        F: Fn(
535            &HashMap<String, Float>,
536            &FidelityLevel,
537        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
538    {
539        let (max_budget, eta, fidelities) = match &self.config.strategy {
540            MultiFidelityStrategy::Hyperband {
541                max_budget,
542                eta,
543                fidelities,
544            } => (*max_budget, *eta, fidelities.clone()),
545            _ => unreachable!(),
546        };
547
548        let log_eta = eta.ln();
549        let s_max = (max_budget.ln() / log_eta).floor() as usize;
550
551        for s in 0..=s_max {
552            let n = ((s_max + 1) as Float * eta.powi(s as i32) / (s + 1) as Float).ceil() as usize;
553            let r = max_budget * eta.powi(-(s as i32));
554
555            let mut configurations = self.generate_initial_configurations(parameter_bounds, n)?;
556            let current_budget = r;
557
558            for i in 0..=s {
559                let n_i = (n as Float * eta.powi(-(i as i32))).floor() as usize;
560                let r_i = current_budget * eta.powi(i as i32);
561
562                if configurations.len() > n_i {
563                    configurations.truncate(n_i);
564                }
565
566                let fidelity = self.budget_to_fidelity(r_i, &fidelities);
567                let mut evaluations = Vec::new();
568
569                for config in &configurations {
570                    let evaluation = evaluation_fn(config, &fidelity)?;
571                    *total_cost += evaluation.cost;
572                    *fidelity_usage
573                        .entry(self.fidelity_to_string(&fidelity))
574                        .or_insert(0) += 1;
575
576                    self.evaluation_history.push(evaluation.clone());
577                    evaluations.push(evaluation.clone());
578
579                    if self.update_best(&evaluation) {
580                        convergence_curve.push(self.current_best.as_ref().unwrap().score);
581                    } else if let Some(best) = &self.current_best {
582                        convergence_curve.push(best.score);
583                    }
584                }
585
586                // Keep top configurations
587                evaluations.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
588                configurations = evaluations
589                    .iter()
590                    .take(n_i)
591                    .map(|eval| eval.hyperparameters.clone())
592                    .collect();
593            }
594        }
595
596        Ok(())
597    }
598
599    /// BOHB optimization (Bayesian Optimization and Hyperband)
600    fn bohb_optimize<F>(
601        &mut self,
602        evaluation_fn: &F,
603        parameter_bounds: &[(Float, Float)],
604        total_cost: &mut Float,
605        convergence_curve: &mut Vec<Float>,
606        fidelity_usage: &mut HashMap<String, usize>,
607    ) -> Result<(), Box<dyn std::error::Error>>
608    where
609        F: Fn(
610            &HashMap<String, Float>,
611            &FidelityLevel,
612        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
613    {
614        // Simplified BOHB implementation combining Hyperband with Bayesian optimization
615        // Start with Hyperband for exploration
616        self.hyperband_optimize(
617            evaluation_fn,
618            parameter_bounds,
619            total_cost,
620            convergence_curve,
621            fidelity_usage,
622        )?;
623
624        // Continue with Bayesian optimization for exploitation
625        let remaining_budget = self.config.max_budget - *total_cost;
626        if remaining_budget > 0.0 {
627            self.bayesian_optimize(
628                evaluation_fn,
629                parameter_bounds,
630                total_cost,
631                convergence_curve,
632                fidelity_usage,
633            )?;
634        }
635
636        Ok(())
637    }
638
639    /// Fabolas optimization
640    fn fabolas_optimize<F>(
641        &mut self,
642        evaluation_fn: &F,
643        parameter_bounds: &[(Float, Float)],
644        total_cost: &mut Float,
645        convergence_curve: &mut Vec<Float>,
646        fidelity_usage: &mut HashMap<String, usize>,
647    ) -> Result<(), Box<dyn std::error::Error>>
648    where
649        F: Fn(
650            &HashMap<String, Float>,
651            &FidelityLevel,
652        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
653    {
654        // Simplified Fabolas implementation focusing on dataset size as fidelity
655        let (min_fraction, max_fraction, _cost_model) = match &self.config.strategy {
656            MultiFidelityStrategy::Fabolas {
657                min_dataset_fraction,
658                max_dataset_fraction,
659                cost_model,
660            } => (*min_dataset_fraction, *max_dataset_fraction, cost_model),
661            _ => unreachable!(),
662        };
663
664        let mut current_fraction = min_fraction;
665        let fraction_step = (max_fraction - min_fraction) / 10.0;
666
667        while current_fraction <= max_fraction && !self.should_stop() {
668            let fidelity = FidelityLevel::Custom {
669                parameters: {
670                    let mut params = HashMap::new();
671                    params.insert("dataset_fraction".to_string(), current_fraction);
672                    params
673                },
674                relative_cost: current_fraction,
675                accuracy_estimate: current_fraction.sqrt(),
676            };
677
678            let config = self.sample_random_configuration(parameter_bounds)?;
679            let evaluation = evaluation_fn(&config, &fidelity)?;
680
681            *total_cost += evaluation.cost;
682            *fidelity_usage
683                .entry(self.fidelity_to_string(&fidelity))
684                .or_insert(0) += 1;
685
686            self.evaluation_history.push(evaluation.clone());
687            if self.update_best(&evaluation) {
688                convergence_curve.push(self.current_best.as_ref().unwrap().score);
689            } else if let Some(best) = &self.current_best {
690                convergence_curve.push(best.score);
691            }
692
693            current_fraction += fraction_step;
694        }
695
696        Ok(())
697    }
698
699    /// Multi-task Gaussian Process optimization
700    fn multi_task_gp_optimize<F>(
701        &mut self,
702        evaluation_fn: &F,
703        parameter_bounds: &[(Float, Float)],
704        total_cost: &mut Float,
705        convergence_curve: &mut Vec<Float>,
706        fidelity_usage: &mut HashMap<String, usize>,
707    ) -> Result<(), Box<dyn std::error::Error>>
708    where
709        F: Fn(
710            &HashMap<String, Float>,
711            &FidelityLevel,
712        ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
713    {
714        // Simplified multi-task GP implementation
715        // Treat each fidelity as a separate task
716        let fidelities = vec![
717            FidelityLevel::Low {
718                sample_fraction: 0.1,
719                epochs_fraction: 0.1,
720                cv_folds: 3,
721            },
722            FidelityLevel::Medium {
723                sample_fraction: 0.5,
724                epochs_fraction: 0.5,
725                cv_folds: 5,
726            },
727            FidelityLevel::High {
728                sample_fraction: 1.0,
729                epochs_fraction: 1.0,
730                cv_folds: 10,
731            },
732        ];
733
734        while self.evaluation_history.len() < self.config.max_evaluations && !self.should_stop() {
735            for fidelity in &fidelities {
736                let config = self.sample_random_configuration(parameter_bounds)?;
737                let evaluation = evaluation_fn(&config, fidelity)?;
738
739                *total_cost += evaluation.cost;
740                *fidelity_usage
741                    .entry(self.fidelity_to_string(fidelity))
742                    .or_insert(0) += 1;
743
744                self.evaluation_history.push(evaluation.clone());
745                if self.update_best(&evaluation) {
746                    convergence_curve.push(self.current_best.as_ref().unwrap().score);
747                } else if let Some(best) = &self.current_best {
748                    convergence_curve.push(best.score);
749                }
750
751                if self.evaluation_history.len() >= self.config.max_evaluations {
752                    break;
753                }
754            }
755        }
756
757        Ok(())
758    }
759
760    /// Generate initial random configurations
761    fn generate_initial_configurations(
762        &mut self,
763        parameter_bounds: &[(Float, Float)],
764        n: usize,
765    ) -> Result<Vec<HashMap<String, Float>>, Box<dyn std::error::Error>> {
766        let mut configurations = Vec::new();
767
768        for _ in 0..n {
769            configurations.push(self.sample_random_configuration(parameter_bounds)?);
770        }
771
772        Ok(configurations)
773    }
774
775    /// Sample a random configuration
776    fn sample_random_configuration(
777        &mut self,
778        parameter_bounds: &[(Float, Float)],
779    ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
780        let mut config = HashMap::new();
781
782        for (i, &(low, high)) in parameter_bounds.iter().enumerate() {
783            let value = self.rng.gen_range(low..high + 1.0);
784            config.insert(format!("param_{}", i), value);
785        }
786
787        Ok(config)
788    }
789
790    /// Select fidelity level based on strategy
791    fn select_fidelity(
792        &mut self,
793        method: &FidelitySelectionMethod,
794        _config: Option<&HashMap<String, Float>>,
795    ) -> Result<FidelityLevel, Box<dyn std::error::Error>> {
796        match method {
797            FidelitySelectionMethod::LowestFirst => Ok(FidelityLevel::Low {
798                sample_fraction: 0.1,
799                epochs_fraction: 0.1,
800                cv_folds: 3,
801            }),
802            FidelitySelectionMethod::UncertaintyBased { threshold } => {
803                // Use uncertainty to determine fidelity
804                if self.evaluation_history.len() < 5 {
805                    Ok(FidelityLevel::Low {
806                        sample_fraction: 0.1,
807                        epochs_fraction: 0.1,
808                        cv_folds: 3,
809                    })
810                } else {
811                    let avg_uncertainty = self
812                        .evaluation_history
813                        .iter()
814                        .filter_map(|eval| eval.uncertainty)
815                        .sum::<Float>()
816                        / self.evaluation_history.len() as Float;
817
818                    if avg_uncertainty > *threshold {
819                        Ok(FidelityLevel::High {
820                            sample_fraction: 1.0,
821                            epochs_fraction: 1.0,
822                            cv_folds: 10,
823                        })
824                    } else {
825                        Ok(FidelityLevel::Medium {
826                            sample_fraction: 0.5,
827                            epochs_fraction: 0.5,
828                            cv_folds: 5,
829                        })
830                    }
831                }
832            }
833            FidelitySelectionMethod::CostAware { budget_fraction } => {
834                let used_budget_fraction = self
835                    .evaluation_history
836                    .iter()
837                    .map(|e| e.cost)
838                    .sum::<Float>()
839                    / self.config.max_budget;
840
841                if used_budget_fraction < *budget_fraction {
842                    Ok(FidelityLevel::Low {
843                        sample_fraction: 0.1,
844                        epochs_fraction: 0.1,
845                        cv_folds: 3,
846                    })
847                } else {
848                    Ok(FidelityLevel::High {
849                        sample_fraction: 1.0,
850                        epochs_fraction: 1.0,
851                        cv_folds: 10,
852                    })
853                }
854            }
855            _ => Ok(FidelityLevel::Medium {
856                sample_fraction: 0.5,
857                epochs_fraction: 0.5,
858                cv_folds: 5,
859            }),
860        }
861    }
862
863    /// Optimize acquisition function
864    fn optimize_acquisition(
865        &mut self,
866        acquisition_function: &AcquisitionFunction,
867        parameter_bounds: &[(Float, Float)],
868    ) -> Result<HashMap<String, Float>, Box<dyn std::error::Error>> {
869        // Simplified acquisition optimization - random search
870        let n_candidates = 100;
871        let mut best_config = self.sample_random_configuration(parameter_bounds)?;
872        let mut best_acquisition_value = Float::NEG_INFINITY;
873
874        for _ in 0..n_candidates {
875            let candidate = self.sample_random_configuration(parameter_bounds)?;
876            let acquisition_value = self.evaluate_acquisition(&candidate, acquisition_function)?;
877
878            if acquisition_value > best_acquisition_value {
879                best_acquisition_value = acquisition_value;
880                best_config = candidate;
881            }
882        }
883
884        Ok(best_config)
885    }
886
887    /// Evaluate acquisition function
888    fn evaluate_acquisition(
889        &mut self,
890        config: &HashMap<String, Float>,
891        acquisition_function: &AcquisitionFunction,
892    ) -> Result<Float, Box<dyn std::error::Error>> {
893        // Simplified acquisition function evaluation
894        match acquisition_function {
895            AcquisitionFunction::ExpectedImprovement => {
896                // Mock EI calculation
897                let config_vec: Vec<Float> = config.values().cloned().collect();
898                let config_sum = config_vec.iter().sum::<Float>();
899                Ok(config_sum + self.rng.random::<Float>() * 0.1)
900            }
901            AcquisitionFunction::UpperConfidenceBound { beta } => {
902                // Mock UCB calculation
903                let config_vec: Vec<Float> = config.values().cloned().collect();
904                let config_sum = config_vec.iter().sum::<Float>();
905                Ok(config_sum + beta * self.rng.random::<Float>())
906            }
907            _ => {
908                // Default to random value
909                Ok(self.rng.random::<Float>())
910            }
911        }
912    }
913
914    /// Increase fidelity level
915    fn increase_fidelity(&self, current: &FidelityLevel, max: &FidelityLevel) -> FidelityLevel {
916        match (current, max) {
917            (FidelityLevel::Low { .. }, _) => FidelityLevel::Medium {
918                sample_fraction: 0.5,
919                epochs_fraction: 0.5,
920                cv_folds: 5,
921            },
922            (FidelityLevel::Medium { .. }, _) => FidelityLevel::High {
923                sample_fraction: 1.0,
924                epochs_fraction: 1.0,
925                cv_folds: 10,
926            },
927            _ => current.clone(),
928        }
929    }
930
931    /// Convert budget to fidelity level
932    fn budget_to_fidelity(&self, budget: Float, fidelities: &[FidelityLevel]) -> FidelityLevel {
933        if budget < 0.3 {
934            fidelities
935                .first()
936                .unwrap_or(&FidelityLevel::Low {
937                    sample_fraction: 0.1,
938                    epochs_fraction: 0.1,
939                    cv_folds: 3,
940                })
941                .clone()
942        } else if budget < 0.7 {
943            fidelities
944                .get(1)
945                .unwrap_or(&FidelityLevel::Medium {
946                    sample_fraction: 0.5,
947                    epochs_fraction: 0.5,
948                    cv_folds: 5,
949                })
950                .clone()
951        } else {
952            fidelities
953                .get(2)
954                .unwrap_or(&FidelityLevel::High {
955                    sample_fraction: 1.0,
956                    epochs_fraction: 1.0,
957                    cv_folds: 10,
958                })
959                .clone()
960        }
961    }
962
963    /// Convert fidelity to string for tracking
964    fn fidelity_to_string(&self, fidelity: &FidelityLevel) -> String {
965        match fidelity {
966            FidelityLevel::Low { .. } => "Low".to_string(),
967            FidelityLevel::Medium { .. } => "Medium".to_string(),
968            FidelityLevel::High { .. } => "High".to_string(),
969            FidelityLevel::Custom { .. } => "Custom".to_string(),
970        }
971    }
972
973    /// Update best configuration
974    fn update_best(&mut self, evaluation: &FidelityEvaluation) -> bool {
975        match &self.current_best {
976            Some(current) => {
977                if evaluation.score > current.score {
978                    self.current_best = Some(evaluation.clone());
979                    true
980                } else {
981                    false
982                }
983            }
984            None => {
985                self.current_best = Some(evaluation.clone());
986                true
987            }
988        }
989    }
990
991    /// Check if optimization should stop
992    fn should_stop(&self) -> bool {
993        self.evaluation_history.len() >= self.config.max_evaluations
994    }
995
996    /// Get default fidelity level
997    fn get_default_fidelity(&self) -> FidelityLevel {
998        FidelityLevel::Medium {
999            sample_fraction: 0.5,
1000            epochs_fraction: 0.5,
1001            cv_folds: 5,
1002        }
1003    }
1004}
1005
1006impl MultiFidelityGP {
1007    /// Create a new multi-fidelity Gaussian Process
1008    fn new() -> Self {
1009        Self {
1010            observations: Vec::new(),
1011            hyperparameters: GPHyperparameters {
1012                length_scales: Array1::from_elem(1, 1.0),
1013                signal_variance: 1.0,
1014                noise_variance: 0.1,
1015                fidelity_correlation: 0.8,
1016            },
1017            trained: false,
1018        }
1019    }
1020
1021    /// Update the GP with new observations
1022    fn update(
1023        &mut self,
1024        evaluations: &[FidelityEvaluation],
1025    ) -> Result<(), Box<dyn std::error::Error>> {
1026        self.observations.clear();
1027
1028        for eval in evaluations {
1029            let params: Vec<Float> = eval.hyperparameters.values().cloned().collect();
1030            let fidelity_value = self.fidelity_to_value(&eval.fidelity);
1031            self.observations
1032                .push((Array1::from_vec(params), fidelity_value, eval.score));
1033        }
1034
1035        // Simplified GP update
1036        self.trained = true;
1037        Ok(())
1038    }
1039
1040    /// Convert fidelity to numerical value
1041    fn fidelity_to_value(&self, fidelity: &FidelityLevel) -> Float {
1042        match fidelity {
1043            FidelityLevel::Low { .. } => 0.1,
1044            FidelityLevel::Medium { .. } => 0.5,
1045            FidelityLevel::High { .. } => 1.0,
1046            FidelityLevel::Custom { relative_cost, .. } => *relative_cost,
1047        }
1048    }
1049}
1050
1051/// Convenience function for multi-fidelity optimization
1052pub fn multi_fidelity_optimize<F>(
1053    evaluation_fn: F,
1054    parameter_bounds: &[(Float, Float)],
1055    config: Option<MultiFidelityConfig>,
1056) -> Result<MultiFidelityResult, Box<dyn std::error::Error>>
1057where
1058    F: Fn(
1059        &HashMap<String, Float>,
1060        &FidelityLevel,
1061    ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>>,
1062{
1063    let config = config.unwrap_or_default();
1064    let mut optimizer = MultiFidelityOptimizer::new(config);
1065    optimizer.optimize(evaluation_fn, parameter_bounds)
1066}
1067
1068#[allow(non_snake_case)]
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072
1073    fn mock_evaluation_function(
1074        hyperparameters: &HashMap<String, Float>,
1075        fidelity: &FidelityLevel,
1076    ) -> Result<FidelityEvaluation, Box<dyn std::error::Error>> {
1077        let score = hyperparameters.values().sum::<Float>() * 0.1;
1078        let cost = match fidelity {
1079            FidelityLevel::Low { .. } => 1.0,
1080            FidelityLevel::Medium { .. } => 5.0,
1081            FidelityLevel::High { .. } => 10.0,
1082            FidelityLevel::Custom { relative_cost, .. } => *relative_cost * 10.0,
1083        };
1084
1085        Ok(FidelityEvaluation {
1086            hyperparameters: hyperparameters.clone(),
1087            fidelity: fidelity.clone(),
1088            score,
1089            cost,
1090            evaluation_time: cost,
1091            uncertainty: Some(0.1),
1092            additional_metrics: HashMap::new(),
1093        })
1094    }
1095
1096    #[test]
1097    fn test_multi_fidelity_optimizer_creation() {
1098        let config = MultiFidelityConfig::default();
1099        let optimizer = MultiFidelityOptimizer::new(config);
1100        assert_eq!(optimizer.evaluation_history.len(), 0);
1101    }
1102
1103    #[test]
1104    fn test_multi_fidelity_optimization() {
1105        let config = MultiFidelityConfig {
1106            max_evaluations: 10,
1107            max_budget: 100.0,
1108            ..Default::default()
1109        };
1110
1111        let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
1112
1113        let result =
1114            multi_fidelity_optimize(mock_evaluation_function, &parameter_bounds, Some(config))
1115                .unwrap();
1116
1117        assert!(result.best_score >= 0.0);
1118        assert!(result.total_cost > 0.0);
1119        assert!(!result.optimization_history.is_empty());
1120    }
1121
1122    #[test]
1123    fn test_fidelity_levels() {
1124        let low_fidelity = FidelityLevel::Low {
1125            sample_fraction: 0.1,
1126            epochs_fraction: 0.1,
1127            cv_folds: 3,
1128        };
1129
1130        let evaluation = mock_evaluation_function(
1131            &HashMap::from([("param_0".to_string(), 0.5)]),
1132            &low_fidelity,
1133        )
1134        .unwrap();
1135
1136        assert_eq!(evaluation.cost, 1.0);
1137    }
1138
1139    #[test]
1140    fn test_successive_halving_strategy() {
1141        let config = MultiFidelityConfig {
1142            strategy: MultiFidelityStrategy::SuccessiveHalving {
1143                eta: 2.0,
1144                min_fidelity: FidelityLevel::Low {
1145                    sample_fraction: 0.1,
1146                    epochs_fraction: 0.1,
1147                    cv_folds: 3,
1148                },
1149                max_fidelity: FidelityLevel::High {
1150                    sample_fraction: 1.0,
1151                    epochs_fraction: 1.0,
1152                    cv_folds: 10,
1153                },
1154            },
1155            max_evaluations: 20,
1156            max_budget: 200.0,
1157            ..Default::default()
1158        };
1159
1160        let parameter_bounds = vec![(0.0, 1.0), (0.0, 1.0)];
1161
1162        let result =
1163            multi_fidelity_optimize(mock_evaluation_function, &parameter_bounds, Some(config))
1164                .unwrap();
1165
1166        assert!(result.best_score >= 0.0);
1167        assert!(!result.fidelity_usage.is_empty());
1168    }
1169}