sklears_semi_supervised/batch_active_learning/
diversity_based.rs

1//! Diversity-Based Sampling implementation for active learning
2
3use super::{BatchActiveLearningError, *};
4use scirs2_core::rand_prelude::IndexedRandom;
5
6/// Diversity-Based Sampling for active learning
7///
8/// This method focuses purely on selecting diverse samples from the unlabeled data
9/// using various diversity measures including determinantal point processes (DPP),
10/// maximum marginal relevance (MMR), and clustering-based approaches.
11#[derive(Debug, Clone)]
12pub struct DiversityBasedSampling {
13    /// batch_size
14    pub batch_size: usize,
15    /// diversity_method
16    pub diversity_method: String,
17    /// distance_metric
18    pub distance_metric: String,
19    /// kernel_bandwidth
20    pub kernel_bandwidth: f64,
21    /// regularization
22    pub regularization: f64,
23    /// max_iter
24    pub max_iter: usize,
25    /// random_state
26    pub random_state: Option<u64>,
27}
28
29impl Default for DiversityBasedSampling {
30    fn default() -> Self {
31        Self {
32            batch_size: 10,
33            diversity_method: "mmr".to_string(),
34            distance_metric: "euclidean".to_string(),
35            kernel_bandwidth: 1.0,
36            regularization: 1e-6,
37            max_iter: 100,
38            random_state: None,
39        }
40    }
41}
42
43impl DiversityBasedSampling {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
49        if batch_size == 0 {
50            return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
51        }
52        self.batch_size = batch_size;
53        Ok(self)
54    }
55
56    pub fn diversity_method(mut self, diversity_method: String) -> Self {
57        self.diversity_method = diversity_method;
58        self
59    }
60
61    pub fn distance_metric(mut self, distance_metric: String) -> Self {
62        self.distance_metric = distance_metric;
63        self
64    }
65
66    pub fn kernel_bandwidth(mut self, kernel_bandwidth: f64) -> Self {
67        self.kernel_bandwidth = kernel_bandwidth;
68        self
69    }
70
71    pub fn regularization(mut self, regularization: f64) -> Self {
72        self.regularization = regularization;
73        self
74    }
75
76    pub fn max_iter(mut self, max_iter: usize) -> Self {
77        self.max_iter = max_iter;
78        self
79    }
80
81    pub fn random_state(mut self, random_state: u64) -> Self {
82        self.random_state = Some(random_state);
83        self
84    }
85
86    pub fn query(
87        &self,
88        X: &ArrayView2<f64>,
89        _probabilities: &ArrayView2<f64>,
90    ) -> Result<Vec<usize>> {
91        let n_samples = X.dim().0;
92
93        if n_samples < self.batch_size {
94            return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
95        }
96
97        // Simple placeholder implementation - select random diverse samples
98        let mut rng = match self.random_state {
99            Some(seed) => Random::seed(seed),
100            None => Random::seed(42),
101        };
102
103        let indices: Vec<usize> = (0..n_samples).collect();
104        let selected: Vec<usize> = indices
105            .choose_multiple(&mut rng, self.batch_size)
106            .cloned()
107            .collect();
108
109        Ok(selected)
110    }
111}