1use super::{ParameterValue, SearchSpace, SearchStrategy, Trial, TrialHistory, TrialState};
8use anyhow::Result;
9use scirs2_core::random::*; use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, SystemTime};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AdvancedEarlyStoppingConfig {
18 pub patience: usize,
20 pub min_delta: f64,
22 pub strategy: EarlyStoppingStrategy,
24 pub adaptive_patience: bool,
26 pub min_evaluation_steps: usize,
28 pub grace_period: usize,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum EarlyStoppingStrategy {
34 Standard,
36 TrainingDynamics {
38 max_gradient_norm: f64,
40 loss_oscillation_threshold: f64,
42 },
43 MultiObjective {
45 primary_metric: String,
47 secondary_metrics: Vec<String>,
49 metric_weights: HashMap<String, f64>,
51 },
52 Bayesian {
54 confidence_threshold: f64,
56 num_samples: usize,
58 },
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct WarmStartConfig {
64 pub strategy: WarmStartStrategy,
66 pub data_source: WarmStartDataSource,
68 pub num_warm_start_trials: usize,
70 pub historical_weight_decay: f64,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub enum WarmStartStrategy {
76 BestTrials,
78 DiverseBest {
80 diversity_threshold: f64,
82 },
83 TransferLearning {
85 similarity_threshold: f64,
87 feature_mapping: String,
89 },
90 MetaLearning {
92 meta_features: Vec<String>,
94 meta_epochs: usize,
96 },
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum WarmStartDataSource {
101 LocalDatabase { path: String },
103 RemoteDatabase { url: String, auth_token: String },
105 FileStorage { directory: String },
107 InMemory,
109}
110
111#[derive(Debug, Clone)]
113pub struct BanditOptimizer {
114 config: BanditConfig,
116 arms: Vec<HashMap<String, ParameterValue>>,
118 arm_stats: Vec<ArmStatistics>,
120 #[allow(dead_code)]
122 exploration_factor: f64,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct BanditConfig {
127 pub algorithm: BanditAlgorithm,
129 pub exploration: ExplorationStrategy,
131 pub reward_function: RewardFunction,
133 pub num_arms: usize,
135 pub arm_generation: ArmGenerationStrategy,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub enum BanditAlgorithm {
141 UCB {
143 confidence_parameter: f64,
145 },
146 ThompsonSampling {
148 alpha_prior: f64,
150 beta_prior: f64,
151 },
152 EpsilonGreedy {
154 epsilon: f64,
156 decay_rate: f64,
158 },
159 EXP3 {
161 gamma: f64,
163 },
164 LinUCB {
166 alpha: f64,
168 context_dim: usize,
170 },
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub enum ExplorationStrategy {
175 Fixed { rate: f64 },
177 Decaying {
179 initial_rate: f64,
180 decay_factor: f64,
181 min_rate: f64,
182 },
183 Adaptive { uncertainty_threshold: f64 },
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum RewardFunction {
189 Direct { metric_name: String },
191 Normalized {
193 metric_name: String,
194 min_value: f64,
195 max_value: f64,
196 },
197 TimeWeighted {
199 metric_name: String,
200 time_weight: f64,
201 },
202 MultiObjective {
204 metrics: HashMap<String, f64>, },
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub enum ArmGenerationStrategy {
210 Random,
212 LatinHypercube,
214 Sobol,
216 Evolutionary {
218 population_size: usize,
219 mutation_rate: f64,
220 crossover_rate: f64,
221 },
222}
223
224#[derive(Debug, Clone)]
225pub struct ArmStatistics {
226 pub pulls: usize,
228 pub total_reward: f64,
230 pub average_reward: f64,
232 pub confidence_bounds: (f64, f64),
234 pub last_update: SystemTime,
236}
237
238#[allow(dead_code)]
240pub struct SurrogateOptimizer {
241 config: SurrogateConfig,
243 observations: Vec<(HashMap<String, ParameterValue>, f64)>,
245 #[allow(dead_code)]
247 model: Box<dyn SurrogateModel>,
248 acquisition: Box<dyn AcquisitionFunction>,
250}
251
252impl std::fmt::Debug for SurrogateOptimizer {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 f.debug_struct("SurrogateOptimizer")
255 .field("config", &self.config)
256 .field("observations", &self.observations)
257 .field("model", &"<dyn SurrogateModel>")
258 .field("acquisition", &"<dyn AcquisitionFunction>")
259 .finish()
260 }
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct SurrogateConfig {
265 pub model_type: SurrogateModelType,
267 pub acquisition_function: AcquisitionFunctionType,
269 pub initial_samples: usize,
271 pub update_frequency: usize,
273 pub optimization_budget: usize,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub enum SurrogateModelType {
279 GaussianProcess {
281 kernel: KernelType,
283 noise_level: f64,
285 length_scales: Vec<f64>,
287 },
288 RandomForest {
290 num_trees: usize,
292 max_depth: usize,
294 min_samples_leaf: usize,
296 },
297 NeuralNetwork {
299 hidden_sizes: Vec<usize>,
301 learning_rate: f64,
303 epochs: usize,
305 },
306 TPE {
308 n_startup_trials: usize,
310 gamma: f64,
312 },
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum KernelType {
317 RBF,
319 Matern { nu: f64 },
321 Linear,
323 Polynomial { degree: usize },
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub enum AcquisitionFunctionType {
329 ExpectedImprovement { xi: f64 },
331 ProbabilityOfImprovement { xi: f64 },
333 UpperConfidenceBound { beta: f64 },
335 EntropySearch,
337 KnowledgeGradient,
339}
340
341#[allow(dead_code)]
343pub struct ParallelEvaluator {
344 config: ParallelEvaluationConfig,
346 #[allow(dead_code)]
348 active_jobs: Arc<Mutex<HashMap<String, EvaluationJob>>>,
349 completed_jobs: Arc<Mutex<VecDeque<EvaluationResult>>>,
351 load_balancer: Box<dyn LoadBalancer>,
353}
354
355impl std::fmt::Debug for ParallelEvaluator {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 f.debug_struct("ParallelEvaluator")
358 .field("config", &self.config)
359 .field(
360 "active_jobs",
361 &"<Arc<Mutex<HashMap<String, EvaluationJob>>>>",
362 )
363 .field(
364 "completed_jobs",
365 &"<Arc<Mutex<VecDeque<EvaluationResult>>>>",
366 )
367 .field("load_balancer", &"<dyn LoadBalancer>")
368 .finish()
369 }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct ParallelEvaluationConfig {
374 pub max_parallel: usize,
376 pub strategy: ParallelStrategy,
378 pub resource_allocation: ResourceAllocation,
380 pub fault_tolerance: FaultToleranceConfig,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub enum ParallelStrategy {
386 Independent,
388 Batch { batch_size: usize },
390 Asynchronous {
392 speculation_depth: usize,
394 },
395 Hierarchical {
397 levels: Vec<usize>,
399 },
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct ResourceAllocation {
404 pub cpu_cores: usize,
406 pub memory_gb: f64,
408 pub gpu_allocation: GPUAllocation,
410 pub priority_levels: Vec<PriorityLevel>,
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
415pub enum GPUAllocation {
416 None,
418 Shared { memory_fraction: f64 },
420 Dedicated { gpu_count: usize },
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct PriorityLevel {
426 pub priority: i32,
428 pub resource_multiplier: f64,
430 pub max_evaluations: usize,
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct FaultToleranceConfig {
436 pub max_retries: usize,
438 pub evaluation_timeout: Duration,
440 pub checkpoint_frequency: Duration,
442}
443
444#[derive(Debug, Clone)]
445pub struct EvaluationJob {
446 pub job_id: String,
448 pub parameters: HashMap<String, ParameterValue>,
450 pub start_time: SystemTime,
452 pub resources: ResourceAllocation,
454 pub status: JobStatus,
456}
457
458#[derive(Debug, Clone)]
459pub enum JobStatus {
460 Queued,
461 Running,
462 Completed,
463 Failed { error: String },
464 Cancelled,
465}
466
467#[derive(Debug, Clone)]
468pub struct EvaluationResult {
469 pub job_id: String,
471 pub parameters: HashMap<String, ParameterValue>,
473 pub metrics: HashMap<String, f64>,
475 pub evaluation_time: Duration,
477 pub resource_usage: ResourceUsage,
479}
480
481#[derive(Debug, Clone)]
482pub struct ResourceUsage {
483 pub cpu_utilization: f64,
485 pub memory_usage: f64,
487 pub gpu_utilization: f64,
489 pub network_io: f64,
491}
492
493pub trait SurrogateModel: Send + Sync {
496 fn fit(&mut self, observations: &[(HashMap<String, ParameterValue>, f64)]) -> Result<()>;
498
499 fn predict(&self, parameters: &HashMap<String, ParameterValue>) -> Result<(f64, f64)>;
501
502 fn update(&mut self, parameters: HashMap<String, ParameterValue>, value: f64) -> Result<()>;
504}
505
506pub trait AcquisitionFunction: Send + Sync {
507 fn compute(
509 &self,
510 parameters: &HashMap<String, ParameterValue>,
511 model: &dyn SurrogateModel,
512 best_value: f64,
513 ) -> Result<f64>;
514
515 fn optimize(
517 &self,
518 model: &dyn SurrogateModel,
519 search_space: &SearchSpace,
520 best_value: f64,
521 ) -> Result<HashMap<String, ParameterValue>>;
522}
523
524pub trait LoadBalancer: Send + Sync {
525 fn assign_job(&mut self, job: &EvaluationJob) -> Result<String>;
527
528 fn update_resource_status(&mut self, resource_id: &str, usage: &ResourceUsage) -> Result<()>;
530
531 fn get_available_resources(&self) -> Vec<String>;
533}
534
535impl BanditOptimizer {
538 pub fn new(config: BanditConfig, search_space: &SearchSpace) -> Result<Self> {
539 let arms = Self::generate_arms(&config, search_space)?;
540 let arm_stats = vec![ArmStatistics::new(); arms.len()];
541
542 Ok(Self {
543 config,
544 arms,
545 arm_stats,
546 exploration_factor: 1.0,
547 })
548 }
549
550 pub fn select_arm(&mut self) -> Result<usize> {
551 match &self.config.algorithm {
552 BanditAlgorithm::UCB {
553 confidence_parameter,
554 } => self.ucb_select(*confidence_parameter),
555 BanditAlgorithm::ThompsonSampling {
556 alpha_prior,
557 beta_prior,
558 } => self.thompson_sampling_select(*alpha_prior, *beta_prior),
559 BanditAlgorithm::EpsilonGreedy {
560 epsilon,
561 decay_rate: _,
562 } => self.epsilon_greedy_select(*epsilon),
563 BanditAlgorithm::EXP3 { gamma } => self.exp3_select(*gamma),
564 BanditAlgorithm::LinUCB {
565 alpha,
566 context_dim: _,
567 } => self.linucb_select(*alpha),
568 }
569 }
570
571 pub fn update_arm(&mut self, arm_index: usize, reward: f64) -> Result<()> {
572 if arm_index >= self.arm_stats.len() {
573 return Err(anyhow::anyhow!("Invalid arm index"));
574 }
575
576 let stats = &mut self.arm_stats[arm_index];
577 stats.pulls += 1;
578 stats.total_reward += reward;
579 stats.average_reward = stats.total_reward / stats.pulls as f64;
580 stats.last_update = SystemTime::now();
581
582 let confidence_radius = (2.0 * (stats.pulls as f64).ln() / stats.pulls as f64).sqrt();
584 stats.confidence_bounds = (
585 stats.average_reward - confidence_radius,
586 stats.average_reward + confidence_radius,
587 );
588
589 Ok(())
590 }
591
592 fn generate_arms(
593 config: &BanditConfig,
594 search_space: &SearchSpace,
595 ) -> Result<Vec<HashMap<String, ParameterValue>>> {
596 let mut arms = Vec::new();
597
598 match &config.arm_generation {
599 ArmGenerationStrategy::Random => {
600 for _ in 0..config.num_arms {
601 arms.push(search_space.sample_random()?);
602 }
603 },
604 ArmGenerationStrategy::LatinHypercube => {
605 arms = search_space.latin_hypercube_sample(config.num_arms)?;
606 },
607 ArmGenerationStrategy::Sobol => {
608 arms = search_space.sobol_sample(config.num_arms)?;
609 },
610 ArmGenerationStrategy::Evolutionary { .. } => {
611 arms = search_space.evolutionary_sample(config.num_arms)?;
613 },
614 }
615
616 Ok(arms)
617 }
618
619 fn ucb_select(&self, confidence_parameter: f64) -> Result<usize> {
620 let total_pulls: usize = self.arm_stats.iter().map(|s| s.pulls).sum();
621
622 if total_pulls == 0 {
623 return Ok(0);
624 }
625
626 let mut best_arm = 0;
627 let mut best_value = f64::NEG_INFINITY;
628
629 for (i, stats) in self.arm_stats.iter().enumerate() {
630 if stats.pulls == 0 {
631 return Ok(i); }
633
634 let confidence_bound = confidence_parameter
635 * (2.0 * (total_pulls as f64).ln() / stats.pulls as f64).sqrt();
636 let ucb_value = stats.average_reward + confidence_bound;
637
638 if ucb_value > best_value {
639 best_value = ucb_value;
640 best_arm = i;
641 }
642 }
643
644 Ok(best_arm)
645 }
646
647 fn thompson_sampling_select(&self, alpha_prior: f64, beta_prior: f64) -> Result<usize> {
648 let mut rng = thread_rng();
649
650 let mut best_arm = 0;
651 let mut best_sample = f64::NEG_INFINITY;
652
653 for (i, stats) in self.arm_stats.iter().enumerate() {
654 let _alpha = alpha_prior + stats.total_reward;
656 let _beta = beta_prior + stats.pulls as f64 - stats.total_reward;
657
658 let sample = rng.random::<f64>(); if sample > best_sample {
662 best_sample = sample;
663 best_arm = i;
664 }
665 }
666
667 Ok(best_arm)
668 }
669
670 fn epsilon_greedy_select(&self, epsilon: f64) -> Result<usize> {
671 let mut rng = thread_rng();
672
673 if rng.random::<f64>() < epsilon {
674 Ok(rng.random_range(0..self.arms.len()))
676 } else {
677 let best_arm = self
679 .arm_stats
680 .iter()
681 .enumerate()
682 .max_by(|(_, a), (_, b)| {
683 a.average_reward
684 .partial_cmp(&b.average_reward)
685 .unwrap_or(std::cmp::Ordering::Equal)
686 })
687 .map(|(i, _)| i)
688 .unwrap_or(0);
689 Ok(best_arm)
690 }
691 }
692
693 fn exp3_select(&self, gamma: f64) -> Result<usize> {
694 let mut rng = thread_rng();
695
696 let num_arms = self.arms.len();
697 if num_arms == 0 {
698 return Err(anyhow::anyhow!("No arms available"));
699 }
700
701 let mut weights = vec![1.0; num_arms];
703 for (i, stats) in self.arm_stats.iter().enumerate() {
704 if stats.pulls > 0 {
705 weights[i] = (gamma * stats.average_reward / num_arms as f64).exp();
707 }
708 }
709
710 let weight_sum: f64 = weights.iter().sum();
712 let mut probabilities = vec![0.0; num_arms];
713
714 for i in 0..num_arms {
715 probabilities[i] = (1.0 - gamma) * weights[i] / weight_sum + gamma / num_arms as f64;
716 }
717
718 let mut cumulative_prob = 0.0;
720 let random_value = rng.random::<f64>();
721
722 for (i, &prob) in probabilities.iter().enumerate() {
723 cumulative_prob += prob;
724 if random_value <= cumulative_prob {
725 return Ok(i);
726 }
727 }
728
729 Ok(num_arms - 1)
731 }
732
733 fn linucb_select(&self, alpha: f64) -> Result<usize> {
734 let total_pulls: usize = self.arm_stats.iter().map(|s| s.pulls).sum();
739
740 if total_pulls == 0 {
741 return Ok(0);
742 }
743
744 let mut best_arm = 0;
745 let mut best_value = f64::NEG_INFINITY;
746
747 for (i, stats) in self.arm_stats.iter().enumerate() {
748 if stats.pulls == 0 {
749 return Ok(i); }
751
752 let confidence_width = alpha * (total_pulls as f64 / stats.pulls as f64).ln().sqrt();
754 let upper_bound = stats.average_reward + confidence_width;
755
756 if upper_bound > best_value {
757 best_value = upper_bound;
758 best_arm = i;
759 }
760 }
761
762 Ok(best_arm)
763 }
764}
765
766impl SearchStrategy for BanditOptimizer {
767 fn suggest(
768 &mut self,
769 _search_space: &SearchSpace,
770 _history: &TrialHistory,
771 ) -> Option<HashMap<String, ParameterValue>> {
772 match self.select_arm() {
773 Ok(arm_index) => Some(self.arms[arm_index].clone()),
774 Err(_) => None,
775 }
776 }
777
778 fn should_terminate(&self, _history: &TrialHistory) -> bool {
779 false }
781
782 fn name(&self) -> &str {
783 "BanditOptimizer"
784 }
785
786 fn update(&mut self, trial: &Trial) {
787 if let TrialState::Complete = trial.state {
788 if let Some(value) =
789 trial.result.as_ref().and_then(|r| r.metrics.metrics.get("objective"))
790 {
791 for (i, arm) in self.arms.iter().enumerate() {
793 if arm == &trial.params {
794 let _ = self.update_arm(i, *value);
795 break;
796 }
797 }
798 }
799 }
800 }
801}
802
803impl ArmStatistics {
804 fn new() -> Self {
805 Self {
806 pulls: 0,
807 total_reward: 0.0,
808 average_reward: 0.0,
809 confidence_bounds: (0.0, 0.0),
810 last_update: SystemTime::now(),
811 }
812 }
813}
814
815impl Default for AdvancedEarlyStoppingConfig {
816 fn default() -> Self {
817 Self {
818 patience: 10,
819 min_delta: 0.001,
820 strategy: EarlyStoppingStrategy::Standard,
821 adaptive_patience: false,
822 min_evaluation_steps: 100,
823 grace_period: 5,
824 }
825 }
826}
827
828impl Default for WarmStartConfig {
829 fn default() -> Self {
830 Self {
831 strategy: WarmStartStrategy::BestTrials,
832 data_source: WarmStartDataSource::InMemory,
833 num_warm_start_trials: 10,
834 historical_weight_decay: 0.9,
835 }
836 }
837}
838
839impl Default for BanditConfig {
840 fn default() -> Self {
841 Self {
842 algorithm: BanditAlgorithm::UCB {
843 confidence_parameter: 1.0,
844 },
845 exploration: ExplorationStrategy::Fixed { rate: 0.1 },
846 reward_function: RewardFunction::Direct {
847 metric_name: "objective".to_string(),
848 },
849 num_arms: 10,
850 arm_generation: ArmGenerationStrategy::Random,
851 }
852 }
853}
854
855impl Default for SurrogateConfig {
856 fn default() -> Self {
857 Self {
858 model_type: SurrogateModelType::GaussianProcess {
859 kernel: KernelType::RBF,
860 noise_level: 0.01,
861 length_scales: vec![1.0],
862 },
863 acquisition_function: AcquisitionFunctionType::ExpectedImprovement { xi: 0.01 },
864 initial_samples: 20,
865 update_frequency: 5,
866 optimization_budget: 1000,
867 }
868 }
869}
870
871impl Default for ParallelEvaluationConfig {
872 fn default() -> Self {
873 Self {
874 max_parallel: 4,
875 strategy: ParallelStrategy::Independent,
876 resource_allocation: ResourceAllocation {
877 cpu_cores: 2,
878 memory_gb: 4.0,
879 gpu_allocation: GPUAllocation::None,
880 priority_levels: vec![],
881 },
882 fault_tolerance: FaultToleranceConfig {
883 max_retries: 3,
884 evaluation_timeout: Duration::from_secs(3600),
885 checkpoint_frequency: Duration::from_secs(300),
886 },
887 }
888 }
889}
890
891impl SearchSpace {
893 pub fn sample_random(&self) -> Result<HashMap<String, ParameterValue>> {
894 let mut rng = thread_rng();
895 let mut params = HashMap::new();
896
897 for param in &self.parameters {
898 let value = match param {
899 super::search_space::HyperParameter::Continuous(p) => {
900 let val = rng.random_range(p.low..=p.high);
901 ParameterValue::Float(val)
902 },
903 super::search_space::HyperParameter::Log(p) => {
904 let log_low = p.low.ln();
905 let log_high = p.high.ln();
906 let log_val = rng.random_range(log_low..=log_high);
907 ParameterValue::Float(log_val.exp())
908 },
909 super::search_space::HyperParameter::Discrete(p) => {
910 let val = rng.random_range(p.low..=p.high);
911 ParameterValue::Int(val)
912 },
913 super::search_space::HyperParameter::Categorical(p) => {
914 let choice = &p.choices[rng.random_range(0..p.choices.len())];
915 ParameterValue::String(choice.clone())
916 },
917 };
918 params.insert(param.name().to_string(), value);
919 }
920
921 Ok(params)
922 }
923
924 pub fn latin_hypercube_sample(
925 &self,
926 n_samples: usize,
927 ) -> Result<Vec<HashMap<String, ParameterValue>>> {
928 let mut rng = thread_rng();
929 let mut samples = Vec::new();
930
931 if n_samples == 0 {
932 return Ok(samples);
933 }
934
935 let continuous_params: Vec<_> = self
937 .parameters
938 .iter()
939 .filter(|p| {
940 matches!(
941 p,
942 super::search_space::HyperParameter::Continuous(_)
943 | super::search_space::HyperParameter::Log(_)
944 )
945 })
946 .collect();
947
948 let n_dims = continuous_params.len();
949
950 if n_dims == 0 {
951 for _ in 0..n_samples {
953 samples.push(self.sample_random()?);
954 }
955 return Ok(samples);
956 }
957
958 let mut lhs_matrix = vec![vec![0.0; n_dims]; n_samples];
960
961 for dim in 0..n_dims {
962 let mut indices: Vec<usize> = (0..n_samples).collect();
963
964 for i in (1..indices.len()).rev() {
966 let j = rng.random_range(0..=i);
967 indices.swap(i, j);
968 }
969
970 for (i, &idx) in indices.iter().enumerate() {
971 let lower = idx as f64 / n_samples as f64;
972 let upper = (idx + 1) as f64 / n_samples as f64;
973 lhs_matrix[i][dim] = rng.random_range(lower..upper);
974 }
975 }
976
977 for i in 0..n_samples {
979 let mut params = HashMap::new();
980
981 for (dim, param) in continuous_params.iter().enumerate() {
983 let unit_value = lhs_matrix[i][dim];
984 let value = match param {
985 super::search_space::HyperParameter::Continuous(p) => {
986 let val = p.low + unit_value * (p.high - p.low);
987 ParameterValue::Float(val)
988 },
989 super::search_space::HyperParameter::Log(p) => {
990 let log_low = p.low.ln();
991 let log_high = p.high.ln();
992 let log_val = log_low + unit_value * (log_high - log_low);
993 ParameterValue::Float(log_val.exp())
994 },
995 _ => unreachable!(),
996 };
997 params.insert(param.name().to_string(), value);
998 }
999
1000 for param in &self.parameters {
1002 if !matches!(
1003 param,
1004 super::search_space::HyperParameter::Continuous(_)
1005 | super::search_space::HyperParameter::Log(_)
1006 ) {
1007 let value = match param {
1008 super::search_space::HyperParameter::Discrete(p) => {
1009 let val = rng.random_range(p.low..=p.high);
1010 ParameterValue::Int(val)
1011 },
1012 super::search_space::HyperParameter::Categorical(p) => {
1013 let choice = &p.choices[rng.random_range(0..p.choices.len())];
1014 ParameterValue::String(choice.clone())
1015 },
1016 _ => unreachable!(),
1017 };
1018 params.insert(param.name().to_string(), value);
1019 }
1020 }
1021
1022 samples.push(params);
1023 }
1024
1025 Ok(samples)
1026 }
1027
1028 pub fn sobol_sample(&self, n_samples: usize) -> Result<Vec<HashMap<String, ParameterValue>>> {
1029 let mut rng = thread_rng();
1032 let mut samples = Vec::new();
1033
1034 let continuous_params: Vec<_> = self
1035 .parameters
1036 .iter()
1037 .filter(|p| {
1038 matches!(
1039 p,
1040 super::search_space::HyperParameter::Continuous(_)
1041 | super::search_space::HyperParameter::Log(_)
1042 )
1043 })
1044 .collect();
1045
1046 let n_dims = continuous_params.len();
1047
1048 if n_dims == 0 {
1049 for _ in 0..n_samples {
1051 samples.push(self.sample_random()?);
1052 }
1053 return Ok(samples);
1054 }
1055
1056 for i in 0..n_samples {
1058 let mut params = HashMap::new();
1059
1060 for (dim, param) in continuous_params.iter().enumerate() {
1061 let unit_value = self.van_der_corput(i + 1, 2 + dim);
1063
1064 let value = match param {
1065 super::search_space::HyperParameter::Continuous(p) => {
1066 let val = p.low + unit_value * (p.high - p.low);
1067 ParameterValue::Float(val)
1068 },
1069 super::search_space::HyperParameter::Log(p) => {
1070 let log_low = p.low.ln();
1071 let log_high = p.high.ln();
1072 let log_val = log_low + unit_value * (log_high - log_low);
1073 ParameterValue::Float(log_val.exp())
1074 },
1075 _ => unreachable!(),
1076 };
1077 params.insert(param.name().to_string(), value);
1078 }
1079
1080 for param in &self.parameters {
1082 if !matches!(
1083 param,
1084 super::search_space::HyperParameter::Continuous(_)
1085 | super::search_space::HyperParameter::Log(_)
1086 ) {
1087 let value = match param {
1088 super::search_space::HyperParameter::Discrete(p) => {
1089 let val = rng.random_range(p.low..=p.high);
1090 ParameterValue::Int(val)
1091 },
1092 super::search_space::HyperParameter::Categorical(p) => {
1093 let choice = &p.choices[rng.random_range(0..p.choices.len())];
1094 ParameterValue::String(choice.clone())
1095 },
1096 _ => unreachable!(),
1097 };
1098 params.insert(param.name().to_string(), value);
1099 }
1100 }
1101
1102 samples.push(params);
1103 }
1104
1105 Ok(samples)
1106 }
1107
1108 pub fn evolutionary_sample(
1109 &self,
1110 n_samples: usize,
1111 ) -> Result<Vec<HashMap<String, ParameterValue>>> {
1112 let mut rng = thread_rng();
1113 let mut samples = Vec::new();
1114
1115 if n_samples == 0 {
1116 return Ok(samples);
1117 }
1118
1119 let population_size = (n_samples / 4).max(10);
1121 let mut population = Vec::new();
1122
1123 for _ in 0..population_size {
1124 population.push(self.sample_random()?);
1125 }
1126
1127 let generations = (n_samples / population_size).max(1);
1129 let mutation_rate = 0.1;
1130 let crossover_rate = 0.7;
1131
1132 for _gen in 0..generations {
1133 let mut new_population = Vec::new();
1134
1135 for _ in 0..population_size {
1137 if rng.random::<f64>() < crossover_rate && population.len() >= 2 {
1138 let parent1_idx = rng.random_range(0..population.len());
1140 let parent2_idx = rng.random_range(0..population.len());
1141 let offspring =
1142 self.crossover(&population[parent1_idx], &population[parent2_idx])?;
1143 new_population.push(offspring);
1144 } else {
1145 let parent_idx = rng.random_range(0..population.len());
1147 let mutated = self.mutate(&population[parent_idx], mutation_rate)?;
1148 new_population.push(mutated);
1149 }
1150 }
1151
1152 population = new_population;
1154
1155 for individual in &population {
1157 if samples.len() < n_samples {
1158 samples.push(individual.clone());
1159 }
1160 }
1161 }
1162
1163 while samples.len() < n_samples {
1165 samples.push(self.sample_random()?);
1166 }
1167
1168 samples.truncate(n_samples);
1169 Ok(samples)
1170 }
1171
1172 fn van_der_corput(&self, n: usize, base: usize) -> f64 {
1174 let mut result = 0.0;
1175 let mut denominator = 1.0;
1176 let mut num = n;
1177
1178 while num > 0 {
1179 denominator *= base as f64;
1180 result += (num % base) as f64 / denominator;
1181 num /= base;
1182 }
1183
1184 result
1185 }
1186
1187 fn crossover(
1189 &self,
1190 parent1: &HashMap<String, ParameterValue>,
1191 parent2: &HashMap<String, ParameterValue>,
1192 ) -> Result<HashMap<String, ParameterValue>> {
1193 let mut rng = thread_rng();
1194 let mut offspring = HashMap::new();
1195
1196 for param in &self.parameters {
1197 let param_name = param.name();
1198 let value = if rng.random::<f64>() < 0.5 {
1199 parent1.get(param_name).cloned()
1200 } else {
1201 parent2.get(param_name).cloned()
1202 };
1203
1204 if let Some(v) = value {
1205 offspring.insert(param_name.to_string(), v);
1206 } else {
1207 let random_value = match param {
1209 super::search_space::HyperParameter::Continuous(p) => {
1210 ParameterValue::Float(rng.random_range(p.low..=p.high))
1211 },
1212 super::search_space::HyperParameter::Log(p) => {
1213 let log_val = rng.random_range(p.low.ln()..=p.high.ln());
1214 ParameterValue::Float(log_val.exp())
1215 },
1216 super::search_space::HyperParameter::Discrete(p) => {
1217 ParameterValue::Int(rng.random_range(p.low..=p.high))
1218 },
1219 super::search_space::HyperParameter::Categorical(p) => {
1220 let choice = &p.choices[rng.random_range(0..p.choices.len())];
1221 ParameterValue::String(choice.clone())
1222 },
1223 };
1224 offspring.insert(param_name.to_string(), random_value);
1225 }
1226 }
1227
1228 Ok(offspring)
1229 }
1230
1231 fn mutate(
1233 &self,
1234 individual: &HashMap<String, ParameterValue>,
1235 mutation_rate: f64,
1236 ) -> Result<HashMap<String, ParameterValue>> {
1237 let mut rng = thread_rng();
1238 let mut mutated = individual.clone();
1239
1240 for param in &self.parameters {
1241 if rng.random::<f64>() < mutation_rate {
1242 let param_name = param.name();
1243 let new_value = match param {
1244 super::search_space::HyperParameter::Continuous(p) => {
1245 if let Some(ParameterValue::Float(current)) = individual.get(param_name) {
1246 let std_dev = (p.high - p.low) * 0.1;
1248 let noise = rng.random::<f64>() * 2.0 - 1.0; let new_val = (current + noise * std_dev).clamp(p.low, p.high);
1250 ParameterValue::Float(new_val)
1251 } else {
1252 ParameterValue::Float(rng.random_range(p.low..=p.high))
1253 }
1254 },
1255 super::search_space::HyperParameter::Log(p) => {
1256 if let Some(ParameterValue::Float(current)) = individual.get(param_name) {
1257 let log_current = current.ln();
1258 let log_std = (p.high.ln() - p.low.ln()) * 0.1;
1259 let noise = rng.random::<f64>() * 2.0 - 1.0;
1260 let new_log =
1261 (log_current + noise * log_std).clamp(p.low.ln(), p.high.ln());
1262 ParameterValue::Float(new_log.exp())
1263 } else {
1264 let log_val = rng.random_range(p.low.ln()..=p.high.ln());
1265 ParameterValue::Float(log_val.exp())
1266 }
1267 },
1268 super::search_space::HyperParameter::Discrete(p) => {
1269 ParameterValue::Int(rng.random_range(p.low..=p.high))
1270 },
1271 super::search_space::HyperParameter::Categorical(p) => {
1272 let choice = &p.choices[rng.random_range(0..p.choices.len())];
1273 ParameterValue::String(choice.clone())
1274 },
1275 };
1276 mutated.insert(param_name.to_string(), new_value);
1277 }
1278 }
1279
1280 Ok(mutated)
1281 }
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286 use super::*;
1287
1288 #[test]
1289 fn test_advanced_early_stopping_config() {
1290 let config = AdvancedEarlyStoppingConfig::default();
1291 assert_eq!(config.patience, 10);
1292 assert!(matches!(config.strategy, EarlyStoppingStrategy::Standard));
1293 }
1294
1295 #[test]
1296 fn test_warm_start_config() {
1297 let config = WarmStartConfig::default();
1298 assert!(matches!(config.strategy, WarmStartStrategy::BestTrials));
1299 assert_eq!(config.num_warm_start_trials, 10);
1300 }
1301
1302 #[test]
1303 fn test_bandit_config() {
1304 let config = BanditConfig::default();
1305 assert!(matches!(config.algorithm, BanditAlgorithm::UCB { .. }));
1306 assert_eq!(config.num_arms, 10);
1307 }
1308
1309 #[test]
1310 fn test_surrogate_config() {
1311 let config = SurrogateConfig::default();
1312 assert!(matches!(
1313 config.model_type,
1314 SurrogateModelType::GaussianProcess { .. }
1315 ));
1316 assert_eq!(config.initial_samples, 20);
1317 }
1318
1319 #[test]
1320 fn test_parallel_evaluation_config() {
1321 let config = ParallelEvaluationConfig::default();
1322 assert_eq!(config.max_parallel, 4);
1323 assert!(matches!(config.strategy, ParallelStrategy::Independent));
1324 }
1325
1326 #[test]
1327 fn test_arm_statistics() {
1328 let stats = ArmStatistics::new();
1329 assert_eq!(stats.pulls, 0);
1330 assert_eq!(stats.total_reward, 0.0);
1331 assert_eq!(stats.average_reward, 0.0);
1332 }
1333}