sklears_semi_supervised/
multi_armed_bandits.rs

1//! Multi-armed bandits for active learning
2//!
3//! This module implements various multi-armed bandit strategies for adaptive
4//! sample selection in active learning scenarios, treating different query
5//! strategies as arms in a bandit problem.
6
7use scirs2_core::ndarray_ext::{Array1, ArrayView1};
8use scirs2_core::random::{Random, Rng};
9use sklears_core::error::{Result, SklearsError};
10use std::collections::HashMap;
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum BanditError {
15    #[error("Invalid epsilon parameter: {0}")]
16    InvalidEpsilon(f64),
17    #[error("Invalid temperature parameter: {0}")]
18    InvalidTemperature(f64),
19    #[error("Invalid confidence parameter: {0}")]
20    InvalidConfidence(f64),
21    #[error("Invalid exploration parameter: {0}")]
22    InvalidExploration(f64),
23    #[error("Invalid arm index: {arm_idx} for {n_arms} arms")]
24    InvalidArmIndex { arm_idx: usize, n_arms: usize },
25    #[error("No arms available")]
26    NoArmsAvailable,
27    #[error("Insufficient data for arm: {0}")]
28    InsufficientDataForArm(usize),
29    #[error("Bandit computation failed: {0}")]
30    BanditComputationFailed(String),
31}
32
33impl From<BanditError> for SklearsError {
34    fn from(err: BanditError) -> Self {
35        SklearsError::FitError(err.to_string())
36    }
37}
38
39/// Epsilon-Greedy strategy for multi-armed bandit active learning
40///
41/// This strategy selects the best-performing query strategy with probability (1-ε)
42/// and explores a random strategy with probability ε.
43#[derive(Debug, Clone)]
44pub struct EpsilonGreedy {
45    /// epsilon
46    pub epsilon: f64,
47    /// decay_rate
48    pub decay_rate: f64,
49    /// min_epsilon
50    pub min_epsilon: f64,
51    /// random_state
52    pub random_state: Option<u64>,
53    arm_counts: Vec<usize>,
54    arm_rewards: Vec<f64>,
55    total_rounds: usize,
56}
57
58impl EpsilonGreedy {
59    pub fn new(epsilon: f64) -> Result<Self> {
60        if !(0.0..=1.0).contains(&epsilon) {
61            return Err(BanditError::InvalidEpsilon(epsilon).into());
62        }
63        Ok(Self {
64            epsilon,
65            decay_rate: 0.995,
66            min_epsilon: 0.01,
67            random_state: None,
68            arm_counts: Vec::new(),
69            arm_rewards: Vec::new(),
70            total_rounds: 0,
71        })
72    }
73
74    pub fn decay_rate(mut self, decay_rate: f64) -> Self {
75        self.decay_rate = decay_rate;
76        self
77    }
78
79    pub fn min_epsilon(mut self, min_epsilon: f64) -> Self {
80        self.min_epsilon = min_epsilon;
81        self
82    }
83
84    pub fn random_state(mut self, random_state: u64) -> Self {
85        self.random_state = Some(random_state);
86        self
87    }
88
89    pub fn initialize(&mut self, n_arms: usize) {
90        self.arm_counts = vec![0; n_arms];
91        self.arm_rewards = vec![0.0; n_arms];
92        self.total_rounds = 0;
93    }
94
95    pub fn select_arm(&mut self) -> Result<usize> {
96        if self.arm_counts.is_empty() {
97            return Err(BanditError::NoArmsAvailable.into());
98        }
99
100        let mut rng = match self.random_state {
101            Some(seed) => Random::seed(seed),
102            None => Random::seed(42),
103        };
104
105        let current_epsilon =
106            (self.epsilon * self.decay_rate.powi(self.total_rounds as i32)).max(self.min_epsilon);
107
108        if rng.gen::<f64>() < current_epsilon {
109            // Explore: select random arm
110            Ok(rng.gen_range(0..self.arm_counts.len()))
111        } else {
112            // Exploit: select arm with highest average reward
113            let mut best_arm = 0;
114            let mut best_reward = f64::NEG_INFINITY;
115
116            for (arm_idx, &count) in self.arm_counts.iter().enumerate() {
117                let avg_reward = if count > 0 {
118                    self.arm_rewards[arm_idx] / count as f64
119                } else {
120                    0.0
121                };
122
123                if avg_reward > best_reward {
124                    best_reward = avg_reward;
125                    best_arm = arm_idx;
126                }
127            }
128
129            Ok(best_arm)
130        }
131    }
132
133    pub fn update(&mut self, arm_idx: usize, reward: f64) -> Result<()> {
134        if arm_idx >= self.arm_counts.len() {
135            return Err(BanditError::InvalidArmIndex {
136                arm_idx,
137                n_arms: self.arm_counts.len(),
138            }
139            .into());
140        }
141
142        self.arm_counts[arm_idx] += 1;
143        self.arm_rewards[arm_idx] += reward;
144        self.total_rounds += 1;
145
146        Ok(())
147    }
148
149    pub fn get_arm_statistics(&self) -> Vec<(usize, f64, f64)> {
150        self.arm_counts
151            .iter()
152            .enumerate()
153            .map(|(idx, &count)| {
154                let avg_reward = if count > 0 {
155                    self.arm_rewards[idx] / count as f64
156                } else {
157                    0.0
158                };
159                (count, avg_reward, self.arm_rewards[idx])
160            })
161            .collect()
162    }
163}
164
165/// Upper Confidence Bound (UCB) strategy for multi-armed bandit active learning
166///
167/// This strategy selects arms based on an upper confidence bound that balances
168/// exploitation of good arms with exploration of uncertain arms.
169#[derive(Debug, Clone)]
170pub struct UpperConfidenceBound {
171    /// confidence
172    pub confidence: f64,
173    /// random_state
174    pub random_state: Option<u64>,
175    arm_counts: Vec<usize>,
176    arm_rewards: Vec<f64>,
177    total_rounds: usize,
178}
179
180impl UpperConfidenceBound {
181    pub fn new(confidence: f64) -> Result<Self> {
182        if confidence <= 0.0 {
183            return Err(BanditError::InvalidConfidence(confidence).into());
184        }
185        Ok(Self {
186            confidence,
187            random_state: None,
188            arm_counts: Vec::new(),
189            arm_rewards: Vec::new(),
190            total_rounds: 0,
191        })
192    }
193
194    pub fn random_state(mut self, random_state: u64) -> Self {
195        self.random_state = Some(random_state);
196        self
197    }
198
199    pub fn initialize(&mut self, n_arms: usize) {
200        self.arm_counts = vec![0; n_arms];
201        self.arm_rewards = vec![0.0; n_arms];
202        self.total_rounds = 0;
203    }
204
205    pub fn select_arm(&mut self) -> Result<usize> {
206        if self.arm_counts.is_empty() {
207            return Err(BanditError::NoArmsAvailable.into());
208        }
209
210        // If any arm has not been tried, select it first
211        for (arm_idx, &count) in self.arm_counts.iter().enumerate() {
212            if count == 0 {
213                return Ok(arm_idx);
214            }
215        }
216
217        // Calculate UCB values for all arms
218        let mut best_arm = 0;
219        let mut best_ucb = f64::NEG_INFINITY;
220
221        for (arm_idx, &count) in self.arm_counts.iter().enumerate() {
222            let avg_reward = self.arm_rewards[arm_idx] / count as f64;
223            let confidence_interval =
224                self.confidence * ((self.total_rounds as f64).ln() / count as f64).sqrt();
225            let ucb_value = avg_reward + confidence_interval;
226
227            if ucb_value > best_ucb {
228                best_ucb = ucb_value;
229                best_arm = arm_idx;
230            }
231        }
232
233        Ok(best_arm)
234    }
235
236    pub fn update(&mut self, arm_idx: usize, reward: f64) -> Result<()> {
237        if arm_idx >= self.arm_counts.len() {
238            return Err(BanditError::InvalidArmIndex {
239                arm_idx,
240                n_arms: self.arm_counts.len(),
241            }
242            .into());
243        }
244
245        self.arm_counts[arm_idx] += 1;
246        self.arm_rewards[arm_idx] += reward;
247        self.total_rounds += 1;
248
249        Ok(())
250    }
251
252    pub fn get_arm_statistics(&self) -> Vec<(usize, f64, f64, f64)> {
253        self.arm_counts
254            .iter()
255            .enumerate()
256            .map(|(idx, &count)| {
257                let avg_reward = if count > 0 {
258                    self.arm_rewards[idx] / count as f64
259                } else {
260                    0.0
261                };
262                let confidence_interval = if count > 0 && self.total_rounds > 0 {
263                    self.confidence * ((self.total_rounds as f64).ln() / count as f64).sqrt()
264                } else {
265                    f64::INFINITY
266                };
267                let ucb_value = avg_reward + confidence_interval;
268                (count, avg_reward, confidence_interval, ucb_value)
269            })
270            .collect()
271    }
272}
273
274/// Thompson Sampling strategy for multi-armed bandit active learning
275///
276/// This strategy uses Bayesian inference to maintain probability distributions
277/// over the reward rates of each arm and samples from these distributions.
278#[derive(Debug, Clone)]
279pub struct ThompsonSampling {
280    /// random_state
281    pub random_state: Option<u64>,
282    alpha_params: Vec<f64>,
283    beta_params: Vec<f64>,
284    total_rounds: usize,
285}
286
287impl Default for ThompsonSampling {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293impl ThompsonSampling {
294    pub fn new() -> Self {
295        Self {
296            random_state: None,
297            alpha_params: Vec::new(),
298            beta_params: Vec::new(),
299            total_rounds: 0,
300        }
301    }
302
303    pub fn random_state(mut self, random_state: u64) -> Self {
304        self.random_state = Some(random_state);
305        self
306    }
307
308    pub fn initialize(&mut self, n_arms: usize) {
309        // Initialize with Beta(1, 1) priors (uniform distribution)
310        self.alpha_params = vec![1.0; n_arms];
311        self.beta_params = vec![1.0; n_arms];
312        self.total_rounds = 0;
313    }
314
315    pub fn select_arm(&mut self) -> Result<usize> {
316        if self.alpha_params.is_empty() {
317            return Err(BanditError::NoArmsAvailable.into());
318        }
319
320        let mut rng = match self.random_state {
321            Some(seed) => Random::seed(seed),
322            None => Random::seed(42),
323        };
324
325        let mut best_arm = 0;
326        let mut best_sample = f64::NEG_INFINITY;
327
328        // Sample from Beta distributions for each arm
329        for (arm_idx, (&alpha, &beta)) in self
330            .alpha_params
331            .iter()
332            .zip(self.beta_params.iter())
333            .enumerate()
334        {
335            // Simplified Thompson sampling using uniform distribution
336            // In practice, this would use proper Beta(alpha, beta) sampling
337            let sample = alpha / (alpha + beta) + rng.random_range(-0.1..0.1);
338
339            if sample > best_sample {
340                best_sample = sample;
341                best_arm = arm_idx;
342            }
343        }
344
345        Ok(best_arm)
346    }
347
348    pub fn update(&mut self, arm_idx: usize, reward: f64) -> Result<()> {
349        if arm_idx >= self.alpha_params.len() {
350            return Err(BanditError::InvalidArmIndex {
351                arm_idx,
352                n_arms: self.alpha_params.len(),
353            }
354            .into());
355        }
356
357        // Update Beta distribution parameters
358        // Reward should be normalized to [0, 1] for Beta distribution
359        let normalized_reward = reward.clamp(0.0, 1.0);
360
361        if normalized_reward > 0.5 {
362            // Treat as success
363            self.alpha_params[arm_idx] += 1.0;
364        } else {
365            // Treat as failure
366            self.beta_params[arm_idx] += 1.0;
367        }
368
369        self.total_rounds += 1;
370        Ok(())
371    }
372
373    pub fn get_arm_statistics(&self) -> Vec<(f64, f64, f64, f64)> {
374        self.alpha_params
375            .iter()
376            .zip(self.beta_params.iter())
377            .map(|(&alpha, &beta)| {
378                let mean = alpha / (alpha + beta);
379                let variance = (alpha * beta) / ((alpha + beta).powi(2) * (alpha + beta + 1.0));
380                (alpha, beta, mean, variance)
381            })
382            .collect()
383    }
384}
385
386/// Contextual Bandit for active learning with context features
387///
388/// This extends multi-armed bandits to include contextual information,
389/// making decisions based on both historical performance and current context.
390#[derive(Debug, Clone)]
391pub struct ContextualBandit {
392    /// learning_rate
393    pub learning_rate: f64,
394    /// exploration
395    pub exploration: f64,
396    /// random_state
397    pub random_state: Option<u64>,
398    arm_weights: Vec<Array1<f64>>,
399    context_dim: usize,
400    total_rounds: usize,
401}
402
403impl ContextualBandit {
404    pub fn new(exploration: f64) -> Result<Self> {
405        if exploration < 0.0 {
406            return Err(BanditError::InvalidExploration(exploration).into());
407        }
408        Ok(Self {
409            learning_rate: 0.1,
410            exploration,
411            random_state: None,
412            arm_weights: Vec::new(),
413            context_dim: 0,
414            total_rounds: 0,
415        })
416    }
417
418    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
419        self.learning_rate = learning_rate;
420        self
421    }
422
423    pub fn random_state(mut self, random_state: u64) -> Self {
424        self.random_state = Some(random_state);
425        self
426    }
427
428    pub fn initialize(&mut self, n_arms: usize, context_dim: usize) {
429        self.context_dim = context_dim;
430        self.arm_weights = vec![Array1::zeros(context_dim); n_arms];
431        self.total_rounds = 0;
432    }
433
434    pub fn select_arm(&mut self, context: &ArrayView1<f64>) -> Result<usize> {
435        if self.arm_weights.is_empty() {
436            return Err(BanditError::NoArmsAvailable.into());
437        }
438
439        if context.len() != self.context_dim {
440            return Err(BanditError::BanditComputationFailed(format!(
441                "Context dimension mismatch: expected {}, got {}",
442                self.context_dim,
443                context.len()
444            ))
445            .into());
446        }
447
448        let mut rng = match self.random_state {
449            Some(seed) => Random::seed(seed),
450            None => Random::seed(42),
451        };
452
453        // Epsilon-greedy with linear contextual bandits
454        if rng.gen::<f64>() < self.exploration {
455            // Explore: select random arm
456            Ok(rng.gen_range(0..self.arm_weights.len()))
457        } else {
458            // Exploit: select arm with highest predicted reward
459            let mut best_arm = 0;
460            let mut best_reward = f64::NEG_INFINITY;
461
462            for (arm_idx, weights) in self.arm_weights.iter().enumerate() {
463                let predicted_reward = weights.dot(context);
464                if predicted_reward > best_reward {
465                    best_reward = predicted_reward;
466                    best_arm = arm_idx;
467                }
468            }
469
470            Ok(best_arm)
471        }
472    }
473
474    pub fn update(&mut self, arm_idx: usize, context: &ArrayView1<f64>, reward: f64) -> Result<()> {
475        if arm_idx >= self.arm_weights.len() {
476            return Err(BanditError::InvalidArmIndex {
477                arm_idx,
478                n_arms: self.arm_weights.len(),
479            }
480            .into());
481        }
482
483        if context.len() != self.context_dim {
484            return Err(BanditError::BanditComputationFailed(format!(
485                "Context dimension mismatch: expected {}, got {}",
486                self.context_dim,
487                context.len()
488            ))
489            .into());
490        }
491
492        // Gradient update for linear contextual bandit
493        let predicted_reward = self.arm_weights[arm_idx].dot(context);
494        let error = reward - predicted_reward;
495
496        // Update weights: w = w + α * error * context
497        for i in 0..self.context_dim {
498            self.arm_weights[arm_idx][i] += self.learning_rate * error * context[i];
499        }
500
501        self.total_rounds += 1;
502        Ok(())
503    }
504
505    pub fn get_arm_weights(&self) -> &Vec<Array1<f64>> {
506        &self.arm_weights
507    }
508
509    pub fn predict_rewards(&self, context: &ArrayView1<f64>) -> Result<Array1<f64>> {
510        if context.len() != self.context_dim {
511            return Err(BanditError::BanditComputationFailed(format!(
512                "Context dimension mismatch: expected {}, got {}",
513                self.context_dim,
514                context.len()
515            ))
516            .into());
517        }
518
519        let mut predicted_rewards = Array1::zeros(self.arm_weights.len());
520        for (arm_idx, weights) in self.arm_weights.iter().enumerate() {
521            predicted_rewards[arm_idx] = weights.dot(context);
522        }
523
524        Ok(predicted_rewards)
525    }
526}
527
528/// Bandit-based Active Learning coordinator
529///
530/// This coordinates multiple query strategies using bandit algorithms,
531/// treating each strategy as an arm and adaptively selecting strategies
532/// based on their performance.
533#[derive(Debug, Clone)]
534pub struct BanditBasedActiveLearning {
535    /// strategy_names
536    pub strategy_names: Vec<String>,
537    /// bandit_algorithm
538    pub bandit_algorithm: String,
539    /// reward_function
540    pub reward_function: String,
541    /// random_state
542    pub random_state: Option<u64>,
543    epsilon_greedy: Option<EpsilonGreedy>,
544    ucb: Option<UpperConfidenceBound>,
545    thompson: Option<ThompsonSampling>,
546    contextual: Option<ContextualBandit>,
547}
548
549impl BanditBasedActiveLearning {
550    pub fn new(strategy_names: Vec<String>, bandit_algorithm: String) -> Self {
551        Self {
552            strategy_names,
553            bandit_algorithm,
554            reward_function: "accuracy_improvement".to_string(),
555            random_state: None,
556            epsilon_greedy: None,
557            ucb: None,
558            thompson: None,
559            contextual: None,
560        }
561    }
562
563    pub fn reward_function(mut self, reward_function: String) -> Self {
564        self.reward_function = reward_function;
565        self
566    }
567
568    pub fn random_state(mut self, random_state: u64) -> Self {
569        self.random_state = Some(random_state);
570        self
571    }
572
573    pub fn initialize(
574        &mut self,
575        epsilon: Option<f64>,
576        confidence: Option<f64>,
577        exploration: Option<f64>,
578    ) -> Result<()> {
579        let n_arms = self.strategy_names.len();
580
581        match self.bandit_algorithm.as_str() {
582            "epsilon_greedy" => {
583                let eps = epsilon.unwrap_or(0.1);
584                let mut eg = EpsilonGreedy::new(eps)?;
585                if let Some(seed) = self.random_state {
586                    eg = eg.random_state(seed);
587                }
588                eg.initialize(n_arms);
589                self.epsilon_greedy = Some(eg);
590            }
591            "ucb" => {
592                let conf = confidence.unwrap_or(2.0);
593                let mut ucb = UpperConfidenceBound::new(conf)?;
594                if let Some(seed) = self.random_state {
595                    ucb = ucb.random_state(seed);
596                }
597                ucb.initialize(n_arms);
598                self.ucb = Some(ucb);
599            }
600            "thompson_sampling" => {
601                let mut ts = ThompsonSampling::new();
602                if let Some(seed) = self.random_state {
603                    ts = ts.random_state(seed);
604                }
605                ts.initialize(n_arms);
606                self.thompson = Some(ts);
607            }
608            "contextual" => {
609                let exp = exploration.unwrap_or(0.1);
610                let mut cb = ContextualBandit::new(exp)?;
611                if let Some(seed) = self.random_state {
612                    cb = cb.random_state(seed);
613                }
614                // Context dimension will be set when first context is provided
615                self.contextual = Some(cb);
616            }
617            _ => {
618                return Err(BanditError::BanditComputationFailed(format!(
619                    "Unknown bandit algorithm: {}",
620                    self.bandit_algorithm
621                ))
622                .into())
623            }
624        }
625
626        Ok(())
627    }
628
629    pub fn select_strategy(&mut self, context: Option<&ArrayView1<f64>>) -> Result<usize> {
630        match self.bandit_algorithm.as_str() {
631            "epsilon_greedy" => {
632                if let Some(ref mut eg) = self.epsilon_greedy {
633                    eg.select_arm()
634                } else {
635                    Err(BanditError::BanditComputationFailed(
636                        "Epsilon-greedy not initialized".to_string(),
637                    )
638                    .into())
639                }
640            }
641            "ucb" => {
642                if let Some(ref mut ucb) = self.ucb {
643                    ucb.select_arm()
644                } else {
645                    Err(
646                        BanditError::BanditComputationFailed("UCB not initialized".to_string())
647                            .into(),
648                    )
649                }
650            }
651            "thompson_sampling" => {
652                if let Some(ref mut ts) = self.thompson {
653                    ts.select_arm()
654                } else {
655                    Err(BanditError::BanditComputationFailed(
656                        "Thompson sampling not initialized".to_string(),
657                    )
658                    .into())
659                }
660            }
661            "contextual" => {
662                if let Some(context) = context {
663                    if let Some(ref mut cb) = self.contextual {
664                        if cb.context_dim == 0 {
665                            cb.initialize(self.strategy_names.len(), context.len());
666                        }
667                        cb.select_arm(context)
668                    } else {
669                        Err(BanditError::BanditComputationFailed(
670                            "Contextual bandit not initialized".to_string(),
671                        )
672                        .into())
673                    }
674                } else {
675                    Err(BanditError::BanditComputationFailed(
676                        "Context required for contextual bandit".to_string(),
677                    )
678                    .into())
679                }
680            }
681            _ => Err(BanditError::BanditComputationFailed(format!(
682                "Unknown bandit algorithm: {}",
683                self.bandit_algorithm
684            ))
685            .into()),
686        }
687    }
688
689    pub fn update_strategy(
690        &mut self,
691        strategy_idx: usize,
692        reward: f64,
693        context: Option<&ArrayView1<f64>>,
694    ) -> Result<()> {
695        match self.bandit_algorithm.as_str() {
696            "epsilon_greedy" => {
697                if let Some(ref mut eg) = self.epsilon_greedy {
698                    eg.update(strategy_idx, reward)
699                } else {
700                    Err(BanditError::BanditComputationFailed(
701                        "Epsilon-greedy not initialized".to_string(),
702                    )
703                    .into())
704                }
705            }
706            "ucb" => {
707                if let Some(ref mut ucb) = self.ucb {
708                    ucb.update(strategy_idx, reward)
709                } else {
710                    Err(
711                        BanditError::BanditComputationFailed("UCB not initialized".to_string())
712                            .into(),
713                    )
714                }
715            }
716            "thompson_sampling" => {
717                if let Some(ref mut ts) = self.thompson {
718                    ts.update(strategy_idx, reward)
719                } else {
720                    Err(BanditError::BanditComputationFailed(
721                        "Thompson sampling not initialized".to_string(),
722                    )
723                    .into())
724                }
725            }
726            "contextual" => {
727                if let Some(context) = context {
728                    if let Some(ref mut cb) = self.contextual {
729                        cb.update(strategy_idx, context, reward)
730                    } else {
731                        Err(BanditError::BanditComputationFailed(
732                            "Contextual bandit not initialized".to_string(),
733                        )
734                        .into())
735                    }
736                } else {
737                    Err(BanditError::BanditComputationFailed(
738                        "Context required for contextual bandit".to_string(),
739                    )
740                    .into())
741                }
742            }
743            _ => Err(BanditError::BanditComputationFailed(format!(
744                "Unknown bandit algorithm: {}",
745                self.bandit_algorithm
746            ))
747            .into()),
748        }
749    }
750
751    pub fn get_strategy_performance(&self) -> Result<HashMap<String, f64>> {
752        let mut performance = HashMap::new();
753
754        match self.bandit_algorithm.as_str() {
755            "epsilon_greedy" => {
756                if let Some(ref eg) = self.epsilon_greedy {
757                    let stats = eg.get_arm_statistics();
758                    for (idx, (_, avg_reward, _)) in stats.iter().enumerate() {
759                        if idx < self.strategy_names.len() {
760                            performance.insert(self.strategy_names[idx].clone(), *avg_reward);
761                        }
762                    }
763                }
764            }
765            "ucb" => {
766                if let Some(ref ucb) = self.ucb {
767                    let stats = ucb.get_arm_statistics();
768                    for (idx, (_, avg_reward, _, _)) in stats.iter().enumerate() {
769                        if idx < self.strategy_names.len() {
770                            performance.insert(self.strategy_names[idx].clone(), *avg_reward);
771                        }
772                    }
773                }
774            }
775            "thompson_sampling" => {
776                if let Some(ref ts) = self.thompson {
777                    let stats = ts.get_arm_statistics();
778                    for (idx, (_, _, mean, _)) in stats.iter().enumerate() {
779                        if idx < self.strategy_names.len() {
780                            performance.insert(self.strategy_names[idx].clone(), *mean);
781                        }
782                    }
783                }
784            }
785            "contextual" => {
786                // For contextual bandits, performance is context-dependent
787                // Return uniform performance for now
788                for name in self.strategy_names.iter() {
789                    performance.insert(name.clone(), 0.5);
790                }
791            }
792            _ => {
793                return Err(BanditError::BanditComputationFailed(format!(
794                    "Unknown bandit algorithm: {}",
795                    self.bandit_algorithm
796                ))
797                .into())
798            }
799        }
800
801        Ok(performance)
802    }
803}
804
805#[allow(non_snake_case)]
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use approx::assert_abs_diff_eq;
810    use scirs2_core::array;
811
812    #[test]
813    fn test_epsilon_greedy_creation() {
814        let eg = EpsilonGreedy::new(0.1).unwrap();
815        assert_eq!(eg.epsilon, 0.1);
816        assert_eq!(eg.decay_rate, 0.995);
817        assert_eq!(eg.min_epsilon, 0.01);
818    }
819
820    #[test]
821    fn test_epsilon_greedy_invalid_epsilon() {
822        assert!(EpsilonGreedy::new(-0.1).is_err());
823        assert!(EpsilonGreedy::new(1.5).is_err());
824    }
825
826    #[test]
827    fn test_epsilon_greedy_basic_functionality() {
828        let mut eg = EpsilonGreedy::new(0.5).unwrap().random_state(42);
829        eg.initialize(3);
830
831        // Select arms and update rewards
832        for _ in 0..10 {
833            let arm = eg.select_arm().unwrap();
834            assert!(arm < 3);
835
836            let reward = if arm == 0 { 1.0 } else { 0.0 }; // Arm 0 is best
837            eg.update(arm, reward).unwrap();
838        }
839
840        let stats = eg.get_arm_statistics();
841        assert_eq!(stats.len(), 3);
842
843        // Check that arm 0 has highest average reward (it should be selected more often)
844        if stats[0].0 > 0 {
845            assert!(stats[0].1 >= stats[1].1 && stats[0].1 >= stats[2].1);
846        }
847    }
848
849    #[test]
850    fn test_upper_confidence_bound_creation() {
851        let ucb = UpperConfidenceBound::new(2.0).unwrap();
852        assert_eq!(ucb.confidence, 2.0);
853    }
854
855    #[test]
856    fn test_upper_confidence_bound_invalid_confidence() {
857        assert!(UpperConfidenceBound::new(0.0).is_err());
858        assert!(UpperConfidenceBound::new(-1.0).is_err());
859    }
860
861    #[test]
862    fn test_upper_confidence_bound_basic_functionality() {
863        let mut ucb = UpperConfidenceBound::new(2.0).unwrap().random_state(42);
864        ucb.initialize(3);
865
866        // Select arms and update rewards
867        for _ in 0..10 {
868            let arm = ucb.select_arm().unwrap();
869            assert!(arm < 3);
870
871            let reward = if arm == 0 { 0.8 } else { 0.2 }; // Arm 0 is best
872            ucb.update(arm, reward).unwrap();
873        }
874
875        let stats = ucb.get_arm_statistics();
876        assert_eq!(stats.len(), 3);
877
878        // All arms should have been tried at least once
879        for (count, _, _, _) in stats.iter() {
880            assert!(*count > 0);
881        }
882    }
883
884    #[test]
885    fn test_thompson_sampling_creation() {
886        let ts = ThompsonSampling::new();
887        assert!(ts.alpha_params.is_empty());
888        assert!(ts.beta_params.is_empty());
889    }
890
891    #[test]
892    fn test_thompson_sampling_basic_functionality() {
893        let mut ts = ThompsonSampling::new().random_state(42);
894        ts.initialize(3);
895
896        // Select arms and update rewards
897        for _ in 0..10 {
898            let arm = ts.select_arm().unwrap();
899            assert!(arm < 3);
900
901            let reward = if arm == 0 { 0.9 } else { 0.1 }; // Arm 0 is best
902            ts.update(arm, reward).unwrap();
903        }
904
905        let stats = ts.get_arm_statistics();
906        assert_eq!(stats.len(), 3);
907
908        // Check that parameters have been updated
909        for (alpha, beta, mean, _) in stats.iter() {
910            assert!(*alpha >= 1.0);
911            assert!(*beta >= 1.0);
912            assert!(*mean >= 0.0 && *mean <= 1.0);
913        }
914    }
915
916    #[test]
917    fn test_contextual_bandit_creation() {
918        let cb = ContextualBandit::new(0.1).unwrap();
919        assert_eq!(cb.exploration, 0.1);
920        assert_eq!(cb.learning_rate, 0.1);
921    }
922
923    #[test]
924    fn test_contextual_bandit_invalid_exploration() {
925        assert!(ContextualBandit::new(-0.1).is_err());
926    }
927
928    #[test]
929    fn test_contextual_bandit_basic_functionality() {
930        let mut cb = ContextualBandit::new(0.1).unwrap().random_state(42);
931        cb.initialize(2, 3);
932
933        let context1 = array![1.0, 0.0, 0.0];
934        let context2 = array![0.0, 1.0, 0.0];
935
936        // Select arms and update rewards
937        for i in 0..10 {
938            let context = if i % 2 == 0 { &context1 } else { &context2 };
939            let arm = cb.select_arm(&context.view()).unwrap();
940            assert!(arm < 2);
941
942            let reward = if (arm == 0 && i % 2 == 0) || (arm == 1 && i % 2 == 1) {
943                1.0
944            } else {
945                0.0
946            };
947            cb.update(arm, &context.view(), reward).unwrap();
948        }
949
950        let weights = cb.get_arm_weights();
951        assert_eq!(weights.len(), 2);
952        assert_eq!(weights[0].len(), 3);
953        assert_eq!(weights[1].len(), 3);
954
955        // Test prediction
956        let predicted = cb.predict_rewards(&context1.view()).unwrap();
957        assert_eq!(predicted.len(), 2);
958    }
959
960    #[test]
961    fn test_bandit_based_active_learning() {
962        let strategies = vec![
963            "entropy".to_string(),
964            "margin".to_string(),
965            "random".to_string(),
966        ];
967        let mut bbal =
968            BanditBasedActiveLearning::new(strategies.clone(), "epsilon_greedy".to_string())
969                .random_state(42);
970
971        bbal.initialize(Some(0.2), None, None).unwrap();
972
973        // Select strategies and update rewards
974        for _ in 0..10 {
975            let strategy_idx = bbal.select_strategy(None).unwrap();
976            assert!(strategy_idx < strategies.len());
977
978            let reward = if strategy_idx == 0 { 0.8 } else { 0.3 }; // Entropy is best
979            bbal.update_strategy(strategy_idx, reward, None).unwrap();
980        }
981
982        let performance = bbal.get_strategy_performance().unwrap();
983        assert_eq!(performance.len(), strategies.len());
984
985        for strategy in strategies.iter() {
986            assert!(performance.contains_key(strategy));
987        }
988    }
989
990    #[test]
991    fn test_bandit_based_active_learning_ucb() {
992        let strategies = vec!["uncertainty".to_string(), "diversity".to_string()];
993        let mut bbal =
994            BanditBasedActiveLearning::new(strategies.clone(), "ucb".to_string()).random_state(42);
995
996        bbal.initialize(None, Some(1.5), None).unwrap();
997
998        // Test basic functionality
999        let strategy_idx = bbal.select_strategy(None).unwrap();
1000        assert!(strategy_idx < strategies.len());
1001
1002        bbal.update_strategy(strategy_idx, 0.5, None).unwrap();
1003
1004        let performance = bbal.get_strategy_performance().unwrap();
1005        assert_eq!(performance.len(), strategies.len());
1006    }
1007
1008    #[test]
1009    fn test_bandit_based_active_learning_thompson() {
1010        let strategies = vec!["query1".to_string(), "query2".to_string()];
1011        let mut bbal =
1012            BanditBasedActiveLearning::new(strategies.clone(), "thompson_sampling".to_string())
1013                .random_state(42);
1014
1015        bbal.initialize(None, None, None).unwrap();
1016
1017        // Test basic functionality
1018        let strategy_idx = bbal.select_strategy(None).unwrap();
1019        assert!(strategy_idx < strategies.len());
1020
1021        bbal.update_strategy(strategy_idx, 0.7, None).unwrap();
1022
1023        let performance = bbal.get_strategy_performance().unwrap();
1024        assert_eq!(performance.len(), strategies.len());
1025    }
1026
1027    #[test]
1028    fn test_bandit_based_active_learning_contextual() {
1029        let strategies = vec![
1030            "context_strategy1".to_string(),
1031            "context_strategy2".to_string(),
1032        ];
1033        let mut bbal = BanditBasedActiveLearning::new(strategies.clone(), "contextual".to_string())
1034            .random_state(42);
1035
1036        bbal.initialize(None, None, Some(0.15)).unwrap();
1037
1038        let context = array![0.5, 1.0, 0.2];
1039
1040        // Test basic functionality
1041        let strategy_idx = bbal.select_strategy(Some(&context.view())).unwrap();
1042        assert!(strategy_idx < strategies.len());
1043
1044        bbal.update_strategy(strategy_idx, 0.6, Some(&context.view()))
1045            .unwrap();
1046
1047        let performance = bbal.get_strategy_performance().unwrap();
1048        assert_eq!(performance.len(), strategies.len());
1049    }
1050}