Skip to main content

torsh_data/sampler/
active_learning.rs

1//! Active learning sampling functionality
2//!
3//! This module provides active learning samplers that prioritize uncertain or informative
4//! samples to maximize learning efficiency with minimal labeling effort.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8use std::collections::HashSet;
9
10// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
11use scirs2_core::rand_prelude::SliceRandom;
12
13use super::core::{rng_utils, Sampler, SamplerIterator};
14
15/// Active learning acquisition strategies
16///
17/// These strategies determine how samples are selected for labeling based on
18/// different criteria such as uncertainty, information gain, or diversity.
19#[derive(Clone, Debug, PartialEq)]
20pub enum AcquisitionStrategy {
21    /// Select samples with highest uncertainty
22    ///
23    /// Selects samples where the model is most uncertain about predictions.
24    /// This is the most common active learning strategy.
25    UncertaintySampling,
26
27    /// Select samples that maximize expected information gain
28    ///
29    /// Chooses samples that are expected to provide the most information
30    /// about the underlying data distribution.
31    ExpectedInformationGain,
32
33    /// Select diverse samples using clustering
34    ///
35    /// Ensures diversity in selected samples by partitioning into clusters
36    /// and sampling from each cluster.
37    ///
38    /// # Arguments
39    ///
40    /// * `num_clusters` - Number of clusters to create for diversity
41    DiversitySampling { num_clusters: usize },
42
43    /// Combine uncertainty and diversity
44    ///
45    /// Balances between uncertain samples and diverse samples using a
46    /// weighted combination approach.
47    ///
48    /// # Arguments
49    ///
50    /// * `alpha` - Weight for uncertainty vs diversity (0.0-1.0)
51    ///   - 1.0 = pure uncertainty sampling
52    ///   - 0.0 = pure diversity sampling
53    UncertaintyDiversity { alpha: f64 },
54
55    /// Query by committee (variance across models)
56    ///
57    /// Selects samples where multiple models disagree the most.
58    /// Requires ensemble predictions or committee of models.
59    QueryByCommittee,
60
61    /// Expected model change
62    ///
63    /// Selects samples that are expected to cause the largest change
64    /// in model parameters when added to training set.
65    ExpectedModelChange,
66}
67
68impl Default for AcquisitionStrategy {
69    fn default() -> Self {
70        AcquisitionStrategy::UncertaintySampling
71    }
72}
73
74/// Active learning sampler that prioritizes uncertain or informative samples
75///
76/// This sampler selects samples based on uncertainty estimates or other
77/// information-theoretic criteria to maximize learning efficiency with
78/// minimal labeling effort.
79///
80/// # Examples
81///
82/// ```rust,ignore
83/// use torsh_data::sampler::{ActiveLearningSampler, AcquisitionStrategy, Sampler};
84///
85/// let mut sampler = ActiveLearningSampler::new(
86///     1000,
87///     AcquisitionStrategy::UncertaintySampling,
88///     10
89/// ).with_generator(42);
90///
91/// // Update with uncertainty scores from model
92/// let uncertainties = vec![0.5; 1000]; // Mock uncertainties
93/// sampler.update_uncertainties(uncertainties);
94///
95/// // Get samples to label
96/// let indices: Vec<usize> = sampler.iter().collect();
97/// assert_eq!(indices.len(), 10);
98///
99/// // Add labeled samples
100/// sampler.add_labeled_samples(&indices);
101/// ```
102#[derive(Clone)]
103pub struct ActiveLearningSampler {
104    uncertainties: Vec<f64>,
105    acquisition_strategy: AcquisitionStrategy,
106    num_samples: usize,
107    budget_per_round: usize,
108    current_round: usize,
109    labeled_indices: HashSet<usize>,
110    generator: Option<u64>,
111}
112
113impl ActiveLearningSampler {
114    /// Create a new active learning sampler
115    ///
116    /// # Arguments
117    ///
118    /// * `dataset_size` - Total size of the dataset
119    /// * `acquisition_strategy` - Strategy for selecting samples
120    /// * `budget_per_round` - Number of samples to select per round
121    ///
122    /// # Examples
123    ///
124    /// ```rust,ignore
125    /// use torsh_data::sampler::{ActiveLearningSampler, AcquisitionStrategy};
126    ///
127    /// let sampler = ActiveLearningSampler::new(
128    ///     1000,
129    ///     AcquisitionStrategy::UncertaintySampling,
130    ///     20
131    /// );
132    /// ```
133    pub fn new(
134        dataset_size: usize,
135        acquisition_strategy: AcquisitionStrategy,
136        budget_per_round: usize,
137    ) -> Self {
138        Self {
139            uncertainties: vec![0.0; dataset_size],
140            acquisition_strategy,
141            num_samples: dataset_size,
142            budget_per_round,
143            current_round: 0,
144            labeled_indices: HashSet::new(),
145            generator: None,
146        }
147    }
148
149    /// Create an active learning sampler with initial labeled samples
150    ///
151    /// # Arguments
152    ///
153    /// * `dataset_size` - Total size of the dataset
154    /// * `acquisition_strategy` - Strategy for selecting samples
155    /// * `budget_per_round` - Number of samples to select per round
156    /// * `initial_labeled` - Initially labeled sample indices
157    pub fn with_initial_labeled(
158        dataset_size: usize,
159        acquisition_strategy: AcquisitionStrategy,
160        budget_per_round: usize,
161        initial_labeled: &[usize],
162    ) -> Self {
163        let mut sampler = Self::new(dataset_size, acquisition_strategy, budget_per_round);
164        for &idx in initial_labeled {
165            sampler.labeled_indices.insert(idx);
166        }
167        sampler
168    }
169
170    /// Update uncertainty scores for all samples
171    ///
172    /// # Arguments
173    ///
174    /// * `uncertainties` - Uncertainty scores for each sample (higher = more uncertain)
175    ///
176    /// # Panics
177    ///
178    /// Panics if the length of uncertainties doesn't match the dataset size.
179    pub fn update_uncertainties(&mut self, uncertainties: Vec<f64>) {
180        assert!(uncertainties.len() == self.num_samples, "assertion failed");
181        self.uncertainties = uncertainties;
182    }
183
184    /// Add newly labeled samples
185    ///
186    /// # Arguments
187    ///
188    /// * `indices` - Indices of newly labeled samples
189    pub fn add_labeled_samples(&mut self, indices: &[usize]) {
190        for &idx in indices {
191            if idx < self.num_samples {
192                self.labeled_indices.insert(idx);
193            }
194        }
195        self.current_round += 1;
196    }
197
198    /// Remove samples from labeled set (useful for experimental scenarios)
199    ///
200    /// # Arguments
201    ///
202    /// * `indices` - Indices to remove from labeled set
203    pub fn remove_labeled_samples(&mut self, indices: &[usize]) {
204        for &idx in indices {
205            self.labeled_indices.remove(&idx);
206        }
207    }
208
209    /// Set random generator seed
210    ///
211    /// # Arguments
212    ///
213    /// * `seed` - Random seed for reproducible sampling
214    pub fn with_generator(mut self, seed: u64) -> Self {
215        self.generator = Some(seed);
216        self
217    }
218
219    /// Get the current round number
220    pub fn current_round(&self) -> usize {
221        self.current_round
222    }
223
224    /// Get the budget per round
225    pub fn budget_per_round(&self) -> usize {
226        self.budget_per_round
227    }
228
229    /// Get the acquisition strategy
230    pub fn strategy(&self) -> &AcquisitionStrategy {
231        &self.acquisition_strategy
232    }
233
234    /// Get the number of labeled samples
235    pub fn num_labeled(&self) -> usize {
236        self.labeled_indices.len()
237    }
238
239    /// Get the number of unlabeled samples
240    pub fn num_unlabeled(&self) -> usize {
241        self.num_samples - self.labeled_indices.len()
242    }
243
244    /// Get the labeled sample indices
245    pub fn labeled_indices(&self) -> Vec<usize> {
246        self.labeled_indices.iter().copied().collect()
247    }
248
249    /// Get unlabeled sample indices
250    pub fn get_unlabeled_indices(&self) -> Vec<usize> {
251        (0..self.num_samples)
252            .filter(|idx| !self.labeled_indices.contains(idx))
253            .collect()
254    }
255
256    /// Check if a sample is labeled
257    pub fn is_labeled(&self, index: usize) -> bool {
258        self.labeled_indices.contains(&index)
259    }
260
261    /// Set a new acquisition strategy
262    ///
263    /// # Arguments
264    ///
265    /// * `strategy` - New acquisition strategy to use
266    pub fn set_strategy(&mut self, strategy: AcquisitionStrategy) {
267        self.acquisition_strategy = strategy;
268    }
269
270    /// Set a new budget per round
271    ///
272    /// # Arguments
273    ///
274    /// * `budget` - New budget per round
275    pub fn set_budget(&mut self, budget: usize) {
276        self.budget_per_round = budget;
277    }
278
279    /// Reset the sampler to initial state
280    pub fn reset(&mut self) {
281        self.labeled_indices.clear();
282        self.current_round = 0;
283    }
284
285    /// Get statistics about the current active learning state
286    pub fn active_learning_stats(&self) -> ActiveLearningStats {
287        let unlabeled_count = self.num_unlabeled();
288        let available_budget = self.budget_per_round.min(unlabeled_count);
289
290        ActiveLearningStats {
291            current_round: self.current_round,
292            num_labeled: self.num_labeled(),
293            num_unlabeled: unlabeled_count,
294            total_samples: self.num_samples,
295            budget_per_round: self.budget_per_round,
296            available_budget,
297            labeling_ratio: self.num_labeled() as f64 / self.num_samples as f64,
298        }
299    }
300
301    /// Select samples based on acquisition strategy
302    fn select_samples(&self) -> Vec<usize> {
303        let unlabeled = self.get_unlabeled_indices();
304        let budget = self.budget_per_round.min(unlabeled.len());
305
306        if budget == 0 {
307            return Vec::new();
308        }
309
310        match &self.acquisition_strategy {
311            AcquisitionStrategy::UncertaintySampling => {
312                self.uncertainty_sampling(&unlabeled, budget)
313            }
314            AcquisitionStrategy::ExpectedInformationGain => {
315                self.information_gain_sampling(&unlabeled, budget)
316            }
317            AcquisitionStrategy::DiversitySampling { num_clusters } => {
318                self.diversity_sampling(&unlabeled, budget, *num_clusters)
319            }
320            AcquisitionStrategy::UncertaintyDiversity { alpha } => {
321                self.uncertainty_diversity_sampling(&unlabeled, budget, *alpha)
322            }
323            AcquisitionStrategy::QueryByCommittee => self.query_by_committee(&unlabeled, budget),
324            AcquisitionStrategy::ExpectedModelChange => {
325                self.expected_model_change(&unlabeled, budget)
326            }
327        }
328    }
329
330    /// Uncertainty sampling: select most uncertain samples
331    fn uncertainty_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
332        let mut scored: Vec<_> = unlabeled
333            .iter()
334            .map(|&idx| (idx, self.uncertainties[idx]))
335            .collect();
336
337        // Sort by uncertainty (descending)
338        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
339
340        scored
341            .into_iter()
342            .take(budget)
343            .map(|(idx, _)| idx)
344            .collect()
345    }
346
347    /// Information gain sampling (simplified version)
348    fn information_gain_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
349        // This is a simplified implementation
350        // In practice, you'd calculate expected information gain more rigorously
351        let mut scored: Vec<_> = unlabeled
352            .iter()
353            .map(|&idx| {
354                let ig = self.uncertainties[idx] * (1.0 + (idx as f64).ln());
355                (idx, ig)
356            })
357            .collect();
358
359        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
360        scored
361            .into_iter()
362            .take(budget)
363            .map(|(idx, _)| idx)
364            .collect()
365    }
366
367    /// Diversity sampling using simple clustering
368    fn diversity_sampling(
369        &self,
370        unlabeled: &[usize],
371        budget: usize,
372        num_clusters: usize,
373    ) -> Vec<usize> {
374        if num_clusters == 0 {
375            return self.uncertainty_sampling(unlabeled, budget);
376        }
377
378        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
379        let mut rng = rng_utils::create_rng(self.generator);
380
381        // Simplified diversity sampling: randomly partition into clusters and sample from each
382        let mut indices = unlabeled.to_vec();
383        indices.shuffle(&mut rng);
384
385        let cluster_size = (unlabeled.len() / num_clusters).max(1);
386        let base_samples_per_cluster = budget / num_clusters;
387        let extra_samples = budget % num_clusters;
388
389        let mut selected = Vec::new();
390        let mut cluster_idx = 0;
391
392        for cluster_start in (0..indices.len()).step_by(cluster_size) {
393            let cluster_end = (cluster_start + cluster_size).min(indices.len());
394            let cluster = &indices[cluster_start..cluster_end];
395
396            // Calculate samples for this cluster (distribute remainder to first clusters)
397            let cluster_samples_count = if cluster_idx < extra_samples {
398                base_samples_per_cluster + 1
399            } else {
400                base_samples_per_cluster
401            };
402
403            if cluster_samples_count == 0 {
404                cluster_idx += 1;
405                continue;
406            }
407
408            // Sample from this cluster based on uncertainty
409            let mut cluster_scored: Vec<_> = cluster
410                .iter()
411                .map(|&idx| (idx, self.uncertainties[idx]))
412                .collect();
413
414            cluster_scored
415                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
416
417            let cluster_samples = cluster_scored
418                .into_iter()
419                .take(cluster_samples_count)
420                .map(|(idx, _)| idx);
421
422            selected.extend(cluster_samples);
423            cluster_idx += 1;
424
425            if selected.len() >= budget {
426                break;
427            }
428        }
429
430        selected.truncate(budget);
431        selected
432    }
433
434    /// Combine uncertainty and diversity
435    fn uncertainty_diversity_sampling(
436        &self,
437        unlabeled: &[usize],
438        budget: usize,
439        alpha: f64,
440    ) -> Vec<usize> {
441        let alpha = alpha.clamp(0.0, 1.0);
442
443        // Simplified combined approach
444        let uncertainty_count = (budget as f64 * alpha) as usize;
445        let diversity_count = budget - uncertainty_count;
446
447        let mut selected = self.uncertainty_sampling(unlabeled, uncertainty_count);
448
449        if diversity_count > 0 {
450            // Remove already selected from unlabeled for diversity sampling
451            let remaining: Vec<usize> = unlabeled
452                .iter()
453                .filter(|idx| !selected.contains(idx))
454                .copied()
455                .collect();
456
457            let diversity_samples = self.diversity_sampling(&remaining, diversity_count, 3);
458            selected.extend(diversity_samples);
459        }
460
461        selected
462    }
463
464    /// Query by committee (simplified)
465    fn query_by_committee(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
466        // In practice, you'd have multiple models and compute variance
467        // For now, use uncertainty as a proxy for committee disagreement
468        self.uncertainty_sampling(unlabeled, budget)
469    }
470
471    /// Expected model change (simplified)
472    fn expected_model_change(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
473        // Simplified: assume uncertainty correlates with model change
474        let mut scored: Vec<_> = unlabeled
475            .iter()
476            .map(|&idx| {
477                let change_score =
478                    self.uncertainties[idx] * (1.0 + idx as f64 / unlabeled.len() as f64);
479                (idx, change_score)
480            })
481            .collect();
482
483        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
484        scored
485            .into_iter()
486            .take(budget)
487            .map(|(idx, _)| idx)
488            .collect()
489    }
490}
491
492impl Sampler for ActiveLearningSampler {
493    type Iter = SamplerIterator;
494
495    fn iter(&self) -> Self::Iter {
496        let indices = self.select_samples();
497        SamplerIterator::new(indices)
498    }
499
500    fn len(&self) -> usize {
501        let unlabeled_count = self.get_unlabeled_indices().len();
502        self.budget_per_round.min(unlabeled_count)
503    }
504}
505
506/// Statistics about the current active learning state
507#[derive(Debug, Clone, PartialEq)]
508pub struct ActiveLearningStats {
509    /// Current active learning round
510    pub current_round: usize,
511    /// Number of labeled samples
512    pub num_labeled: usize,
513    /// Number of unlabeled samples
514    pub num_unlabeled: usize,
515    /// Total number of samples
516    pub total_samples: usize,
517    /// Budget per round
518    pub budget_per_round: usize,
519    /// Available budget for current round
520    pub available_budget: usize,
521    /// Ratio of labeled to total samples
522    pub labeling_ratio: f64,
523}
524
525/// Create an uncertainty sampling active learner
526///
527/// Convenience function for creating an active learning sampler with uncertainty sampling.
528///
529/// # Arguments
530///
531/// * `dataset_size` - Total size of the dataset
532/// * `budget_per_round` - Number of samples to select per round
533/// * `seed` - Optional random seed for reproducible sampling
534pub fn uncertainty_sampler(
535    dataset_size: usize,
536    budget_per_round: usize,
537    seed: Option<u64>,
538) -> ActiveLearningSampler {
539    let mut sampler = ActiveLearningSampler::new(
540        dataset_size,
541        AcquisitionStrategy::UncertaintySampling,
542        budget_per_round,
543    );
544    if let Some(s) = seed {
545        sampler = sampler.with_generator(s);
546    }
547    sampler
548}
549
550/// Create a diversity sampling active learner
551///
552/// Convenience function for creating an active learning sampler with diversity sampling.
553///
554/// # Arguments
555///
556/// * `dataset_size` - Total size of the dataset
557/// * `budget_per_round` - Number of samples to select per round
558/// * `num_clusters` - Number of clusters for diversity
559/// * `seed` - Optional random seed for reproducible sampling
560pub fn diversity_sampler(
561    dataset_size: usize,
562    budget_per_round: usize,
563    num_clusters: usize,
564    seed: Option<u64>,
565) -> ActiveLearningSampler {
566    let mut sampler = ActiveLearningSampler::new(
567        dataset_size,
568        AcquisitionStrategy::DiversitySampling { num_clusters },
569        budget_per_round,
570    );
571    if let Some(s) = seed {
572        sampler = sampler.with_generator(s);
573    }
574    sampler
575}
576
577/// Create a combined uncertainty-diversity active learner
578///
579/// Convenience function for creating an active learning sampler that combines
580/// uncertainty and diversity sampling.
581///
582/// # Arguments
583///
584/// * `dataset_size` - Total size of the dataset
585/// * `budget_per_round` - Number of samples to select per round
586/// * `alpha` - Weight for uncertainty vs diversity (0.0-1.0)
587/// * `seed` - Optional random seed for reproducible sampling
588pub fn uncertainty_diversity_sampler(
589    dataset_size: usize,
590    budget_per_round: usize,
591    alpha: f64,
592    seed: Option<u64>,
593) -> ActiveLearningSampler {
594    let mut sampler = ActiveLearningSampler::new(
595        dataset_size,
596        AcquisitionStrategy::UncertaintyDiversity { alpha },
597        budget_per_round,
598    );
599    if let Some(s) = seed {
600        sampler = sampler.with_generator(s);
601    }
602    sampler
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn test_active_learning_sampler_basic() {
611        let mut sampler =
612            ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 10)
613                .with_generator(42);
614
615        assert_eq!(sampler.num_samples, 100);
616        assert_eq!(sampler.budget_per_round(), 10);
617        assert_eq!(sampler.current_round(), 0);
618        assert_eq!(sampler.num_labeled(), 0);
619        assert_eq!(sampler.num_unlabeled(), 100);
620        assert_eq!(sampler.generator, Some(42));
621
622        // Update uncertainties
623        let uncertainties: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
624        sampler.update_uncertainties(uncertainties);
625
626        // Should select highest uncertainty samples
627        let indices: Vec<usize> = sampler.iter().collect();
628        assert_eq!(indices.len(), 10);
629
630        // With uncertainty sampling, should select samples with highest uncertainty (90-99)
631        for &idx in &indices {
632            assert!(idx >= 90); // Highest uncertainty samples
633        }
634
635        // Add labeled samples
636        sampler.add_labeled_samples(&indices);
637        assert_eq!(sampler.num_labeled(), 10);
638        assert_eq!(sampler.num_unlabeled(), 90);
639        assert_eq!(sampler.current_round(), 1);
640    }
641
642    #[test]
643    fn test_acquisition_strategies() {
644        let dataset_size = 50;
645        let budget = 5;
646
647        // Test different strategies
648        let strategies = vec![
649            AcquisitionStrategy::UncertaintySampling,
650            AcquisitionStrategy::ExpectedInformationGain,
651            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
652            AcquisitionStrategy::UncertaintyDiversity { alpha: 0.5 },
653            AcquisitionStrategy::QueryByCommittee,
654            AcquisitionStrategy::ExpectedModelChange,
655        ];
656
657        for strategy in strategies {
658            let mut sampler = ActiveLearningSampler::new(dataset_size, strategy.clone(), budget)
659                .with_generator(42);
660
661            // Set up uncertainties with clear pattern
662            let uncertainties: Vec<f64> = (0..dataset_size)
663                .map(|i| if i < 10 { 0.9 } else { 0.1 })
664                .collect();
665            sampler.update_uncertainties(uncertainties);
666
667            let indices: Vec<usize> = sampler.iter().collect();
668            assert_eq!(indices.len(), budget);
669
670            // All strategies should prefer high uncertainty samples to some degree
671            // (except pure diversity which might sample differently)
672            match strategy {
673                AcquisitionStrategy::DiversitySampling { .. } => {
674                    // Diversity sampling might select from anywhere
675                    assert!(indices.iter().all(|&idx| idx < dataset_size));
676                }
677                _ => {
678                    // Other strategies should prefer high uncertainty (indices 0-9)
679                    let high_uncertainty_count = indices.iter().filter(|&&idx| idx < 10).count();
680                    assert!(high_uncertainty_count > 0);
681                }
682            }
683        }
684    }
685
686    #[test]
687    fn test_active_learning_with_initial_labeled() {
688        let initial_labeled = vec![0, 1, 2, 3, 4];
689        let mut sampler = ActiveLearningSampler::with_initial_labeled(
690            100,
691            AcquisitionStrategy::UncertaintySampling,
692            5,
693            &initial_labeled,
694        );
695
696        assert_eq!(sampler.num_labeled(), 5);
697        assert_eq!(sampler.num_unlabeled(), 95);
698
699        for &idx in &initial_labeled {
700            assert!(sampler.is_labeled(idx));
701        }
702
703        // Update uncertainties
704        let uncertainties = vec![0.5; 100];
705        sampler.update_uncertainties(uncertainties);
706
707        // Should not select already labeled samples
708        let indices: Vec<usize> = sampler.iter().collect();
709        assert_eq!(indices.len(), 5);
710
711        for &idx in &indices {
712            assert!(!initial_labeled.contains(&idx));
713        }
714    }
715
716    #[test]
717    fn test_uncertainty_diversity_sampling() {
718        let mut sampler = ActiveLearningSampler::new(
719            20,
720            AcquisitionStrategy::UncertaintyDiversity { alpha: 0.6 },
721            10,
722        )
723        .with_generator(42);
724
725        // Set up uncertainties with clear pattern
726        let uncertainties: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
727        sampler.update_uncertainties(uncertainties);
728
729        let indices: Vec<usize> = sampler.iter().collect();
730        assert_eq!(indices.len(), 10);
731
732        // With alpha=0.6, should get 6 uncertainty samples + 4 diversity samples
733        // The exact composition depends on implementation details, but should include mix
734        let high_uncertainty_count = indices.iter().filter(|&&idx| idx >= 15).count();
735        assert!(high_uncertainty_count > 0); // Should have some high uncertainty
736        assert!(high_uncertainty_count < indices.len()); // But not all
737    }
738
739    #[test]
740    fn test_diversity_sampling() {
741        let mut sampler = ActiveLearningSampler::new(
742            30,
743            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
744            9,
745        )
746        .with_generator(42);
747
748        let uncertainties = vec![0.5; 30]; // Equal uncertainties
749        sampler.update_uncertainties(uncertainties);
750
751        let indices: Vec<usize> = sampler.iter().collect();
752        assert_eq!(indices.len(), 9);
753
754        // With equal uncertainties, diversity sampling should spread across the range
755        // (exact behavior depends on clustering implementation)
756        assert!(indices.iter().all(|&idx| idx < 30));
757    }
758
759    #[test]
760    fn test_active_learning_stats() {
761        let mut sampler =
762            ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 15);
763
764        let stats = sampler.active_learning_stats();
765        assert_eq!(stats.current_round, 0);
766        assert_eq!(stats.num_labeled, 0);
767        assert_eq!(stats.num_unlabeled, 100);
768        assert_eq!(stats.total_samples, 100);
769        assert_eq!(stats.budget_per_round, 15);
770        assert_eq!(stats.available_budget, 15);
771        assert_eq!(stats.labeling_ratio, 0.0);
772
773        // Add some labeled samples
774        sampler.add_labeled_samples(&[0, 1, 2, 3, 4]);
775
776        let stats = sampler.active_learning_stats();
777        assert_eq!(stats.current_round, 1);
778        assert_eq!(stats.num_labeled, 5);
779        assert_eq!(stats.num_unlabeled, 95);
780        assert_eq!(stats.labeling_ratio, 0.05);
781    }
782
783    #[test]
784    fn test_sampler_methods() {
785        let mut sampler =
786            ActiveLearningSampler::new(50, AcquisitionStrategy::UncertaintySampling, 10);
787
788        // Test labeled/unlabeled tracking
789        assert_eq!(sampler.get_unlabeled_indices().len(), 50);
790        assert!(sampler.labeled_indices().is_empty());
791
792        sampler.add_labeled_samples(&[5, 15, 25]);
793        assert_eq!(sampler.num_labeled(), 3);
794        assert!(sampler.is_labeled(5));
795        assert!(sampler.is_labeled(15));
796        assert!(sampler.is_labeled(25));
797        assert!(!sampler.is_labeled(0));
798
799        let labeled = sampler.labeled_indices();
800        assert_eq!(labeled.len(), 3);
801        assert!(labeled.contains(&5));
802        assert!(labeled.contains(&15));
803        assert!(labeled.contains(&25));
804
805        // Test remove labeled samples
806        sampler.remove_labeled_samples(&[15]);
807        assert_eq!(sampler.num_labeled(), 2);
808        assert!(!sampler.is_labeled(15));
809
810        // Test strategy change
811        sampler.set_strategy(AcquisitionStrategy::DiversitySampling { num_clusters: 4 });
812        assert!(matches!(
813            sampler.strategy(),
814            AcquisitionStrategy::DiversitySampling { num_clusters: 4 }
815        ));
816
817        // Test budget change
818        sampler.set_budget(5);
819        assert_eq!(sampler.budget_per_round(), 5);
820
821        // Test reset
822        sampler.reset();
823        assert_eq!(sampler.num_labeled(), 0);
824        assert_eq!(sampler.current_round(), 0);
825    }
826
827    #[test]
828    fn test_convenience_functions() {
829        // Test uncertainty_sampler
830        let uncertainty = uncertainty_sampler(100, 10, Some(42));
831        assert!(matches!(
832            uncertainty.strategy(),
833            AcquisitionStrategy::UncertaintySampling
834        ));
835        assert_eq!(uncertainty.budget_per_round(), 10);
836
837        // Test diversity_sampler
838        let diversity = diversity_sampler(100, 10, 5, Some(42));
839        assert!(matches!(
840            diversity.strategy(),
841            AcquisitionStrategy::DiversitySampling { num_clusters: 5 }
842        ));
843
844        // Test uncertainty_diversity_sampler
845        let combined = uncertainty_diversity_sampler(100, 10, 0.7, Some(42));
846        assert!(matches!(
847            combined.strategy(),
848            AcquisitionStrategy::UncertaintyDiversity { alpha } if (*alpha - 0.7).abs() < f64::EPSILON
849        ));
850    }
851
852    #[test]
853    fn test_edge_cases() {
854        // Empty budget
855        let mut sampler =
856            ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 0);
857        assert_eq!(sampler.len(), 0);
858        let indices: Vec<usize> = sampler.iter().collect();
859        assert!(indices.is_empty());
860
861        // All samples labeled
862        sampler.set_budget(5);
863        sampler.add_labeled_samples(&(0..10).collect::<Vec<_>>());
864        assert_eq!(sampler.num_unlabeled(), 0);
865        assert_eq!(sampler.len(), 0);
866
867        // Budget larger than unlabeled
868        let large_budget =
869            ActiveLearningSampler::new(5, AcquisitionStrategy::UncertaintySampling, 10);
870        assert_eq!(large_budget.len(), 5); // Should be clamped to available
871
872        // Invalid alpha should be clamped
873        let mut clamped = ActiveLearningSampler::new(
874            10,
875            AcquisitionStrategy::UncertaintyDiversity { alpha: 1.5 },
876            5,
877        );
878        let uncertainties = vec![0.5; 10];
879        clamped.update_uncertainties(uncertainties);
880        let indices: Vec<usize> = clamped.iter().collect();
881        assert_eq!(indices.len(), 5); // Should still work
882
883        // Zero clusters in diversity sampling
884        let mut zero_clusters = ActiveLearningSampler::new(
885            10,
886            AcquisitionStrategy::DiversitySampling { num_clusters: 0 },
887            3,
888        );
889        let uncertainties = vec![0.5; 10];
890        zero_clusters.update_uncertainties(uncertainties);
891        let indices: Vec<usize> = zero_clusters.iter().collect();
892        assert_eq!(indices.len(), 3); // Should fallback to uncertainty sampling
893    }
894
895    #[test]
896    #[should_panic(expected = "assertion failed")]
897    fn test_update_uncertainties_wrong_size() {
898        let mut sampler =
899            ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 5);
900        // Wrong size should panic
901        sampler.update_uncertainties(vec![0.5; 5]);
902    }
903
904    #[test]
905    fn test_acquisition_strategy_equality() {
906        assert_eq!(
907            AcquisitionStrategy::UncertaintySampling,
908            AcquisitionStrategy::UncertaintySampling
909        );
910        assert_eq!(
911            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
912            AcquisitionStrategy::DiversitySampling { num_clusters: 3 }
913        );
914        assert_ne!(
915            AcquisitionStrategy::UncertaintySampling,
916            AcquisitionStrategy::ExpectedInformationGain
917        );
918    }
919
920    #[test]
921    fn test_acquisition_strategy_default() {
922        assert_eq!(
923            AcquisitionStrategy::default(),
924            AcquisitionStrategy::UncertaintySampling
925        );
926    }
927
928    #[test]
929    fn test_reproducibility() {
930        let mut sampler1 = ActiveLearningSampler::new(
931            20,
932            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
933            5,
934        )
935        .with_generator(123);
936
937        let mut sampler2 = ActiveLearningSampler::new(
938            20,
939            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
940            5,
941        )
942        .with_generator(123);
943
944        let uncertainties = vec![0.5; 20];
945        sampler1.update_uncertainties(uncertainties.clone());
946        sampler2.update_uncertainties(uncertainties);
947
948        let indices1: Vec<usize> = sampler1.iter().collect();
949        let indices2: Vec<usize> = sampler2.iter().collect();
950
951        assert_eq!(indices1, indices2);
952    }
953}