1use crate::Dataset;
7use tenflowers_core::{Result, Tensor, TensorError};
8
9#[derive(Debug, Clone)]
11pub enum UncertaintyStrategy {
12 Entropy,
14 Margin,
16 LeastConfident,
18 QueryByCommittee,
20}
21
22#[derive(Debug, Clone)]
24pub enum DiversityStrategy {
25 KMeansClustering,
27 Representative,
29 Hybrid {
31 uncertainty_weight: f32,
32 diversity_weight: f32,
33 },
34}
35
36pub struct ActiveLearningSampler {
38 uncertainty_strategy: UncertaintyStrategy,
39 diversity_strategy: Option<DiversityStrategy>,
40 batch_size: usize,
41}
42
43impl ActiveLearningSampler {
44 pub fn new(uncertainty_strategy: UncertaintyStrategy, batch_size: usize) -> Self {
46 Self {
47 uncertainty_strategy,
48 diversity_strategy: None,
49 batch_size,
50 }
51 }
52
53 pub fn with_diversity(mut self, diversity_strategy: DiversityStrategy) -> Self {
55 self.diversity_strategy = Some(diversity_strategy);
56 self
57 }
58
59 pub fn select_samples<T, D: Dataset<T>>(
61 &self,
62 dataset: &D,
63 predictions: &[Vec<f32>], features: Option<&[Vec<f32>]>, ) -> Result<Vec<usize>>
66 where
67 T: Clone + Default + Send + Sync + 'static,
68 {
69 if predictions.len() != dataset.len() {
70 return Err(TensorError::invalid_argument(
71 "Number of predictions must match dataset size".to_string(),
72 ));
73 }
74
75 let uncertainty_scores = self.calculate_uncertainty_scores(predictions)?;
77
78 let diversity_scores = if let Some(ref diversity_strategy) = self.diversity_strategy {
80 if let Some(features) = features {
81 self.calculate_diversity_scores(features, diversity_strategy)?
82 } else {
83 return Err(TensorError::invalid_argument(
84 "Features required for diversity sampling".to_string(),
85 ));
86 }
87 } else {
88 vec![0.0; dataset.len()]
89 };
90
91 let combined_scores = self.combine_scores(&uncertainty_scores, &diversity_scores)?;
93
94 let mut indexed_scores: Vec<(usize, f32)> =
96 combined_scores.into_iter().enumerate().collect();
97
98 indexed_scores.sort_by(|a, b| {
100 b.1.partial_cmp(&a.1)
101 .expect("partial_cmp should not return None for valid values")
102 });
103
104 Ok(indexed_scores
106 .into_iter()
107 .take(self.batch_size)
108 .map(|(idx, _)| idx)
109 .collect())
110 }
111
112 fn calculate_uncertainty_scores(&self, predictions: &[Vec<f32>]) -> Result<Vec<f32>> {
114 let mut scores = Vec::with_capacity(predictions.len());
115
116 for pred in predictions {
117 let score = match self.uncertainty_strategy {
118 UncertaintyStrategy::Entropy => self.calculate_entropy(pred)?,
119 UncertaintyStrategy::Margin => self.calculate_margin(pred)?,
120 UncertaintyStrategy::LeastConfident => self.calculate_least_confident(pred)?,
121 UncertaintyStrategy::QueryByCommittee => {
122 self.calculate_entropy(pred)?
124 }
125 };
126 scores.push(score);
127 }
128
129 Ok(scores)
130 }
131
132 fn calculate_entropy(&self, predictions: &[f32]) -> Result<f32> {
134 let mut entropy = 0.0;
135 let sum: f32 = predictions.iter().sum();
136
137 if sum == 0.0 {
138 return Ok(0.0);
139 }
140
141 for &p in predictions {
142 let normalized_p = p / sum;
143 if normalized_p > 0.0 {
144 entropy -= normalized_p * normalized_p.ln();
145 }
146 }
147
148 Ok(entropy)
149 }
150
151 fn calculate_margin(&self, predictions: &[f32]) -> Result<f32> {
153 if predictions.len() < 2 {
154 return Ok(0.0);
155 }
156
157 let mut sorted_preds = predictions.to_vec();
158 sorted_preds.sort_by(|a, b| {
159 b.partial_cmp(a)
160 .expect("partial_cmp should not return None for valid values")
161 });
162
163 Ok(-(sorted_preds[0] - sorted_preds[1]))
165 }
166
167 fn calculate_least_confident(&self, predictions: &[f32]) -> Result<f32> {
169 let max_pred = predictions.iter().max_by(|a, b| {
170 a.partial_cmp(b)
171 .expect("partial_cmp should not return None for valid values")
172 });
173 match max_pred {
174 Some(max_val) => Ok(1.0 - max_val), None => Ok(0.0),
176 }
177 }
178
179 fn calculate_diversity_scores(
181 &self,
182 features: &[Vec<f32>],
183 strategy: &DiversityStrategy,
184 ) -> Result<Vec<f32>> {
185 match strategy {
186 DiversityStrategy::KMeansClustering => self.calculate_kmeans_diversity_scores(features),
187 DiversityStrategy::Representative => self.calculate_representative_scores(features),
188 DiversityStrategy::Hybrid { .. } => {
189 self.calculate_kmeans_diversity_scores(features)
191 }
192 }
193 }
194
195 fn calculate_kmeans_diversity_scores(&self, features: &[Vec<f32>]) -> Result<Vec<f32>> {
197 let k = ((features.len() as f32).sqrt() as usize).max(2);
199 let centroids = self.simple_kmeans(features, k)?;
200
201 let mut scores = Vec::with_capacity(features.len());
202 for feature in features {
203 let min_distance = centroids
205 .iter()
206 .map(|centroid| self.euclidean_distance(feature, centroid))
207 .min_by(|a, b| {
208 a.partial_cmp(b)
209 .expect("partial_cmp should not return None for valid values")
210 })
211 .unwrap_or(0.0);
212
213 scores.push(min_distance);
214 }
215
216 Ok(scores)
217 }
218
219 fn calculate_representative_scores(&self, features: &[Vec<f32>]) -> Result<Vec<f32>> {
221 if features.is_empty() {
222 return Ok(vec![]);
223 }
224
225 let feature_dim = features[0].len();
226 let mut centroid = vec![0.0; feature_dim];
227
228 for feature in features {
230 for (i, &val) in feature.iter().enumerate() {
231 centroid[i] += val;
232 }
233 }
234
235 let n = features.len() as f32;
236 for val in centroid.iter_mut() {
237 *val /= n;
238 }
239
240 let mut scores = Vec::with_capacity(features.len());
242 for feature in features {
243 let distance = self.euclidean_distance(feature, ¢roid);
244 scores.push(distance);
245 }
246
247 Ok(scores)
248 }
249
250 fn simple_kmeans(&self, features: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
252 if features.is_empty() || k == 0 {
253 return Ok(vec![]);
254 }
255
256 let feature_dim = features[0].len();
257 let mut centroids = Vec::with_capacity(k);
258
259 use scirs2_core::random::rand_prelude::*;
261 let mut rng = scirs2_core::random::rng();
262 for _ in 0..k {
263 let random_val: f64 = rng.random();
264 let idx = (random_val * features.len() as f64) as usize;
265 let idx = idx.min(features.len() - 1);
266 centroids.push(features[idx].clone());
267 }
268
269 for _ in 0..10 {
271 let mut new_centroids = vec![vec![0.0; feature_dim]; k];
272 let mut counts = vec![0; k];
273
274 for feature in features {
276 let nearest_idx = centroids
277 .iter()
278 .enumerate()
279 .min_by(|(_, a), (_, b)| {
280 let dist_a = self.euclidean_distance(feature, a);
281 let dist_b = self.euclidean_distance(feature, b);
282 dist_a
283 .partial_cmp(&dist_b)
284 .expect("partial_cmp should not return None for valid values")
285 })
286 .map(|(idx, _)| idx)
287 .unwrap_or(0);
288
289 counts[nearest_idx] += 1;
290 for (i, &val) in feature.iter().enumerate() {
291 new_centroids[nearest_idx][i] += val;
292 }
293 }
294
295 for (i, centroid) in new_centroids.iter_mut().enumerate() {
297 if counts[i] > 0 {
298 for val in centroid.iter_mut() {
299 *val /= counts[i] as f32;
300 }
301 }
302 }
303
304 centroids = new_centroids;
305 }
306
307 Ok(centroids)
308 }
309
310 fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
312 a.iter()
313 .zip(b.iter())
314 .map(|(x, y)| (x - y).powi(2))
315 .sum::<f32>()
316 .sqrt()
317 }
318
319 fn combine_scores(
321 &self,
322 uncertainty_scores: &[f32],
323 diversity_scores: &[f32],
324 ) -> Result<Vec<f32>> {
325 if uncertainty_scores.len() != diversity_scores.len() {
326 return Err(TensorError::invalid_argument(
327 "Uncertainty and diversity scores must have same length".to_string(),
328 ));
329 }
330
331 let mut combined_scores = Vec::with_capacity(uncertainty_scores.len());
332
333 match &self.diversity_strategy {
334 Some(DiversityStrategy::Hybrid {
335 uncertainty_weight,
336 diversity_weight,
337 }) => {
338 let max_uncertainty = uncertainty_scores
340 .iter()
341 .max_by(|a, b| {
342 a.partial_cmp(b)
343 .expect("partial_cmp should not return None for valid values")
344 })
345 .unwrap_or(&1.0);
346 let max_diversity = diversity_scores
347 .iter()
348 .max_by(|a, b| {
349 a.partial_cmp(b)
350 .expect("partial_cmp should not return None for valid values")
351 })
352 .unwrap_or(&1.0);
353
354 for (u_score, d_score) in uncertainty_scores.iter().zip(diversity_scores.iter()) {
355 let normalized_u = u_score / max_uncertainty;
356 let normalized_d = d_score / max_diversity;
357 let combined =
358 uncertainty_weight * normalized_u + diversity_weight * normalized_d;
359 combined_scores.push(combined);
360 }
361 }
362 Some(_) => {
363 for (u_score, d_score) in uncertainty_scores.iter().zip(diversity_scores.iter()) {
365 combined_scores.push(u_score + d_score);
366 }
367 }
368 None => {
369 combined_scores.extend_from_slice(uncertainty_scores);
371 }
372 }
373
374 Ok(combined_scores)
375 }
376}
377
378pub struct ActiveLearningDataset<T, D: Dataset<T>> {
380 dataset: D,
381 labeled_indices: Vec<usize>,
382 unlabeled_indices: Vec<usize>,
383 _phantom: std::marker::PhantomData<T>,
384}
385
386impl<T, D: Dataset<T>> ActiveLearningDataset<T, D> {
387 pub fn new(dataset: D, initial_labeled_indices: Vec<usize>) -> Self {
389 let total_len = dataset.len();
390 let labeled_set: std::collections::HashSet<usize> =
391 initial_labeled_indices.iter().cloned().collect();
392 let unlabeled_indices: Vec<usize> = (0..total_len)
393 .filter(|i| !labeled_set.contains(i))
394 .collect();
395
396 Self {
397 dataset,
398 labeled_indices: initial_labeled_indices,
399 unlabeled_indices,
400 _phantom: std::marker::PhantomData,
401 }
402 }
403
404 pub fn add_labeled_samples(&mut self, indices: Vec<usize>) {
406 let indices_set: std::collections::HashSet<usize> = indices.iter().cloned().collect();
407
408 self.labeled_indices.extend(indices);
410
411 self.unlabeled_indices
413 .retain(|&i| !indices_set.contains(&i));
414 }
415
416 pub fn get_labeled_dataset(&self) -> LabeledSubset<'_, T, D>
418 where
419 D: Clone,
420 {
421 LabeledSubset {
422 dataset: self.dataset.clone(),
423 indices: &self.labeled_indices,
424 _phantom: std::marker::PhantomData,
425 }
426 }
427
428 pub fn get_unlabeled_dataset(&self) -> UnlabeledSubset<'_, T, D>
430 where
431 D: Clone,
432 {
433 UnlabeledSubset {
434 dataset: self.dataset.clone(),
435 indices: &self.unlabeled_indices,
436 _phantom: std::marker::PhantomData,
437 }
438 }
439
440 pub fn labeled_indices(&self) -> &[usize] {
442 &self.labeled_indices
443 }
444
445 pub fn unlabeled_indices(&self) -> &[usize] {
447 &self.unlabeled_indices
448 }
449}
450
451pub struct LabeledSubset<'a, T, D: Dataset<T>> {
453 dataset: D,
454 indices: &'a [usize],
455 _phantom: std::marker::PhantomData<T>,
456}
457
458impl<'a, T, D: Dataset<T>> Dataset<T> for LabeledSubset<'a, T, D> {
459 fn len(&self) -> usize {
460 self.indices.len()
461 }
462
463 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
464 if index >= self.indices.len() {
465 return Err(TensorError::invalid_argument(format!(
466 "Index {} out of bounds for labeled subset of length {}",
467 index,
468 self.indices.len()
469 )));
470 }
471
472 let actual_index = self.indices[index];
473 self.dataset.get(actual_index)
474 }
475}
476
477pub struct UnlabeledSubset<'a, T, D: Dataset<T>> {
479 dataset: D,
480 indices: &'a [usize],
481 _phantom: std::marker::PhantomData<T>,
482}
483
484impl<'a, T, D: Dataset<T>> Dataset<T> for UnlabeledSubset<'a, T, D> {
485 fn len(&self) -> usize {
486 self.indices.len()
487 }
488
489 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
490 if index >= self.indices.len() {
491 return Err(TensorError::invalid_argument(format!(
492 "Index {} out of bounds for unlabeled subset of length {}",
493 index,
494 self.indices.len()
495 )));
496 }
497
498 let actual_index = self.indices[index];
499 self.dataset.get(actual_index)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::TensorDataset;
507 use tenflowers_core::Tensor;
508
509 #[test]
510 fn test_uncertainty_sampling() {
511 let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Entropy, 2);
512
513 let predictions = vec![
515 vec![0.9, 0.1], vec![0.5, 0.5], vec![0.8, 0.2], vec![0.6, 0.4], ];
520
521 let scores = sampler
522 .calculate_uncertainty_scores(&predictions)
523 .expect("test: uncertainty scores should succeed");
524
525 assert!(scores[1] > scores[0]); assert!(scores[3] > scores[2]); }
529
530 #[test]
531 fn test_active_learning_dataset() {
532 let features =
534 Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2])
535 .expect("test: tensor creation should succeed");
536 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 0.0, 1.0], &[4])
537 .expect("test: tensor creation should succeed");
538 let dataset = TensorDataset::new(features, labels);
539
540 let mut al_dataset = ActiveLearningDataset::new(dataset, vec![0, 1]);
542
543 assert_eq!(al_dataset.labeled_indices().len(), 2);
544 assert_eq!(al_dataset.unlabeled_indices().len(), 2);
545
546 al_dataset.add_labeled_samples(vec![2]);
548
549 assert_eq!(al_dataset.labeled_indices().len(), 3);
550 assert_eq!(al_dataset.unlabeled_indices().len(), 1);
551
552 let labeled_subset = al_dataset.get_labeled_dataset();
554 assert_eq!(labeled_subset.len(), 3);
555
556 let unlabeled_subset = al_dataset.get_unlabeled_dataset();
558 assert_eq!(unlabeled_subset.len(), 1);
559 }
560
561 #[test]
562 fn test_diversity_sampling() {
563 let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Entropy, 2)
564 .with_diversity(DiversityStrategy::Representative);
565
566 let features = vec![
568 vec![0.0, 0.0], vec![2.0, 2.0], vec![0.1, 0.1], vec![1.5, 1.5], ];
573
574 let scores = sampler
575 .calculate_diversity_scores(&features, &DiversityStrategy::Representative)
576 .expect("test: operation should succeed");
577
578 assert!(scores[1] > scores[2]); assert!(scores[1] > scores[0]); assert!(scores.len() == 4);
583 assert!(scores.iter().all(|&s| s >= 0.0)); }
585
586 #[test]
587 fn test_margin_uncertainty() {
588 let sampler = ActiveLearningSampler::new(UncertaintyStrategy::Margin, 2);
589
590 let predictions = vec![
591 vec![0.9, 0.1], vec![0.51, 0.49], vec![0.8, 0.2], ];
595
596 let scores = sampler
597 .calculate_uncertainty_scores(&predictions)
598 .expect("test: uncertainty scores should succeed");
599
600 assert!(scores[1] > scores[0]); assert!(scores[2] > scores[0]); }
604}