Skip to main content

trustformers_training/hyperopt/
efficiency.rs

1//! Efficiency features for hyperparameter optimization
2//!
3//! This module provides advanced efficiency features to accelerate hyperparameter
4//! optimization including warm starting, bandit algorithms, surrogate models,
5//! and parallel evaluation strategies.
6
7use super::{ParameterValue, SearchSpace, SearchStrategy, Trial, TrialHistory, TrialState};
8use anyhow::Result;
9use scirs2_core::random::*; // SciRS2 Integration Policy
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, SystemTime};
14
15/// Advanced early stopping configuration with multiple strategies
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AdvancedEarlyStoppingConfig {
18    /// Basic patience configuration
19    pub patience: usize,
20    /// Minimum improvement threshold
21    pub min_delta: f64,
22    /// Early stopping strategy
23    pub strategy: EarlyStoppingStrategy,
24    /// Adaptive patience adjustment
25    pub adaptive_patience: bool,
26    /// Minimum evaluation steps before early stopping
27    pub min_evaluation_steps: usize,
28    /// Grace period for initial convergence
29    pub grace_period: usize,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum EarlyStoppingStrategy {
34    /// Standard early stopping based on validation loss
35    Standard,
36    /// Early stopping based on training dynamics
37    TrainingDynamics {
38        /// Maximum gradient norm threshold
39        max_gradient_norm: f64,
40        /// Loss oscillation threshold
41        loss_oscillation_threshold: f64,
42    },
43    /// Multi-objective early stopping
44    MultiObjective {
45        /// Primary metric for early stopping
46        primary_metric: String,
47        /// Secondary metrics to consider
48        secondary_metrics: Vec<String>,
49        /// Weights for each metric
50        metric_weights: HashMap<String, f64>,
51    },
52    /// Bayesian early stopping using posterior predictions
53    Bayesian {
54        /// Confidence threshold for stopping
55        confidence_threshold: f64,
56        /// Number of posterior samples
57        num_samples: usize,
58    },
59}
60
61/// Warm starting strategies for hyperparameter optimization
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct WarmStartConfig {
64    /// Strategy for warm starting
65    pub strategy: WarmStartStrategy,
66    /// Source of historical data
67    pub data_source: WarmStartDataSource,
68    /// Number of warm start trials to use
69    pub num_warm_start_trials: usize,
70    /// Weight decay for historical data importance
71    pub historical_weight_decay: f64,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum WarmStartStrategy {
76    /// Use best trials from previous studies
77    BestTrials,
78    /// Use diverse set of good trials
79    DiverseBest {
80        /// Diversity threshold
81        diversity_threshold: f64,
82    },
83    /// Transfer learning from similar models
84    TransferLearning {
85        /// Similarity threshold
86        similarity_threshold: f64,
87        /// Feature mapping function
88        feature_mapping: String,
89    },
90    /// Meta-learning based warm start
91    MetaLearning {
92        /// Meta-features to use
93        meta_features: Vec<String>,
94        /// Number of meta-learning epochs
95        meta_epochs: usize,
96    },
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum WarmStartDataSource {
101    /// Local database of previous runs
102    LocalDatabase { path: String },
103    /// Remote database or API
104    RemoteDatabase { url: String, auth_token: String },
105    /// File-based storage
106    FileStorage { directory: String },
107    /// In-memory cache
108    InMemory,
109}
110
111/// Multi-armed bandit algorithms for hyperparameter optimization
112#[derive(Debug, Clone)]
113pub struct BanditOptimizer {
114    /// Bandit algorithm configuration
115    config: BanditConfig,
116    /// Arms (hyperparameter configurations)
117    arms: Vec<HashMap<String, ParameterValue>>,
118    /// Arm statistics
119    arm_stats: Vec<ArmStatistics>,
120    /// Current exploration factor
121    #[allow(dead_code)]
122    exploration_factor: f64,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct BanditConfig {
127    /// Bandit algorithm type
128    pub algorithm: BanditAlgorithm,
129    /// Exploration strategy
130    pub exploration: ExplorationStrategy,
131    /// Reward function configuration
132    pub reward_function: RewardFunction,
133    /// Number of arms to maintain
134    pub num_arms: usize,
135    /// Arm generation strategy
136    pub arm_generation: ArmGenerationStrategy,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub enum BanditAlgorithm {
141    /// Upper Confidence Bound
142    UCB {
143        /// Confidence parameter
144        confidence_parameter: f64,
145    },
146    /// Thompson Sampling
147    ThompsonSampling {
148        /// Prior parameters for Beta distribution
149        alpha_prior: f64,
150        beta_prior: f64,
151    },
152    /// Epsilon-Greedy
153    EpsilonGreedy {
154        /// Exploration probability
155        epsilon: f64,
156        /// Epsilon decay rate
157        decay_rate: f64,
158    },
159    /// EXP3 (Exponential-weight algorithm for Exploration and Exploitation)
160    EXP3 {
161        /// Learning rate
162        gamma: f64,
163    },
164    /// LinUCB for contextual bandits
165    LinUCB {
166        /// Regularization parameter
167        alpha: f64,
168        /// Context dimensionality
169        context_dim: usize,
170    },
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub enum ExplorationStrategy {
175    /// Fixed exploration rate
176    Fixed { rate: f64 },
177    /// Decaying exploration rate
178    Decaying {
179        initial_rate: f64,
180        decay_factor: f64,
181        min_rate: f64,
182    },
183    /// Adaptive exploration based on uncertainty
184    Adaptive { uncertainty_threshold: f64 },
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum RewardFunction {
189    /// Direct performance metric
190    Direct { metric_name: String },
191    /// Normalized performance
192    Normalized {
193        metric_name: String,
194        min_value: f64,
195        max_value: f64,
196    },
197    /// Time-weighted performance
198    TimeWeighted {
199        metric_name: String,
200        time_weight: f64,
201    },
202    /// Multi-objective reward
203    MultiObjective {
204        metrics: HashMap<String, f64>, // metric_name -> weight
205    },
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub enum ArmGenerationStrategy {
210    /// Random sampling from search space
211    Random,
212    /// Latin Hypercube Sampling
213    LatinHypercube,
214    /// Sobol sequences
215    Sobol,
216    /// Evolutionary generation
217    Evolutionary {
218        population_size: usize,
219        mutation_rate: f64,
220        crossover_rate: f64,
221    },
222}
223
224#[derive(Debug, Clone)]
225pub struct ArmStatistics {
226    /// Number of times this arm was pulled
227    pub pulls: usize,
228    /// Total reward accumulated
229    pub total_reward: f64,
230    /// Average reward
231    pub average_reward: f64,
232    /// Confidence bounds
233    pub confidence_bounds: (f64, f64),
234    /// Last update timestamp
235    pub last_update: SystemTime,
236}
237
238/// Surrogate model optimization for expensive hyperparameter evaluations
239#[allow(dead_code)]
240pub struct SurrogateOptimizer {
241    /// Surrogate model configuration
242    config: SurrogateConfig,
243    /// Observed data points
244    observations: Vec<(HashMap<String, ParameterValue>, f64)>,
245    /// Surrogate model
246    #[allow(dead_code)]
247    model: Box<dyn SurrogateModel>,
248    /// Acquisition function
249    acquisition: Box<dyn AcquisitionFunction>,
250}
251
252impl std::fmt::Debug for SurrogateOptimizer {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        f.debug_struct("SurrogateOptimizer")
255            .field("config", &self.config)
256            .field("observations", &self.observations)
257            .field("model", &"<dyn SurrogateModel>")
258            .field("acquisition", &"<dyn AcquisitionFunction>")
259            .finish()
260    }
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct SurrogateConfig {
265    /// Surrogate model type
266    pub model_type: SurrogateModelType,
267    /// Acquisition function type
268    pub acquisition_function: AcquisitionFunctionType,
269    /// Number of initial random samples
270    pub initial_samples: usize,
271    /// Model update frequency
272    pub update_frequency: usize,
273    /// Optimization budget per iteration
274    pub optimization_budget: usize,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub enum SurrogateModelType {
279    /// Gaussian Process
280    GaussianProcess {
281        /// Kernel type
282        kernel: KernelType,
283        /// Noise level
284        noise_level: f64,
285        /// Length scales
286        length_scales: Vec<f64>,
287    },
288    /// Random Forest
289    RandomForest {
290        /// Number of trees
291        num_trees: usize,
292        /// Maximum depth
293        max_depth: usize,
294        /// Minimum samples per leaf
295        min_samples_leaf: usize,
296    },
297    /// Neural Network
298    NeuralNetwork {
299        /// Hidden layer sizes
300        hidden_sizes: Vec<usize>,
301        /// Learning rate
302        learning_rate: f64,
303        /// Number of epochs
304        epochs: usize,
305    },
306    /// Tree-structured Parzen Estimator
307    TPE {
308        /// Number of good/bad samples split
309        n_startup_trials: usize,
310        /// Gamma parameter
311        gamma: f64,
312    },
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum KernelType {
317    /// Radial Basis Function kernel
318    RBF,
319    /// Matern kernel
320    Matern { nu: f64 },
321    /// Linear kernel
322    Linear,
323    /// Polynomial kernel
324    Polynomial { degree: usize },
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub enum AcquisitionFunctionType {
329    /// Expected Improvement
330    ExpectedImprovement { xi: f64 },
331    /// Probability of Improvement
332    ProbabilityOfImprovement { xi: f64 },
333    /// Upper Confidence Bound
334    UpperConfidenceBound { beta: f64 },
335    /// Entropy Search
336    EntropySearch,
337    /// Knowledge Gradient
338    KnowledgeGradient,
339}
340
341/// Parallel evaluation strategies for hyperparameter optimization
342#[allow(dead_code)]
343pub struct ParallelEvaluator {
344    /// Configuration for parallel evaluation
345    config: ParallelEvaluationConfig,
346    /// Active evaluation jobs
347    #[allow(dead_code)]
348    active_jobs: Arc<Mutex<HashMap<String, EvaluationJob>>>,
349    /// Completed jobs queue
350    completed_jobs: Arc<Mutex<VecDeque<EvaluationResult>>>,
351    /// Load balancer
352    load_balancer: Box<dyn LoadBalancer>,
353}
354
355impl std::fmt::Debug for ParallelEvaluator {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        f.debug_struct("ParallelEvaluator")
358            .field("config", &self.config)
359            .field(
360                "active_jobs",
361                &"<Arc<Mutex<HashMap<String, EvaluationJob>>>>",
362            )
363            .field(
364                "completed_jobs",
365                &"<Arc<Mutex<VecDeque<EvaluationResult>>>>",
366            )
367            .field("load_balancer", &"<dyn LoadBalancer>")
368            .finish()
369    }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct ParallelEvaluationConfig {
374    /// Maximum number of parallel evaluations
375    pub max_parallel: usize,
376    /// Evaluation strategy
377    pub strategy: ParallelStrategy,
378    /// Resource allocation
379    pub resource_allocation: ResourceAllocation,
380    /// Fault tolerance settings
381    pub fault_tolerance: FaultToleranceConfig,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub enum ParallelStrategy {
386    /// Independent parallel evaluations
387    Independent,
388    /// Synchronized batch evaluations
389    Batch { batch_size: usize },
390    /// Asynchronous with speculation
391    Asynchronous {
392        /// Maximum speculation depth
393        speculation_depth: usize,
394    },
395    /// Hierarchical evaluation
396    Hierarchical {
397        /// Levels in hierarchy
398        levels: Vec<usize>,
399    },
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct ResourceAllocation {
404    /// CPU cores per evaluation
405    pub cpu_cores: usize,
406    /// Memory per evaluation (GB)
407    pub memory_gb: f64,
408    /// GPU allocation
409    pub gpu_allocation: GPUAllocation,
410    /// Priority levels
411    pub priority_levels: Vec<PriorityLevel>,
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
415pub enum GPUAllocation {
416    /// No GPU
417    None,
418    /// Shared GPU
419    Shared { memory_fraction: f64 },
420    /// Dedicated GPU
421    Dedicated { gpu_count: usize },
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct PriorityLevel {
426    /// Priority value
427    pub priority: i32,
428    /// Resource multiplier
429    pub resource_multiplier: f64,
430    /// Maximum evaluations at this priority
431    pub max_evaluations: usize,
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct FaultToleranceConfig {
436    /// Maximum retries per evaluation
437    pub max_retries: usize,
438    /// Timeout per evaluation
439    pub evaluation_timeout: Duration,
440    /// Checkpoint frequency
441    pub checkpoint_frequency: Duration,
442}
443
444#[derive(Debug, Clone)]
445pub struct EvaluationJob {
446    /// Job identifier
447    pub job_id: String,
448    /// Hyperparameters being evaluated
449    pub parameters: HashMap<String, ParameterValue>,
450    /// Start time
451    pub start_time: SystemTime,
452    /// Resource allocation
453    pub resources: ResourceAllocation,
454    /// Current status
455    pub status: JobStatus,
456}
457
458#[derive(Debug, Clone)]
459pub enum JobStatus {
460    Queued,
461    Running,
462    Completed,
463    Failed { error: String },
464    Cancelled,
465}
466
467#[derive(Debug, Clone)]
468pub struct EvaluationResult {
469    /// Job identifier
470    pub job_id: String,
471    /// Hyperparameters that were evaluated
472    pub parameters: HashMap<String, ParameterValue>,
473    /// Evaluation metrics
474    pub metrics: HashMap<String, f64>,
475    /// Evaluation time
476    pub evaluation_time: Duration,
477    /// Resource usage
478    pub resource_usage: ResourceUsage,
479}
480
481#[derive(Debug, Clone)]
482pub struct ResourceUsage {
483    /// CPU utilization
484    pub cpu_utilization: f64,
485    /// Memory usage
486    pub memory_usage: f64,
487    /// GPU utilization
488    pub gpu_utilization: f64,
489    /// Network I/O
490    pub network_io: f64,
491}
492
493/// Traits for extensibility
494
495pub trait SurrogateModel: Send + Sync {
496    /// Fit the model to observed data
497    fn fit(&mut self, observations: &[(HashMap<String, ParameterValue>, f64)]) -> Result<()>;
498
499    /// Predict mean and variance for given parameters
500    fn predict(&self, parameters: &HashMap<String, ParameterValue>) -> Result<(f64, f64)>;
501
502    /// Update model with new observation
503    fn update(&mut self, parameters: HashMap<String, ParameterValue>, value: f64) -> Result<()>;
504}
505
506pub trait AcquisitionFunction: Send + Sync {
507    /// Compute acquisition value for given parameters
508    fn compute(
509        &self,
510        parameters: &HashMap<String, ParameterValue>,
511        model: &dyn SurrogateModel,
512        best_value: f64,
513    ) -> Result<f64>;
514
515    /// Optimize acquisition function to find next candidate
516    fn optimize(
517        &self,
518        model: &dyn SurrogateModel,
519        search_space: &SearchSpace,
520        best_value: f64,
521    ) -> Result<HashMap<String, ParameterValue>>;
522}
523
524pub trait LoadBalancer: Send + Sync {
525    /// Assign job to best available resource
526    fn assign_job(&mut self, job: &EvaluationJob) -> Result<String>;
527
528    /// Update resource status
529    fn update_resource_status(&mut self, resource_id: &str, usage: &ResourceUsage) -> Result<()>;
530
531    /// Get available resources
532    fn get_available_resources(&self) -> Vec<String>;
533}
534
535// Implementations
536
537impl BanditOptimizer {
538    pub fn new(config: BanditConfig, search_space: &SearchSpace) -> Result<Self> {
539        let arms = Self::generate_arms(&config, search_space)?;
540        let arm_stats = vec![ArmStatistics::new(); arms.len()];
541
542        Ok(Self {
543            config,
544            arms,
545            arm_stats,
546            exploration_factor: 1.0,
547        })
548    }
549
550    pub fn select_arm(&mut self) -> Result<usize> {
551        match &self.config.algorithm {
552            BanditAlgorithm::UCB {
553                confidence_parameter,
554            } => self.ucb_select(*confidence_parameter),
555            BanditAlgorithm::ThompsonSampling {
556                alpha_prior,
557                beta_prior,
558            } => self.thompson_sampling_select(*alpha_prior, *beta_prior),
559            BanditAlgorithm::EpsilonGreedy {
560                epsilon,
561                decay_rate: _,
562            } => self.epsilon_greedy_select(*epsilon),
563            BanditAlgorithm::EXP3 { gamma } => self.exp3_select(*gamma),
564            BanditAlgorithm::LinUCB {
565                alpha,
566                context_dim: _,
567            } => self.linucb_select(*alpha),
568        }
569    }
570
571    pub fn update_arm(&mut self, arm_index: usize, reward: f64) -> Result<()> {
572        if arm_index >= self.arm_stats.len() {
573            return Err(anyhow::anyhow!("Invalid arm index"));
574        }
575
576        let stats = &mut self.arm_stats[arm_index];
577        stats.pulls += 1;
578        stats.total_reward += reward;
579        stats.average_reward = stats.total_reward / stats.pulls as f64;
580        stats.last_update = SystemTime::now();
581
582        // Update confidence bounds
583        let confidence_radius = (2.0 * (stats.pulls as f64).ln() / stats.pulls as f64).sqrt();
584        stats.confidence_bounds = (
585            stats.average_reward - confidence_radius,
586            stats.average_reward + confidence_radius,
587        );
588
589        Ok(())
590    }
591
592    fn generate_arms(
593        config: &BanditConfig,
594        search_space: &SearchSpace,
595    ) -> Result<Vec<HashMap<String, ParameterValue>>> {
596        let mut arms = Vec::new();
597
598        match &config.arm_generation {
599            ArmGenerationStrategy::Random => {
600                for _ in 0..config.num_arms {
601                    arms.push(search_space.sample_random()?);
602                }
603            },
604            ArmGenerationStrategy::LatinHypercube => {
605                arms = search_space.latin_hypercube_sample(config.num_arms)?;
606            },
607            ArmGenerationStrategy::Sobol => {
608                arms = search_space.sobol_sample(config.num_arms)?;
609            },
610            ArmGenerationStrategy::Evolutionary { .. } => {
611                // Implement evolutionary arm generation
612                arms = search_space.evolutionary_sample(config.num_arms)?;
613            },
614        }
615
616        Ok(arms)
617    }
618
619    fn ucb_select(&self, confidence_parameter: f64) -> Result<usize> {
620        let total_pulls: usize = self.arm_stats.iter().map(|s| s.pulls).sum();
621
622        if total_pulls == 0 {
623            return Ok(0);
624        }
625
626        let mut best_arm = 0;
627        let mut best_value = f64::NEG_INFINITY;
628
629        for (i, stats) in self.arm_stats.iter().enumerate() {
630            if stats.pulls == 0 {
631                return Ok(i); // Explore unplayed arms first
632            }
633
634            let confidence_bound = confidence_parameter
635                * (2.0 * (total_pulls as f64).ln() / stats.pulls as f64).sqrt();
636            let ucb_value = stats.average_reward + confidence_bound;
637
638            if ucb_value > best_value {
639                best_value = ucb_value;
640                best_arm = i;
641            }
642        }
643
644        Ok(best_arm)
645    }
646
647    fn thompson_sampling_select(&self, alpha_prior: f64, beta_prior: f64) -> Result<usize> {
648        let mut rng = thread_rng();
649
650        let mut best_arm = 0;
651        let mut best_sample = f64::NEG_INFINITY;
652
653        for (i, stats) in self.arm_stats.iter().enumerate() {
654            // Beta distribution parameters
655            let _alpha = alpha_prior + stats.total_reward;
656            let _beta = beta_prior + stats.pulls as f64 - stats.total_reward;
657
658            // Sample from Beta distribution (simplified)
659            let sample = rng.random::<f64>(); // In practice, use proper Beta sampling
660
661            if sample > best_sample {
662                best_sample = sample;
663                best_arm = i;
664            }
665        }
666
667        Ok(best_arm)
668    }
669
670    fn epsilon_greedy_select(&self, epsilon: f64) -> Result<usize> {
671        let mut rng = thread_rng();
672
673        if rng.random::<f64>() < epsilon {
674            // Explore: select random arm
675            Ok(rng.random_range(0..self.arms.len()))
676        } else {
677            // Exploit: select best arm
678            let best_arm = self
679                .arm_stats
680                .iter()
681                .enumerate()
682                .max_by(|(_, a), (_, b)| {
683                    a.average_reward
684                        .partial_cmp(&b.average_reward)
685                        .unwrap_or(std::cmp::Ordering::Equal)
686                })
687                .map(|(i, _)| i)
688                .unwrap_or(0);
689            Ok(best_arm)
690        }
691    }
692
693    fn exp3_select(&self, gamma: f64) -> Result<usize> {
694        let mut rng = thread_rng();
695
696        let num_arms = self.arms.len();
697        if num_arms == 0 {
698            return Err(anyhow::anyhow!("No arms available"));
699        }
700
701        // Compute weights based on cumulative rewards
702        let mut weights = vec![1.0; num_arms];
703        for (i, stats) in self.arm_stats.iter().enumerate() {
704            if stats.pulls > 0 {
705                // EXP3 weight update: w_i = exp(gamma * average_reward / num_arms)
706                weights[i] = (gamma * stats.average_reward / num_arms as f64).exp();
707            }
708        }
709
710        // Compute probabilities
711        let weight_sum: f64 = weights.iter().sum();
712        let mut probabilities = vec![0.0; num_arms];
713
714        for i in 0..num_arms {
715            probabilities[i] = (1.0 - gamma) * weights[i] / weight_sum + gamma / num_arms as f64;
716        }
717
718        // Sample according to probabilities
719        let mut cumulative_prob = 0.0;
720        let random_value = rng.random::<f64>();
721
722        for (i, &prob) in probabilities.iter().enumerate() {
723            cumulative_prob += prob;
724            if random_value <= cumulative_prob {
725                return Ok(i);
726            }
727        }
728
729        // Fallback to last arm
730        Ok(num_arms - 1)
731    }
732
733    fn linucb_select(&self, alpha: f64) -> Result<usize> {
734        // LinUCB requires contextual information which isn't available in the current design
735        // For now, implement a simplified version that falls back to UCB
736        // In a full implementation, this would use feature vectors for each arm
737
738        let total_pulls: usize = self.arm_stats.iter().map(|s| s.pulls).sum();
739
740        if total_pulls == 0 {
741            return Ok(0);
742        }
743
744        let mut best_arm = 0;
745        let mut best_value = f64::NEG_INFINITY;
746
747        for (i, stats) in self.arm_stats.iter().enumerate() {
748            if stats.pulls == 0 {
749                return Ok(i); // Explore unplayed arms first
750            }
751
752            // Simplified LinUCB using parameter uncertainty as context
753            let confidence_width = alpha * (total_pulls as f64 / stats.pulls as f64).ln().sqrt();
754            let upper_bound = stats.average_reward + confidence_width;
755
756            if upper_bound > best_value {
757                best_value = upper_bound;
758                best_arm = i;
759            }
760        }
761
762        Ok(best_arm)
763    }
764}
765
766impl SearchStrategy for BanditOptimizer {
767    fn suggest(
768        &mut self,
769        _search_space: &SearchSpace,
770        _history: &TrialHistory,
771    ) -> Option<HashMap<String, ParameterValue>> {
772        match self.select_arm() {
773            Ok(arm_index) => Some(self.arms[arm_index].clone()),
774            Err(_) => None,
775        }
776    }
777
778    fn should_terminate(&self, _history: &TrialHistory) -> bool {
779        false // Bandit algorithms typically don't self-terminate
780    }
781
782    fn name(&self) -> &str {
783        "BanditOptimizer"
784    }
785
786    fn update(&mut self, trial: &Trial) {
787        if let TrialState::Complete = trial.state {
788            if let Some(value) =
789                trial.result.as_ref().and_then(|r| r.metrics.metrics.get("objective"))
790            {
791                // Find which arm this trial corresponds to
792                for (i, arm) in self.arms.iter().enumerate() {
793                    if arm == &trial.params {
794                        let _ = self.update_arm(i, *value);
795                        break;
796                    }
797                }
798            }
799        }
800    }
801}
802
803impl ArmStatistics {
804    fn new() -> Self {
805        Self {
806            pulls: 0,
807            total_reward: 0.0,
808            average_reward: 0.0,
809            confidence_bounds: (0.0, 0.0),
810            last_update: SystemTime::now(),
811        }
812    }
813}
814
815impl Default for AdvancedEarlyStoppingConfig {
816    fn default() -> Self {
817        Self {
818            patience: 10,
819            min_delta: 0.001,
820            strategy: EarlyStoppingStrategy::Standard,
821            adaptive_patience: false,
822            min_evaluation_steps: 100,
823            grace_period: 5,
824        }
825    }
826}
827
828impl Default for WarmStartConfig {
829    fn default() -> Self {
830        Self {
831            strategy: WarmStartStrategy::BestTrials,
832            data_source: WarmStartDataSource::InMemory,
833            num_warm_start_trials: 10,
834            historical_weight_decay: 0.9,
835        }
836    }
837}
838
839impl Default for BanditConfig {
840    fn default() -> Self {
841        Self {
842            algorithm: BanditAlgorithm::UCB {
843                confidence_parameter: 1.0,
844            },
845            exploration: ExplorationStrategy::Fixed { rate: 0.1 },
846            reward_function: RewardFunction::Direct {
847                metric_name: "objective".to_string(),
848            },
849            num_arms: 10,
850            arm_generation: ArmGenerationStrategy::Random,
851        }
852    }
853}
854
855impl Default for SurrogateConfig {
856    fn default() -> Self {
857        Self {
858            model_type: SurrogateModelType::GaussianProcess {
859                kernel: KernelType::RBF,
860                noise_level: 0.01,
861                length_scales: vec![1.0],
862            },
863            acquisition_function: AcquisitionFunctionType::ExpectedImprovement { xi: 0.01 },
864            initial_samples: 20,
865            update_frequency: 5,
866            optimization_budget: 1000,
867        }
868    }
869}
870
871impl Default for ParallelEvaluationConfig {
872    fn default() -> Self {
873        Self {
874            max_parallel: 4,
875            strategy: ParallelStrategy::Independent,
876            resource_allocation: ResourceAllocation {
877                cpu_cores: 2,
878                memory_gb: 4.0,
879                gpu_allocation: GPUAllocation::None,
880                priority_levels: vec![],
881            },
882            fault_tolerance: FaultToleranceConfig {
883                max_retries: 3,
884                evaluation_timeout: Duration::from_secs(3600),
885                checkpoint_frequency: Duration::from_secs(300),
886            },
887        }
888    }
889}
890
891// Extension methods for SearchSpace
892impl SearchSpace {
893    pub fn sample_random(&self) -> Result<HashMap<String, ParameterValue>> {
894        let mut rng = thread_rng();
895        let mut params = HashMap::new();
896
897        for param in &self.parameters {
898            let value = match param {
899                super::search_space::HyperParameter::Continuous(p) => {
900                    let val = rng.random_range(p.low..=p.high);
901                    ParameterValue::Float(val)
902                },
903                super::search_space::HyperParameter::Log(p) => {
904                    let log_low = p.low.ln();
905                    let log_high = p.high.ln();
906                    let log_val = rng.random_range(log_low..=log_high);
907                    ParameterValue::Float(log_val.exp())
908                },
909                super::search_space::HyperParameter::Discrete(p) => {
910                    let val = rng.random_range(p.low..=p.high);
911                    ParameterValue::Int(val)
912                },
913                super::search_space::HyperParameter::Categorical(p) => {
914                    let choice = &p.choices[rng.random_range(0..p.choices.len())];
915                    ParameterValue::String(choice.clone())
916                },
917            };
918            params.insert(param.name().to_string(), value);
919        }
920
921        Ok(params)
922    }
923
924    pub fn latin_hypercube_sample(
925        &self,
926        n_samples: usize,
927    ) -> Result<Vec<HashMap<String, ParameterValue>>> {
928        let mut rng = thread_rng();
929        let mut samples = Vec::new();
930
931        if n_samples == 0 {
932            return Ok(samples);
933        }
934
935        // Get only continuous/log parameters for LHS
936        let continuous_params: Vec<_> = self
937            .parameters
938            .iter()
939            .filter(|p| {
940                matches!(
941                    p,
942                    super::search_space::HyperParameter::Continuous(_)
943                        | super::search_space::HyperParameter::Log(_)
944                )
945            })
946            .collect();
947
948        let n_dims = continuous_params.len();
949
950        if n_dims == 0 {
951            // Fall back to random sampling for discrete/categorical only
952            for _ in 0..n_samples {
953                samples.push(self.sample_random()?);
954            }
955            return Ok(samples);
956        }
957
958        // Generate LHS matrix
959        let mut lhs_matrix = vec![vec![0.0; n_dims]; n_samples];
960
961        for dim in 0..n_dims {
962            let mut indices: Vec<usize> = (0..n_samples).collect();
963
964            // Shuffle indices
965            for i in (1..indices.len()).rev() {
966                let j = rng.random_range(0..=i);
967                indices.swap(i, j);
968            }
969
970            for (i, &idx) in indices.iter().enumerate() {
971                let lower = idx as f64 / n_samples as f64;
972                let upper = (idx + 1) as f64 / n_samples as f64;
973                lhs_matrix[i][dim] = rng.random_range(lower..upper);
974            }
975        }
976
977        // Convert LHS matrix to parameter samples
978        for i in 0..n_samples {
979            let mut params = HashMap::new();
980
981            // Handle continuous/log parameters with LHS
982            for (dim, param) in continuous_params.iter().enumerate() {
983                let unit_value = lhs_matrix[i][dim];
984                let value = match param {
985                    super::search_space::HyperParameter::Continuous(p) => {
986                        let val = p.low + unit_value * (p.high - p.low);
987                        ParameterValue::Float(val)
988                    },
989                    super::search_space::HyperParameter::Log(p) => {
990                        let log_low = p.low.ln();
991                        let log_high = p.high.ln();
992                        let log_val = log_low + unit_value * (log_high - log_low);
993                        ParameterValue::Float(log_val.exp())
994                    },
995                    _ => unreachable!(),
996                };
997                params.insert(param.name().to_string(), value);
998            }
999
1000            // Handle discrete/categorical parameters randomly
1001            for param in &self.parameters {
1002                if !matches!(
1003                    param,
1004                    super::search_space::HyperParameter::Continuous(_)
1005                        | super::search_space::HyperParameter::Log(_)
1006                ) {
1007                    let value = match param {
1008                        super::search_space::HyperParameter::Discrete(p) => {
1009                            let val = rng.random_range(p.low..=p.high);
1010                            ParameterValue::Int(val)
1011                        },
1012                        super::search_space::HyperParameter::Categorical(p) => {
1013                            let choice = &p.choices[rng.random_range(0..p.choices.len())];
1014                            ParameterValue::String(choice.clone())
1015                        },
1016                        _ => unreachable!(),
1017                    };
1018                    params.insert(param.name().to_string(), value);
1019                }
1020            }
1021
1022            samples.push(params);
1023        }
1024
1025        Ok(samples)
1026    }
1027
1028    pub fn sobol_sample(&self, n_samples: usize) -> Result<Vec<HashMap<String, ParameterValue>>> {
1029        // Simplified Sobol sequence implementation
1030        // For production use, consider using a proper Sobol sequence library
1031        let mut rng = thread_rng();
1032        let mut samples = Vec::new();
1033
1034        let continuous_params: Vec<_> = self
1035            .parameters
1036            .iter()
1037            .filter(|p| {
1038                matches!(
1039                    p,
1040                    super::search_space::HyperParameter::Continuous(_)
1041                        | super::search_space::HyperParameter::Log(_)
1042                )
1043            })
1044            .collect();
1045
1046        let n_dims = continuous_params.len();
1047
1048        if n_dims == 0 {
1049            // Fall back to random sampling
1050            for _ in 0..n_samples {
1051                samples.push(self.sample_random()?);
1052            }
1053            return Ok(samples);
1054        }
1055
1056        // Generate quasi-random Sobol-like sequence
1057        for i in 0..n_samples {
1058            let mut params = HashMap::new();
1059
1060            for (dim, param) in continuous_params.iter().enumerate() {
1061                // Simple Van der Corput sequence for each dimension
1062                let unit_value = self.van_der_corput(i + 1, 2 + dim);
1063
1064                let value = match param {
1065                    super::search_space::HyperParameter::Continuous(p) => {
1066                        let val = p.low + unit_value * (p.high - p.low);
1067                        ParameterValue::Float(val)
1068                    },
1069                    super::search_space::HyperParameter::Log(p) => {
1070                        let log_low = p.low.ln();
1071                        let log_high = p.high.ln();
1072                        let log_val = log_low + unit_value * (log_high - log_low);
1073                        ParameterValue::Float(log_val.exp())
1074                    },
1075                    _ => unreachable!(),
1076                };
1077                params.insert(param.name().to_string(), value);
1078            }
1079
1080            // Handle discrete/categorical parameters randomly
1081            for param in &self.parameters {
1082                if !matches!(
1083                    param,
1084                    super::search_space::HyperParameter::Continuous(_)
1085                        | super::search_space::HyperParameter::Log(_)
1086                ) {
1087                    let value = match param {
1088                        super::search_space::HyperParameter::Discrete(p) => {
1089                            let val = rng.random_range(p.low..=p.high);
1090                            ParameterValue::Int(val)
1091                        },
1092                        super::search_space::HyperParameter::Categorical(p) => {
1093                            let choice = &p.choices[rng.random_range(0..p.choices.len())];
1094                            ParameterValue::String(choice.clone())
1095                        },
1096                        _ => unreachable!(),
1097                    };
1098                    params.insert(param.name().to_string(), value);
1099                }
1100            }
1101
1102            samples.push(params);
1103        }
1104
1105        Ok(samples)
1106    }
1107
1108    pub fn evolutionary_sample(
1109        &self,
1110        n_samples: usize,
1111    ) -> Result<Vec<HashMap<String, ParameterValue>>> {
1112        let mut rng = thread_rng();
1113        let mut samples = Vec::new();
1114
1115        if n_samples == 0 {
1116            return Ok(samples);
1117        }
1118
1119        // Initialize population with random samples
1120        let population_size = (n_samples / 4).max(10);
1121        let mut population = Vec::new();
1122
1123        for _ in 0..population_size {
1124            population.push(self.sample_random()?);
1125        }
1126
1127        // Evolve population to generate samples
1128        let generations = (n_samples / population_size).max(1);
1129        let mutation_rate = 0.1;
1130        let crossover_rate = 0.7;
1131
1132        for _gen in 0..generations {
1133            let mut new_population = Vec::new();
1134
1135            // Selection and reproduction
1136            for _ in 0..population_size {
1137                if rng.random::<f64>() < crossover_rate && population.len() >= 2 {
1138                    // Crossover
1139                    let parent1_idx = rng.random_range(0..population.len());
1140                    let parent2_idx = rng.random_range(0..population.len());
1141                    let offspring =
1142                        self.crossover(&population[parent1_idx], &population[parent2_idx])?;
1143                    new_population.push(offspring);
1144                } else {
1145                    // Mutation
1146                    let parent_idx = rng.random_range(0..population.len());
1147                    let mutated = self.mutate(&population[parent_idx], mutation_rate)?;
1148                    new_population.push(mutated);
1149                }
1150            }
1151
1152            // Replace population
1153            population = new_population;
1154
1155            // Add best individuals to samples
1156            for individual in &population {
1157                if samples.len() < n_samples {
1158                    samples.push(individual.clone());
1159                }
1160            }
1161        }
1162
1163        // Fill remaining samples with random if needed
1164        while samples.len() < n_samples {
1165            samples.push(self.sample_random()?);
1166        }
1167
1168        samples.truncate(n_samples);
1169        Ok(samples)
1170    }
1171
1172    // Helper function for Van der Corput sequence
1173    fn van_der_corput(&self, n: usize, base: usize) -> f64 {
1174        let mut result = 0.0;
1175        let mut denominator = 1.0;
1176        let mut num = n;
1177
1178        while num > 0 {
1179            denominator *= base as f64;
1180            result += (num % base) as f64 / denominator;
1181            num /= base;
1182        }
1183
1184        result
1185    }
1186
1187    // Helper function for crossover in evolutionary sampling
1188    fn crossover(
1189        &self,
1190        parent1: &HashMap<String, ParameterValue>,
1191        parent2: &HashMap<String, ParameterValue>,
1192    ) -> Result<HashMap<String, ParameterValue>> {
1193        let mut rng = thread_rng();
1194        let mut offspring = HashMap::new();
1195
1196        for param in &self.parameters {
1197            let param_name = param.name();
1198            let value = if rng.random::<f64>() < 0.5 {
1199                parent1.get(param_name).cloned()
1200            } else {
1201                parent2.get(param_name).cloned()
1202            };
1203
1204            if let Some(v) = value {
1205                offspring.insert(param_name.to_string(), v);
1206            } else {
1207                // Fallback to random value if parent doesn't have this parameter
1208                let random_value = match param {
1209                    super::search_space::HyperParameter::Continuous(p) => {
1210                        ParameterValue::Float(rng.random_range(p.low..=p.high))
1211                    },
1212                    super::search_space::HyperParameter::Log(p) => {
1213                        let log_val = rng.random_range(p.low.ln()..=p.high.ln());
1214                        ParameterValue::Float(log_val.exp())
1215                    },
1216                    super::search_space::HyperParameter::Discrete(p) => {
1217                        ParameterValue::Int(rng.random_range(p.low..=p.high))
1218                    },
1219                    super::search_space::HyperParameter::Categorical(p) => {
1220                        let choice = &p.choices[rng.random_range(0..p.choices.len())];
1221                        ParameterValue::String(choice.clone())
1222                    },
1223                };
1224                offspring.insert(param_name.to_string(), random_value);
1225            }
1226        }
1227
1228        Ok(offspring)
1229    }
1230
1231    // Helper function for mutation in evolutionary sampling
1232    fn mutate(
1233        &self,
1234        individual: &HashMap<String, ParameterValue>,
1235        mutation_rate: f64,
1236    ) -> Result<HashMap<String, ParameterValue>> {
1237        let mut rng = thread_rng();
1238        let mut mutated = individual.clone();
1239
1240        for param in &self.parameters {
1241            if rng.random::<f64>() < mutation_rate {
1242                let param_name = param.name();
1243                let new_value = match param {
1244                    super::search_space::HyperParameter::Continuous(p) => {
1245                        if let Some(ParameterValue::Float(current)) = individual.get(param_name) {
1246                            // Gaussian mutation
1247                            let std_dev = (p.high - p.low) * 0.1;
1248                            let noise = rng.random::<f64>() * 2.0 - 1.0; // Simple noise
1249                            let new_val = (current + noise * std_dev).clamp(p.low, p.high);
1250                            ParameterValue::Float(new_val)
1251                        } else {
1252                            ParameterValue::Float(rng.random_range(p.low..=p.high))
1253                        }
1254                    },
1255                    super::search_space::HyperParameter::Log(p) => {
1256                        if let Some(ParameterValue::Float(current)) = individual.get(param_name) {
1257                            let log_current = current.ln();
1258                            let log_std = (p.high.ln() - p.low.ln()) * 0.1;
1259                            let noise = rng.random::<f64>() * 2.0 - 1.0;
1260                            let new_log =
1261                                (log_current + noise * log_std).clamp(p.low.ln(), p.high.ln());
1262                            ParameterValue::Float(new_log.exp())
1263                        } else {
1264                            let log_val = rng.random_range(p.low.ln()..=p.high.ln());
1265                            ParameterValue::Float(log_val.exp())
1266                        }
1267                    },
1268                    super::search_space::HyperParameter::Discrete(p) => {
1269                        ParameterValue::Int(rng.random_range(p.low..=p.high))
1270                    },
1271                    super::search_space::HyperParameter::Categorical(p) => {
1272                        let choice = &p.choices[rng.random_range(0..p.choices.len())];
1273                        ParameterValue::String(choice.clone())
1274                    },
1275                };
1276                mutated.insert(param_name.to_string(), new_value);
1277            }
1278        }
1279
1280        Ok(mutated)
1281    }
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286    use super::*;
1287
1288    #[test]
1289    fn test_advanced_early_stopping_config() {
1290        let config = AdvancedEarlyStoppingConfig::default();
1291        assert_eq!(config.patience, 10);
1292        assert!(matches!(config.strategy, EarlyStoppingStrategy::Standard));
1293    }
1294
1295    #[test]
1296    fn test_warm_start_config() {
1297        let config = WarmStartConfig::default();
1298        assert!(matches!(config.strategy, WarmStartStrategy::BestTrials));
1299        assert_eq!(config.num_warm_start_trials, 10);
1300    }
1301
1302    #[test]
1303    fn test_bandit_config() {
1304        let config = BanditConfig::default();
1305        assert!(matches!(config.algorithm, BanditAlgorithm::UCB { .. }));
1306        assert_eq!(config.num_arms, 10);
1307    }
1308
1309    #[test]
1310    fn test_surrogate_config() {
1311        let config = SurrogateConfig::default();
1312        assert!(matches!(
1313            config.model_type,
1314            SurrogateModelType::GaussianProcess { .. }
1315        ));
1316        assert_eq!(config.initial_samples, 20);
1317    }
1318
1319    #[test]
1320    fn test_parallel_evaluation_config() {
1321        let config = ParallelEvaluationConfig::default();
1322        assert_eq!(config.max_parallel, 4);
1323        assert!(matches!(config.strategy, ParallelStrategy::Independent));
1324    }
1325
1326    #[test]
1327    fn test_arm_statistics() {
1328        let stats = ArmStatistics::new();
1329        assert_eq!(stats.pulls, 0);
1330        assert_eq!(stats.total_reward, 0.0);
1331        assert_eq!(stats.average_reward, 0.0);
1332    }
1333}