sklears_semi_supervised/batch_active_learning/
core_set.rs1use super::{BatchActiveLearningError, *};
4
5#[derive(Debug, Clone)]
10pub struct CoreSetApproach {
11 pub batch_size: usize,
13 pub distance_metric: String,
15 pub initialization: String,
17 pub max_iter: usize,
19 pub random_state: Option<u64>,
21}
22
23impl Default for CoreSetApproach {
24 fn default() -> Self {
25 Self {
26 batch_size: 10,
27 distance_metric: "euclidean".to_string(),
28 initialization: "farthest_first".to_string(),
29 max_iter: 100,
30 random_state: None,
31 }
32 }
33}
34
35impl CoreSetApproach {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
41 if batch_size == 0 {
42 return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
43 }
44 self.batch_size = batch_size;
45 Ok(self)
46 }
47
48 pub fn distance_metric(mut self, distance_metric: String) -> Self {
49 self.distance_metric = distance_metric;
50 self
51 }
52
53 pub fn initialization(mut self, initialization: String) -> Self {
54 self.initialization = initialization;
55 self
56 }
57
58 pub fn max_iter(mut self, max_iter: usize) -> Self {
59 self.max_iter = max_iter;
60 self
61 }
62
63 pub fn random_state(mut self, random_state: u64) -> Self {
64 self.random_state = Some(random_state);
65 self
66 }
67
68 fn compute_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> Result<f64> {
69 match self.distance_metric.as_str() {
70 "euclidean" => {
71 let dist = x1
72 .iter()
73 .zip(x2.iter())
74 .map(|(a, b)| (a - b).powi(2))
75 .sum::<f64>()
76 .sqrt();
77 Ok(dist)
78 }
79 "manhattan" => {
80 let dist = x1
81 .iter()
82 .zip(x2.iter())
83 .map(|(a, b)| (a - b).abs())
84 .sum::<f64>();
85 Ok(dist)
86 }
87 _ => Err(
88 BatchActiveLearningError::InvalidDistanceMetric(self.distance_metric.clone())
89 .into(),
90 ),
91 }
92 }
93
94 fn farthest_first_initialization(&self, X: &ArrayView2<f64>) -> Result<Vec<usize>> {
95 let n_samples = X.dim().0;
96 let mut rng = match self.random_state {
97 Some(seed) => Random::seed(seed),
98 None => Random::seed(42),
99 };
100
101 if n_samples < self.batch_size {
102 return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
103 }
104
105 let mut selected_indices = Vec::new();
106 let mut distances = vec![f64::INFINITY; n_samples];
107
108 let first_idx = rng.gen_range(0..n_samples);
110 selected_indices.push(first_idx);
111
112 for (i, dist) in distances.iter_mut().enumerate() {
114 if i != first_idx {
115 *dist = self.compute_distance(&X.row(i), &X.row(first_idx))?;
116 }
117 }
118
119 for _ in 1..self.batch_size {
121 let mut max_distance = 0.0;
123 let mut best_idx = 0;
124
125 for (i, &dist) in distances.iter().enumerate() {
126 if !selected_indices.contains(&i) && dist > max_distance {
127 max_distance = dist;
128 best_idx = i;
129 }
130 }
131
132 selected_indices.push(best_idx);
133
134 for (i, dist) in distances.iter_mut().enumerate() {
136 if !selected_indices.contains(&i) {
137 let new_distance = self.compute_distance(&X.row(i), &X.row(best_idx))?;
138 *dist = (*dist).min(new_distance);
139 }
140 }
141 }
142
143 Ok(selected_indices)
144 }
145
146 fn k_center_greedy(&self, X: &ArrayView2<f64>) -> Result<Vec<usize>> {
147 let n_samples = X.dim().0;
148 let mut selected_indices = Vec::new();
149 let mut distances = vec![f64::INFINITY; n_samples];
150
151 if n_samples < self.batch_size {
152 return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
153 }
154
155 let mut centroid = Array1::zeros(X.dim().1);
157 for i in 0..n_samples {
158 centroid = centroid + X.row(i);
159 }
160 centroid /= n_samples as f64;
161
162 let mut min_distance = f64::INFINITY;
164 let mut first_idx = 0;
165 for i in 0..n_samples {
166 let distance = self.compute_distance(&X.row(i), ¢roid.view())?;
167 if distance < min_distance {
168 min_distance = distance;
169 first_idx = i;
170 }
171 }
172
173 selected_indices.push(first_idx);
174
175 for (i, dist) in distances.iter_mut().enumerate() {
177 if i != first_idx {
178 *dist = self.compute_distance(&X.row(i), &X.row(first_idx))?;
179 }
180 }
181
182 for _ in 1..self.batch_size {
184 let mut max_distance = 0.0;
185 let mut best_idx = 0;
186
187 for (i, &dist) in distances.iter().enumerate() {
188 if !selected_indices.contains(&i) && dist > max_distance {
189 max_distance = dist;
190 best_idx = i;
191 }
192 }
193
194 selected_indices.push(best_idx);
195
196 for (i, dist) in distances.iter_mut().enumerate() {
198 if !selected_indices.contains(&i) {
199 let new_distance = self.compute_distance(&X.row(i), &X.row(best_idx))?;
200 *dist = (*dist).min(new_distance);
201 }
202 }
203 }
204
205 Ok(selected_indices)
206 }
207
208 pub fn query(
209 &self,
210 X: &ArrayView2<f64>,
211 _probabilities: &ArrayView2<f64>,
212 ) -> Result<Vec<usize>> {
213 match self.initialization.as_str() {
214 "farthest_first" => self.farthest_first_initialization(X),
215 "k_center_greedy" => self.k_center_greedy(X),
216 _ => self.farthest_first_initialization(X), }
218 }
219}