1use super::search_space::HyperParameter;
4use super::{Direction, ParameterValue, SearchSpace, Trial, TrialHistory};
5use scirs2_core::random::*; use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SamplerConfig {
12 pub seed: Option<u64>,
14 pub n_startup_trials: usize,
16 pub n_ei_candidates: usize,
18}
19
20impl Default for SamplerConfig {
21 fn default() -> Self {
22 Self {
23 seed: None,
24 n_startup_trials: 10,
25 n_ei_candidates: 24,
26 }
27 }
28}
29
30pub trait Sampler: Send + Sync {
32 fn sample(
34 &mut self,
35 search_space: &SearchSpace,
36 trial_history: &TrialHistory,
37 ) -> HashMap<String, ParameterValue>;
38
39 fn update(&mut self, _trial: &Trial) {}
41
42 fn name(&self) -> &str;
44}
45
46pub struct RandomSampler {
48 rng: StdRng,
49 name: String,
50}
51
52impl RandomSampler {
53 pub fn new() -> Self {
55 Self {
56 rng: StdRng::seed_from_u64(thread_rng().random()),
57 name: "RandomSampler".to_string(),
58 }
59 }
60
61 pub fn with_seed(seed: u64) -> Self {
63 Self {
64 rng: StdRng::seed_from_u64(seed),
65 name: format!("RandomSampler(seed={})", seed),
66 }
67 }
68}
69
70impl Sampler for RandomSampler {
71 fn sample(
72 &mut self,
73 search_space: &SearchSpace,
74 _trial_history: &TrialHistory,
75 ) -> HashMap<String, ParameterValue> {
76 search_space.sample(&mut self.rng).unwrap_or_else(|e| {
77 log::warn!(
78 "Failed to sample from search space: {}. Using empty configuration.",
79 e
80 );
81 HashMap::new()
82 })
83 }
84
85 fn name(&self) -> &str {
86 &self.name
87 }
88}
89
90impl Default for RandomSampler {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96pub struct TPESampler {
99 config: SamplerConfig,
100 rng: StdRng,
101 name: String,
102 good_trials: Vec<HashMap<String, ParameterValue>>,
104 bad_trials: Vec<HashMap<String, ParameterValue>>,
106 percentile: f64,
108}
109
110impl TPESampler {
111 pub fn new() -> Self {
113 Self::with_config(SamplerConfig::default())
114 }
115
116 pub fn with_config(config: SamplerConfig) -> Self {
118 let rng = if let Some(seed) = config.seed {
119 StdRng::seed_from_u64(seed)
120 } else {
121 StdRng::seed_from_u64(thread_rng().random())
122 };
123
124 let name = format!(
125 "TPESampler(startup={}, candidates={})",
126 config.n_startup_trials, config.n_ei_candidates
127 );
128
129 Self {
130 config,
131 rng,
132 name,
133 good_trials: Vec::new(),
134 bad_trials: Vec::new(),
135 percentile: 0.1, }
137 }
138}
139
140impl Sampler for TPESampler {
141 fn sample(
142 &mut self,
143 search_space: &SearchSpace,
144 trial_history: &TrialHistory,
145 ) -> HashMap<String, ParameterValue> {
146 let completed_trials = trial_history.completed_trials();
147
148 if completed_trials.len() < self.config.n_startup_trials {
150 return search_space.sample(&mut self.rng).unwrap_or_else(|e| {
151 log::warn!(
152 "Failed to sample from search space: {}. Using empty configuration.",
153 e
154 );
155 HashMap::new()
156 });
157 }
158
159 self.update_trial_groups(&completed_trials, trial_history.direction.clone());
161
162 if self.good_trials.is_empty() {
164 return search_space.sample(&mut self.rng).unwrap_or_else(|e| {
165 log::warn!(
166 "Failed to sample from search space: {}. Using empty configuration.",
167 e
168 );
169 HashMap::new()
170 });
171 }
172
173 let mut best_candidate = None;
175 let mut best_score = f64::NEG_INFINITY;
176
177 for _ in 0..self.config.n_ei_candidates {
178 let candidate = self.sample_from_good_trials(search_space);
179 let score = self.compute_expected_improvement(&candidate, search_space);
180
181 if score > best_score {
182 best_score = score;
183 best_candidate = Some(candidate);
184 }
185 }
186
187 best_candidate.unwrap_or_else(|| {
188 search_space.sample(&mut self.rng).unwrap_or_else(|e| {
189 log::warn!(
190 "Failed to sample from search space: {}. Using empty configuration.",
191 e
192 );
193 HashMap::new()
194 })
195 })
196 }
197
198 fn update(&mut self, _trial: &Trial) {
199 }
201
202 fn name(&self) -> &str {
203 &self.name
204 }
205}
206
207impl TPESampler {
208 fn update_trial_groups(&mut self, trials: &[&Trial], direction: Direction) {
209 if trials.is_empty() {
210 return;
211 }
212
213 let mut sorted_trials = trials.to_vec();
215 sorted_trials.sort_by(|a, b| {
216 let a_val = a.objective_value().unwrap_or(f64::NEG_INFINITY);
217 let b_val = b.objective_value().unwrap_or(f64::NEG_INFINITY);
218
219 match direction {
220 Direction::Maximize => {
221 b_val.partial_cmp(&a_val).unwrap_or(std::cmp::Ordering::Equal)
222 },
223 Direction::Minimize => {
224 a_val.partial_cmp(&b_val).unwrap_or(std::cmp::Ordering::Equal)
225 },
226 }
227 });
228
229 let split_idx = ((trials.len() as f64 * self.percentile).ceil() as usize).max(1);
231
232 self.good_trials = sorted_trials[..split_idx].iter().map(|t| t.params.clone()).collect();
233
234 self.bad_trials = sorted_trials[split_idx..].iter().map(|t| t.params.clone()).collect();
235 }
236
237 fn sample_from_good_trials(
238 &mut self,
239 search_space: &SearchSpace,
240 ) -> HashMap<String, ParameterValue> {
241 if self.good_trials.is_empty() {
242 return search_space.sample(&mut self.rng).unwrap_or_else(|e| {
243 log::warn!(
244 "Failed to sample from search space: {}. Using empty configuration.",
245 e
246 );
247 HashMap::new()
248 });
249 }
250
251 let base_trial_index = self.rng.random_range(0..self.good_trials.len());
253 let base_trial = self.good_trials[base_trial_index].clone();
254 let mut result = HashMap::new();
255
256 for param in &search_space.parameters {
257 let param_name = param.name();
258
259 if let Some(base_value) = base_trial.get(param_name) {
260 let new_value = self.add_noise_to_parameter(param, base_value);
262 result.insert(param_name.to_string(), new_value);
263 } else {
264 result.insert(
266 param_name.to_string(),
267 param.sample(&mut self.rng).unwrap_or_else(|e| {
268 log::warn!(
269 "Failed to sample parameter '{}': {}. Using default value.",
270 param.name(),
271 e
272 );
273 match param {
275 HyperParameter::Categorical(_) => {
276 ParameterValue::String("default".to_string())
277 },
278 HyperParameter::Continuous(_) => ParameterValue::Float(0.0),
279 HyperParameter::Discrete(_) => ParameterValue::Int(0),
280 HyperParameter::Log(_) => ParameterValue::Float(1e-3),
281 }
282 }),
283 );
284 }
285 }
286
287 result
288 }
289
290 fn add_noise_to_parameter(
291 &mut self,
292 param: &super::search_space::HyperParameter,
293 base_value: &ParameterValue,
294 ) -> ParameterValue {
295 use super::search_space::HyperParameter;
296
297 match param {
298 HyperParameter::Categorical(_) => {
299 param.sample(&mut self.rng).unwrap_or_else(|e| {
301 log::warn!(
302 "Failed to sample parameter '{}': {}. Using default value.",
303 param.name(),
304 e
305 );
306 match param {
308 HyperParameter::Categorical(_) => {
309 ParameterValue::String("default".to_string())
310 },
311 HyperParameter::Continuous(_) => ParameterValue::Float(0.0),
312 HyperParameter::Discrete(_) => ParameterValue::Int(0),
313 HyperParameter::Log(_) => ParameterValue::Float(1e-3),
314 }
315 })
316 },
317 HyperParameter::Continuous(p) => {
318 if let Some(base_float) = base_value.as_float() {
319 let noise_std = (p.high - p.low) * 0.1; let normal = Normal::new(0.0, noise_std).unwrap_or_else(|_| {
323 Normal::new(0.0, 1.0)
324 .expect("Standard normal distribution should always be valid")
325 });
326 let noisy_value = base_float + normal.sample(&mut self.rng);
327 let clamped_value = noisy_value.clamp(p.low, p.high);
328 ParameterValue::Float(clamped_value)
329 } else {
330 param.sample(&mut self.rng).unwrap_or_else(|e| {
331 log::warn!(
332 "Failed to sample parameter '{}': {}. Using default value.",
333 param.name(),
334 e
335 );
336 match param {
338 HyperParameter::Categorical(_) => {
339 ParameterValue::String("default".to_string())
340 },
341 HyperParameter::Continuous(_) => ParameterValue::Float(0.0),
342 HyperParameter::Discrete(_) => ParameterValue::Int(0),
343 HyperParameter::Log(_) => ParameterValue::Float(1e-3),
344 }
345 })
346 }
347 },
348 HyperParameter::Discrete(p) => {
349 if let Some(base_int) = base_value.as_int() {
350 let noise_range = ((p.high - p.low) / 10).max(p.step); let noise = self.rng.random_range(-noise_range..=noise_range);
353 let noisy_value = base_int + noise;
354 let clamped_value = noisy_value.clamp(p.low, p.high);
355 let stepped_value = p.low + ((clamped_value - p.low) / p.step) * p.step;
357 ParameterValue::Int(stepped_value)
358 } else {
359 param.sample(&mut self.rng).unwrap_or_else(|e| {
360 log::warn!(
361 "Failed to sample parameter '{}': {}. Using default value.",
362 param.name(),
363 e
364 );
365 match param {
367 HyperParameter::Categorical(_) => {
368 ParameterValue::String("default".to_string())
369 },
370 HyperParameter::Continuous(_) => ParameterValue::Float(0.0),
371 HyperParameter::Discrete(_) => ParameterValue::Int(0),
372 HyperParameter::Log(_) => ParameterValue::Float(1e-3),
373 }
374 })
375 }
376 },
377 HyperParameter::Log(p) => {
378 if let Some(base_float) = base_value.as_float() {
379 let log_base = base_float.log(p.base);
381 let log_low = p.low.log(p.base);
382 let log_high = p.high.log(p.base);
383 let noise_std = (log_high - log_low) * 0.1;
384
385 let normal = Normal::new(0.0, noise_std).unwrap_or_else(|_| {
387 Normal::new(0.0, 1.0)
388 .expect("Standard normal distribution should always be valid")
389 });
390 let noisy_log = log_base + normal.sample(&mut self.rng);
391 let clamped_log = noisy_log.clamp(log_low, log_high);
392 let new_value = p.base.powf(clamped_log);
393 ParameterValue::Float(new_value)
394 } else {
395 param.sample(&mut self.rng).unwrap_or_else(|e| {
396 log::warn!(
397 "Failed to sample parameter '{}': {}. Using default value.",
398 param.name(),
399 e
400 );
401 match param {
403 HyperParameter::Categorical(_) => {
404 ParameterValue::String("default".to_string())
405 },
406 HyperParameter::Continuous(_) => ParameterValue::Float(0.0),
407 HyperParameter::Discrete(_) => ParameterValue::Int(0),
408 HyperParameter::Log(_) => ParameterValue::Float(1e-3),
409 }
410 })
411 }
412 },
413 }
414 }
415
416 fn compute_expected_improvement(
417 &mut self,
418 candidate: &HashMap<String, ParameterValue>,
419 _search_space: &SearchSpace,
420 ) -> f64 {
421 if self.good_trials.is_empty() || self.bad_trials.is_empty() {
426 return self.rng.random::<f64>(); }
428
429 let good_similarity = self.compute_similarity(candidate, &self.good_trials);
431 let bad_similarity = self.compute_similarity(candidate, &self.bad_trials);
432
433 good_similarity - bad_similarity
435 }
436
437 fn compute_similarity(
438 &self,
439 candidate: &HashMap<String, ParameterValue>,
440 trials: &[HashMap<String, ParameterValue>],
441 ) -> f64 {
442 if trials.is_empty() {
443 return 0.0;
444 }
445
446 let total_similarity: f64 =
447 trials.iter().map(|trial| self.parameter_similarity(candidate, trial)).sum();
448
449 total_similarity / trials.len() as f64
450 }
451
452 fn parameter_similarity(
453 &self,
454 a: &HashMap<String, ParameterValue>,
455 b: &HashMap<String, ParameterValue>,
456 ) -> f64 {
457 let mut total_similarity = 0.0;
458 let mut count = 0;
459
460 for (name, value_a) in a {
461 if let Some(value_b) = b.get(name) {
462 let similarity = match (value_a, value_b) {
463 (ParameterValue::Float(a), ParameterValue::Float(b)) => {
464 let diff = (a - b).abs();
465 1.0 / (1.0 + diff) },
467 (ParameterValue::Int(a), ParameterValue::Int(b)) => {
468 let diff = (a - b).abs() as f64;
469 1.0 / (1.0 + diff)
470 },
471 (ParameterValue::String(a), ParameterValue::String(b)) if a == b => 1.0,
472 (ParameterValue::Bool(a), ParameterValue::Bool(b)) if a == b => 1.0,
473 _ => 0.0, };
475
476 total_similarity += similarity;
477 count += 1;
478 }
479 }
480
481 if count > 0 {
482 total_similarity / count as f64
483 } else {
484 0.0
485 }
486 }
487}
488
489impl Default for TPESampler {
490 fn default() -> Self {
491 Self::new()
492 }
493}
494
495pub struct GPSampler {
498 config: SamplerConfig,
499 rng: StdRng,
500 name: String,
501 trials: Vec<(HashMap<String, ParameterValue>, f64)>, }
503
504impl GPSampler {
505 pub fn new() -> Self {
507 Self::with_config(SamplerConfig::default())
508 }
509
510 pub fn with_config(config: SamplerConfig) -> Self {
512 let rng = if let Some(seed) = config.seed {
513 StdRng::seed_from_u64(seed)
514 } else {
515 StdRng::seed_from_u64(thread_rng().random())
516 };
517
518 let name = format!("GPSampler(startup={})", config.n_startup_trials);
519
520 Self {
521 config,
522 rng,
523 name,
524 trials: Vec::new(),
525 }
526 }
527}
528
529impl Sampler for GPSampler {
530 fn sample(
531 &mut self,
532 search_space: &SearchSpace,
533 trial_history: &TrialHistory,
534 ) -> HashMap<String, ParameterValue> {
535 let completed_trials = trial_history.completed_trials();
536
537 if completed_trials.len() < self.config.n_startup_trials {
539 return search_space.sample(&mut self.rng).unwrap_or_else(|e| {
540 log::warn!(
541 "Failed to sample from search space: {}. Using empty configuration.",
542 e
543 );
544 HashMap::new()
545 });
546 }
547
548 self.trials.clear();
550 for trial in completed_trials {
551 if let Some(objective) = trial.objective_value() {
552 self.trials.push((trial.params.clone(), objective));
553 }
554 }
555
556 let mut best_candidate = None;
561 let mut best_score = f64::NEG_INFINITY;
562
563 for _ in 0..self.config.n_ei_candidates {
565 let candidate = match search_space.sample(&mut self.rng) {
566 Ok(c) => c,
567 Err(e) => {
568 log::warn!("Failed to sample candidate: {}. Skipping.", e);
569 continue;
570 },
571 };
572 let score = self.acquisition_function(&candidate);
573
574 if score > best_score {
575 best_score = score;
576 best_candidate = Some(candidate);
577 }
578 }
579
580 best_candidate.unwrap_or_else(|| {
581 search_space.sample(&mut self.rng).unwrap_or_else(|e| {
582 log::warn!(
583 "Failed to sample from search space: {}. Using empty configuration.",
584 e
585 );
586 HashMap::new()
587 })
588 })
589 }
590
591 fn name(&self) -> &str {
592 &self.name
593 }
594}
595
596impl GPSampler {
597 fn acquisition_function(&mut self, candidate: &HashMap<String, ParameterValue>) -> f64 {
598 if self.trials.is_empty() {
602 return self.rng.random::<f64>();
603 }
604
605 let mut best_similarity = 0.0;
607 let mut corresponding_objective = 0.0;
608
609 for (trial_params, objective) in &self.trials {
610 let similarity = self.compute_similarity(candidate, trial_params);
611 if similarity > best_similarity {
612 best_similarity = similarity;
613 corresponding_objective = *objective;
614 }
615 }
616
617 let exploration = 1.0 - best_similarity; let exploitation = corresponding_objective;
620
621 exploitation + 0.1 * exploration }
623
624 fn compute_similarity(
625 &self,
626 a: &HashMap<String, ParameterValue>,
627 b: &HashMap<String, ParameterValue>,
628 ) -> f64 {
629 let mut total_similarity = 0.0;
630 let mut count = 0;
631
632 for (name, value_a) in a {
633 if let Some(value_b) = b.get(name) {
634 let similarity = match (value_a, value_b) {
635 (ParameterValue::Float(a), ParameterValue::Float(b)) => {
636 let diff = (a - b).abs();
637 (-diff).exp() },
639 (ParameterValue::Int(a), ParameterValue::Int(b)) => {
640 let diff = (a - b).abs() as f64;
641 (-diff).exp()
642 },
643 (ParameterValue::String(a), ParameterValue::String(b)) if a == b => 1.0,
644 (ParameterValue::Bool(a), ParameterValue::Bool(b)) if a == b => 1.0,
645 _ => 0.0,
646 };
647
648 total_similarity += similarity;
649 count += 1;
650 }
651 }
652
653 if count > 0 {
654 total_similarity / count as f64
655 } else {
656 0.0
657 }
658 }
659}
660
661impl Default for GPSampler {
662 fn default() -> Self {
663 Self::new()
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670 use crate::hyperopt::{search_space::SearchSpaceBuilder, Trial, TrialMetrics, TrialResult};
671
672 fn create_test_search_space() -> SearchSpace {
673 SearchSpaceBuilder::new()
674 .continuous("learning_rate", 1e-5, 1e-1)
675 .discrete("batch_size", 8, 128, 8)
676 .categorical("optimizer", vec!["adam", "sgd", "adamw"])
677 .build()
678 }
679
680 fn create_test_trial_history() -> TrialHistory {
681 let mut history = TrialHistory::new(Direction::Maximize);
682
683 for i in 0..5 {
685 let mut params = HashMap::new();
686 params.insert(
687 "learning_rate".to_string(),
688 ParameterValue::Float(0.01 * (i + 1) as f64),
689 );
690 params.insert("batch_size".to_string(), ParameterValue::Int(32));
691 params.insert(
692 "optimizer".to_string(),
693 ParameterValue::String("adam".to_string()),
694 );
695
696 let mut trial = Trial::new(i, params);
697 trial.complete(TrialResult::success(TrialMetrics::new(
698 0.8 + i as f64 * 0.02,
699 )));
700 history.add_trial(trial);
701 }
702
703 history
704 }
705
706 #[test]
707 fn test_random_sampler() {
708 let mut sampler = RandomSampler::with_seed(42);
709 let search_space = create_test_search_space();
710 let history = TrialHistory::new(Direction::Maximize);
711
712 let config = sampler.sample(&search_space, &history);
713
714 assert_eq!(config.len(), 3);
715 assert!(config.contains_key("learning_rate"));
716 assert!(config.contains_key("batch_size"));
717 assert!(config.contains_key("optimizer"));
718
719 assert!(search_space.validate(&config).is_ok());
721 }
722
723 #[test]
724 fn test_tpe_sampler() {
725 let mut sampler = TPESampler::with_config(SamplerConfig {
726 seed: Some(42),
727 n_startup_trials: 3,
728 n_ei_candidates: 5,
729 });
730
731 let search_space = create_test_search_space();
732 let history = create_test_trial_history();
733
734 let config = sampler.sample(&search_space, &history);
735
736 assert_eq!(config.len(), 3);
737 assert!(search_space.validate(&config).is_ok());
738 assert_eq!(sampler.name(), "TPESampler(startup=3, candidates=5)");
739 }
740
741 #[test]
742 fn test_gp_sampler() {
743 let mut sampler = GPSampler::with_config(SamplerConfig {
744 seed: Some(42),
745 n_startup_trials: 3,
746 n_ei_candidates: 5,
747 });
748
749 let search_space = create_test_search_space();
750 let history = create_test_trial_history();
751
752 let config = sampler.sample(&search_space, &history);
753
754 assert_eq!(config.len(), 3);
755 assert!(search_space.validate(&config).is_ok());
756 }
757
758 #[test]
759 fn test_sampler_with_insufficient_trials() {
760 let mut sampler = TPESampler::with_config(SamplerConfig {
761 seed: Some(42),
762 n_startup_trials: 10, n_ei_candidates: 5,
764 });
765
766 let search_space = create_test_search_space();
767 let history = create_test_trial_history(); let config = sampler.sample(&search_space, &history);
771 assert!(search_space.validate(&config).is_ok());
772 }
773}