1use scirs2_core::ndarray_ext::{Array1, ArrayView1};
8use scirs2_core::random::{Random, Rng};
9use sklears_core::error::{Result, SklearsError};
10use std::collections::HashMap;
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum BanditError {
15 #[error("Invalid epsilon parameter: {0}")]
16 InvalidEpsilon(f64),
17 #[error("Invalid temperature parameter: {0}")]
18 InvalidTemperature(f64),
19 #[error("Invalid confidence parameter: {0}")]
20 InvalidConfidence(f64),
21 #[error("Invalid exploration parameter: {0}")]
22 InvalidExploration(f64),
23 #[error("Invalid arm index: {arm_idx} for {n_arms} arms")]
24 InvalidArmIndex { arm_idx: usize, n_arms: usize },
25 #[error("No arms available")]
26 NoArmsAvailable,
27 #[error("Insufficient data for arm: {0}")]
28 InsufficientDataForArm(usize),
29 #[error("Bandit computation failed: {0}")]
30 BanditComputationFailed(String),
31}
32
33impl From<BanditError> for SklearsError {
34 fn from(err: BanditError) -> Self {
35 SklearsError::FitError(err.to_string())
36 }
37}
38
39#[derive(Debug, Clone)]
44pub struct EpsilonGreedy {
45 pub epsilon: f64,
47 pub decay_rate: f64,
49 pub min_epsilon: f64,
51 pub random_state: Option<u64>,
53 arm_counts: Vec<usize>,
54 arm_rewards: Vec<f64>,
55 total_rounds: usize,
56}
57
58impl EpsilonGreedy {
59 pub fn new(epsilon: f64) -> Result<Self> {
60 if !(0.0..=1.0).contains(&epsilon) {
61 return Err(BanditError::InvalidEpsilon(epsilon).into());
62 }
63 Ok(Self {
64 epsilon,
65 decay_rate: 0.995,
66 min_epsilon: 0.01,
67 random_state: None,
68 arm_counts: Vec::new(),
69 arm_rewards: Vec::new(),
70 total_rounds: 0,
71 })
72 }
73
74 pub fn decay_rate(mut self, decay_rate: f64) -> Self {
75 self.decay_rate = decay_rate;
76 self
77 }
78
79 pub fn min_epsilon(mut self, min_epsilon: f64) -> Self {
80 self.min_epsilon = min_epsilon;
81 self
82 }
83
84 pub fn random_state(mut self, random_state: u64) -> Self {
85 self.random_state = Some(random_state);
86 self
87 }
88
89 pub fn initialize(&mut self, n_arms: usize) {
90 self.arm_counts = vec![0; n_arms];
91 self.arm_rewards = vec![0.0; n_arms];
92 self.total_rounds = 0;
93 }
94
95 pub fn select_arm(&mut self) -> Result<usize> {
96 if self.arm_counts.is_empty() {
97 return Err(BanditError::NoArmsAvailable.into());
98 }
99
100 let mut rng = match self.random_state {
101 Some(seed) => Random::seed(seed),
102 None => Random::seed(42),
103 };
104
105 let current_epsilon =
106 (self.epsilon * self.decay_rate.powi(self.total_rounds as i32)).max(self.min_epsilon);
107
108 if rng.gen::<f64>() < current_epsilon {
109 Ok(rng.gen_range(0..self.arm_counts.len()))
111 } else {
112 let mut best_arm = 0;
114 let mut best_reward = f64::NEG_INFINITY;
115
116 for (arm_idx, &count) in self.arm_counts.iter().enumerate() {
117 let avg_reward = if count > 0 {
118 self.arm_rewards[arm_idx] / count as f64
119 } else {
120 0.0
121 };
122
123 if avg_reward > best_reward {
124 best_reward = avg_reward;
125 best_arm = arm_idx;
126 }
127 }
128
129 Ok(best_arm)
130 }
131 }
132
133 pub fn update(&mut self, arm_idx: usize, reward: f64) -> Result<()> {
134 if arm_idx >= self.arm_counts.len() {
135 return Err(BanditError::InvalidArmIndex {
136 arm_idx,
137 n_arms: self.arm_counts.len(),
138 }
139 .into());
140 }
141
142 self.arm_counts[arm_idx] += 1;
143 self.arm_rewards[arm_idx] += reward;
144 self.total_rounds += 1;
145
146 Ok(())
147 }
148
149 pub fn get_arm_statistics(&self) -> Vec<(usize, f64, f64)> {
150 self.arm_counts
151 .iter()
152 .enumerate()
153 .map(|(idx, &count)| {
154 let avg_reward = if count > 0 {
155 self.arm_rewards[idx] / count as f64
156 } else {
157 0.0
158 };
159 (count, avg_reward, self.arm_rewards[idx])
160 })
161 .collect()
162 }
163}
164
165#[derive(Debug, Clone)]
170pub struct UpperConfidenceBound {
171 pub confidence: f64,
173 pub random_state: Option<u64>,
175 arm_counts: Vec<usize>,
176 arm_rewards: Vec<f64>,
177 total_rounds: usize,
178}
179
180impl UpperConfidenceBound {
181 pub fn new(confidence: f64) -> Result<Self> {
182 if confidence <= 0.0 {
183 return Err(BanditError::InvalidConfidence(confidence).into());
184 }
185 Ok(Self {
186 confidence,
187 random_state: None,
188 arm_counts: Vec::new(),
189 arm_rewards: Vec::new(),
190 total_rounds: 0,
191 })
192 }
193
194 pub fn random_state(mut self, random_state: u64) -> Self {
195 self.random_state = Some(random_state);
196 self
197 }
198
199 pub fn initialize(&mut self, n_arms: usize) {
200 self.arm_counts = vec![0; n_arms];
201 self.arm_rewards = vec![0.0; n_arms];
202 self.total_rounds = 0;
203 }
204
205 pub fn select_arm(&mut self) -> Result<usize> {
206 if self.arm_counts.is_empty() {
207 return Err(BanditError::NoArmsAvailable.into());
208 }
209
210 for (arm_idx, &count) in self.arm_counts.iter().enumerate() {
212 if count == 0 {
213 return Ok(arm_idx);
214 }
215 }
216
217 let mut best_arm = 0;
219 let mut best_ucb = f64::NEG_INFINITY;
220
221 for (arm_idx, &count) in self.arm_counts.iter().enumerate() {
222 let avg_reward = self.arm_rewards[arm_idx] / count as f64;
223 let confidence_interval =
224 self.confidence * ((self.total_rounds as f64).ln() / count as f64).sqrt();
225 let ucb_value = avg_reward + confidence_interval;
226
227 if ucb_value > best_ucb {
228 best_ucb = ucb_value;
229 best_arm = arm_idx;
230 }
231 }
232
233 Ok(best_arm)
234 }
235
236 pub fn update(&mut self, arm_idx: usize, reward: f64) -> Result<()> {
237 if arm_idx >= self.arm_counts.len() {
238 return Err(BanditError::InvalidArmIndex {
239 arm_idx,
240 n_arms: self.arm_counts.len(),
241 }
242 .into());
243 }
244
245 self.arm_counts[arm_idx] += 1;
246 self.arm_rewards[arm_idx] += reward;
247 self.total_rounds += 1;
248
249 Ok(())
250 }
251
252 pub fn get_arm_statistics(&self) -> Vec<(usize, f64, f64, f64)> {
253 self.arm_counts
254 .iter()
255 .enumerate()
256 .map(|(idx, &count)| {
257 let avg_reward = if count > 0 {
258 self.arm_rewards[idx] / count as f64
259 } else {
260 0.0
261 };
262 let confidence_interval = if count > 0 && self.total_rounds > 0 {
263 self.confidence * ((self.total_rounds as f64).ln() / count as f64).sqrt()
264 } else {
265 f64::INFINITY
266 };
267 let ucb_value = avg_reward + confidence_interval;
268 (count, avg_reward, confidence_interval, ucb_value)
269 })
270 .collect()
271 }
272}
273
274#[derive(Debug, Clone)]
279pub struct ThompsonSampling {
280 pub random_state: Option<u64>,
282 alpha_params: Vec<f64>,
283 beta_params: Vec<f64>,
284 total_rounds: usize,
285}
286
287impl Default for ThompsonSampling {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293impl ThompsonSampling {
294 pub fn new() -> Self {
295 Self {
296 random_state: None,
297 alpha_params: Vec::new(),
298 beta_params: Vec::new(),
299 total_rounds: 0,
300 }
301 }
302
303 pub fn random_state(mut self, random_state: u64) -> Self {
304 self.random_state = Some(random_state);
305 self
306 }
307
308 pub fn initialize(&mut self, n_arms: usize) {
309 self.alpha_params = vec![1.0; n_arms];
311 self.beta_params = vec![1.0; n_arms];
312 self.total_rounds = 0;
313 }
314
315 pub fn select_arm(&mut self) -> Result<usize> {
316 if self.alpha_params.is_empty() {
317 return Err(BanditError::NoArmsAvailable.into());
318 }
319
320 let mut rng = match self.random_state {
321 Some(seed) => Random::seed(seed),
322 None => Random::seed(42),
323 };
324
325 let mut best_arm = 0;
326 let mut best_sample = f64::NEG_INFINITY;
327
328 for (arm_idx, (&alpha, &beta)) in self
330 .alpha_params
331 .iter()
332 .zip(self.beta_params.iter())
333 .enumerate()
334 {
335 let sample = alpha / (alpha + beta) + rng.random_range(-0.1..0.1);
338
339 if sample > best_sample {
340 best_sample = sample;
341 best_arm = arm_idx;
342 }
343 }
344
345 Ok(best_arm)
346 }
347
348 pub fn update(&mut self, arm_idx: usize, reward: f64) -> Result<()> {
349 if arm_idx >= self.alpha_params.len() {
350 return Err(BanditError::InvalidArmIndex {
351 arm_idx,
352 n_arms: self.alpha_params.len(),
353 }
354 .into());
355 }
356
357 let normalized_reward = reward.clamp(0.0, 1.0);
360
361 if normalized_reward > 0.5 {
362 self.alpha_params[arm_idx] += 1.0;
364 } else {
365 self.beta_params[arm_idx] += 1.0;
367 }
368
369 self.total_rounds += 1;
370 Ok(())
371 }
372
373 pub fn get_arm_statistics(&self) -> Vec<(f64, f64, f64, f64)> {
374 self.alpha_params
375 .iter()
376 .zip(self.beta_params.iter())
377 .map(|(&alpha, &beta)| {
378 let mean = alpha / (alpha + beta);
379 let variance = (alpha * beta) / ((alpha + beta).powi(2) * (alpha + beta + 1.0));
380 (alpha, beta, mean, variance)
381 })
382 .collect()
383 }
384}
385
386#[derive(Debug, Clone)]
391pub struct ContextualBandit {
392 pub learning_rate: f64,
394 pub exploration: f64,
396 pub random_state: Option<u64>,
398 arm_weights: Vec<Array1<f64>>,
399 context_dim: usize,
400 total_rounds: usize,
401}
402
403impl ContextualBandit {
404 pub fn new(exploration: f64) -> Result<Self> {
405 if exploration < 0.0 {
406 return Err(BanditError::InvalidExploration(exploration).into());
407 }
408 Ok(Self {
409 learning_rate: 0.1,
410 exploration,
411 random_state: None,
412 arm_weights: Vec::new(),
413 context_dim: 0,
414 total_rounds: 0,
415 })
416 }
417
418 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
419 self.learning_rate = learning_rate;
420 self
421 }
422
423 pub fn random_state(mut self, random_state: u64) -> Self {
424 self.random_state = Some(random_state);
425 self
426 }
427
428 pub fn initialize(&mut self, n_arms: usize, context_dim: usize) {
429 self.context_dim = context_dim;
430 self.arm_weights = vec![Array1::zeros(context_dim); n_arms];
431 self.total_rounds = 0;
432 }
433
434 pub fn select_arm(&mut self, context: &ArrayView1<f64>) -> Result<usize> {
435 if self.arm_weights.is_empty() {
436 return Err(BanditError::NoArmsAvailable.into());
437 }
438
439 if context.len() != self.context_dim {
440 return Err(BanditError::BanditComputationFailed(format!(
441 "Context dimension mismatch: expected {}, got {}",
442 self.context_dim,
443 context.len()
444 ))
445 .into());
446 }
447
448 let mut rng = match self.random_state {
449 Some(seed) => Random::seed(seed),
450 None => Random::seed(42),
451 };
452
453 if rng.gen::<f64>() < self.exploration {
455 Ok(rng.gen_range(0..self.arm_weights.len()))
457 } else {
458 let mut best_arm = 0;
460 let mut best_reward = f64::NEG_INFINITY;
461
462 for (arm_idx, weights) in self.arm_weights.iter().enumerate() {
463 let predicted_reward = weights.dot(context);
464 if predicted_reward > best_reward {
465 best_reward = predicted_reward;
466 best_arm = arm_idx;
467 }
468 }
469
470 Ok(best_arm)
471 }
472 }
473
474 pub fn update(&mut self, arm_idx: usize, context: &ArrayView1<f64>, reward: f64) -> Result<()> {
475 if arm_idx >= self.arm_weights.len() {
476 return Err(BanditError::InvalidArmIndex {
477 arm_idx,
478 n_arms: self.arm_weights.len(),
479 }
480 .into());
481 }
482
483 if context.len() != self.context_dim {
484 return Err(BanditError::BanditComputationFailed(format!(
485 "Context dimension mismatch: expected {}, got {}",
486 self.context_dim,
487 context.len()
488 ))
489 .into());
490 }
491
492 let predicted_reward = self.arm_weights[arm_idx].dot(context);
494 let error = reward - predicted_reward;
495
496 for i in 0..self.context_dim {
498 self.arm_weights[arm_idx][i] += self.learning_rate * error * context[i];
499 }
500
501 self.total_rounds += 1;
502 Ok(())
503 }
504
505 pub fn get_arm_weights(&self) -> &Vec<Array1<f64>> {
506 &self.arm_weights
507 }
508
509 pub fn predict_rewards(&self, context: &ArrayView1<f64>) -> Result<Array1<f64>> {
510 if context.len() != self.context_dim {
511 return Err(BanditError::BanditComputationFailed(format!(
512 "Context dimension mismatch: expected {}, got {}",
513 self.context_dim,
514 context.len()
515 ))
516 .into());
517 }
518
519 let mut predicted_rewards = Array1::zeros(self.arm_weights.len());
520 for (arm_idx, weights) in self.arm_weights.iter().enumerate() {
521 predicted_rewards[arm_idx] = weights.dot(context);
522 }
523
524 Ok(predicted_rewards)
525 }
526}
527
528#[derive(Debug, Clone)]
534pub struct BanditBasedActiveLearning {
535 pub strategy_names: Vec<String>,
537 pub bandit_algorithm: String,
539 pub reward_function: String,
541 pub random_state: Option<u64>,
543 epsilon_greedy: Option<EpsilonGreedy>,
544 ucb: Option<UpperConfidenceBound>,
545 thompson: Option<ThompsonSampling>,
546 contextual: Option<ContextualBandit>,
547}
548
549impl BanditBasedActiveLearning {
550 pub fn new(strategy_names: Vec<String>, bandit_algorithm: String) -> Self {
551 Self {
552 strategy_names,
553 bandit_algorithm,
554 reward_function: "accuracy_improvement".to_string(),
555 random_state: None,
556 epsilon_greedy: None,
557 ucb: None,
558 thompson: None,
559 contextual: None,
560 }
561 }
562
563 pub fn reward_function(mut self, reward_function: String) -> Self {
564 self.reward_function = reward_function;
565 self
566 }
567
568 pub fn random_state(mut self, random_state: u64) -> Self {
569 self.random_state = Some(random_state);
570 self
571 }
572
573 pub fn initialize(
574 &mut self,
575 epsilon: Option<f64>,
576 confidence: Option<f64>,
577 exploration: Option<f64>,
578 ) -> Result<()> {
579 let n_arms = self.strategy_names.len();
580
581 match self.bandit_algorithm.as_str() {
582 "epsilon_greedy" => {
583 let eps = epsilon.unwrap_or(0.1);
584 let mut eg = EpsilonGreedy::new(eps)?;
585 if let Some(seed) = self.random_state {
586 eg = eg.random_state(seed);
587 }
588 eg.initialize(n_arms);
589 self.epsilon_greedy = Some(eg);
590 }
591 "ucb" => {
592 let conf = confidence.unwrap_or(2.0);
593 let mut ucb = UpperConfidenceBound::new(conf)?;
594 if let Some(seed) = self.random_state {
595 ucb = ucb.random_state(seed);
596 }
597 ucb.initialize(n_arms);
598 self.ucb = Some(ucb);
599 }
600 "thompson_sampling" => {
601 let mut ts = ThompsonSampling::new();
602 if let Some(seed) = self.random_state {
603 ts = ts.random_state(seed);
604 }
605 ts.initialize(n_arms);
606 self.thompson = Some(ts);
607 }
608 "contextual" => {
609 let exp = exploration.unwrap_or(0.1);
610 let mut cb = ContextualBandit::new(exp)?;
611 if let Some(seed) = self.random_state {
612 cb = cb.random_state(seed);
613 }
614 self.contextual = Some(cb);
616 }
617 _ => {
618 return Err(BanditError::BanditComputationFailed(format!(
619 "Unknown bandit algorithm: {}",
620 self.bandit_algorithm
621 ))
622 .into())
623 }
624 }
625
626 Ok(())
627 }
628
629 pub fn select_strategy(&mut self, context: Option<&ArrayView1<f64>>) -> Result<usize> {
630 match self.bandit_algorithm.as_str() {
631 "epsilon_greedy" => {
632 if let Some(ref mut eg) = self.epsilon_greedy {
633 eg.select_arm()
634 } else {
635 Err(BanditError::BanditComputationFailed(
636 "Epsilon-greedy not initialized".to_string(),
637 )
638 .into())
639 }
640 }
641 "ucb" => {
642 if let Some(ref mut ucb) = self.ucb {
643 ucb.select_arm()
644 } else {
645 Err(
646 BanditError::BanditComputationFailed("UCB not initialized".to_string())
647 .into(),
648 )
649 }
650 }
651 "thompson_sampling" => {
652 if let Some(ref mut ts) = self.thompson {
653 ts.select_arm()
654 } else {
655 Err(BanditError::BanditComputationFailed(
656 "Thompson sampling not initialized".to_string(),
657 )
658 .into())
659 }
660 }
661 "contextual" => {
662 if let Some(context) = context {
663 if let Some(ref mut cb) = self.contextual {
664 if cb.context_dim == 0 {
665 cb.initialize(self.strategy_names.len(), context.len());
666 }
667 cb.select_arm(context)
668 } else {
669 Err(BanditError::BanditComputationFailed(
670 "Contextual bandit not initialized".to_string(),
671 )
672 .into())
673 }
674 } else {
675 Err(BanditError::BanditComputationFailed(
676 "Context required for contextual bandit".to_string(),
677 )
678 .into())
679 }
680 }
681 _ => Err(BanditError::BanditComputationFailed(format!(
682 "Unknown bandit algorithm: {}",
683 self.bandit_algorithm
684 ))
685 .into()),
686 }
687 }
688
689 pub fn update_strategy(
690 &mut self,
691 strategy_idx: usize,
692 reward: f64,
693 context: Option<&ArrayView1<f64>>,
694 ) -> Result<()> {
695 match self.bandit_algorithm.as_str() {
696 "epsilon_greedy" => {
697 if let Some(ref mut eg) = self.epsilon_greedy {
698 eg.update(strategy_idx, reward)
699 } else {
700 Err(BanditError::BanditComputationFailed(
701 "Epsilon-greedy not initialized".to_string(),
702 )
703 .into())
704 }
705 }
706 "ucb" => {
707 if let Some(ref mut ucb) = self.ucb {
708 ucb.update(strategy_idx, reward)
709 } else {
710 Err(
711 BanditError::BanditComputationFailed("UCB not initialized".to_string())
712 .into(),
713 )
714 }
715 }
716 "thompson_sampling" => {
717 if let Some(ref mut ts) = self.thompson {
718 ts.update(strategy_idx, reward)
719 } else {
720 Err(BanditError::BanditComputationFailed(
721 "Thompson sampling not initialized".to_string(),
722 )
723 .into())
724 }
725 }
726 "contextual" => {
727 if let Some(context) = context {
728 if let Some(ref mut cb) = self.contextual {
729 cb.update(strategy_idx, context, reward)
730 } else {
731 Err(BanditError::BanditComputationFailed(
732 "Contextual bandit not initialized".to_string(),
733 )
734 .into())
735 }
736 } else {
737 Err(BanditError::BanditComputationFailed(
738 "Context required for contextual bandit".to_string(),
739 )
740 .into())
741 }
742 }
743 _ => Err(BanditError::BanditComputationFailed(format!(
744 "Unknown bandit algorithm: {}",
745 self.bandit_algorithm
746 ))
747 .into()),
748 }
749 }
750
751 pub fn get_strategy_performance(&self) -> Result<HashMap<String, f64>> {
752 let mut performance = HashMap::new();
753
754 match self.bandit_algorithm.as_str() {
755 "epsilon_greedy" => {
756 if let Some(ref eg) = self.epsilon_greedy {
757 let stats = eg.get_arm_statistics();
758 for (idx, (_, avg_reward, _)) in stats.iter().enumerate() {
759 if idx < self.strategy_names.len() {
760 performance.insert(self.strategy_names[idx].clone(), *avg_reward);
761 }
762 }
763 }
764 }
765 "ucb" => {
766 if let Some(ref ucb) = self.ucb {
767 let stats = ucb.get_arm_statistics();
768 for (idx, (_, avg_reward, _, _)) in stats.iter().enumerate() {
769 if idx < self.strategy_names.len() {
770 performance.insert(self.strategy_names[idx].clone(), *avg_reward);
771 }
772 }
773 }
774 }
775 "thompson_sampling" => {
776 if let Some(ref ts) = self.thompson {
777 let stats = ts.get_arm_statistics();
778 for (idx, (_, _, mean, _)) in stats.iter().enumerate() {
779 if idx < self.strategy_names.len() {
780 performance.insert(self.strategy_names[idx].clone(), *mean);
781 }
782 }
783 }
784 }
785 "contextual" => {
786 for name in self.strategy_names.iter() {
789 performance.insert(name.clone(), 0.5);
790 }
791 }
792 _ => {
793 return Err(BanditError::BanditComputationFailed(format!(
794 "Unknown bandit algorithm: {}",
795 self.bandit_algorithm
796 ))
797 .into())
798 }
799 }
800
801 Ok(performance)
802 }
803}
804
805#[allow(non_snake_case)]
806#[cfg(test)]
807mod tests {
808 use super::*;
809 use approx::assert_abs_diff_eq;
810 use scirs2_core::array;
811
812 #[test]
813 fn test_epsilon_greedy_creation() {
814 let eg = EpsilonGreedy::new(0.1).unwrap();
815 assert_eq!(eg.epsilon, 0.1);
816 assert_eq!(eg.decay_rate, 0.995);
817 assert_eq!(eg.min_epsilon, 0.01);
818 }
819
820 #[test]
821 fn test_epsilon_greedy_invalid_epsilon() {
822 assert!(EpsilonGreedy::new(-0.1).is_err());
823 assert!(EpsilonGreedy::new(1.5).is_err());
824 }
825
826 #[test]
827 fn test_epsilon_greedy_basic_functionality() {
828 let mut eg = EpsilonGreedy::new(0.5).unwrap().random_state(42);
829 eg.initialize(3);
830
831 for _ in 0..10 {
833 let arm = eg.select_arm().unwrap();
834 assert!(arm < 3);
835
836 let reward = if arm == 0 { 1.0 } else { 0.0 }; eg.update(arm, reward).unwrap();
838 }
839
840 let stats = eg.get_arm_statistics();
841 assert_eq!(stats.len(), 3);
842
843 if stats[0].0 > 0 {
845 assert!(stats[0].1 >= stats[1].1 && stats[0].1 >= stats[2].1);
846 }
847 }
848
849 #[test]
850 fn test_upper_confidence_bound_creation() {
851 let ucb = UpperConfidenceBound::new(2.0).unwrap();
852 assert_eq!(ucb.confidence, 2.0);
853 }
854
855 #[test]
856 fn test_upper_confidence_bound_invalid_confidence() {
857 assert!(UpperConfidenceBound::new(0.0).is_err());
858 assert!(UpperConfidenceBound::new(-1.0).is_err());
859 }
860
861 #[test]
862 fn test_upper_confidence_bound_basic_functionality() {
863 let mut ucb = UpperConfidenceBound::new(2.0).unwrap().random_state(42);
864 ucb.initialize(3);
865
866 for _ in 0..10 {
868 let arm = ucb.select_arm().unwrap();
869 assert!(arm < 3);
870
871 let reward = if arm == 0 { 0.8 } else { 0.2 }; ucb.update(arm, reward).unwrap();
873 }
874
875 let stats = ucb.get_arm_statistics();
876 assert_eq!(stats.len(), 3);
877
878 for (count, _, _, _) in stats.iter() {
880 assert!(*count > 0);
881 }
882 }
883
884 #[test]
885 fn test_thompson_sampling_creation() {
886 let ts = ThompsonSampling::new();
887 assert!(ts.alpha_params.is_empty());
888 assert!(ts.beta_params.is_empty());
889 }
890
891 #[test]
892 fn test_thompson_sampling_basic_functionality() {
893 let mut ts = ThompsonSampling::new().random_state(42);
894 ts.initialize(3);
895
896 for _ in 0..10 {
898 let arm = ts.select_arm().unwrap();
899 assert!(arm < 3);
900
901 let reward = if arm == 0 { 0.9 } else { 0.1 }; ts.update(arm, reward).unwrap();
903 }
904
905 let stats = ts.get_arm_statistics();
906 assert_eq!(stats.len(), 3);
907
908 for (alpha, beta, mean, _) in stats.iter() {
910 assert!(*alpha >= 1.0);
911 assert!(*beta >= 1.0);
912 assert!(*mean >= 0.0 && *mean <= 1.0);
913 }
914 }
915
916 #[test]
917 fn test_contextual_bandit_creation() {
918 let cb = ContextualBandit::new(0.1).unwrap();
919 assert_eq!(cb.exploration, 0.1);
920 assert_eq!(cb.learning_rate, 0.1);
921 }
922
923 #[test]
924 fn test_contextual_bandit_invalid_exploration() {
925 assert!(ContextualBandit::new(-0.1).is_err());
926 }
927
928 #[test]
929 fn test_contextual_bandit_basic_functionality() {
930 let mut cb = ContextualBandit::new(0.1).unwrap().random_state(42);
931 cb.initialize(2, 3);
932
933 let context1 = array![1.0, 0.0, 0.0];
934 let context2 = array![0.0, 1.0, 0.0];
935
936 for i in 0..10 {
938 let context = if i % 2 == 0 { &context1 } else { &context2 };
939 let arm = cb.select_arm(&context.view()).unwrap();
940 assert!(arm < 2);
941
942 let reward = if (arm == 0 && i % 2 == 0) || (arm == 1 && i % 2 == 1) {
943 1.0
944 } else {
945 0.0
946 };
947 cb.update(arm, &context.view(), reward).unwrap();
948 }
949
950 let weights = cb.get_arm_weights();
951 assert_eq!(weights.len(), 2);
952 assert_eq!(weights[0].len(), 3);
953 assert_eq!(weights[1].len(), 3);
954
955 let predicted = cb.predict_rewards(&context1.view()).unwrap();
957 assert_eq!(predicted.len(), 2);
958 }
959
960 #[test]
961 fn test_bandit_based_active_learning() {
962 let strategies = vec![
963 "entropy".to_string(),
964 "margin".to_string(),
965 "random".to_string(),
966 ];
967 let mut bbal =
968 BanditBasedActiveLearning::new(strategies.clone(), "epsilon_greedy".to_string())
969 .random_state(42);
970
971 bbal.initialize(Some(0.2), None, None).unwrap();
972
973 for _ in 0..10 {
975 let strategy_idx = bbal.select_strategy(None).unwrap();
976 assert!(strategy_idx < strategies.len());
977
978 let reward = if strategy_idx == 0 { 0.8 } else { 0.3 }; bbal.update_strategy(strategy_idx, reward, None).unwrap();
980 }
981
982 let performance = bbal.get_strategy_performance().unwrap();
983 assert_eq!(performance.len(), strategies.len());
984
985 for strategy in strategies.iter() {
986 assert!(performance.contains_key(strategy));
987 }
988 }
989
990 #[test]
991 fn test_bandit_based_active_learning_ucb() {
992 let strategies = vec!["uncertainty".to_string(), "diversity".to_string()];
993 let mut bbal =
994 BanditBasedActiveLearning::new(strategies.clone(), "ucb".to_string()).random_state(42);
995
996 bbal.initialize(None, Some(1.5), None).unwrap();
997
998 let strategy_idx = bbal.select_strategy(None).unwrap();
1000 assert!(strategy_idx < strategies.len());
1001
1002 bbal.update_strategy(strategy_idx, 0.5, None).unwrap();
1003
1004 let performance = bbal.get_strategy_performance().unwrap();
1005 assert_eq!(performance.len(), strategies.len());
1006 }
1007
1008 #[test]
1009 fn test_bandit_based_active_learning_thompson() {
1010 let strategies = vec!["query1".to_string(), "query2".to_string()];
1011 let mut bbal =
1012 BanditBasedActiveLearning::new(strategies.clone(), "thompson_sampling".to_string())
1013 .random_state(42);
1014
1015 bbal.initialize(None, None, None).unwrap();
1016
1017 let strategy_idx = bbal.select_strategy(None).unwrap();
1019 assert!(strategy_idx < strategies.len());
1020
1021 bbal.update_strategy(strategy_idx, 0.7, None).unwrap();
1022
1023 let performance = bbal.get_strategy_performance().unwrap();
1024 assert_eq!(performance.len(), strategies.len());
1025 }
1026
1027 #[test]
1028 fn test_bandit_based_active_learning_contextual() {
1029 let strategies = vec![
1030 "context_strategy1".to_string(),
1031 "context_strategy2".to_string(),
1032 ];
1033 let mut bbal = BanditBasedActiveLearning::new(strategies.clone(), "contextual".to_string())
1034 .random_state(42);
1035
1036 bbal.initialize(None, None, Some(0.15)).unwrap();
1037
1038 let context = array![0.5, 1.0, 0.2];
1039
1040 let strategy_idx = bbal.select_strategy(Some(&context.view())).unwrap();
1042 assert!(strategy_idx < strategies.len());
1043
1044 bbal.update_strategy(strategy_idx, 0.6, Some(&context.view()))
1045 .unwrap();
1046
1047 let performance = bbal.get_strategy_performance().unwrap();
1048 assert_eq!(performance.len(), strategies.len());
1049 }
1050}