1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8use std::collections::HashSet;
9
10use scirs2_core::rand_prelude::SliceRandom;
12use scirs2_core::random::{Random, SeedableRng};
13
14use super::core::{rng_utils, Sampler, SamplerIterator};
15
16#[derive(Clone, Debug, PartialEq)]
21pub enum AcquisitionStrategy {
22 UncertaintySampling,
27
28 ExpectedInformationGain,
33
34 DiversitySampling { num_clusters: usize },
43
44 UncertaintyDiversity { alpha: f64 },
55
56 QueryByCommittee,
61
62 ExpectedModelChange,
67}
68
69impl Default for AcquisitionStrategy {
70 fn default() -> Self {
71 AcquisitionStrategy::UncertaintySampling
72 }
73}
74
75#[derive(Clone)]
104pub struct ActiveLearningSampler {
105 uncertainties: Vec<f64>,
106 acquisition_strategy: AcquisitionStrategy,
107 num_samples: usize,
108 budget_per_round: usize,
109 current_round: usize,
110 labeled_indices: HashSet<usize>,
111 generator: Option<u64>,
112}
113
114impl ActiveLearningSampler {
115 pub fn new(
135 dataset_size: usize,
136 acquisition_strategy: AcquisitionStrategy,
137 budget_per_round: usize,
138 ) -> Self {
139 Self {
140 uncertainties: vec![0.0; dataset_size],
141 acquisition_strategy,
142 num_samples: dataset_size,
143 budget_per_round,
144 current_round: 0,
145 labeled_indices: HashSet::new(),
146 generator: None,
147 }
148 }
149
150 pub fn with_initial_labeled(
159 dataset_size: usize,
160 acquisition_strategy: AcquisitionStrategy,
161 budget_per_round: usize,
162 initial_labeled: &[usize],
163 ) -> Self {
164 let mut sampler = Self::new(dataset_size, acquisition_strategy, budget_per_round);
165 for &idx in initial_labeled {
166 sampler.labeled_indices.insert(idx);
167 }
168 sampler
169 }
170
171 pub fn update_uncertainties(&mut self, uncertainties: Vec<f64>) {
181 assert!(uncertainties.len() == self.num_samples, "assertion failed");
182 self.uncertainties = uncertainties;
183 }
184
185 pub fn add_labeled_samples(&mut self, indices: &[usize]) {
191 for &idx in indices {
192 if idx < self.num_samples {
193 self.labeled_indices.insert(idx);
194 }
195 }
196 self.current_round += 1;
197 }
198
199 pub fn remove_labeled_samples(&mut self, indices: &[usize]) {
205 for &idx in indices {
206 self.labeled_indices.remove(&idx);
207 }
208 }
209
210 pub fn with_generator(mut self, seed: u64) -> Self {
216 self.generator = Some(seed);
217 self
218 }
219
220 pub fn current_round(&self) -> usize {
222 self.current_round
223 }
224
225 pub fn budget_per_round(&self) -> usize {
227 self.budget_per_round
228 }
229
230 pub fn strategy(&self) -> &AcquisitionStrategy {
232 &self.acquisition_strategy
233 }
234
235 pub fn num_labeled(&self) -> usize {
237 self.labeled_indices.len()
238 }
239
240 pub fn num_unlabeled(&self) -> usize {
242 self.num_samples - self.labeled_indices.len()
243 }
244
245 pub fn labeled_indices(&self) -> Vec<usize> {
247 self.labeled_indices.iter().copied().collect()
248 }
249
250 pub fn get_unlabeled_indices(&self) -> Vec<usize> {
252 (0..self.num_samples)
253 .filter(|idx| !self.labeled_indices.contains(idx))
254 .collect()
255 }
256
257 pub fn is_labeled(&self, index: usize) -> bool {
259 self.labeled_indices.contains(&index)
260 }
261
262 pub fn set_strategy(&mut self, strategy: AcquisitionStrategy) {
268 self.acquisition_strategy = strategy;
269 }
270
271 pub fn set_budget(&mut self, budget: usize) {
277 self.budget_per_round = budget;
278 }
279
280 pub fn reset(&mut self) {
282 self.labeled_indices.clear();
283 self.current_round = 0;
284 }
285
286 pub fn active_learning_stats(&self) -> ActiveLearningStats {
288 let unlabeled_count = self.num_unlabeled();
289 let available_budget = self.budget_per_round.min(unlabeled_count);
290
291 ActiveLearningStats {
292 current_round: self.current_round,
293 num_labeled: self.num_labeled(),
294 num_unlabeled: unlabeled_count,
295 total_samples: self.num_samples,
296 budget_per_round: self.budget_per_round,
297 available_budget,
298 labeling_ratio: self.num_labeled() as f64 / self.num_samples as f64,
299 }
300 }
301
302 fn select_samples(&self) -> Vec<usize> {
304 let unlabeled = self.get_unlabeled_indices();
305 let budget = self.budget_per_round.min(unlabeled.len());
306
307 if budget == 0 {
308 return Vec::new();
309 }
310
311 match &self.acquisition_strategy {
312 AcquisitionStrategy::UncertaintySampling => {
313 self.uncertainty_sampling(&unlabeled, budget)
314 }
315 AcquisitionStrategy::ExpectedInformationGain => {
316 self.information_gain_sampling(&unlabeled, budget)
317 }
318 AcquisitionStrategy::DiversitySampling { num_clusters } => {
319 self.diversity_sampling(&unlabeled, budget, *num_clusters)
320 }
321 AcquisitionStrategy::UncertaintyDiversity { alpha } => {
322 self.uncertainty_diversity_sampling(&unlabeled, budget, *alpha)
323 }
324 AcquisitionStrategy::QueryByCommittee => self.query_by_committee(&unlabeled, budget),
325 AcquisitionStrategy::ExpectedModelChange => {
326 self.expected_model_change(&unlabeled, budget)
327 }
328 }
329 }
330
331 fn uncertainty_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
333 let mut scored: Vec<_> = unlabeled
334 .iter()
335 .map(|&idx| (idx, self.uncertainties[idx]))
336 .collect();
337
338 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
340
341 scored
342 .into_iter()
343 .take(budget)
344 .map(|(idx, _)| idx)
345 .collect()
346 }
347
348 fn information_gain_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
350 let mut scored: Vec<_> = unlabeled
353 .iter()
354 .map(|&idx| {
355 let ig = self.uncertainties[idx] * (1.0 + (idx as f64).ln());
356 (idx, ig)
357 })
358 .collect();
359
360 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
361 scored
362 .into_iter()
363 .take(budget)
364 .map(|(idx, _)| idx)
365 .collect()
366 }
367
368 fn diversity_sampling(
370 &self,
371 unlabeled: &[usize],
372 budget: usize,
373 num_clusters: usize,
374 ) -> Vec<usize> {
375 if num_clusters == 0 {
376 return self.uncertainty_sampling(unlabeled, budget);
377 }
378
379 let mut rng = rng_utils::create_rng(self.generator);
381
382 let mut indices = unlabeled.to_vec();
384 indices.shuffle(&mut rng);
385
386 let cluster_size = (unlabeled.len() / num_clusters).max(1);
387 let base_samples_per_cluster = budget / num_clusters;
388 let extra_samples = budget % num_clusters;
389
390 let mut selected = Vec::new();
391 let mut cluster_idx = 0;
392
393 for cluster_start in (0..indices.len()).step_by(cluster_size) {
394 let cluster_end = (cluster_start + cluster_size).min(indices.len());
395 let cluster = &indices[cluster_start..cluster_end];
396
397 let cluster_samples_count = if cluster_idx < extra_samples {
399 base_samples_per_cluster + 1
400 } else {
401 base_samples_per_cluster
402 };
403
404 if cluster_samples_count == 0 {
405 cluster_idx += 1;
406 continue;
407 }
408
409 let mut cluster_scored: Vec<_> = cluster
411 .iter()
412 .map(|&idx| (idx, self.uncertainties[idx]))
413 .collect();
414
415 cluster_scored
416 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417
418 let cluster_samples = cluster_scored
419 .into_iter()
420 .take(cluster_samples_count)
421 .map(|(idx, _)| idx);
422
423 selected.extend(cluster_samples);
424 cluster_idx += 1;
425
426 if selected.len() >= budget {
427 break;
428 }
429 }
430
431 selected.truncate(budget);
432 selected
433 }
434
435 fn uncertainty_diversity_sampling(
437 &self,
438 unlabeled: &[usize],
439 budget: usize,
440 alpha: f64,
441 ) -> Vec<usize> {
442 let alpha = alpha.clamp(0.0, 1.0);
443
444 let uncertainty_count = (budget as f64 * alpha) as usize;
446 let diversity_count = budget - uncertainty_count;
447
448 let mut selected = self.uncertainty_sampling(unlabeled, uncertainty_count);
449
450 if diversity_count > 0 {
451 let remaining: Vec<usize> = unlabeled
453 .iter()
454 .filter(|idx| !selected.contains(idx))
455 .copied()
456 .collect();
457
458 let diversity_samples = self.diversity_sampling(&remaining, diversity_count, 3);
459 selected.extend(diversity_samples);
460 }
461
462 selected
463 }
464
465 fn query_by_committee(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
467 self.uncertainty_sampling(unlabeled, budget)
470 }
471
472 fn expected_model_change(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
474 let mut scored: Vec<_> = unlabeled
476 .iter()
477 .map(|&idx| {
478 let change_score =
479 self.uncertainties[idx] * (1.0 + idx as f64 / unlabeled.len() as f64);
480 (idx, change_score)
481 })
482 .collect();
483
484 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
485 scored
486 .into_iter()
487 .take(budget)
488 .map(|(idx, _)| idx)
489 .collect()
490 }
491}
492
493impl Sampler for ActiveLearningSampler {
494 type Iter = SamplerIterator;
495
496 fn iter(&self) -> Self::Iter {
497 let indices = self.select_samples();
498 SamplerIterator::new(indices)
499 }
500
501 fn len(&self) -> usize {
502 let unlabeled_count = self.get_unlabeled_indices().len();
503 self.budget_per_round.min(unlabeled_count)
504 }
505}
506
507#[derive(Debug, Clone, PartialEq)]
509pub struct ActiveLearningStats {
510 pub current_round: usize,
512 pub num_labeled: usize,
514 pub num_unlabeled: usize,
516 pub total_samples: usize,
518 pub budget_per_round: usize,
520 pub available_budget: usize,
522 pub labeling_ratio: f64,
524}
525
526pub fn uncertainty_sampler(
536 dataset_size: usize,
537 budget_per_round: usize,
538 seed: Option<u64>,
539) -> ActiveLearningSampler {
540 let mut sampler = ActiveLearningSampler::new(
541 dataset_size,
542 AcquisitionStrategy::UncertaintySampling,
543 budget_per_round,
544 );
545 if let Some(s) = seed {
546 sampler = sampler.with_generator(s);
547 }
548 sampler
549}
550
551pub fn diversity_sampler(
562 dataset_size: usize,
563 budget_per_round: usize,
564 num_clusters: usize,
565 seed: Option<u64>,
566) -> ActiveLearningSampler {
567 let mut sampler = ActiveLearningSampler::new(
568 dataset_size,
569 AcquisitionStrategy::DiversitySampling { num_clusters },
570 budget_per_round,
571 );
572 if let Some(s) = seed {
573 sampler = sampler.with_generator(s);
574 }
575 sampler
576}
577
578pub fn uncertainty_diversity_sampler(
590 dataset_size: usize,
591 budget_per_round: usize,
592 alpha: f64,
593 seed: Option<u64>,
594) -> ActiveLearningSampler {
595 let mut sampler = ActiveLearningSampler::new(
596 dataset_size,
597 AcquisitionStrategy::UncertaintyDiversity { alpha },
598 budget_per_round,
599 );
600 if let Some(s) = seed {
601 sampler = sampler.with_generator(s);
602 }
603 sampler
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[test]
611 fn test_active_learning_sampler_basic() {
612 let mut sampler =
613 ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 10)
614 .with_generator(42);
615
616 assert_eq!(sampler.num_samples, 100);
617 assert_eq!(sampler.budget_per_round(), 10);
618 assert_eq!(sampler.current_round(), 0);
619 assert_eq!(sampler.num_labeled(), 0);
620 assert_eq!(sampler.num_unlabeled(), 100);
621 assert_eq!(sampler.generator, Some(42));
622
623 let uncertainties: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
625 sampler.update_uncertainties(uncertainties);
626
627 let indices: Vec<usize> = sampler.iter().collect();
629 assert_eq!(indices.len(), 10);
630
631 for &idx in &indices {
633 assert!(idx >= 90); }
635
636 sampler.add_labeled_samples(&indices);
638 assert_eq!(sampler.num_labeled(), 10);
639 assert_eq!(sampler.num_unlabeled(), 90);
640 assert_eq!(sampler.current_round(), 1);
641 }
642
643 #[test]
644 fn test_acquisition_strategies() {
645 let dataset_size = 50;
646 let budget = 5;
647
648 let strategies = vec![
650 AcquisitionStrategy::UncertaintySampling,
651 AcquisitionStrategy::ExpectedInformationGain,
652 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
653 AcquisitionStrategy::UncertaintyDiversity { alpha: 0.5 },
654 AcquisitionStrategy::QueryByCommittee,
655 AcquisitionStrategy::ExpectedModelChange,
656 ];
657
658 for strategy in strategies {
659 let mut sampler = ActiveLearningSampler::new(dataset_size, strategy.clone(), budget)
660 .with_generator(42);
661
662 let uncertainties: Vec<f64> = (0..dataset_size)
664 .map(|i| if i < 10 { 0.9 } else { 0.1 })
665 .collect();
666 sampler.update_uncertainties(uncertainties);
667
668 let indices: Vec<usize> = sampler.iter().collect();
669 assert_eq!(indices.len(), budget);
670
671 match strategy {
674 AcquisitionStrategy::DiversitySampling { .. } => {
675 assert!(indices.iter().all(|&idx| idx < dataset_size));
677 }
678 _ => {
679 let high_uncertainty_count = indices.iter().filter(|&&idx| idx < 10).count();
681 assert!(high_uncertainty_count > 0);
682 }
683 }
684 }
685 }
686
687 #[test]
688 fn test_active_learning_with_initial_labeled() {
689 let initial_labeled = vec![0, 1, 2, 3, 4];
690 let mut sampler = ActiveLearningSampler::with_initial_labeled(
691 100,
692 AcquisitionStrategy::UncertaintySampling,
693 5,
694 &initial_labeled,
695 );
696
697 assert_eq!(sampler.num_labeled(), 5);
698 assert_eq!(sampler.num_unlabeled(), 95);
699
700 for &idx in &initial_labeled {
701 assert!(sampler.is_labeled(idx));
702 }
703
704 let uncertainties = vec![0.5; 100];
706 sampler.update_uncertainties(uncertainties);
707
708 let indices: Vec<usize> = sampler.iter().collect();
710 assert_eq!(indices.len(), 5);
711
712 for &idx in &indices {
713 assert!(!initial_labeled.contains(&idx));
714 }
715 }
716
717 #[test]
718 fn test_uncertainty_diversity_sampling() {
719 let mut sampler = ActiveLearningSampler::new(
720 20,
721 AcquisitionStrategy::UncertaintyDiversity { alpha: 0.6 },
722 10,
723 )
724 .with_generator(42);
725
726 let uncertainties: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
728 sampler.update_uncertainties(uncertainties);
729
730 let indices: Vec<usize> = sampler.iter().collect();
731 assert_eq!(indices.len(), 10);
732
733 let high_uncertainty_count = indices.iter().filter(|&&idx| idx >= 15).count();
736 assert!(high_uncertainty_count > 0); assert!(high_uncertainty_count < indices.len()); }
739
740 #[test]
741 fn test_diversity_sampling() {
742 let mut sampler = ActiveLearningSampler::new(
743 30,
744 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
745 9,
746 )
747 .with_generator(42);
748
749 let uncertainties = vec![0.5; 30]; sampler.update_uncertainties(uncertainties);
751
752 let indices: Vec<usize> = sampler.iter().collect();
753 assert_eq!(indices.len(), 9);
754
755 assert!(indices.iter().all(|&idx| idx < 30));
758 }
759
760 #[test]
761 fn test_active_learning_stats() {
762 let mut sampler =
763 ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 15);
764
765 let stats = sampler.active_learning_stats();
766 assert_eq!(stats.current_round, 0);
767 assert_eq!(stats.num_labeled, 0);
768 assert_eq!(stats.num_unlabeled, 100);
769 assert_eq!(stats.total_samples, 100);
770 assert_eq!(stats.budget_per_round, 15);
771 assert_eq!(stats.available_budget, 15);
772 assert_eq!(stats.labeling_ratio, 0.0);
773
774 sampler.add_labeled_samples(&[0, 1, 2, 3, 4]);
776
777 let stats = sampler.active_learning_stats();
778 assert_eq!(stats.current_round, 1);
779 assert_eq!(stats.num_labeled, 5);
780 assert_eq!(stats.num_unlabeled, 95);
781 assert_eq!(stats.labeling_ratio, 0.05);
782 }
783
784 #[test]
785 fn test_sampler_methods() {
786 let mut sampler =
787 ActiveLearningSampler::new(50, AcquisitionStrategy::UncertaintySampling, 10);
788
789 assert_eq!(sampler.get_unlabeled_indices().len(), 50);
791 assert!(sampler.labeled_indices().is_empty());
792
793 sampler.add_labeled_samples(&[5, 15, 25]);
794 assert_eq!(sampler.num_labeled(), 3);
795 assert!(sampler.is_labeled(5));
796 assert!(sampler.is_labeled(15));
797 assert!(sampler.is_labeled(25));
798 assert!(!sampler.is_labeled(0));
799
800 let labeled = sampler.labeled_indices();
801 assert_eq!(labeled.len(), 3);
802 assert!(labeled.contains(&5));
803 assert!(labeled.contains(&15));
804 assert!(labeled.contains(&25));
805
806 sampler.remove_labeled_samples(&[15]);
808 assert_eq!(sampler.num_labeled(), 2);
809 assert!(!sampler.is_labeled(15));
810
811 sampler.set_strategy(AcquisitionStrategy::DiversitySampling { num_clusters: 4 });
813 assert!(matches!(
814 sampler.strategy(),
815 AcquisitionStrategy::DiversitySampling { num_clusters: 4 }
816 ));
817
818 sampler.set_budget(5);
820 assert_eq!(sampler.budget_per_round(), 5);
821
822 sampler.reset();
824 assert_eq!(sampler.num_labeled(), 0);
825 assert_eq!(sampler.current_round(), 0);
826 }
827
828 #[test]
829 fn test_convenience_functions() {
830 let uncertainty = uncertainty_sampler(100, 10, Some(42));
832 assert!(matches!(
833 uncertainty.strategy(),
834 AcquisitionStrategy::UncertaintySampling
835 ));
836 assert_eq!(uncertainty.budget_per_round(), 10);
837
838 let diversity = diversity_sampler(100, 10, 5, Some(42));
840 assert!(matches!(
841 diversity.strategy(),
842 AcquisitionStrategy::DiversitySampling { num_clusters: 5 }
843 ));
844
845 let combined = uncertainty_diversity_sampler(100, 10, 0.7, Some(42));
847 assert!(matches!(
848 combined.strategy(),
849 AcquisitionStrategy::UncertaintyDiversity { alpha } if (*alpha - 0.7).abs() < f64::EPSILON
850 ));
851 }
852
853 #[test]
854 fn test_edge_cases() {
855 let mut sampler =
857 ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 0);
858 assert_eq!(sampler.len(), 0);
859 let indices: Vec<usize> = sampler.iter().collect();
860 assert!(indices.is_empty());
861
862 sampler.set_budget(5);
864 sampler.add_labeled_samples(&(0..10).collect::<Vec<_>>());
865 assert_eq!(sampler.num_unlabeled(), 0);
866 assert_eq!(sampler.len(), 0);
867
868 let mut large_budget =
870 ActiveLearningSampler::new(5, AcquisitionStrategy::UncertaintySampling, 10);
871 assert_eq!(large_budget.len(), 5); let mut clamped = ActiveLearningSampler::new(
875 10,
876 AcquisitionStrategy::UncertaintyDiversity { alpha: 1.5 },
877 5,
878 );
879 let uncertainties = vec![0.5; 10];
880 clamped.update_uncertainties(uncertainties);
881 let indices: Vec<usize> = clamped.iter().collect();
882 assert_eq!(indices.len(), 5); let mut zero_clusters = ActiveLearningSampler::new(
886 10,
887 AcquisitionStrategy::DiversitySampling { num_clusters: 0 },
888 3,
889 );
890 let uncertainties = vec![0.5; 10];
891 zero_clusters.update_uncertainties(uncertainties);
892 let indices: Vec<usize> = zero_clusters.iter().collect();
893 assert_eq!(indices.len(), 3); }
895
896 #[test]
897 #[should_panic(expected = "assertion failed")]
898 fn test_update_uncertainties_wrong_size() {
899 let mut sampler =
900 ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 5);
901 sampler.update_uncertainties(vec![0.5; 5]);
903 }
904
905 #[test]
906 fn test_acquisition_strategy_equality() {
907 assert_eq!(
908 AcquisitionStrategy::UncertaintySampling,
909 AcquisitionStrategy::UncertaintySampling
910 );
911 assert_eq!(
912 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
913 AcquisitionStrategy::DiversitySampling { num_clusters: 3 }
914 );
915 assert_ne!(
916 AcquisitionStrategy::UncertaintySampling,
917 AcquisitionStrategy::ExpectedInformationGain
918 );
919 }
920
921 #[test]
922 fn test_acquisition_strategy_default() {
923 assert_eq!(
924 AcquisitionStrategy::default(),
925 AcquisitionStrategy::UncertaintySampling
926 );
927 }
928
929 #[test]
930 fn test_reproducibility() {
931 let mut sampler1 = ActiveLearningSampler::new(
932 20,
933 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
934 5,
935 )
936 .with_generator(123);
937
938 let mut sampler2 = ActiveLearningSampler::new(
939 20,
940 AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
941 5,
942 )
943 .with_generator(123);
944
945 let uncertainties = vec![0.5; 20];
946 sampler1.update_uncertainties(uncertainties.clone());
947 sampler2.update_uncertainties(uncertainties);
948
949 let indices1: Vec<usize> = sampler1.iter().collect();
950 let indices2: Vec<usize> = sampler2.iter().collect();
951
952 assert_eq!(indices1, indices2);
953 }
954}