sklears_semi_supervised/batch_active_learning/
batch_mode.rs

1//! Batch Mode Active Learning implementation
2
3use super::{BatchActiveLearningError, *};
4
5/// Batch Mode Active Learning using uncertainty and diversity
6///
7/// This method selects a batch of samples by balancing uncertainty (information gain)
8/// and diversity (avoiding redundant samples) using various strategies.
9#[derive(Debug, Clone)]
10pub struct BatchModeActiveLearning {
11    /// batch_size
12    pub batch_size: usize,
13    /// diversity_weight
14    pub diversity_weight: f64,
15    /// strategy
16    pub strategy: String,
17    /// distance_metric
18    pub distance_metric: String,
19    /// random_state
20    pub random_state: Option<u64>,
21}
22
23impl Default for BatchModeActiveLearning {
24    fn default() -> Self {
25        Self {
26            batch_size: 10,
27            diversity_weight: 0.5,
28            strategy: "uncertainty_diversity".to_string(),
29            distance_metric: "euclidean".to_string(),
30            random_state: None,
31        }
32    }
33}
34
35impl BatchModeActiveLearning {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
41        if batch_size == 0 {
42            return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
43        }
44        self.batch_size = batch_size;
45        Ok(self)
46    }
47
48    pub fn diversity_weight(mut self, diversity_weight: f64) -> Result<Self> {
49        if !(0.0..=1.0).contains(&diversity_weight) {
50            return Err(BatchActiveLearningError::InvalidDiversityWeight(diversity_weight).into());
51        }
52        self.diversity_weight = diversity_weight;
53        Ok(self)
54    }
55
56    pub fn strategy(mut self, strategy: String) -> Self {
57        self.strategy = strategy;
58        self
59    }
60
61    pub fn distance_metric(mut self, distance_metric: String) -> Self {
62        self.distance_metric = distance_metric;
63        self
64    }
65
66    pub fn random_state(mut self, random_state: u64) -> Self {
67        self.random_state = Some(random_state);
68        self
69    }
70
71    fn compute_uncertainty_scores(&self, probabilities: &ArrayView2<f64>) -> Result<Array1<f64>> {
72        let n_samples = probabilities.dim().0;
73        let mut uncertainty_scores = Array1::zeros(n_samples);
74
75        match self.strategy.as_str() {
76            "uncertainty_diversity" | "entropy" => {
77                // Entropy-based uncertainty
78                for i in 0..n_samples {
79                    let mut entropy = 0.0;
80                    for prob in probabilities.row(i) {
81                        if *prob > 0.0 {
82                            entropy -= prob * prob.ln();
83                        }
84                    }
85                    uncertainty_scores[i] = entropy;
86                }
87            }
88            "margin" => {
89                // Margin-based uncertainty
90                for i in 0..n_samples {
91                    let mut probs: Vec<f64> = probabilities.row(i).to_vec();
92                    probs.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
93                    if probs.len() >= 2 {
94                        uncertainty_scores[i] = 1.0 - (probs[0] - probs[1]);
95                    } else {
96                        uncertainty_scores[i] = 1.0 - probs[0];
97                    }
98                }
99            }
100            "least_confident" => {
101                // Least confident uncertainty
102                for i in 0..n_samples {
103                    let max_prob = probabilities.row(i).fold(0.0f64, |a, &b| a.max(b));
104                    uncertainty_scores[i] = 1.0 - max_prob;
105                }
106            }
107            _ => {
108                // Default to entropy
109                for i in 0..n_samples {
110                    let mut entropy = 0.0;
111                    for prob in probabilities.row(i) {
112                        if *prob > 0.0 {
113                            entropy -= prob * prob.ln();
114                        }
115                    }
116                    uncertainty_scores[i] = entropy;
117                }
118            }
119        }
120
121        Ok(uncertainty_scores)
122    }
123
124    fn compute_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> Result<f64> {
125        match self.distance_metric.as_str() {
126            "euclidean" => {
127                let dist = x1
128                    .iter()
129                    .zip(x2.iter())
130                    .map(|(a, b)| (a - b).powi(2))
131                    .sum::<f64>()
132                    .sqrt();
133                Ok(dist)
134            }
135            "manhattan" => {
136                let dist = x1
137                    .iter()
138                    .zip(x2.iter())
139                    .map(|(a, b)| (a - b).abs())
140                    .sum::<f64>();
141                Ok(dist)
142            }
143            "cosine" => {
144                let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
145                let norm1 = x1.iter().map(|&x| x * x).sum::<f64>().sqrt();
146                let norm2 = x2.iter().map(|&x| x * x).sum::<f64>().sqrt();
147
148                if norm1 == 0.0 || norm2 == 0.0 {
149                    Ok(1.0)
150                } else {
151                    Ok(1.0 - dot_product / (norm1 * norm2))
152                }
153            }
154            _ => Err(
155                BatchActiveLearningError::InvalidDistanceMetric(self.distance_metric.clone())
156                    .into(),
157            ),
158        }
159    }
160
161    fn select_diverse_batch(
162        &self,
163        X: &ArrayView2<f64>,
164        uncertainty_scores: &ArrayView1<f64>,
165    ) -> Result<Vec<usize>> {
166        let n_samples = X.dim().0;
167        let mut selected_indices = Vec::new();
168        let mut remaining_indices: Vec<usize> = (0..n_samples).collect();
169
170        if remaining_indices.len() < self.batch_size {
171            return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
172        }
173
174        // Select first sample with highest uncertainty
175        let first_idx = remaining_indices
176            .iter()
177            .max_by(|&&a, &&b| {
178                uncertainty_scores[a]
179                    .partial_cmp(&uncertainty_scores[b])
180                    .unwrap_or(std::cmp::Ordering::Equal)
181            })
182            .copied()
183            .unwrap();
184        selected_indices.push(first_idx);
185        remaining_indices.retain(|&x| x != first_idx);
186
187        // Select remaining samples balancing uncertainty and diversity
188        while selected_indices.len() < self.batch_size && !remaining_indices.is_empty() {
189            let mut best_idx = 0;
190            let mut best_score = f64::NEG_INFINITY;
191
192            for &candidate_idx in remaining_indices.iter() {
193                let uncertainty_score = uncertainty_scores[candidate_idx];
194
195                // Compute minimum distance to already selected samples
196                let mut min_distance = f64::INFINITY;
197                for &selected_idx in selected_indices.iter() {
198                    let distance =
199                        self.compute_distance(&X.row(candidate_idx), &X.row(selected_idx))?;
200                    min_distance = min_distance.min(distance);
201                }
202
203                // Combined score: uncertainty + diversity
204                let combined_score = (1.0 - self.diversity_weight) * uncertainty_score
205                    + self.diversity_weight * min_distance;
206
207                if combined_score > best_score {
208                    best_score = combined_score;
209                    best_idx = candidate_idx;
210                }
211            }
212
213            selected_indices.push(best_idx);
214            remaining_indices.retain(|&x| x != best_idx);
215        }
216
217        Ok(selected_indices)
218    }
219
220    pub fn query(
221        &self,
222        X: &ArrayView2<f64>,
223        probabilities: &ArrayView2<f64>,
224    ) -> Result<Vec<usize>> {
225        let uncertainty_scores = self.compute_uncertainty_scores(probabilities)?;
226        self.select_diverse_batch(X, &uncertainty_scores.view())
227    }
228}