1use crate::{CrossValidator, ParameterValue, Scoring};
8use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Dim, OwnedRepr};
9use scirs2_core::random::prelude::*;
10use sklears_core::{
11 error::{Result, SklearsError},
12 prelude::Predict,
13 traits::Fit,
14 traits::Score,
15 types::Float,
16};
17use std::collections::HashMap;
18
19#[derive(Debug, Clone)]
21pub struct BanditConfig {
22 pub n_iterations: usize,
24 pub n_initial_random: usize,
26 pub ucb_c: f64,
28 pub temperature: f64,
30 pub temperature_decay: f64,
32 pub random_state: Option<u64>,
34}
35
36impl Default for BanditConfig {
37 fn default() -> Self {
38 Self {
39 n_iterations: 100,
40 n_initial_random: 5,
41 ucb_c: 1.96, temperature: 1.0,
43 temperature_decay: 0.95,
44 random_state: None,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub enum BanditStrategy {
52 UCB,
54 EpsilonGreedy(f64),
56 Boltzmann(f64),
58 ThompsonSampling,
60}
61
62#[derive(Debug, Clone)]
64struct ArmStats {
65 n_pulls: usize,
67 sum_rewards: f64,
69 sum_squared_rewards: f64,
71 best_score: f64,
73 recent_scores: Vec<f64>,
75}
76
77impl ArmStats {
78 fn new() -> Self {
79 Self {
80 n_pulls: 0,
81 sum_rewards: 0.0,
82 sum_squared_rewards: 0.0,
83 best_score: f64::NEG_INFINITY,
84 recent_scores: Vec::new(),
85 }
86 }
87
88 fn update(&mut self, reward: f64) {
89 self.n_pulls += 1;
90 self.sum_rewards += reward;
91 self.sum_squared_rewards += reward * reward;
92
93 if reward > self.best_score {
94 self.best_score = reward;
95 }
96
97 self.recent_scores.push(reward);
99 if self.recent_scores.len() > 10 {
100 self.recent_scores.remove(0);
101 }
102 }
103
104 fn mean_reward(&self) -> f64 {
105 if self.n_pulls == 0 {
106 0.0
107 } else {
108 self.sum_rewards / self.n_pulls as f64
109 }
110 }
111
112 fn variance(&self) -> f64 {
113 if self.n_pulls <= 1 {
114 0.0
115 } else {
116 let mean = self.mean_reward();
117 (self.sum_squared_rewards / self.n_pulls as f64) - (mean * mean)
118 }
119 }
120
121 fn confidence_interval(&self, confidence: f64) -> f64 {
122 if self.n_pulls == 0 {
123 f64::INFINITY
124 } else {
125 let std_err = (self.variance() / self.n_pulls as f64).sqrt();
126 confidence * std_err
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct BanditOptimizationResult {
134 pub best_params: ParameterValue,
136 pub best_score: f64,
138 pub all_params: Vec<ParameterValue>,
140 pub all_scores: Vec<f64>,
142 pub convergence_history: Vec<f64>,
144 pub n_iterations: usize,
146 pub arm_stats: HashMap<String, (f64, usize)>, }
149
150pub struct BanditOptimizer {
152 param_space: Vec<ParameterValue>,
154 strategy: BanditStrategy,
156 config: BanditConfig,
158 arm_stats: Vec<ArmStats>,
160 rng: StdRng,
162 current_iteration: usize,
164 current_temperature: f64,
166}
167
168impl BanditOptimizer {
169 pub fn new(
170 param_space: Vec<ParameterValue>,
171 strategy: BanditStrategy,
172 config: BanditConfig,
173 ) -> Self {
174 let rng = if let Some(seed) = config.random_state {
175 StdRng::seed_from_u64(seed)
176 } else {
177 StdRng::seed_from_u64(42)
178 };
179
180 let arm_stats = vec![ArmStats::new(); param_space.len()];
181 let current_temperature = config.temperature;
182
183 Self {
184 param_space,
185 strategy,
186 config,
187 arm_stats,
188 rng,
189 current_iteration: 0,
190 current_temperature,
191 }
192 }
193
194 fn select_arm(&mut self) -> usize {
196 if self.current_iteration < self.config.n_initial_random * self.param_space.len() {
198 return self.rng.random_range(0..self.param_space.len());
199 }
200
201 match &self.strategy {
202 BanditStrategy::UCB => self.select_ucb_arm(),
203 BanditStrategy::EpsilonGreedy(epsilon) => self.select_epsilon_greedy_arm(*epsilon),
204 BanditStrategy::Boltzmann(_) => self.select_boltzmann_arm(),
205 BanditStrategy::ThompsonSampling => self.select_thompson_sampling_arm(),
206 }
207 }
208
209 fn select_ucb_arm(&self) -> usize {
211 let total_pulls = self
212 .arm_stats
213 .iter()
214 .map(|stats| stats.n_pulls)
215 .sum::<usize>();
216 let log_total = (total_pulls as f64).ln();
217
218 let mut best_arm = 0;
219 let mut best_ucb = f64::NEG_INFINITY;
220
221 for (i, stats) in self.arm_stats.iter().enumerate() {
222 if stats.n_pulls == 0 {
223 return i; }
225
226 let mean_reward = stats.mean_reward();
227 let confidence_bonus = self.config.ucb_c * (log_total / stats.n_pulls as f64).sqrt();
228 let ucb_value = mean_reward + confidence_bonus;
229
230 if ucb_value > best_ucb {
231 best_ucb = ucb_value;
232 best_arm = i;
233 }
234 }
235
236 best_arm
237 }
238
239 fn select_epsilon_greedy_arm(&mut self, base_epsilon: f64) -> usize {
241 let decayed_epsilon = base_epsilon / (1.0 + self.current_iteration as f64 * 0.01);
242
243 if self.rng.random::<f64>() < decayed_epsilon {
244 self.rng.random_range(0..self.param_space.len())
246 } else {
247 let mut best_arm = 0;
249 let mut best_reward = f64::NEG_INFINITY;
250
251 for (i, stats) in self.arm_stats.iter().enumerate() {
252 let reward = if stats.n_pulls == 0 {
253 f64::INFINITY } else {
255 stats.mean_reward()
256 };
257
258 if reward > best_reward {
259 best_reward = reward;
260 best_arm = i;
261 }
262 }
263
264 best_arm
265 }
266 }
267
268 fn select_boltzmann_arm(&mut self) -> usize {
270 let mut unnormalized_probs = Vec::with_capacity(self.param_space.len());
271
272 for stats in &self.arm_stats {
273 let reward = if stats.n_pulls == 0 {
274 1.0 } else {
276 stats.mean_reward()
277 };
278
279 unnormalized_probs.push((reward / self.current_temperature).exp());
280 }
281
282 let total: f64 = unnormalized_probs.iter().sum();
284 let probs: Vec<f64> = unnormalized_probs.iter().map(|p| p / total).collect();
285
286 let mut cumsum = 0.0;
288 let random_val = self.rng.random::<f64>();
289
290 for (i, &prob) in probs.iter().enumerate() {
291 cumsum += prob;
292 if random_val <= cumsum {
293 return i;
294 }
295 }
296
297 self.param_space.len() - 1
299 }
300
301 fn select_thompson_sampling_arm(&mut self) -> usize {
303 let mut best_arm = 0;
304 let mut best_sample = f64::NEG_INFINITY;
305
306 for (i, stats) in self.arm_stats.iter().enumerate() {
307 let sample = if stats.n_pulls == 0 {
308 self.rng.random()
310 } else {
311 let mean = stats.mean_reward();
313 let std = (stats.variance() / stats.n_pulls as f64).sqrt().max(0.1);
314
315 use scirs2_core::random::RandNormal;
316 let normal = RandNormal::new(mean, std).expect("operation should succeed");
317 self.rng.sample(normal)
318 };
319
320 if sample > best_sample {
321 best_sample = sample;
322 best_arm = i;
323 }
324 }
325
326 best_arm
327 }
328
329 fn update_arm(&mut self, arm: usize, reward: f64) {
331 self.arm_stats[arm].update(reward);
332
333 if matches!(self.strategy, BanditStrategy::Boltzmann(_)) {
335 self.current_temperature *= self.config.temperature_decay;
336 }
337
338 self.current_iteration += 1;
339 }
340
341 fn best_arm(&self) -> (usize, f64) {
343 let mut best_arm = 0;
344 let mut best_score = f64::NEG_INFINITY;
345
346 for (i, stats) in self.arm_stats.iter().enumerate() {
347 if stats.best_score > best_score {
348 best_score = stats.best_score;
349 best_arm = i;
350 }
351 }
352
353 (best_arm, best_score)
354 }
355}
356
357pub struct BanditSearchCV<E> {
359 estimator: E,
361 param_space: Vec<ParameterValue>,
363 strategy: BanditStrategy,
365 config: BanditConfig,
367 scoring: Option<Scoring>,
369 param_config_fn: Option<Box<dyn Fn(E, &ParameterValue) -> Result<E>>>,
371}
372
373impl<E> BanditSearchCV<E>
374where
375 E: Clone,
376{
377 pub fn new(estimator: E, param_space: Vec<ParameterValue>) -> Self {
379 Self {
380 estimator,
381 param_space,
382 strategy: BanditStrategy::UCB,
383 config: BanditConfig::default(),
384 scoring: None,
385 param_config_fn: None,
386 }
387 }
388
389 pub fn with_strategy(mut self, strategy: BanditStrategy) -> Self {
391 self.strategy = strategy;
392 self
393 }
394
395 pub fn with_config(mut self, config: BanditConfig) -> Self {
397 self.config = config;
398 self
399 }
400
401 pub fn with_scoring(mut self, scoring: Scoring) -> Self {
403 self.scoring = Some(scoring);
404 self
405 }
406
407 pub fn with_param_config<F>(mut self, func: F) -> Self
409 where
410 F: Fn(E, &ParameterValue) -> Result<E> + 'static,
411 {
412 self.param_config_fn = Some(Box::new(func));
413 self
414 }
415
416 pub fn fit<F, C>(
418 &self,
419 x: &Array2<Float>,
420 y: &Array1<Float>,
421 cv: &C,
422 ) -> Result<BanditOptimizationResult>
423 where
424 F: Clone,
425 E: Fit<
426 ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
427 ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
428 Fitted = F,
429 >,
430 F: Predict<
431 ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
432 ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
433 >,
434 F: Score<
435 ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
436 ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
437 Float = f64,
438 >,
439 C: CrossValidator,
440 {
441 if self.param_space.is_empty() {
442 return Err(SklearsError::InvalidInput(
443 "Parameter space cannot be empty".to_string(),
444 ));
445 }
446
447 let param_config_fn = self.param_config_fn.as_ref().ok_or_else(|| {
448 SklearsError::InvalidInput("Parameter configuration function not set".to_string())
449 })?;
450
451 let mut optimizer = BanditOptimizer::new(
452 self.param_space.clone(),
453 self.strategy.clone(),
454 self.config.clone(),
455 );
456
457 let mut all_params = Vec::new();
458 let mut all_scores = Vec::new();
459 let mut convergence_history = Vec::new();
460
461 for _ in 0..self.config.n_iterations {
462 let arm_idx = optimizer.select_arm();
464 let param = &self.param_space[arm_idx];
465
466 let configured_estimator = param_config_fn(self.estimator.clone(), param)?;
468
469 let scores = crate::validation::cross_val_score(
471 configured_estimator,
472 x,
473 y,
474 cv,
475 self.scoring.clone(),
476 None,
477 )?;
478
479 let mean_score = scores.mean().unwrap_or(0.0);
480
481 optimizer.update_arm(arm_idx, mean_score);
483
484 all_params.push(param.clone());
486 all_scores.push(mean_score);
487
488 let (_, current_best_score) = optimizer.best_arm();
489 convergence_history.push(current_best_score);
490 }
491
492 let (best_arm_idx, best_score) = optimizer.best_arm();
494 let best_params = self.param_space[best_arm_idx].clone();
495
496 let mut arm_stats = HashMap::new();
498 for (i, stats) in optimizer.arm_stats.iter().enumerate() {
499 arm_stats.insert(format!("arm_{}", i), (stats.mean_reward(), stats.n_pulls));
500 }
501
502 Ok(BanditOptimizationResult {
503 best_params,
504 best_score,
505 all_params,
506 all_scores,
507 convergence_history,
508 n_iterations: self.config.n_iterations,
509 arm_stats,
510 })
511 }
512}
513
514#[derive(Debug, Clone)]
516pub struct BanditOptimization<E, S> {
517 estimator: E,
518 parameter_space: Vec<ParameterValue>,
519 scorer: Box<S>,
520 cv_folds: usize,
521 n_iter: usize,
522 strategy: BanditStrategy,
523 random_state: Option<u64>,
524 arm_stats: HashMap<usize, ArmStats>,
525}
526
527impl<E, S> BanditOptimization<E, S>
528where
529 E: Clone
530 + Fit<
531 ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
532 ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
533 > + Predict<
534 ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
535 ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
536 >,
537 S: Fn(
538 &E::Fitted,
539 &ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
540 &ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
541 ) -> Result<f64>,
542{
543 pub fn new(
545 estimator: E,
546 parameter_space: Vec<ParameterValue>,
547 scorer: Box<S>,
548 cv_folds: usize,
549 ) -> Self {
550 Self {
551 estimator,
552 parameter_space,
553 scorer,
554 cv_folds,
555 n_iter: 100,
556 strategy: BanditStrategy::UCB,
557 random_state: None,
558 arm_stats: HashMap::new(),
559 }
560 }
561
562 pub fn set_n_iter(&mut self, n_iter: usize) {
564 self.n_iter = n_iter;
565 }
566
567 pub fn set_strategy(&mut self, strategy: BanditStrategy) {
569 self.strategy = strategy;
570 }
571
572 pub fn set_random_state(&mut self, seed: u64) {
574 self.random_state = Some(seed);
575 }
576
577 pub fn fit(
579 &mut self,
580 x: &Array2<Float>,
581 y: &Array1<Float>,
582 ) -> Result<BanditOptimizationResult> {
583 let mut rng = match self.random_state {
584 Some(seed) => StdRng::seed_from_u64(seed),
585 None => StdRng::seed_from_u64(42),
586 };
587
588 let mut best_score = f64::NEG_INFINITY;
589 let mut best_arm = 0;
590 let n_arms = self.parameter_space.len();
591
592 for i in 0..n_arms {
594 self.arm_stats.insert(i, ArmStats::new());
595 }
596
597 for iteration in 0..self.n_iter {
599 let selected_arm = self.select_arm(iteration, &mut rng)?;
601
602 let score = self.evaluate_arm(selected_arm, x, y)?;
604
605 if let Some(stats) = self.arm_stats.get_mut(&selected_arm) {
607 stats.update(score);
608
609 if score > best_score {
610 best_score = score;
611 best_arm = selected_arm;
612 }
613 }
614 }
615
616 let arm_stats_converted: HashMap<String, (f64, usize)> = self
618 .arm_stats
619 .iter()
620 .map(|(&idx, stats)| (format!("arm_{}", idx), (stats.mean_reward(), stats.n_pulls)))
621 .collect();
622
623 Ok(BanditOptimizationResult {
624 best_params: self.parameter_space[best_arm].clone(),
625 best_score,
626 all_params: self.parameter_space.clone(),
627 all_scores: self.arm_stats.values().map(|s| s.best_score).collect(),
628 convergence_history: vec![best_score], n_iterations: self.n_iter,
630 arm_stats: arm_stats_converted,
631 })
632 }
633
634 fn select_arm(&self, iteration: usize, rng: &mut StdRng) -> Result<usize> {
636 let n_arms = self.parameter_space.len();
637
638 match &self.strategy {
639 BanditStrategy::UCB => {
640 let mut best_value = f64::NEG_INFINITY;
641 let mut best_arm = 0;
642
643 for arm in 0..n_arms {
644 if let Some(stats) = self.arm_stats.get(&arm) {
645 let value = if stats.n_pulls == 0 {
646 f64::INFINITY
647 } else {
648 let confidence =
649 (2.0 * (iteration as f64 + 1.0).ln() / stats.n_pulls as f64).sqrt();
650 stats.mean_reward() + 1.96 * confidence };
652
653 if value > best_value {
654 best_value = value;
655 best_arm = arm;
656 }
657 }
658 }
659
660 Ok(best_arm)
661 }
662 BanditStrategy::EpsilonGreedy(epsilon) => {
663 if rng.random::<f64>() < *epsilon {
664 Ok(rng.random_range(0..n_arms))
666 } else {
667 let mut best_mean = f64::NEG_INFINITY;
669 let mut best_arm = 0;
670
671 for arm in 0..n_arms {
672 if let Some(stats) = self.arm_stats.get(&arm) {
673 let mean = stats.mean_reward();
674 if mean > best_mean {
675 best_mean = mean;
676 best_arm = arm;
677 }
678 }
679 }
680
681 Ok(best_arm)
682 }
683 }
684 BanditStrategy::ThompsonSampling => {
685 let mut best_sample = f64::NEG_INFINITY;
687 let mut best_arm = 0;
688
689 for arm in 0..n_arms {
690 if let Some(stats) = self.arm_stats.get(&arm) {
691 let alpha = stats.sum_rewards + 1.0;
693 let _beta = (stats.n_pulls as f64 - stats.sum_rewards) + 1.0;
694 let sample = rng.random::<f64>().powf(1.0 / alpha); if sample > best_sample {
697 best_sample = sample;
698 best_arm = arm;
699 }
700 }
701 }
702
703 Ok(best_arm)
704 }
705 BanditStrategy::Boltzmann(temperature) => {
706 let mut weights = Vec::new();
708 let mut max_mean = f64::NEG_INFINITY;
709
710 for arm in 0..n_arms {
712 if let Some(stats) = self.arm_stats.get(&arm) {
713 let mean = stats.mean_reward();
714 if mean > max_mean {
715 max_mean = mean;
716 }
717 }
718 }
719
720 for arm in 0..n_arms {
722 if let Some(stats) = self.arm_stats.get(&arm) {
723 let mean = stats.mean_reward();
724 weights.push(((mean - max_mean) / temperature).exp());
725 } else {
726 weights.push(1.0);
727 }
728 }
729
730 let sum: f64 = weights.iter().sum();
732 let mut cumsum = 0.0;
733 let threshold = rng.random::<f64>() * sum;
734
735 for (arm, weight) in weights.iter().enumerate() {
736 cumsum += weight;
737 if cumsum >= threshold {
738 return Ok(arm);
739 }
740 }
741
742 Ok(n_arms - 1) }
744 }
745 }
746
747 fn evaluate_arm(&self, _arm: usize, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
749 let n_samples = x.nrows();
751 let train_size = (n_samples as f64 * 0.8) as usize;
752
753 let x_train_view = x.slice(scirs2_core::ndarray::s![..train_size, ..]);
754 let y_train_view = y.slice(scirs2_core::ndarray::s![..train_size]);
755 let x_test_view = x.slice(scirs2_core::ndarray::s![train_size.., ..]);
756 let y_test_view = y.slice(scirs2_core::ndarray::s![train_size..]);
757
758 let x_train = Array2::from_shape_vec(
759 (x_train_view.nrows(), x_train_view.ncols()),
760 x_train_view.iter().copied().collect(),
761 )?;
762 let y_train = Array1::from_vec(y_train_view.iter().copied().collect());
763 let x_test = Array2::from_shape_vec(
764 (x_test_view.nrows(), x_test_view.ncols()),
765 x_test_view.iter().copied().collect(),
766 )?;
767 let y_test = Array1::from_vec(y_test_view.iter().copied().collect());
768
769 let estimator = self.estimator.clone();
771 let fitted = estimator.fit(&x_train, &y_train)?;
772
773 (self.scorer)(&fitted, &x_test, &y_test)
774 }
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use crate::KFold;
782 use scirs2_core::ndarray::array;
783
784 #[derive(Clone)]
786 struct MockEstimator {
787 param: f64,
788 }
789
790 #[derive(Clone)]
791 struct MockFitted {
792 param: f64,
793 }
794
795 impl Fit<Array2<Float>, Array1<Float>> for MockEstimator {
796 type Fitted = MockFitted;
797
798 fn fit(self, _x: &Array2<Float>, _y: &Array1<Float>) -> Result<Self::Fitted> {
799 Ok(MockFitted { param: self.param })
800 }
801 }
802
803 impl Predict<Array2<Float>, Array1<Float>> for MockFitted {
804 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
805 Ok(Array1::from_elem(x.nrows(), self.param))
807 }
808 }
809
810 impl Score<Array2<Float>, Array1<Float>> for MockFitted {
811 type Float = Float;
812
813 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
814 let y_pred = self.predict(x)?;
815 let mean_abs_error = (&y_pred - y).mapv(|diff| diff.abs()).mean().unwrap_or(0.0);
817 Ok(1.0 - mean_abs_error) }
819 }
820
821 #[test]
822 fn test_bandit_optimization_ucb() {
823 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
824 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let estimator = MockEstimator { param: 0.0 };
827 let param_space = vec![
828 ParameterValue::Float(1.0),
829 ParameterValue::Float(3.5), ParameterValue::Float(5.0),
831 ];
832
833 let param_config_fn = |_estimator: MockEstimator, param_value: &ParameterValue| {
834 if let ParameterValue::Float(val) = param_value {
835 Ok(MockEstimator { param: *val })
836 } else {
837 Err(SklearsError::InvalidInput(
838 "Expected float parameter".to_string(),
839 ))
840 }
841 };
842
843 let config = BanditConfig {
844 n_iterations: 30,
845 n_initial_random: 2,
846 ..Default::default()
847 };
848
849 let search = BanditSearchCV::new(estimator, param_space)
850 .with_strategy(BanditStrategy::UCB)
851 .with_config(config)
852 .with_param_config(param_config_fn);
853
854 let cv = KFold::new(3);
855 let result = search.fit(&x, &y, &cv).expect("operation should succeed");
856
857 if let ParameterValue::Float(best_val) = result.best_params {
859 assert!((best_val - 3.5).abs() < 2.0); }
861
862 assert_eq!(result.n_iterations, 30);
863 assert_eq!(result.all_scores.len(), 30);
864 assert_eq!(result.convergence_history.len(), 30);
865
866 let early_best = result.convergence_history[..10]
868 .iter()
869 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
870 let late_best = result.convergence_history[20..]
871 .iter()
872 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
873 assert!(late_best >= early_best);
874 }
875
876 #[test]
877 fn test_bandit_optimization_epsilon_greedy() {
878 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
879 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
880
881 let estimator = MockEstimator { param: 0.0 };
882 let param_space = vec![
883 ParameterValue::Float(1.0),
884 ParameterValue::Float(3.5),
885 ParameterValue::Float(5.0),
886 ];
887
888 let param_config_fn = |_estimator: MockEstimator, param_value: &ParameterValue| {
889 if let ParameterValue::Float(val) = param_value {
890 Ok(MockEstimator { param: *val })
891 } else {
892 Err(SklearsError::InvalidInput(
893 "Expected float parameter".to_string(),
894 ))
895 }
896 };
897
898 let config = BanditConfig {
899 n_iterations: 20,
900 n_initial_random: 1,
901 random_state: Some(42),
902 ..Default::default()
903 };
904
905 let search = BanditSearchCV::new(estimator, param_space)
906 .with_strategy(BanditStrategy::EpsilonGreedy(0.1))
907 .with_config(config)
908 .with_param_config(param_config_fn);
909
910 let cv = KFold::new(2);
911 let result = search.fit(&x, &y, &cv).expect("operation should succeed");
912
913 assert_eq!(result.n_iterations, 20);
914 assert!(result.best_score.is_finite());
915 assert!(!result.arm_stats.is_empty());
916 }
917
918 #[test]
919 fn test_arm_stats() {
920 let mut stats = ArmStats::new();
921
922 assert_eq!(stats.n_pulls, 0);
923 assert_eq!(stats.mean_reward(), 0.0);
924
925 stats.update(1.0);
926 stats.update(2.0);
927 stats.update(3.0);
928
929 assert_eq!(stats.n_pulls, 3);
930 assert_eq!(stats.mean_reward(), 2.0);
931 assert_eq!(stats.best_score, 3.0);
932 assert!(stats.variance() > 0.0);
933 }
934}