1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8use std::collections::HashSet;
9
10use scirs2_core::rand_prelude::SliceRandom;
12
13use super::core::{rng_utils, Sampler, SamplerIterator};
14
15#[derive(Clone, Debug, PartialEq)]
20pub enum AcquisitionStrategy {
21 UncertaintySampling,
26
27 ExpectedInformationGain,
32
33 DiversitySampling { num_clusters: usize },
42
43 UncertaintyDiversity { alpha: f64 },
54
55 QueryByCommittee,
60
61 ExpectedModelChange,
66}
67
68impl Default for AcquisitionStrategy {
69 fn default() -> Self {
70 AcquisitionStrategy::UncertaintySampling
71 }
72}
73
74#[derive(Clone)]
103pub struct ActiveLearningSampler {
104 uncertainties: Vec<f64>,
105 acquisition_strategy: AcquisitionStrategy,
106 num_samples: usize,
107 budget_per_round: usize,
108 current_round: usize,
109 labeled_indices: HashSet<usize>,
110 generator: Option<u64>,
111}
112
113impl ActiveLearningSampler {
114 pub fn new(
134 dataset_size: usize,
135 acquisition_strategy: AcquisitionStrategy,
136 budget_per_round: usize,
137 ) -> Self {
138 Self {
139 uncertainties: vec![0.0; dataset_size],
140 acquisition_strategy,
141 num_samples: dataset_size,
142 budget_per_round,
143 current_round: 0,
144 labeled_indices: HashSet::new(),
145 generator: None,
146 }
147 }
148
149 pub fn with_initial_labeled(
158 dataset_size: usize,
159 acquisition_strategy: AcquisitionStrategy,
160 budget_per_round: usize,
161 initial_labeled: &[usize],
162 ) -> Self {
163 let mut sampler = Self::new(dataset_size, acquisition_strategy, budget_per_round);
164 for &idx in initial_labeled {
165 sampler.labeled_indices.insert(idx);
166 }
167 sampler
168 }
169
170 pub fn update_uncertainties(&mut self, uncertainties: Vec<f64>) {
180 assert!(uncertainties.len() == self.num_samples, "assertion failed");
181 self.uncertainties = uncertainties;
182 }
183
184 pub fn add_labeled_samples(&mut self, indices: &[usize]) {
190 for &idx in indices {
191 if idx < self.num_samples {
192 self.labeled_indices.insert(idx);
193 }
194 }
195 self.current_round += 1;
196 }
197
198 pub fn remove_labeled_samples(&mut self, indices: &[usize]) {
204 for &idx in indices {
205 self.labeled_indices.remove(&idx);
206 }
207 }
208
209 pub fn with_generator(mut self, seed: u64) -> Self {
215 self.generator = Some(seed);
216 self
217 }
218
219 pub fn current_round(&self) -> usize {
221 self.current_round
222 }
223
224 pub fn budget_per_round(&self) -> usize {
226 self.budget_per_round
227 }
228
229 pub fn strategy(&self) -> &AcquisitionStrategy {
231 &self.acquisition_strategy
232 }
233
234 pub fn num_labeled(&self) -> usize {
236 self.labeled_indices.len()
237 }
238
239 pub fn num_unlabeled(&self) -> usize {
241 self.num_samples - self.labeled_indices.len()
242 }
243
244 pub fn labeled_indices(&self) -> Vec<usize> {
246 self.labeled_indices.iter().copied().collect()
247 }
248
249 pub fn get_unlabeled_indices(&self) -> Vec<usize> {
251 (0..self.num_samples)
252 .filter(|idx| !self.labeled_indices.contains(idx))
253 .collect()
254 }
255
256 pub fn is_labeled(&self, index: usize) -> bool {
258 self.labeled_indices.contains(&index)
259 }
260
261 pub fn set_strategy(&mut self, strategy: AcquisitionStrategy) {
267 self.acquisition_strategy = strategy;
268 }
269
270 pub fn set_budget(&mut self, budget: usize) {
276 self.budget_per_round = budget;
277 }
278
279 pub fn reset(&mut self) {
281 self.labeled_indices.clear();
282 self.current_round = 0;
283 }
284
285 pub fn active_learning_stats(&self) -> ActiveLearningStats {
287 let unlabeled_count = self.num_unlabeled();
288 let available_budget = self.budget_per_round.min(unlabeled_count);
289
290 ActiveLearningStats {
291 current_round: self.current_round,
292 num_labeled: self.num_labeled(),
293 num_unlabeled: unlabeled_count,
294 total_samples: self.num_samples,
295 budget_per_round: self.budget_per_round,
296 available_budget,
297 labeling_ratio: self.num_labeled() as f64 / self.num_samples as f64,
298 }
299 }
300
301 fn select_samples(&self) -> Vec<usize> {
303 let unlabeled = self.get_unlabeled_indices();
304 let budget = self.budget_per_round.min(unlabeled.len());
305
306 if budget == 0 {
307 return Vec::new();
308 }
309
310 match &self.acquisition_strategy {
311 AcquisitionStrategy::UncertaintySampling => {
312 self.uncertainty_sampling(&unlabeled, budget)
313 }
314 AcquisitionStrategy::ExpectedInformationGain => {
315 self.information_gain_sampling(&unlabeled, budget)
316 }
317 AcquisitionStrategy::DiversitySampling { num_clusters } => {
318 self.diversity_sampling(&unlabeled, budget, *num_clusters)
319 }
320 AcquisitionStrategy::UncertaintyDiversity { alpha } => {
321 self.uncertainty_diversity_sampling(&unlabeled, budget, *alpha)
322 }
323 AcquisitionStrategy::QueryByCommittee => self.query_by_committee(&unlabeled, budget),
324 AcquisitionStrategy::ExpectedModelChange => {
325 self.expected_model_change(&unlabeled, budget)
326 }
327 }
328 }
329
330 fn uncertainty_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
332 let mut scored: Vec<_> = unlabeled
333 .iter()
334 .map(|&idx| (idx, self.uncertainties[idx]))
335 .collect();
336
337 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
339
340 scored
341 .into_iter()
342 .take(budget)
343 .map(|(idx, _)| idx)
344 .collect()
345 }
346
347 fn information_gain_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
349 let mut scored: Vec<_> = unlabeled
352 .iter()
353 .map(|&idx| {
354 let ig = self.uncertainties[idx] * (1.0 + (idx as f64).ln());
355 (idx, ig)
356 })
357 .collect();
358
359 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
360 scored
361 .into_iter()
362 .take(budget)
363 .map(|(idx, _)| idx)
364 .collect()
365 }
366
367 fn diversity_sampling(
369 &self,
370 unlabeled: &[usize],
371 budget: usize,
372 num_clusters: usize,
373 ) -> Vec<usize> {
374 if num_clusters == 0 {
375 return self.uncertainty_sampling(unlabeled, budget);
376 }
377
378 let mut rng = rng_utils::create_rng(self.generator);
380
381 let mut indices = unlabeled.to_vec();
383 indices.shuffle(&mut rng);
384
385 let cluster_size = (unlabeled.len() / num_clusters).max(1);
386 let base_samples_per_cluster = budget / num_clusters;
387 let extra_samples = budget % num_clusters;
388
389 let mut selected = Vec::new();
390 let mut cluster_idx = 0;
391
392 for cluster_start in (0..indices.len()).step_by(cluster_size) {
393 let cluster_end = (cluster_start + cluster_size).min(indices.len());
394 let cluster = &indices[cluster_start..cluster_end];
395
396 let cluster_samples_count = if cluster_idx < extra_samples {
398 base_samples_per_cluster + 1
399 } else {
400 base_samples_per_cluster
401 };
402
403 if cluster_samples_count == 0 {
404 cluster_idx += 1;
405 continue;
406 }
407
408 let mut cluster_scored: Vec<_> = cluster
410 .iter()
411 .map(|&idx| (idx, self.uncertainties[idx]))
412 .collect();
413
414 cluster_scored
415 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
416
417 let cluster_samples = cluster_scored
418 .into_iter()
419 .take(cluster_samples_count)
420 .map(|(idx, _)| idx);
421
422 selected.extend(cluster_samples);
423 cluster_idx += 1;
424
425 if selected.len() >= budget {
426 break;
427 }
428 }
429
430 selected.truncate(budget);
431 selected
432 }
433
434 fn uncertainty_diversity_sampling(
436 &self,
437 unlabeled: &[usize],
438 budget: usize,
439 alpha: f64,
440 ) -> Vec<usize> {
441 let alpha = alpha.clamp(0.0, 1.0);
442
443 let uncertainty_count = (budget as f64 * alpha) as usize;
445 let diversity_count = budget - uncertainty_count;
446
447 let mut selected = self.uncertainty_sampling(unlabeled, uncertainty_count);
448
449 if diversity_count > 0 {
450 let remaining: Vec<usize> = unlabeled
452 .iter()
453 .filter(|idx| !selected.contains(idx))
454 .copied()
455 .collect();
456
457 let diversity_samples = self.diversity_sampling(&remaining, diversity_count, 3);
458 selected.extend(diversity_samples);
459 }
460
461 selected
462 }
463
464 fn query_by_committee(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
466 self.uncertainty_sampling(unlabeled, budget)
469 }
470
471 fn expected_model_change(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
473 let mut scored: Vec<_> = unlabeled
475 .iter()
476 .map(|&idx| {
477 let change_score =
478 self.uncertainties[idx] * (1.0 + idx as f64 / unlabeled.len() as f64);
479 (idx, change_score)
480 })
481 .collect();
482
483 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
484 scored
485 .into_iter()
486 .take(budget)
487 .map(|(idx, _)| idx)
488 .collect()
489 }
490}
491
492impl Sampler for ActiveLearningSampler {
493 type Iter = SamplerIterator;
494
495 fn iter(&self) -> Self::Iter {
496 let indices = self.select_samples();
497 SamplerIterator::new(indices)
498 }
499
500 fn len(&self) -> usize {
501 let unlabeled_count = self.get_unlabeled_indices().len();
502 self.budget_per_round.min(unlabeled_count)
503 }
504}
505
506#[derive(Debug, Clone, PartialEq)]
508pub struct ActiveLearningStats {
509 pub current_round: usize,
511 pub num_labeled: usize,
513 pub num_unlabeled: usize,
515 pub total_samples: usize,
517 pub budget_per_round: usize,
519 pub available_budget: usize,
521 pub labeling_ratio: f64,
523}
524
525pub fn uncertainty_sampler(
535 dataset_size: usize,
536 budget_per_round: usize,
537 seed: Option<u64>,
538) -> ActiveLearningSampler {
539 let mut sampler = ActiveLearningSampler::new(
540 dataset_size,
541 AcquisitionStrategy::UncertaintySampling,
542 budget_per_round,
543 );
544 if let Some(s) = seed {
545 sampler = sampler.with_generator(s);
546 }
547 sampler
548}
549
550pub fn diversity_sampler(
561 dataset_size: usize,
562 budget_per_round: usize,
563 num_clusters: usize,
564 seed: Option<u64>,
565) -> ActiveLearningSampler {
566 let mut sampler = ActiveLearningSampler::new(
567 dataset_size,
568 AcquisitionStrategy::DiversitySampling { num_clusters },
569 budget_per_round,
570 );
571 if let Some(s) = seed {
572 sampler = sampler.with_generator(s);
573 }
574 sampler
575}
576
577pub fn uncertainty_diversity_sampler(
589 dataset_size: usize,
590 budget_per_round: usize,
591 alpha: f64,
592 seed: Option<u64>,
593) -> ActiveLearningSampler {
594 let mut sampler = ActiveLearningSampler::new(
595 dataset_size,
596 AcquisitionStrategy::UncertaintyDiversity { alpha },
597 budget_per_round,
598 );
599 if let Some(s) = seed {
600 sampler = sampler.with_generator(s);
601 }
602 sampler
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_active_learning_sampler_basic() {
611 let mut sampler =
612 ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 10)
613 .with_generator(42);
614
615 assert_eq!(sampler.num_samples, 100);
616 assert_eq!(sampler.budget_per_round(), 10);
617 assert_eq!(sampler.current_round(), 0);
618 assert_eq!(sampler.num_labeled(), 0);
619 assert_eq!(sampler.num_unlabeled(), 100);
620 assert_eq!(sampler.generator, Some(42));
621
622 let uncertainties: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
624 sampler.update_uncertainties(uncertainties);
625
626 let indices: Vec<usize> = sampler.iter().collect();
628 assert_eq!(indices.len(), 10);
629
630 for &idx in &indices {
632 assert!(idx >= 90); }
634
635 sampler.add_labeled_samples(&indices);
637 assert_eq!(sampler.num_labeled(), 10);
638 assert_eq!(sampler.num_unlabeled(), 90);
639 assert_eq!(sampler.current_round(), 1);
640 }
641
642 #[test]
643 fn test_acquisition_strategies() {
644 let dataset_size = 50;
645 let budget = 5;
646
647 let strategies = vec![
649 AcquisitionStrategy::UncertaintySampling,
650 AcquisitionStrategy::ExpectedInformationGain,
651 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
652 AcquisitionStrategy::UncertaintyDiversity { alpha: 0.5 },
653 AcquisitionStrategy::QueryByCommittee,
654 AcquisitionStrategy::ExpectedModelChange,
655 ];
656
657 for strategy in strategies {
658 let mut sampler = ActiveLearningSampler::new(dataset_size, strategy.clone(), budget)
659 .with_generator(42);
660
661 let uncertainties: Vec<f64> = (0..dataset_size)
663 .map(|i| if i < 10 { 0.9 } else { 0.1 })
664 .collect();
665 sampler.update_uncertainties(uncertainties);
666
667 let indices: Vec<usize> = sampler.iter().collect();
668 assert_eq!(indices.len(), budget);
669
670 match strategy {
673 AcquisitionStrategy::DiversitySampling { .. } => {
674 assert!(indices.iter().all(|&idx| idx < dataset_size));
676 }
677 _ => {
678 let high_uncertainty_count = indices.iter().filter(|&&idx| idx < 10).count();
680 assert!(high_uncertainty_count > 0);
681 }
682 }
683 }
684 }
685
686 #[test]
687 fn test_active_learning_with_initial_labeled() {
688 let initial_labeled = vec![0, 1, 2, 3, 4];
689 let mut sampler = ActiveLearningSampler::with_initial_labeled(
690 100,
691 AcquisitionStrategy::UncertaintySampling,
692 5,
693 &initial_labeled,
694 );
695
696 assert_eq!(sampler.num_labeled(), 5);
697 assert_eq!(sampler.num_unlabeled(), 95);
698
699 for &idx in &initial_labeled {
700 assert!(sampler.is_labeled(idx));
701 }
702
703 let uncertainties = vec![0.5; 100];
705 sampler.update_uncertainties(uncertainties);
706
707 let indices: Vec<usize> = sampler.iter().collect();
709 assert_eq!(indices.len(), 5);
710
711 for &idx in &indices {
712 assert!(!initial_labeled.contains(&idx));
713 }
714 }
715
716 #[test]
717 fn test_uncertainty_diversity_sampling() {
718 let mut sampler = ActiveLearningSampler::new(
719 20,
720 AcquisitionStrategy::UncertaintyDiversity { alpha: 0.6 },
721 10,
722 )
723 .with_generator(42);
724
725 let uncertainties: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
727 sampler.update_uncertainties(uncertainties);
728
729 let indices: Vec<usize> = sampler.iter().collect();
730 assert_eq!(indices.len(), 10);
731
732 let high_uncertainty_count = indices.iter().filter(|&&idx| idx >= 15).count();
735 assert!(high_uncertainty_count > 0); assert!(high_uncertainty_count < indices.len()); }
738
739 #[test]
740 fn test_diversity_sampling() {
741 let mut sampler = ActiveLearningSampler::new(
742 30,
743 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
744 9,
745 )
746 .with_generator(42);
747
748 let uncertainties = vec![0.5; 30]; sampler.update_uncertainties(uncertainties);
750
751 let indices: Vec<usize> = sampler.iter().collect();
752 assert_eq!(indices.len(), 9);
753
754 assert!(indices.iter().all(|&idx| idx < 30));
757 }
758
759 #[test]
760 fn test_active_learning_stats() {
761 let mut sampler =
762 ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 15);
763
764 let stats = sampler.active_learning_stats();
765 assert_eq!(stats.current_round, 0);
766 assert_eq!(stats.num_labeled, 0);
767 assert_eq!(stats.num_unlabeled, 100);
768 assert_eq!(stats.total_samples, 100);
769 assert_eq!(stats.budget_per_round, 15);
770 assert_eq!(stats.available_budget, 15);
771 assert_eq!(stats.labeling_ratio, 0.0);
772
773 sampler.add_labeled_samples(&[0, 1, 2, 3, 4]);
775
776 let stats = sampler.active_learning_stats();
777 assert_eq!(stats.current_round, 1);
778 assert_eq!(stats.num_labeled, 5);
779 assert_eq!(stats.num_unlabeled, 95);
780 assert_eq!(stats.labeling_ratio, 0.05);
781 }
782
783 #[test]
784 fn test_sampler_methods() {
785 let mut sampler =
786 ActiveLearningSampler::new(50, AcquisitionStrategy::UncertaintySampling, 10);
787
788 assert_eq!(sampler.get_unlabeled_indices().len(), 50);
790 assert!(sampler.labeled_indices().is_empty());
791
792 sampler.add_labeled_samples(&[5, 15, 25]);
793 assert_eq!(sampler.num_labeled(), 3);
794 assert!(sampler.is_labeled(5));
795 assert!(sampler.is_labeled(15));
796 assert!(sampler.is_labeled(25));
797 assert!(!sampler.is_labeled(0));
798
799 let labeled = sampler.labeled_indices();
800 assert_eq!(labeled.len(), 3);
801 assert!(labeled.contains(&5));
802 assert!(labeled.contains(&15));
803 assert!(labeled.contains(&25));
804
805 sampler.remove_labeled_samples(&[15]);
807 assert_eq!(sampler.num_labeled(), 2);
808 assert!(!sampler.is_labeled(15));
809
810 sampler.set_strategy(AcquisitionStrategy::DiversitySampling { num_clusters: 4 });
812 assert!(matches!(
813 sampler.strategy(),
814 AcquisitionStrategy::DiversitySampling { num_clusters: 4 }
815 ));
816
817 sampler.set_budget(5);
819 assert_eq!(sampler.budget_per_round(), 5);
820
821 sampler.reset();
823 assert_eq!(sampler.num_labeled(), 0);
824 assert_eq!(sampler.current_round(), 0);
825 }
826
827 #[test]
828 fn test_convenience_functions() {
829 let uncertainty = uncertainty_sampler(100, 10, Some(42));
831 assert!(matches!(
832 uncertainty.strategy(),
833 AcquisitionStrategy::UncertaintySampling
834 ));
835 assert_eq!(uncertainty.budget_per_round(), 10);
836
837 let diversity = diversity_sampler(100, 10, 5, Some(42));
839 assert!(matches!(
840 diversity.strategy(),
841 AcquisitionStrategy::DiversitySampling { num_clusters: 5 }
842 ));
843
844 let combined = uncertainty_diversity_sampler(100, 10, 0.7, Some(42));
846 assert!(matches!(
847 combined.strategy(),
848 AcquisitionStrategy::UncertaintyDiversity { alpha } if (*alpha - 0.7).abs() < f64::EPSILON
849 ));
850 }
851
852 #[test]
853 fn test_edge_cases() {
854 let mut sampler =
856 ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 0);
857 assert_eq!(sampler.len(), 0);
858 let indices: Vec<usize> = sampler.iter().collect();
859 assert!(indices.is_empty());
860
861 sampler.set_budget(5);
863 sampler.add_labeled_samples(&(0..10).collect::<Vec<_>>());
864 assert_eq!(sampler.num_unlabeled(), 0);
865 assert_eq!(sampler.len(), 0);
866
867 let large_budget =
869 ActiveLearningSampler::new(5, AcquisitionStrategy::UncertaintySampling, 10);
870 assert_eq!(large_budget.len(), 5); let mut clamped = ActiveLearningSampler::new(
874 10,
875 AcquisitionStrategy::UncertaintyDiversity { alpha: 1.5 },
876 5,
877 );
878 let uncertainties = vec![0.5; 10];
879 clamped.update_uncertainties(uncertainties);
880 let indices: Vec<usize> = clamped.iter().collect();
881 assert_eq!(indices.len(), 5); let mut zero_clusters = ActiveLearningSampler::new(
885 10,
886 AcquisitionStrategy::DiversitySampling { num_clusters: 0 },
887 3,
888 );
889 let uncertainties = vec![0.5; 10];
890 zero_clusters.update_uncertainties(uncertainties);
891 let indices: Vec<usize> = zero_clusters.iter().collect();
892 assert_eq!(indices.len(), 3); }
894
895 #[test]
896 #[should_panic(expected = "assertion failed")]
897 fn test_update_uncertainties_wrong_size() {
898 let mut sampler =
899 ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 5);
900 sampler.update_uncertainties(vec![0.5; 5]);
902 }
903
904 #[test]
905 fn test_acquisition_strategy_equality() {
906 assert_eq!(
907 AcquisitionStrategy::UncertaintySampling,
908 AcquisitionStrategy::UncertaintySampling
909 );
910 assert_eq!(
911 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
912 AcquisitionStrategy::DiversitySampling { num_clusters: 3 }
913 );
914 assert_ne!(
915 AcquisitionStrategy::UncertaintySampling,
916 AcquisitionStrategy::ExpectedInformationGain
917 );
918 }
919
920 #[test]
921 fn test_acquisition_strategy_default() {
922 assert_eq!(
923 AcquisitionStrategy::default(),
924 AcquisitionStrategy::UncertaintySampling
925 );
926 }
927
928 #[test]
929 fn test_reproducibility() {
930 let mut sampler1 = ActiveLearningSampler::new(
931 20,
932 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
933 5,
934 )
935 .with_generator(123);
936
937 let mut sampler2 = ActiveLearningSampler::new(
938 20,
939 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
940 5,
941 )
942 .with_generator(123);
943
944 let uncertainties = vec![0.5; 20];
945 sampler1.update_uncertainties(uncertainties.clone());
946 sampler2.update_uncertainties(uncertainties);
947
948 let indices1: Vec<usize> = sampler1.iter().collect();
949 let indices2: Vec<usize> = sampler2.iter().collect();
950
951 assert_eq!(indices1, indices2);
952 }
953}