sklears_semi_supervised/batch_active_learning/
batch_mode.rs1use super::{BatchActiveLearningError, *};
4
5#[derive(Debug, Clone)]
10pub struct BatchModeActiveLearning {
11 pub batch_size: usize,
13 pub diversity_weight: f64,
15 pub strategy: String,
17 pub distance_metric: String,
19 pub random_state: Option<u64>,
21}
22
23impl Default for BatchModeActiveLearning {
24 fn default() -> Self {
25 Self {
26 batch_size: 10,
27 diversity_weight: 0.5,
28 strategy: "uncertainty_diversity".to_string(),
29 distance_metric: "euclidean".to_string(),
30 random_state: None,
31 }
32 }
33}
34
35impl BatchModeActiveLearning {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
41 if batch_size == 0 {
42 return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
43 }
44 self.batch_size = batch_size;
45 Ok(self)
46 }
47
48 pub fn diversity_weight(mut self, diversity_weight: f64) -> Result<Self> {
49 if !(0.0..=1.0).contains(&diversity_weight) {
50 return Err(BatchActiveLearningError::InvalidDiversityWeight(diversity_weight).into());
51 }
52 self.diversity_weight = diversity_weight;
53 Ok(self)
54 }
55
56 pub fn strategy(mut self, strategy: String) -> Self {
57 self.strategy = strategy;
58 self
59 }
60
61 pub fn distance_metric(mut self, distance_metric: String) -> Self {
62 self.distance_metric = distance_metric;
63 self
64 }
65
66 pub fn random_state(mut self, random_state: u64) -> Self {
67 self.random_state = Some(random_state);
68 self
69 }
70
71 fn compute_uncertainty_scores(&self, probabilities: &ArrayView2<f64>) -> Result<Array1<f64>> {
72 let n_samples = probabilities.dim().0;
73 let mut uncertainty_scores = Array1::zeros(n_samples);
74
75 match self.strategy.as_str() {
76 "uncertainty_diversity" | "entropy" => {
77 for i in 0..n_samples {
79 let mut entropy = 0.0;
80 for prob in probabilities.row(i) {
81 if *prob > 0.0 {
82 entropy -= prob * prob.ln();
83 }
84 }
85 uncertainty_scores[i] = entropy;
86 }
87 }
88 "margin" => {
89 for i in 0..n_samples {
91 let mut probs: Vec<f64> = probabilities.row(i).to_vec();
92 probs.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
93 if probs.len() >= 2 {
94 uncertainty_scores[i] = 1.0 - (probs[0] - probs[1]);
95 } else {
96 uncertainty_scores[i] = 1.0 - probs[0];
97 }
98 }
99 }
100 "least_confident" => {
101 for i in 0..n_samples {
103 let max_prob = probabilities.row(i).fold(0.0f64, |a, &b| a.max(b));
104 uncertainty_scores[i] = 1.0 - max_prob;
105 }
106 }
107 _ => {
108 for i in 0..n_samples {
110 let mut entropy = 0.0;
111 for prob in probabilities.row(i) {
112 if *prob > 0.0 {
113 entropy -= prob * prob.ln();
114 }
115 }
116 uncertainty_scores[i] = entropy;
117 }
118 }
119 }
120
121 Ok(uncertainty_scores)
122 }
123
124 fn compute_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> Result<f64> {
125 match self.distance_metric.as_str() {
126 "euclidean" => {
127 let dist = x1
128 .iter()
129 .zip(x2.iter())
130 .map(|(a, b)| (a - b).powi(2))
131 .sum::<f64>()
132 .sqrt();
133 Ok(dist)
134 }
135 "manhattan" => {
136 let dist = x1
137 .iter()
138 .zip(x2.iter())
139 .map(|(a, b)| (a - b).abs())
140 .sum::<f64>();
141 Ok(dist)
142 }
143 "cosine" => {
144 let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
145 let norm1 = x1.iter().map(|&x| x * x).sum::<f64>().sqrt();
146 let norm2 = x2.iter().map(|&x| x * x).sum::<f64>().sqrt();
147
148 if norm1 == 0.0 || norm2 == 0.0 {
149 Ok(1.0)
150 } else {
151 Ok(1.0 - dot_product / (norm1 * norm2))
152 }
153 }
154 _ => Err(
155 BatchActiveLearningError::InvalidDistanceMetric(self.distance_metric.clone())
156 .into(),
157 ),
158 }
159 }
160
161 fn select_diverse_batch(
162 &self,
163 X: &ArrayView2<f64>,
164 uncertainty_scores: &ArrayView1<f64>,
165 ) -> Result<Vec<usize>> {
166 let n_samples = X.dim().0;
167 let mut selected_indices = Vec::new();
168 let mut remaining_indices: Vec<usize> = (0..n_samples).collect();
169
170 if remaining_indices.len() < self.batch_size {
171 return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
172 }
173
174 let first_idx = remaining_indices
176 .iter()
177 .max_by(|&&a, &&b| {
178 uncertainty_scores[a]
179 .partial_cmp(&uncertainty_scores[b])
180 .unwrap_or(std::cmp::Ordering::Equal)
181 })
182 .copied()
183 .unwrap();
184 selected_indices.push(first_idx);
185 remaining_indices.retain(|&x| x != first_idx);
186
187 while selected_indices.len() < self.batch_size && !remaining_indices.is_empty() {
189 let mut best_idx = 0;
190 let mut best_score = f64::NEG_INFINITY;
191
192 for &candidate_idx in remaining_indices.iter() {
193 let uncertainty_score = uncertainty_scores[candidate_idx];
194
195 let mut min_distance = f64::INFINITY;
197 for &selected_idx in selected_indices.iter() {
198 let distance =
199 self.compute_distance(&X.row(candidate_idx), &X.row(selected_idx))?;
200 min_distance = min_distance.min(distance);
201 }
202
203 let combined_score = (1.0 - self.diversity_weight) * uncertainty_score
205 + self.diversity_weight * min_distance;
206
207 if combined_score > best_score {
208 best_score = combined_score;
209 best_idx = candidate_idx;
210 }
211 }
212
213 selected_indices.push(best_idx);
214 remaining_indices.retain(|&x| x != best_idx);
215 }
216
217 Ok(selected_indices)
218 }
219
220 pub fn query(
221 &self,
222 X: &ArrayView2<f64>,
223 probabilities: &ArrayView2<f64>,
224 ) -> Result<Vec<usize>> {
225 let uncertainty_scores = self.compute_uncertainty_scores(probabilities)?;
226 self.select_diverse_batch(X, &uncertainty_scores.view())
227 }
228}