torsh_data/sampler/
active_learning.rs

1//! Active learning sampling functionality
2//!
3//! This module provides active learning samplers that prioritize uncertain or informative
4//! samples to maximize learning efficiency with minimal labeling effort.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8use std::collections::HashSet;
9
10// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
11use scirs2_core::rand_prelude::SliceRandom;
12use scirs2_core::random::{Random, SeedableRng};
13
14use super::core::{rng_utils, Sampler, SamplerIterator};
15
16/// Active learning acquisition strategies
17///
18/// These strategies determine how samples are selected for labeling based on
19/// different criteria such as uncertainty, information gain, or diversity.
20#[derive(Clone, Debug, PartialEq)]
21pub enum AcquisitionStrategy {
22    /// Select samples with highest uncertainty
23    ///
24    /// Selects samples where the model is most uncertain about predictions.
25    /// This is the most common active learning strategy.
26    UncertaintySampling,
27
28    /// Select samples that maximize expected information gain
29    ///
30    /// Chooses samples that are expected to provide the most information
31    /// about the underlying data distribution.
32    ExpectedInformationGain,
33
34    /// Select diverse samples using clustering
35    ///
36    /// Ensures diversity in selected samples by partitioning into clusters
37    /// and sampling from each cluster.
38    ///
39    /// # Arguments
40    ///
41    /// * `num_clusters` - Number of clusters to create for diversity
42    DiversitySampling { num_clusters: usize },
43
44    /// Combine uncertainty and diversity
45    ///
46    /// Balances between uncertain samples and diverse samples using a
47    /// weighted combination approach.
48    ///
49    /// # Arguments
50    ///
51    /// * `alpha` - Weight for uncertainty vs diversity (0.0-1.0)
52    ///   - 1.0 = pure uncertainty sampling
53    ///   - 0.0 = pure diversity sampling
54    UncertaintyDiversity { alpha: f64 },
55
56    /// Query by committee (variance across models)
57    ///
58    /// Selects samples where multiple models disagree the most.
59    /// Requires ensemble predictions or committee of models.
60    QueryByCommittee,
61
62    /// Expected model change
63    ///
64    /// Selects samples that are expected to cause the largest change
65    /// in model parameters when added to training set.
66    ExpectedModelChange,
67}
68
69impl Default for AcquisitionStrategy {
70    fn default() -> Self {
71        AcquisitionStrategy::UncertaintySampling
72    }
73}
74
75/// Active learning sampler that prioritizes uncertain or informative samples
76///
77/// This sampler selects samples based on uncertainty estimates or other
78/// information-theoretic criteria to maximize learning efficiency with
79/// minimal labeling effort.
80///
81/// # Examples
82///
83/// ```rust
84/// use torsh_data::sampler::{ActiveLearningSampler, AcquisitionStrategy, Sampler};
85///
86/// let mut sampler = ActiveLearningSampler::new(
87///     1000,
88///     AcquisitionStrategy::UncertaintySampling,
89///     10
90/// ).with_generator(42);
91///
92/// // Update with uncertainty scores from model
93/// let uncertainties = vec![0.5; 1000]; // Mock uncertainties
94/// sampler.update_uncertainties(uncertainties);
95///
96/// // Get samples to label
97/// let indices: Vec<usize> = sampler.iter().collect();
98/// assert_eq!(indices.len(), 10);
99///
100/// // Add labeled samples
101/// sampler.add_labeled_samples(&indices);
102/// ```
103#[derive(Clone)]
104pub struct ActiveLearningSampler {
105    uncertainties: Vec<f64>,
106    acquisition_strategy: AcquisitionStrategy,
107    num_samples: usize,
108    budget_per_round: usize,
109    current_round: usize,
110    labeled_indices: HashSet<usize>,
111    generator: Option<u64>,
112}
113
114impl ActiveLearningSampler {
115    /// Create a new active learning sampler
116    ///
117    /// # Arguments
118    ///
119    /// * `dataset_size` - Total size of the dataset
120    /// * `acquisition_strategy` - Strategy for selecting samples
121    /// * `budget_per_round` - Number of samples to select per round
122    ///
123    /// # Examples
124    ///
125    /// ```rust
126    /// use torsh_data::sampler::{ActiveLearningSampler, AcquisitionStrategy};
127    ///
128    /// let sampler = ActiveLearningSampler::new(
129    ///     1000,
130    ///     AcquisitionStrategy::UncertaintySampling,
131    ///     20
132    /// );
133    /// ```
134    pub fn new(
135        dataset_size: usize,
136        acquisition_strategy: AcquisitionStrategy,
137        budget_per_round: usize,
138    ) -> Self {
139        Self {
140            uncertainties: vec![0.0; dataset_size],
141            acquisition_strategy,
142            num_samples: dataset_size,
143            budget_per_round,
144            current_round: 0,
145            labeled_indices: HashSet::new(),
146            generator: None,
147        }
148    }
149
150    /// Create an active learning sampler with initial labeled samples
151    ///
152    /// # Arguments
153    ///
154    /// * `dataset_size` - Total size of the dataset
155    /// * `acquisition_strategy` - Strategy for selecting samples
156    /// * `budget_per_round` - Number of samples to select per round
157    /// * `initial_labeled` - Initially labeled sample indices
158    pub fn with_initial_labeled(
159        dataset_size: usize,
160        acquisition_strategy: AcquisitionStrategy,
161        budget_per_round: usize,
162        initial_labeled: &[usize],
163    ) -> Self {
164        let mut sampler = Self::new(dataset_size, acquisition_strategy, budget_per_round);
165        for &idx in initial_labeled {
166            sampler.labeled_indices.insert(idx);
167        }
168        sampler
169    }
170
171    /// Update uncertainty scores for all samples
172    ///
173    /// # Arguments
174    ///
175    /// * `uncertainties` - Uncertainty scores for each sample (higher = more uncertain)
176    ///
177    /// # Panics
178    ///
179    /// Panics if the length of uncertainties doesn't match the dataset size.
180    pub fn update_uncertainties(&mut self, uncertainties: Vec<f64>) {
181        assert!(uncertainties.len() == self.num_samples, "assertion failed");
182        self.uncertainties = uncertainties;
183    }
184
185    /// Add newly labeled samples
186    ///
187    /// # Arguments
188    ///
189    /// * `indices` - Indices of newly labeled samples
190    pub fn add_labeled_samples(&mut self, indices: &[usize]) {
191        for &idx in indices {
192            if idx < self.num_samples {
193                self.labeled_indices.insert(idx);
194            }
195        }
196        self.current_round += 1;
197    }
198
199    /// Remove samples from labeled set (useful for experimental scenarios)
200    ///
201    /// # Arguments
202    ///
203    /// * `indices` - Indices to remove from labeled set
204    pub fn remove_labeled_samples(&mut self, indices: &[usize]) {
205        for &idx in indices {
206            self.labeled_indices.remove(&idx);
207        }
208    }
209
210    /// Set random generator seed
211    ///
212    /// # Arguments
213    ///
214    /// * `seed` - Random seed for reproducible sampling
215    pub fn with_generator(mut self, seed: u64) -> Self {
216        self.generator = Some(seed);
217        self
218    }
219
220    /// Get the current round number
221    pub fn current_round(&self) -> usize {
222        self.current_round
223    }
224
225    /// Get the budget per round
226    pub fn budget_per_round(&self) -> usize {
227        self.budget_per_round
228    }
229
230    /// Get the acquisition strategy
231    pub fn strategy(&self) -> &AcquisitionStrategy {
232        &self.acquisition_strategy
233    }
234
235    /// Get the number of labeled samples
236    pub fn num_labeled(&self) -> usize {
237        self.labeled_indices.len()
238    }
239
240    /// Get the number of unlabeled samples
241    pub fn num_unlabeled(&self) -> usize {
242        self.num_samples - self.labeled_indices.len()
243    }
244
245    /// Get the labeled sample indices
246    pub fn labeled_indices(&self) -> Vec<usize> {
247        self.labeled_indices.iter().copied().collect()
248    }
249
250    /// Get unlabeled sample indices
251    pub fn get_unlabeled_indices(&self) -> Vec<usize> {
252        (0..self.num_samples)
253            .filter(|idx| !self.labeled_indices.contains(idx))
254            .collect()
255    }
256
257    /// Check if a sample is labeled
258    pub fn is_labeled(&self, index: usize) -> bool {
259        self.labeled_indices.contains(&index)
260    }
261
262    /// Set a new acquisition strategy
263    ///
264    /// # Arguments
265    ///
266    /// * `strategy` - New acquisition strategy to use
267    pub fn set_strategy(&mut self, strategy: AcquisitionStrategy) {
268        self.acquisition_strategy = strategy;
269    }
270
271    /// Set a new budget per round
272    ///
273    /// # Arguments
274    ///
275    /// * `budget` - New budget per round
276    pub fn set_budget(&mut self, budget: usize) {
277        self.budget_per_round = budget;
278    }
279
280    /// Reset the sampler to initial state
281    pub fn reset(&mut self) {
282        self.labeled_indices.clear();
283        self.current_round = 0;
284    }
285
286    /// Get statistics about the current active learning state
287    pub fn active_learning_stats(&self) -> ActiveLearningStats {
288        let unlabeled_count = self.num_unlabeled();
289        let available_budget = self.budget_per_round.min(unlabeled_count);
290
291        ActiveLearningStats {
292            current_round: self.current_round,
293            num_labeled: self.num_labeled(),
294            num_unlabeled: unlabeled_count,
295            total_samples: self.num_samples,
296            budget_per_round: self.budget_per_round,
297            available_budget,
298            labeling_ratio: self.num_labeled() as f64 / self.num_samples as f64,
299        }
300    }
301
302    /// Select samples based on acquisition strategy
303    fn select_samples(&self) -> Vec<usize> {
304        let unlabeled = self.get_unlabeled_indices();
305        let budget = self.budget_per_round.min(unlabeled.len());
306
307        if budget == 0 {
308            return Vec::new();
309        }
310
311        match &self.acquisition_strategy {
312            AcquisitionStrategy::UncertaintySampling => {
313                self.uncertainty_sampling(&unlabeled, budget)
314            }
315            AcquisitionStrategy::ExpectedInformationGain => {
316                self.information_gain_sampling(&unlabeled, budget)
317            }
318            AcquisitionStrategy::DiversitySampling { num_clusters } => {
319                self.diversity_sampling(&unlabeled, budget, *num_clusters)
320            }
321            AcquisitionStrategy::UncertaintyDiversity { alpha } => {
322                self.uncertainty_diversity_sampling(&unlabeled, budget, *alpha)
323            }
324            AcquisitionStrategy::QueryByCommittee => self.query_by_committee(&unlabeled, budget),
325            AcquisitionStrategy::ExpectedModelChange => {
326                self.expected_model_change(&unlabeled, budget)
327            }
328        }
329    }
330
331    /// Uncertainty sampling: select most uncertain samples
332    fn uncertainty_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
333        let mut scored: Vec<_> = unlabeled
334            .iter()
335            .map(|&idx| (idx, self.uncertainties[idx]))
336            .collect();
337
338        // Sort by uncertainty (descending)
339        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
340
341        scored
342            .into_iter()
343            .take(budget)
344            .map(|(idx, _)| idx)
345            .collect()
346    }
347
348    /// Information gain sampling (simplified version)
349    fn information_gain_sampling(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
350        // This is a simplified implementation
351        // In practice, you'd calculate expected information gain more rigorously
352        let mut scored: Vec<_> = unlabeled
353            .iter()
354            .map(|&idx| {
355                let ig = self.uncertainties[idx] * (1.0 + (idx as f64).ln());
356                (idx, ig)
357            })
358            .collect();
359
360        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
361        scored
362            .into_iter()
363            .take(budget)
364            .map(|(idx, _)| idx)
365            .collect()
366    }
367
368    /// Diversity sampling using simple clustering
369    fn diversity_sampling(
370        &self,
371        unlabeled: &[usize],
372        budget: usize,
373        num_clusters: usize,
374    ) -> Vec<usize> {
375        if num_clusters == 0 {
376            return self.uncertainty_sampling(unlabeled, budget);
377        }
378
379        // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
380        let mut rng = rng_utils::create_rng(self.generator);
381
382        // Simplified diversity sampling: randomly partition into clusters and sample from each
383        let mut indices = unlabeled.to_vec();
384        indices.shuffle(&mut rng);
385
386        let cluster_size = (unlabeled.len() / num_clusters).max(1);
387        let base_samples_per_cluster = budget / num_clusters;
388        let extra_samples = budget % num_clusters;
389
390        let mut selected = Vec::new();
391        let mut cluster_idx = 0;
392
393        for cluster_start in (0..indices.len()).step_by(cluster_size) {
394            let cluster_end = (cluster_start + cluster_size).min(indices.len());
395            let cluster = &indices[cluster_start..cluster_end];
396
397            // Calculate samples for this cluster (distribute remainder to first clusters)
398            let cluster_samples_count = if cluster_idx < extra_samples {
399                base_samples_per_cluster + 1
400            } else {
401                base_samples_per_cluster
402            };
403
404            if cluster_samples_count == 0 {
405                cluster_idx += 1;
406                continue;
407            }
408
409            // Sample from this cluster based on uncertainty
410            let mut cluster_scored: Vec<_> = cluster
411                .iter()
412                .map(|&idx| (idx, self.uncertainties[idx]))
413                .collect();
414
415            cluster_scored
416                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417
418            let cluster_samples = cluster_scored
419                .into_iter()
420                .take(cluster_samples_count)
421                .map(|(idx, _)| idx);
422
423            selected.extend(cluster_samples);
424            cluster_idx += 1;
425
426            if selected.len() >= budget {
427                break;
428            }
429        }
430
431        selected.truncate(budget);
432        selected
433    }
434
435    /// Combine uncertainty and diversity
436    fn uncertainty_diversity_sampling(
437        &self,
438        unlabeled: &[usize],
439        budget: usize,
440        alpha: f64,
441    ) -> Vec<usize> {
442        let alpha = alpha.clamp(0.0, 1.0);
443
444        // Simplified combined approach
445        let uncertainty_count = (budget as f64 * alpha) as usize;
446        let diversity_count = budget - uncertainty_count;
447
448        let mut selected = self.uncertainty_sampling(unlabeled, uncertainty_count);
449
450        if diversity_count > 0 {
451            // Remove already selected from unlabeled for diversity sampling
452            let remaining: Vec<usize> = unlabeled
453                .iter()
454                .filter(|idx| !selected.contains(idx))
455                .copied()
456                .collect();
457
458            let diversity_samples = self.diversity_sampling(&remaining, diversity_count, 3);
459            selected.extend(diversity_samples);
460        }
461
462        selected
463    }
464
465    /// Query by committee (simplified)
466    fn query_by_committee(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
467        // In practice, you'd have multiple models and compute variance
468        // For now, use uncertainty as a proxy for committee disagreement
469        self.uncertainty_sampling(unlabeled, budget)
470    }
471
472    /// Expected model change (simplified)
473    fn expected_model_change(&self, unlabeled: &[usize], budget: usize) -> Vec<usize> {
474        // Simplified: assume uncertainty correlates with model change
475        let mut scored: Vec<_> = unlabeled
476            .iter()
477            .map(|&idx| {
478                let change_score =
479                    self.uncertainties[idx] * (1.0 + idx as f64 / unlabeled.len() as f64);
480                (idx, change_score)
481            })
482            .collect();
483
484        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
485        scored
486            .into_iter()
487            .take(budget)
488            .map(|(idx, _)| idx)
489            .collect()
490    }
491}
492
493impl Sampler for ActiveLearningSampler {
494    type Iter = SamplerIterator;
495
496    fn iter(&self) -> Self::Iter {
497        let indices = self.select_samples();
498        SamplerIterator::new(indices)
499    }
500
501    fn len(&self) -> usize {
502        let unlabeled_count = self.get_unlabeled_indices().len();
503        self.budget_per_round.min(unlabeled_count)
504    }
505}
506
507/// Statistics about the current active learning state
508#[derive(Debug, Clone, PartialEq)]
509pub struct ActiveLearningStats {
510    /// Current active learning round
511    pub current_round: usize,
512    /// Number of labeled samples
513    pub num_labeled: usize,
514    /// Number of unlabeled samples
515    pub num_unlabeled: usize,
516    /// Total number of samples
517    pub total_samples: usize,
518    /// Budget per round
519    pub budget_per_round: usize,
520    /// Available budget for current round
521    pub available_budget: usize,
522    /// Ratio of labeled to total samples
523    pub labeling_ratio: f64,
524}
525
526/// Create an uncertainty sampling active learner
527///
528/// Convenience function for creating an active learning sampler with uncertainty sampling.
529///
530/// # Arguments
531///
532/// * `dataset_size` - Total size of the dataset
533/// * `budget_per_round` - Number of samples to select per round
534/// * `seed` - Optional random seed for reproducible sampling
535pub fn uncertainty_sampler(
536    dataset_size: usize,
537    budget_per_round: usize,
538    seed: Option<u64>,
539) -> ActiveLearningSampler {
540    let mut sampler = ActiveLearningSampler::new(
541        dataset_size,
542        AcquisitionStrategy::UncertaintySampling,
543        budget_per_round,
544    );
545    if let Some(s) = seed {
546        sampler = sampler.with_generator(s);
547    }
548    sampler
549}
550
551/// Create a diversity sampling active learner
552///
553/// Convenience function for creating an active learning sampler with diversity sampling.
554///
555/// # Arguments
556///
557/// * `dataset_size` - Total size of the dataset
558/// * `budget_per_round` - Number of samples to select per round
559/// * `num_clusters` - Number of clusters for diversity
560/// * `seed` - Optional random seed for reproducible sampling
561pub fn diversity_sampler(
562    dataset_size: usize,
563    budget_per_round: usize,
564    num_clusters: usize,
565    seed: Option<u64>,
566) -> ActiveLearningSampler {
567    let mut sampler = ActiveLearningSampler::new(
568        dataset_size,
569        AcquisitionStrategy::DiversitySampling { num_clusters },
570        budget_per_round,
571    );
572    if let Some(s) = seed {
573        sampler = sampler.with_generator(s);
574    }
575    sampler
576}
577
578/// Create a combined uncertainty-diversity active learner
579///
580/// Convenience function for creating an active learning sampler that combines
581/// uncertainty and diversity sampling.
582///
583/// # Arguments
584///
585/// * `dataset_size` - Total size of the dataset
586/// * `budget_per_round` - Number of samples to select per round
587/// * `alpha` - Weight for uncertainty vs diversity (0.0-1.0)
588/// * `seed` - Optional random seed for reproducible sampling
589pub fn uncertainty_diversity_sampler(
590    dataset_size: usize,
591    budget_per_round: usize,
592    alpha: f64,
593    seed: Option<u64>,
594) -> ActiveLearningSampler {
595    let mut sampler = ActiveLearningSampler::new(
596        dataset_size,
597        AcquisitionStrategy::UncertaintyDiversity { alpha },
598        budget_per_round,
599    );
600    if let Some(s) = seed {
601        sampler = sampler.with_generator(s);
602    }
603    sampler
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    #[test]
611    fn test_active_learning_sampler_basic() {
612        let mut sampler =
613            ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 10)
614                .with_generator(42);
615
616        assert_eq!(sampler.num_samples, 100);
617        assert_eq!(sampler.budget_per_round(), 10);
618        assert_eq!(sampler.current_round(), 0);
619        assert_eq!(sampler.num_labeled(), 0);
620        assert_eq!(sampler.num_unlabeled(), 100);
621        assert_eq!(sampler.generator, Some(42));
622
623        // Update uncertainties
624        let uncertainties: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
625        sampler.update_uncertainties(uncertainties);
626
627        // Should select highest uncertainty samples
628        let indices: Vec<usize> = sampler.iter().collect();
629        assert_eq!(indices.len(), 10);
630
631        // With uncertainty sampling, should select samples with highest uncertainty (90-99)
632        for &idx in &indices {
633            assert!(idx >= 90); // Highest uncertainty samples
634        }
635
636        // Add labeled samples
637        sampler.add_labeled_samples(&indices);
638        assert_eq!(sampler.num_labeled(), 10);
639        assert_eq!(sampler.num_unlabeled(), 90);
640        assert_eq!(sampler.current_round(), 1);
641    }
642
643    #[test]
644    fn test_acquisition_strategies() {
645        let dataset_size = 50;
646        let budget = 5;
647
648        // Test different strategies
649        let strategies = vec![
650            AcquisitionStrategy::UncertaintySampling,
651            AcquisitionStrategy::ExpectedInformationGain,
652            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
653            AcquisitionStrategy::UncertaintyDiversity { alpha: 0.5 },
654            AcquisitionStrategy::QueryByCommittee,
655            AcquisitionStrategy::ExpectedModelChange,
656        ];
657
658        for strategy in strategies {
659            let mut sampler = ActiveLearningSampler::new(dataset_size, strategy.clone(), budget)
660                .with_generator(42);
661
662            // Set up uncertainties with clear pattern
663            let uncertainties: Vec<f64> = (0..dataset_size)
664                .map(|i| if i < 10 { 0.9 } else { 0.1 })
665                .collect();
666            sampler.update_uncertainties(uncertainties);
667
668            let indices: Vec<usize> = sampler.iter().collect();
669            assert_eq!(indices.len(), budget);
670
671            // All strategies should prefer high uncertainty samples to some degree
672            // (except pure diversity which might sample differently)
673            match strategy {
674                AcquisitionStrategy::DiversitySampling { .. } => {
675                    // Diversity sampling might select from anywhere
676                    assert!(indices.iter().all(|&idx| idx < dataset_size));
677                }
678                _ => {
679                    // Other strategies should prefer high uncertainty (indices 0-9)
680                    let high_uncertainty_count = indices.iter().filter(|&&idx| idx < 10).count();
681                    assert!(high_uncertainty_count > 0);
682                }
683            }
684        }
685    }
686
687    #[test]
688    fn test_active_learning_with_initial_labeled() {
689        let initial_labeled = vec![0, 1, 2, 3, 4];
690        let mut sampler = ActiveLearningSampler::with_initial_labeled(
691            100,
692            AcquisitionStrategy::UncertaintySampling,
693            5,
694            &initial_labeled,
695        );
696
697        assert_eq!(sampler.num_labeled(), 5);
698        assert_eq!(sampler.num_unlabeled(), 95);
699
700        for &idx in &initial_labeled {
701            assert!(sampler.is_labeled(idx));
702        }
703
704        // Update uncertainties
705        let uncertainties = vec![0.5; 100];
706        sampler.update_uncertainties(uncertainties);
707
708        // Should not select already labeled samples
709        let indices: Vec<usize> = sampler.iter().collect();
710        assert_eq!(indices.len(), 5);
711
712        for &idx in &indices {
713            assert!(!initial_labeled.contains(&idx));
714        }
715    }
716
717    #[test]
718    fn test_uncertainty_diversity_sampling() {
719        let mut sampler = ActiveLearningSampler::new(
720            20,
721            AcquisitionStrategy::UncertaintyDiversity { alpha: 0.6 },
722            10,
723        )
724        .with_generator(42);
725
726        // Set up uncertainties with clear pattern
727        let uncertainties: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
728        sampler.update_uncertainties(uncertainties);
729
730        let indices: Vec<usize> = sampler.iter().collect();
731        assert_eq!(indices.len(), 10);
732
733        // With alpha=0.6, should get 6 uncertainty samples + 4 diversity samples
734        // The exact composition depends on implementation details, but should include mix
735        let high_uncertainty_count = indices.iter().filter(|&&idx| idx >= 15).count();
736        assert!(high_uncertainty_count > 0); // Should have some high uncertainty
737        assert!(high_uncertainty_count < indices.len()); // But not all
738    }
739
740    #[test]
741    fn test_diversity_sampling() {
742        let mut sampler = ActiveLearningSampler::new(
743            30,
744            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
745            9,
746        )
747        .with_generator(42);
748
749        let uncertainties = vec![0.5; 30]; // Equal uncertainties
750        sampler.update_uncertainties(uncertainties);
751
752        let indices: Vec<usize> = sampler.iter().collect();
753        assert_eq!(indices.len(), 9);
754
755        // With equal uncertainties, diversity sampling should spread across the range
756        // (exact behavior depends on clustering implementation)
757        assert!(indices.iter().all(|&idx| idx < 30));
758    }
759
760    #[test]
761    fn test_active_learning_stats() {
762        let mut sampler =
763            ActiveLearningSampler::new(100, AcquisitionStrategy::UncertaintySampling, 15);
764
765        let stats = sampler.active_learning_stats();
766        assert_eq!(stats.current_round, 0);
767        assert_eq!(stats.num_labeled, 0);
768        assert_eq!(stats.num_unlabeled, 100);
769        assert_eq!(stats.total_samples, 100);
770        assert_eq!(stats.budget_per_round, 15);
771        assert_eq!(stats.available_budget, 15);
772        assert_eq!(stats.labeling_ratio, 0.0);
773
774        // Add some labeled samples
775        sampler.add_labeled_samples(&[0, 1, 2, 3, 4]);
776
777        let stats = sampler.active_learning_stats();
778        assert_eq!(stats.current_round, 1);
779        assert_eq!(stats.num_labeled, 5);
780        assert_eq!(stats.num_unlabeled, 95);
781        assert_eq!(stats.labeling_ratio, 0.05);
782    }
783
784    #[test]
785    fn test_sampler_methods() {
786        let mut sampler =
787            ActiveLearningSampler::new(50, AcquisitionStrategy::UncertaintySampling, 10);
788
789        // Test labeled/unlabeled tracking
790        assert_eq!(sampler.get_unlabeled_indices().len(), 50);
791        assert!(sampler.labeled_indices().is_empty());
792
793        sampler.add_labeled_samples(&[5, 15, 25]);
794        assert_eq!(sampler.num_labeled(), 3);
795        assert!(sampler.is_labeled(5));
796        assert!(sampler.is_labeled(15));
797        assert!(sampler.is_labeled(25));
798        assert!(!sampler.is_labeled(0));
799
800        let labeled = sampler.labeled_indices();
801        assert_eq!(labeled.len(), 3);
802        assert!(labeled.contains(&5));
803        assert!(labeled.contains(&15));
804        assert!(labeled.contains(&25));
805
806        // Test remove labeled samples
807        sampler.remove_labeled_samples(&[15]);
808        assert_eq!(sampler.num_labeled(), 2);
809        assert!(!sampler.is_labeled(15));
810
811        // Test strategy change
812        sampler.set_strategy(AcquisitionStrategy::DiversitySampling { num_clusters: 4 });
813        assert!(matches!(
814            sampler.strategy(),
815            AcquisitionStrategy::DiversitySampling { num_clusters: 4 }
816        ));
817
818        // Test budget change
819        sampler.set_budget(5);
820        assert_eq!(sampler.budget_per_round(), 5);
821
822        // Test reset
823        sampler.reset();
824        assert_eq!(sampler.num_labeled(), 0);
825        assert_eq!(sampler.current_round(), 0);
826    }
827
828    #[test]
829    fn test_convenience_functions() {
830        // Test uncertainty_sampler
831        let uncertainty = uncertainty_sampler(100, 10, Some(42));
832        assert!(matches!(
833            uncertainty.strategy(),
834            AcquisitionStrategy::UncertaintySampling
835        ));
836        assert_eq!(uncertainty.budget_per_round(), 10);
837
838        // Test diversity_sampler
839        let diversity = diversity_sampler(100, 10, 5, Some(42));
840        assert!(matches!(
841            diversity.strategy(),
842            AcquisitionStrategy::DiversitySampling { num_clusters: 5 }
843        ));
844
845        // Test uncertainty_diversity_sampler
846        let combined = uncertainty_diversity_sampler(100, 10, 0.7, Some(42));
847        assert!(matches!(
848            combined.strategy(),
849            AcquisitionStrategy::UncertaintyDiversity { alpha } if (*alpha - 0.7).abs() < f64::EPSILON
850        ));
851    }
852
853    #[test]
854    fn test_edge_cases() {
855        // Empty budget
856        let mut sampler =
857            ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 0);
858        assert_eq!(sampler.len(), 0);
859        let indices: Vec<usize> = sampler.iter().collect();
860        assert!(indices.is_empty());
861
862        // All samples labeled
863        sampler.set_budget(5);
864        sampler.add_labeled_samples(&(0..10).collect::<Vec<_>>());
865        assert_eq!(sampler.num_unlabeled(), 0);
866        assert_eq!(sampler.len(), 0);
867
868        // Budget larger than unlabeled
869        let mut large_budget =
870            ActiveLearningSampler::new(5, AcquisitionStrategy::UncertaintySampling, 10);
871        assert_eq!(large_budget.len(), 5); // Should be clamped to available
872
873        // Invalid alpha should be clamped
874        let mut clamped = ActiveLearningSampler::new(
875            10,
876            AcquisitionStrategy::UncertaintyDiversity { alpha: 1.5 },
877            5,
878        );
879        let uncertainties = vec![0.5; 10];
880        clamped.update_uncertainties(uncertainties);
881        let indices: Vec<usize> = clamped.iter().collect();
882        assert_eq!(indices.len(), 5); // Should still work
883
884        // Zero clusters in diversity sampling
885        let mut zero_clusters = ActiveLearningSampler::new(
886            10,
887            AcquisitionStrategy::DiversitySampling { num_clusters: 0 },
888            3,
889        );
890        let uncertainties = vec![0.5; 10];
891        zero_clusters.update_uncertainties(uncertainties);
892        let indices: Vec<usize> = zero_clusters.iter().collect();
893        assert_eq!(indices.len(), 3); // Should fallback to uncertainty sampling
894    }
895
896    #[test]
897    #[should_panic(expected = "assertion failed")]
898    fn test_update_uncertainties_wrong_size() {
899        let mut sampler =
900            ActiveLearningSampler::new(10, AcquisitionStrategy::UncertaintySampling, 5);
901        // Wrong size should panic
902        sampler.update_uncertainties(vec![0.5; 5]);
903    }
904
905    #[test]
906    fn test_acquisition_strategy_equality() {
907        assert_eq!(
908            AcquisitionStrategy::UncertaintySampling,
909            AcquisitionStrategy::UncertaintySampling
910        );
911        assert_eq!(
912            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
913            AcquisitionStrategy::DiversitySampling { num_clusters: 3 }
914        );
915        assert_ne!(
916            AcquisitionStrategy::UncertaintySampling,
917            AcquisitionStrategy::ExpectedInformationGain
918        );
919    }
920
921    #[test]
922    fn test_acquisition_strategy_default() {
923        assert_eq!(
924            AcquisitionStrategy::default(),
925            AcquisitionStrategy::UncertaintySampling
926        );
927    }
928
929    #[test]
930    fn test_reproducibility() {
931        let mut sampler1 = ActiveLearningSampler::new(
932            20,
933            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
934            5,
935        )
936        .with_generator(123);
937
938        let mut sampler2 = ActiveLearningSampler::new(
939            20,
940            AcquisitionStrategy::DiversitySampling { num_clusters: 3 },
941            5,
942        )
943        .with_generator(123);
944
945        let uncertainties = vec![0.5; 20];
946        sampler1.update_uncertainties(uncertainties.clone());
947        sampler2.update_uncertainties(uncertainties);
948
949        let indices1: Vec<usize> = sampler1.iter().collect();
950        let indices2: Vec<usize> = sampler2.iter().collect();
951
952        assert_eq!(indices1, indices2);
953    }
954}