sklears_semi_supervised/batch_active_learning/
diversity_based.rs1use super::{BatchActiveLearningError, *};
4use scirs2_core::rand_prelude::IndexedRandom;
5
6#[derive(Debug, Clone)]
12pub struct DiversityBasedSampling {
13 pub batch_size: usize,
15 pub diversity_method: String,
17 pub distance_metric: String,
19 pub kernel_bandwidth: f64,
21 pub regularization: f64,
23 pub max_iter: usize,
25 pub random_state: Option<u64>,
27}
28
29impl Default for DiversityBasedSampling {
30 fn default() -> Self {
31 Self {
32 batch_size: 10,
33 diversity_method: "mmr".to_string(),
34 distance_metric: "euclidean".to_string(),
35 kernel_bandwidth: 1.0,
36 regularization: 1e-6,
37 max_iter: 100,
38 random_state: None,
39 }
40 }
41}
42
43impl DiversityBasedSampling {
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
49 if batch_size == 0 {
50 return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
51 }
52 self.batch_size = batch_size;
53 Ok(self)
54 }
55
56 pub fn diversity_method(mut self, diversity_method: String) -> Self {
57 self.diversity_method = diversity_method;
58 self
59 }
60
61 pub fn distance_metric(mut self, distance_metric: String) -> Self {
62 self.distance_metric = distance_metric;
63 self
64 }
65
66 pub fn kernel_bandwidth(mut self, kernel_bandwidth: f64) -> Self {
67 self.kernel_bandwidth = kernel_bandwidth;
68 self
69 }
70
71 pub fn regularization(mut self, regularization: f64) -> Self {
72 self.regularization = regularization;
73 self
74 }
75
76 pub fn max_iter(mut self, max_iter: usize) -> Self {
77 self.max_iter = max_iter;
78 self
79 }
80
81 pub fn random_state(mut self, random_state: u64) -> Self {
82 self.random_state = Some(random_state);
83 self
84 }
85
86 pub fn query(
87 &self,
88 X: &ArrayView2<f64>,
89 _probabilities: &ArrayView2<f64>,
90 ) -> Result<Vec<usize>> {
91 let n_samples = X.dim().0;
92
93 if n_samples < self.batch_size {
94 return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
95 }
96
97 let mut rng = match self.random_state {
99 Some(seed) => Random::seed(seed),
100 None => Random::seed(42),
101 };
102
103 let indices: Vec<usize> = (0..n_samples).collect();
104 let selected: Vec<usize> = indices
105 .choose_multiple(&mut rng, self.batch_size)
106 .cloned()
107 .collect();
108
109 Ok(selected)
110 }
111}