Skip to main content

sklears_model_selection/
bandit_optimization.rs

1//! Multi-Armed Bandit algorithms for hyperparameter optimization
2//!
3//! This module implements various bandit-based optimization algorithms that can be used
4//! for efficient hyperparameter optimization by treating each parameter configuration
5//! as an arm in a multi-armed bandit problem.
6
7use crate::{CrossValidator, ParameterValue, Scoring};
8use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Dim, OwnedRepr};
9use scirs2_core::random::prelude::*;
10use sklears_core::{
11    error::{Result, SklearsError},
12    prelude::Predict,
13    traits::Fit,
14    traits::Score,
15    types::Float,
16};
17use std::collections::HashMap;
18
19/// Configuration for bandit-based optimization
20#[derive(Debug, Clone)]
21pub struct BanditConfig {
22    /// Number of iterations to run
23    pub n_iterations: usize,
24    /// Number of initial random pulls for each arm
25    pub n_initial_random: usize,
26    /// Exploration parameter for UCB algorithm
27    pub ucb_c: f64,
28    /// Temperature parameter for Boltzmann exploration
29    pub temperature: f64,
30    /// Decay rate for temperature cooling
31    pub temperature_decay: f64,
32    /// Random state for reproducibility
33    pub random_state: Option<u64>,
34}
35
36impl Default for BanditConfig {
37    fn default() -> Self {
38        Self {
39            n_iterations: 100,
40            n_initial_random: 5,
41            ucb_c: 1.96, // 95% confidence interval
42            temperature: 1.0,
43            temperature_decay: 0.95,
44            random_state: None,
45        }
46    }
47}
48
49/// Strategy for arm selection in multi-armed bandit
50#[derive(Debug, Clone)]
51pub enum BanditStrategy {
52    /// Upper Confidence Bound
53    UCB,
54    /// Epsilon-greedy with decaying epsilon
55    EpsilonGreedy(f64),
56    /// Boltzmann exploration with temperature cooling
57    Boltzmann(f64),
58    /// Thompson sampling (Bayesian approach)
59    ThompsonSampling,
60}
61
62/// Arm statistics for bandit algorithms
63#[derive(Debug, Clone)]
64struct ArmStats {
65    /// Number of times this arm was pulled
66    n_pulls: usize,
67    /// Sum of rewards received
68    sum_rewards: f64,
69    /// Sum of squared rewards (for variance estimation)
70    sum_squared_rewards: f64,
71    /// Best score achieved
72    best_score: f64,
73    /// Recent scores for trend analysis
74    recent_scores: Vec<f64>,
75}
76
77impl ArmStats {
78    fn new() -> Self {
79        Self {
80            n_pulls: 0,
81            sum_rewards: 0.0,
82            sum_squared_rewards: 0.0,
83            best_score: f64::NEG_INFINITY,
84            recent_scores: Vec::new(),
85        }
86    }
87
88    fn update(&mut self, reward: f64) {
89        self.n_pulls += 1;
90        self.sum_rewards += reward;
91        self.sum_squared_rewards += reward * reward;
92
93        if reward > self.best_score {
94            self.best_score = reward;
95        }
96
97        // Keep only recent scores for trend analysis
98        self.recent_scores.push(reward);
99        if self.recent_scores.len() > 10 {
100            self.recent_scores.remove(0);
101        }
102    }
103
104    fn mean_reward(&self) -> f64 {
105        if self.n_pulls == 0 {
106            0.0
107        } else {
108            self.sum_rewards / self.n_pulls as f64
109        }
110    }
111
112    fn variance(&self) -> f64 {
113        if self.n_pulls <= 1 {
114            0.0
115        } else {
116            let mean = self.mean_reward();
117            (self.sum_squared_rewards / self.n_pulls as f64) - (mean * mean)
118        }
119    }
120
121    fn confidence_interval(&self, confidence: f64) -> f64 {
122        if self.n_pulls == 0 {
123            f64::INFINITY
124        } else {
125            let std_err = (self.variance() / self.n_pulls as f64).sqrt();
126            confidence * std_err
127        }
128    }
129}
130
131/// Result of bandit optimization
132#[derive(Debug, Clone)]
133pub struct BanditOptimizationResult {
134    /// Best parameter configuration found
135    pub best_params: ParameterValue,
136    /// Best score achieved
137    pub best_score: f64,
138    /// All parameter configurations tried
139    pub all_params: Vec<ParameterValue>,
140    /// All scores achieved
141    pub all_scores: Vec<f64>,
142    /// Convergence history
143    pub convergence_history: Vec<f64>,
144    /// Number of iterations run
145    pub n_iterations: usize,
146    /// Final arm statistics
147    pub arm_stats: HashMap<String, (f64, usize)>, // (mean_reward, n_pulls)
148}
149
150/// Multi-Armed Bandit hyperparameter optimizer
151pub struct BanditOptimizer {
152    /// Parameter space to search
153    param_space: Vec<ParameterValue>,
154    /// Bandit strategy to use
155    strategy: BanditStrategy,
156    /// Configuration
157    config: BanditConfig,
158    /// Arm statistics
159    arm_stats: Vec<ArmStats>,
160    /// Random number generator
161    rng: StdRng,
162    /// Current iteration
163    current_iteration: usize,
164    /// Current temperature (for Boltzmann)
165    current_temperature: f64,
166}
167
168impl BanditOptimizer {
169    pub fn new(
170        param_space: Vec<ParameterValue>,
171        strategy: BanditStrategy,
172        config: BanditConfig,
173    ) -> Self {
174        let rng = if let Some(seed) = config.random_state {
175            StdRng::seed_from_u64(seed)
176        } else {
177            StdRng::seed_from_u64(42)
178        };
179
180        let arm_stats = vec![ArmStats::new(); param_space.len()];
181        let current_temperature = config.temperature;
182
183        Self {
184            param_space,
185            strategy,
186            config,
187            arm_stats,
188            rng,
189            current_iteration: 0,
190            current_temperature,
191        }
192    }
193
194    /// Select next arm to pull based on strategy
195    fn select_arm(&mut self) -> usize {
196        // Initial random exploration
197        if self.current_iteration < self.config.n_initial_random * self.param_space.len() {
198            return self.rng.random_range(0..self.param_space.len());
199        }
200
201        match &self.strategy {
202            BanditStrategy::UCB => self.select_ucb_arm(),
203            BanditStrategy::EpsilonGreedy(epsilon) => self.select_epsilon_greedy_arm(*epsilon),
204            BanditStrategy::Boltzmann(_) => self.select_boltzmann_arm(),
205            BanditStrategy::ThompsonSampling => self.select_thompson_sampling_arm(),
206        }
207    }
208
209    /// UCB arm selection
210    fn select_ucb_arm(&self) -> usize {
211        let total_pulls = self
212            .arm_stats
213            .iter()
214            .map(|stats| stats.n_pulls)
215            .sum::<usize>();
216        let log_total = (total_pulls as f64).ln();
217
218        let mut best_arm = 0;
219        let mut best_ucb = f64::NEG_INFINITY;
220
221        for (i, stats) in self.arm_stats.iter().enumerate() {
222            if stats.n_pulls == 0 {
223                return i; // Prefer unvisited arms
224            }
225
226            let mean_reward = stats.mean_reward();
227            let confidence_bonus = self.config.ucb_c * (log_total / stats.n_pulls as f64).sqrt();
228            let ucb_value = mean_reward + confidence_bonus;
229
230            if ucb_value > best_ucb {
231                best_ucb = ucb_value;
232                best_arm = i;
233            }
234        }
235
236        best_arm
237    }
238
239    /// Epsilon-greedy arm selection with decaying epsilon
240    fn select_epsilon_greedy_arm(&mut self, base_epsilon: f64) -> usize {
241        let decayed_epsilon = base_epsilon / (1.0 + self.current_iteration as f64 * 0.01);
242
243        if self.rng.random::<f64>() < decayed_epsilon {
244            // Explore: random arm
245            self.rng.random_range(0..self.param_space.len())
246        } else {
247            // Exploit: best arm
248            let mut best_arm = 0;
249            let mut best_reward = f64::NEG_INFINITY;
250
251            for (i, stats) in self.arm_stats.iter().enumerate() {
252                let reward = if stats.n_pulls == 0 {
253                    f64::INFINITY // Prefer unvisited arms
254                } else {
255                    stats.mean_reward()
256                };
257
258                if reward > best_reward {
259                    best_reward = reward;
260                    best_arm = i;
261                }
262            }
263
264            best_arm
265        }
266    }
267
268    /// Boltzmann exploration arm selection
269    fn select_boltzmann_arm(&mut self) -> usize {
270        let mut unnormalized_probs = Vec::with_capacity(self.param_space.len());
271
272        for stats in &self.arm_stats {
273            let reward = if stats.n_pulls == 0 {
274                1.0 // Give unvisited arms high probability
275            } else {
276                stats.mean_reward()
277            };
278
279            unnormalized_probs.push((reward / self.current_temperature).exp());
280        }
281
282        // Normalize probabilities
283        let total: f64 = unnormalized_probs.iter().sum();
284        let probs: Vec<f64> = unnormalized_probs.iter().map(|p| p / total).collect();
285
286        // Sample according to probabilities
287        let mut cumsum = 0.0;
288        let random_val = self.rng.random::<f64>();
289
290        for (i, &prob) in probs.iter().enumerate() {
291            cumsum += prob;
292            if random_val <= cumsum {
293                return i;
294            }
295        }
296
297        // Fallback
298        self.param_space.len() - 1
299    }
300
301    /// Thompson sampling arm selection
302    fn select_thompson_sampling_arm(&mut self) -> usize {
303        let mut best_arm = 0;
304        let mut best_sample = f64::NEG_INFINITY;
305
306        for (i, stats) in self.arm_stats.iter().enumerate() {
307            let sample = if stats.n_pulls == 0 {
308                // Sample from uninformative prior
309                self.rng.random()
310            } else {
311                // Sample from posterior (assuming Gaussian with known variance)
312                let mean = stats.mean_reward();
313                let std = (stats.variance() / stats.n_pulls as f64).sqrt().max(0.1);
314
315                use scirs2_core::random::RandNormal;
316                let normal = RandNormal::new(mean, std).expect("operation should succeed");
317                self.rng.sample(normal)
318            };
319
320            if sample > best_sample {
321                best_sample = sample;
322                best_arm = i;
323            }
324        }
325
326        best_arm
327    }
328
329    /// Update arm statistics with new reward
330    fn update_arm(&mut self, arm: usize, reward: f64) {
331        self.arm_stats[arm].update(reward);
332
333        // Update temperature for Boltzmann
334        if matches!(self.strategy, BanditStrategy::Boltzmann(_)) {
335            self.current_temperature *= self.config.temperature_decay;
336        }
337
338        self.current_iteration += 1;
339    }
340
341    /// Get the best arm so far
342    fn best_arm(&self) -> (usize, f64) {
343        let mut best_arm = 0;
344        let mut best_score = f64::NEG_INFINITY;
345
346        for (i, stats) in self.arm_stats.iter().enumerate() {
347            if stats.best_score > best_score {
348                best_score = stats.best_score;
349                best_arm = i;
350            }
351        }
352
353        (best_arm, best_score)
354    }
355}
356
357/// Bandit-based cross-validation optimizer
358pub struct BanditSearchCV<E> {
359    /// Base estimator
360    estimator: E,
361    /// Parameter space
362    param_space: Vec<ParameterValue>,
363    /// Bandit strategy
364    strategy: BanditStrategy,
365    /// Configuration
366    config: BanditConfig,
367    /// Scoring method
368    scoring: Option<Scoring>,
369    /// Parameter configuration function
370    param_config_fn: Option<Box<dyn Fn(E, &ParameterValue) -> Result<E>>>,
371}
372
373impl<E> BanditSearchCV<E>
374where
375    E: Clone,
376{
377    /// Create a new bandit search CV
378    pub fn new(estimator: E, param_space: Vec<ParameterValue>) -> Self {
379        Self {
380            estimator,
381            param_space,
382            strategy: BanditStrategy::UCB,
383            config: BanditConfig::default(),
384            scoring: None,
385            param_config_fn: None,
386        }
387    }
388
389    /// Set the bandit strategy
390    pub fn with_strategy(mut self, strategy: BanditStrategy) -> Self {
391        self.strategy = strategy;
392        self
393    }
394
395    /// Set the configuration
396    pub fn with_config(mut self, config: BanditConfig) -> Self {
397        self.config = config;
398        self
399    }
400
401    /// Set the scoring method
402    pub fn with_scoring(mut self, scoring: Scoring) -> Self {
403        self.scoring = Some(scoring);
404        self
405    }
406
407    /// Set the parameter configuration function
408    pub fn with_param_config<F>(mut self, func: F) -> Self
409    where
410        F: Fn(E, &ParameterValue) -> Result<E> + 'static,
411    {
412        self.param_config_fn = Some(Box::new(func));
413        self
414    }
415
416    /// Fit the bandit search optimizer
417    pub fn fit<F, C>(
418        &self,
419        x: &Array2<Float>,
420        y: &Array1<Float>,
421        cv: &C,
422    ) -> Result<BanditOptimizationResult>
423    where
424        F: Clone,
425        E: Fit<
426            ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
427            ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
428            Fitted = F,
429        >,
430        F: Predict<
431            ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
432            ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
433        >,
434        F: Score<
435            ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
436            ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
437            Float = f64,
438        >,
439        C: CrossValidator,
440    {
441        if self.param_space.is_empty() {
442            return Err(SklearsError::InvalidInput(
443                "Parameter space cannot be empty".to_string(),
444            ));
445        }
446
447        let param_config_fn = self.param_config_fn.as_ref().ok_or_else(|| {
448            SklearsError::InvalidInput("Parameter configuration function not set".to_string())
449        })?;
450
451        let mut optimizer = BanditOptimizer::new(
452            self.param_space.clone(),
453            self.strategy.clone(),
454            self.config.clone(),
455        );
456
457        let mut all_params = Vec::new();
458        let mut all_scores = Vec::new();
459        let mut convergence_history = Vec::new();
460
461        for _ in 0..self.config.n_iterations {
462            // Select arm (parameter configuration)
463            let arm_idx = optimizer.select_arm();
464            let param = &self.param_space[arm_idx];
465
466            // Configure estimator
467            let configured_estimator = param_config_fn(self.estimator.clone(), param)?;
468
469            // Evaluate using cross-validation
470            let scores = crate::validation::cross_val_score(
471                configured_estimator,
472                x,
473                y,
474                cv,
475                self.scoring.clone(),
476                None,
477            )?;
478
479            let mean_score = scores.mean().unwrap_or(0.0);
480
481            // Update bandit statistics
482            optimizer.update_arm(arm_idx, mean_score);
483
484            // Track progress
485            all_params.push(param.clone());
486            all_scores.push(mean_score);
487
488            let (_, current_best_score) = optimizer.best_arm();
489            convergence_history.push(current_best_score);
490        }
491
492        // Get final results
493        let (best_arm_idx, best_score) = optimizer.best_arm();
494        let best_params = self.param_space[best_arm_idx].clone();
495
496        // Collect arm statistics
497        let mut arm_stats = HashMap::new();
498        for (i, stats) in optimizer.arm_stats.iter().enumerate() {
499            arm_stats.insert(format!("arm_{}", i), (stats.mean_reward(), stats.n_pulls));
500        }
501
502        Ok(BanditOptimizationResult {
503            best_params,
504            best_score,
505            all_params,
506            all_scores,
507            convergence_history,
508            n_iterations: self.config.n_iterations,
509            arm_stats,
510        })
511    }
512}
513
514/// Bandit-based hyperparameter optimization with cross-validation
515#[derive(Debug, Clone)]
516pub struct BanditOptimization<E, S> {
517    estimator: E,
518    parameter_space: Vec<ParameterValue>,
519    scorer: Box<S>,
520    cv_folds: usize,
521    n_iter: usize,
522    strategy: BanditStrategy,
523    random_state: Option<u64>,
524    arm_stats: HashMap<usize, ArmStats>,
525}
526
527impl<E, S> BanditOptimization<E, S>
528where
529    E: Clone
530        + Fit<
531            ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
532            ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
533        > + Predict<
534            ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
535            ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
536        >,
537    S: Fn(
538        &E::Fitted,
539        &ArrayBase<OwnedRepr<Float>, Dim<[usize; 2]>, Float>,
540        &ArrayBase<OwnedRepr<Float>, Dim<[usize; 1]>, Float>,
541    ) -> Result<f64>,
542{
543    /// Create a new bandit optimization
544    pub fn new(
545        estimator: E,
546        parameter_space: Vec<ParameterValue>,
547        scorer: Box<S>,
548        cv_folds: usize,
549    ) -> Self {
550        Self {
551            estimator,
552            parameter_space,
553            scorer,
554            cv_folds,
555            n_iter: 100,
556            strategy: BanditStrategy::UCB,
557            random_state: None,
558            arm_stats: HashMap::new(),
559        }
560    }
561
562    /// Set number of iterations
563    pub fn set_n_iter(&mut self, n_iter: usize) {
564        self.n_iter = n_iter;
565    }
566
567    /// Set bandit strategy
568    pub fn set_strategy(&mut self, strategy: BanditStrategy) {
569        self.strategy = strategy;
570    }
571
572    /// Set random state for reproducibility
573    pub fn set_random_state(&mut self, seed: u64) {
574        self.random_state = Some(seed);
575    }
576
577    /// Run bandit optimization
578    pub fn fit(
579        &mut self,
580        x: &Array2<Float>,
581        y: &Array1<Float>,
582    ) -> Result<BanditOptimizationResult> {
583        let mut rng = match self.random_state {
584            Some(seed) => StdRng::seed_from_u64(seed),
585            None => StdRng::seed_from_u64(42),
586        };
587
588        let mut best_score = f64::NEG_INFINITY;
589        let mut best_arm = 0;
590        let n_arms = self.parameter_space.len();
591
592        // Initialize arm statistics
593        for i in 0..n_arms {
594            self.arm_stats.insert(i, ArmStats::new());
595        }
596
597        // Bandit optimization loop
598        for iteration in 0..self.n_iter {
599            // Select arm based on strategy
600            let selected_arm = self.select_arm(iteration, &mut rng)?;
601
602            // Evaluate selected arm
603            let score = self.evaluate_arm(selected_arm, x, y)?;
604
605            // Update arm statistics
606            if let Some(stats) = self.arm_stats.get_mut(&selected_arm) {
607                stats.update(score);
608
609                if score > best_score {
610                    best_score = score;
611                    best_arm = selected_arm;
612                }
613            }
614        }
615
616        // Convert arm_stats to the expected format
617        let arm_stats_converted: HashMap<String, (f64, usize)> = self
618            .arm_stats
619            .iter()
620            .map(|(&idx, stats)| (format!("arm_{}", idx), (stats.mean_reward(), stats.n_pulls)))
621            .collect();
622
623        Ok(BanditOptimizationResult {
624            best_params: self.parameter_space[best_arm].clone(),
625            best_score,
626            all_params: self.parameter_space.clone(),
627            all_scores: self.arm_stats.values().map(|s| s.best_score).collect(),
628            convergence_history: vec![best_score], // Simplified
629            n_iterations: self.n_iter,
630            arm_stats: arm_stats_converted,
631        })
632    }
633
634    /// Select arm based on bandit strategy
635    fn select_arm(&self, iteration: usize, rng: &mut StdRng) -> Result<usize> {
636        let n_arms = self.parameter_space.len();
637
638        match &self.strategy {
639            BanditStrategy::UCB => {
640                let mut best_value = f64::NEG_INFINITY;
641                let mut best_arm = 0;
642
643                for arm in 0..n_arms {
644                    if let Some(stats) = self.arm_stats.get(&arm) {
645                        let value = if stats.n_pulls == 0 {
646                            f64::INFINITY
647                        } else {
648                            let confidence =
649                                (2.0 * (iteration as f64 + 1.0).ln() / stats.n_pulls as f64).sqrt();
650                            stats.mean_reward() + 1.96 * confidence // UCB with c=1.96
651                        };
652
653                        if value > best_value {
654                            best_value = value;
655                            best_arm = arm;
656                        }
657                    }
658                }
659
660                Ok(best_arm)
661            }
662            BanditStrategy::EpsilonGreedy(epsilon) => {
663                if rng.random::<f64>() < *epsilon {
664                    // Explore: random arm
665                    Ok(rng.random_range(0..n_arms))
666                } else {
667                    // Exploit: best arm so far
668                    let mut best_mean = f64::NEG_INFINITY;
669                    let mut best_arm = 0;
670
671                    for arm in 0..n_arms {
672                        if let Some(stats) = self.arm_stats.get(&arm) {
673                            let mean = stats.mean_reward();
674                            if mean > best_mean {
675                                best_mean = mean;
676                                best_arm = arm;
677                            }
678                        }
679                    }
680
681                    Ok(best_arm)
682                }
683            }
684            BanditStrategy::ThompsonSampling => {
685                // Simplified Thompson sampling (assuming beta distribution)
686                let mut best_sample = f64::NEG_INFINITY;
687                let mut best_arm = 0;
688
689                for arm in 0..n_arms {
690                    if let Some(stats) = self.arm_stats.get(&arm) {
691                        // Sample from posterior (simplified)
692                        let alpha = stats.sum_rewards + 1.0;
693                        let _beta = (stats.n_pulls as f64 - stats.sum_rewards) + 1.0;
694                        let sample = rng.random::<f64>().powf(1.0 / alpha); // Simplified beta sampling
695
696                        if sample > best_sample {
697                            best_sample = sample;
698                            best_arm = arm;
699                        }
700                    }
701                }
702
703                Ok(best_arm)
704            }
705            BanditStrategy::Boltzmann(temperature) => {
706                // Softmax selection
707                let mut weights = Vec::new();
708                let mut max_mean = f64::NEG_INFINITY;
709
710                // Find max for numerical stability
711                for arm in 0..n_arms {
712                    if let Some(stats) = self.arm_stats.get(&arm) {
713                        let mean = stats.mean_reward();
714                        if mean > max_mean {
715                            max_mean = mean;
716                        }
717                    }
718                }
719
720                // Calculate softmax weights
721                for arm in 0..n_arms {
722                    if let Some(stats) = self.arm_stats.get(&arm) {
723                        let mean = stats.mean_reward();
724                        weights.push(((mean - max_mean) / temperature).exp());
725                    } else {
726                        weights.push(1.0);
727                    }
728                }
729
730                // Sample from categorical distribution
731                let sum: f64 = weights.iter().sum();
732                let mut cumsum = 0.0;
733                let threshold = rng.random::<f64>() * sum;
734
735                for (arm, weight) in weights.iter().enumerate() {
736                    cumsum += weight;
737                    if cumsum >= threshold {
738                        return Ok(arm);
739                    }
740                }
741
742                Ok(n_arms - 1) // Fallback
743            }
744        }
745    }
746
747    /// Evaluate a specific arm (parameter configuration)
748    fn evaluate_arm(&self, _arm: usize, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
749        // Simple train-test split for evaluation (placeholder for full CV)
750        let n_samples = x.nrows();
751        let train_size = (n_samples as f64 * 0.8) as usize;
752
753        let x_train_view = x.slice(scirs2_core::ndarray::s![..train_size, ..]);
754        let y_train_view = y.slice(scirs2_core::ndarray::s![..train_size]);
755        let x_test_view = x.slice(scirs2_core::ndarray::s![train_size.., ..]);
756        let y_test_view = y.slice(scirs2_core::ndarray::s![train_size..]);
757
758        let x_train = Array2::from_shape_vec(
759            (x_train_view.nrows(), x_train_view.ncols()),
760            x_train_view.iter().copied().collect(),
761        )?;
762        let y_train = Array1::from_vec(y_train_view.iter().copied().collect());
763        let x_test = Array2::from_shape_vec(
764            (x_test_view.nrows(), x_test_view.ncols()),
765            x_test_view.iter().copied().collect(),
766        )?;
767        let y_test = Array1::from_vec(y_test_view.iter().copied().collect());
768
769        // Configure estimator with selected parameter (simplified)
770        let estimator = self.estimator.clone();
771        let fitted = estimator.fit(&x_train, &y_train)?;
772
773        (self.scorer)(&fitted, &x_test, &y_test)
774    }
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use crate::KFold;
782    use scirs2_core::ndarray::array;
783
784    // Mock estimator for testing
785    #[derive(Clone)]
786    struct MockEstimator {
787        param: f64,
788    }
789
790    #[derive(Clone)]
791    struct MockFitted {
792        param: f64,
793    }
794
795    impl Fit<Array2<Float>, Array1<Float>> for MockEstimator {
796        type Fitted = MockFitted;
797
798        fn fit(self, _x: &Array2<Float>, _y: &Array1<Float>) -> Result<Self::Fitted> {
799            Ok(MockFitted { param: self.param })
800        }
801    }
802
803    impl Predict<Array2<Float>, Array1<Float>> for MockFitted {
804        fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
805            // Simple prediction based on parameter value
806            Ok(Array1::from_elem(x.nrows(), self.param))
807        }
808    }
809
810    impl Score<Array2<Float>, Array1<Float>> for MockFitted {
811        type Float = Float;
812
813        fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
814            let y_pred = self.predict(x)?;
815            // Score is higher when predictions stay close to targets
816            let mean_abs_error = (&y_pred - y).mapv(|diff| diff.abs()).mean().unwrap_or(0.0);
817            Ok(1.0 - mean_abs_error) // Higher score for lower error
818        }
819    }
820
821    #[test]
822    fn test_bandit_optimization_ucb() {
823        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
824        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // Mean = 3.5
825
826        let estimator = MockEstimator { param: 0.0 };
827        let param_space = vec![
828            ParameterValue::Float(1.0),
829            ParameterValue::Float(3.5), // Should be best
830            ParameterValue::Float(5.0),
831        ];
832
833        let param_config_fn = |_estimator: MockEstimator, param_value: &ParameterValue| {
834            if let ParameterValue::Float(val) = param_value {
835                Ok(MockEstimator { param: *val })
836            } else {
837                Err(SklearsError::InvalidInput(
838                    "Expected float parameter".to_string(),
839                ))
840            }
841        };
842
843        let config = BanditConfig {
844            n_iterations: 30,
845            n_initial_random: 2,
846            ..Default::default()
847        };
848
849        let search = BanditSearchCV::new(estimator, param_space)
850            .with_strategy(BanditStrategy::UCB)
851            .with_config(config)
852            .with_param_config(param_config_fn);
853
854        let cv = KFold::new(3);
855        let result = search.fit(&x, &y, &cv).expect("operation should succeed");
856
857        // The best parameter should be close to 3.5
858        if let ParameterValue::Float(best_val) = result.best_params {
859            assert!((best_val - 3.5).abs() < 2.0); // Allow some tolerance
860        }
861
862        assert_eq!(result.n_iterations, 30);
863        assert_eq!(result.all_scores.len(), 30);
864        assert_eq!(result.convergence_history.len(), 30);
865
866        // Convergence should improve over time
867        let early_best = result.convergence_history[..10]
868            .iter()
869            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
870        let late_best = result.convergence_history[20..]
871            .iter()
872            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
873        assert!(late_best >= early_best);
874    }
875
876    #[test]
877    fn test_bandit_optimization_epsilon_greedy() {
878        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
879        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
880
881        let estimator = MockEstimator { param: 0.0 };
882        let param_space = vec![
883            ParameterValue::Float(1.0),
884            ParameterValue::Float(3.5),
885            ParameterValue::Float(5.0),
886        ];
887
888        let param_config_fn = |_estimator: MockEstimator, param_value: &ParameterValue| {
889            if let ParameterValue::Float(val) = param_value {
890                Ok(MockEstimator { param: *val })
891            } else {
892                Err(SklearsError::InvalidInput(
893                    "Expected float parameter".to_string(),
894                ))
895            }
896        };
897
898        let config = BanditConfig {
899            n_iterations: 20,
900            n_initial_random: 1,
901            random_state: Some(42),
902            ..Default::default()
903        };
904
905        let search = BanditSearchCV::new(estimator, param_space)
906            .with_strategy(BanditStrategy::EpsilonGreedy(0.1))
907            .with_config(config)
908            .with_param_config(param_config_fn);
909
910        let cv = KFold::new(2);
911        let result = search.fit(&x, &y, &cv).expect("operation should succeed");
912
913        assert_eq!(result.n_iterations, 20);
914        assert!(result.best_score.is_finite());
915        assert!(!result.arm_stats.is_empty());
916    }
917
918    #[test]
919    fn test_arm_stats() {
920        let mut stats = ArmStats::new();
921
922        assert_eq!(stats.n_pulls, 0);
923        assert_eq!(stats.mean_reward(), 0.0);
924
925        stats.update(1.0);
926        stats.update(2.0);
927        stats.update(3.0);
928
929        assert_eq!(stats.n_pulls, 3);
930        assert_eq!(stats.mean_reward(), 2.0);
931        assert_eq!(stats.best_score, 3.0);
932        assert!(stats.variance() > 0.0);
933    }
934}