sklears_semi_supervised/batch_active_learning/
core_set.rs

1//! Core-set approach implementation for batch active learning
2
3use super::{BatchActiveLearningError, *};
4
5/// Core-set approach for batch active learning
6///
7/// This method selects a batch of samples that best represents the unlabeled data
8/// distribution, acting as a "core-set" or summary of the data.
9#[derive(Debug, Clone)]
10pub struct CoreSetApproach {
11    /// batch_size
12    pub batch_size: usize,
13    /// distance_metric
14    pub distance_metric: String,
15    /// initialization
16    pub initialization: String,
17    /// max_iter
18    pub max_iter: usize,
19    /// random_state
20    pub random_state: Option<u64>,
21}
22
23impl Default for CoreSetApproach {
24    fn default() -> Self {
25        Self {
26            batch_size: 10,
27            distance_metric: "euclidean".to_string(),
28            initialization: "farthest_first".to_string(),
29            max_iter: 100,
30            random_state: None,
31        }
32    }
33}
34
35impl CoreSetApproach {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
41        if batch_size == 0 {
42            return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
43        }
44        self.batch_size = batch_size;
45        Ok(self)
46    }
47
48    pub fn distance_metric(mut self, distance_metric: String) -> Self {
49        self.distance_metric = distance_metric;
50        self
51    }
52
53    pub fn initialization(mut self, initialization: String) -> Self {
54        self.initialization = initialization;
55        self
56    }
57
58    pub fn max_iter(mut self, max_iter: usize) -> Self {
59        self.max_iter = max_iter;
60        self
61    }
62
63    pub fn random_state(mut self, random_state: u64) -> Self {
64        self.random_state = Some(random_state);
65        self
66    }
67
68    fn compute_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> Result<f64> {
69        match self.distance_metric.as_str() {
70            "euclidean" => {
71                let dist = x1
72                    .iter()
73                    .zip(x2.iter())
74                    .map(|(a, b)| (a - b).powi(2))
75                    .sum::<f64>()
76                    .sqrt();
77                Ok(dist)
78            }
79            "manhattan" => {
80                let dist = x1
81                    .iter()
82                    .zip(x2.iter())
83                    .map(|(a, b)| (a - b).abs())
84                    .sum::<f64>();
85                Ok(dist)
86            }
87            _ => Err(
88                BatchActiveLearningError::InvalidDistanceMetric(self.distance_metric.clone())
89                    .into(),
90            ),
91        }
92    }
93
94    fn farthest_first_initialization(&self, X: &ArrayView2<f64>) -> Result<Vec<usize>> {
95        let n_samples = X.dim().0;
96        let mut rng = match self.random_state {
97            Some(seed) => Random::seed(seed),
98            None => Random::seed(42),
99        };
100
101        if n_samples < self.batch_size {
102            return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
103        }
104
105        let mut selected_indices = Vec::new();
106        let mut distances = vec![f64::INFINITY; n_samples];
107
108        // Select first point randomly
109        let first_idx = rng.gen_range(0..n_samples);
110        selected_indices.push(first_idx);
111
112        // Update distances to first point
113        for (i, dist) in distances.iter_mut().enumerate() {
114            if i != first_idx {
115                *dist = self.compute_distance(&X.row(i), &X.row(first_idx))?;
116            }
117        }
118
119        // Select remaining points
120        for _ in 1..self.batch_size {
121            // Find point with maximum distance to nearest selected point
122            let mut max_distance = 0.0;
123            let mut best_idx = 0;
124
125            for (i, &dist) in distances.iter().enumerate() {
126                if !selected_indices.contains(&i) && dist > max_distance {
127                    max_distance = dist;
128                    best_idx = i;
129                }
130            }
131
132            selected_indices.push(best_idx);
133
134            // Update distances
135            for (i, dist) in distances.iter_mut().enumerate() {
136                if !selected_indices.contains(&i) {
137                    let new_distance = self.compute_distance(&X.row(i), &X.row(best_idx))?;
138                    *dist = (*dist).min(new_distance);
139                }
140            }
141        }
142
143        Ok(selected_indices)
144    }
145
146    fn k_center_greedy(&self, X: &ArrayView2<f64>) -> Result<Vec<usize>> {
147        let n_samples = X.dim().0;
148        let mut selected_indices = Vec::new();
149        let mut distances = vec![f64::INFINITY; n_samples];
150
151        if n_samples < self.batch_size {
152            return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
153        }
154
155        // Select first point (center of data)
156        let mut centroid = Array1::zeros(X.dim().1);
157        for i in 0..n_samples {
158            centroid = centroid + X.row(i);
159        }
160        centroid /= n_samples as f64;
161
162        // Find closest point to centroid
163        let mut min_distance = f64::INFINITY;
164        let mut first_idx = 0;
165        for i in 0..n_samples {
166            let distance = self.compute_distance(&X.row(i), &centroid.view())?;
167            if distance < min_distance {
168                min_distance = distance;
169                first_idx = i;
170            }
171        }
172
173        selected_indices.push(first_idx);
174
175        // Update distances to first point
176        for (i, dist) in distances.iter_mut().enumerate() {
177            if i != first_idx {
178                *dist = self.compute_distance(&X.row(i), &X.row(first_idx))?;
179            }
180        }
181
182        // Greedily select remaining points
183        for _ in 1..self.batch_size {
184            let mut max_distance = 0.0;
185            let mut best_idx = 0;
186
187            for (i, &dist) in distances.iter().enumerate() {
188                if !selected_indices.contains(&i) && dist > max_distance {
189                    max_distance = dist;
190                    best_idx = i;
191                }
192            }
193
194            selected_indices.push(best_idx);
195
196            // Update distances
197            for (i, dist) in distances.iter_mut().enumerate() {
198                if !selected_indices.contains(&i) {
199                    let new_distance = self.compute_distance(&X.row(i), &X.row(best_idx))?;
200                    *dist = (*dist).min(new_distance);
201                }
202            }
203        }
204
205        Ok(selected_indices)
206    }
207
208    pub fn query(
209        &self,
210        X: &ArrayView2<f64>,
211        _probabilities: &ArrayView2<f64>,
212    ) -> Result<Vec<usize>> {
213        match self.initialization.as_str() {
214            "farthest_first" => self.farthest_first_initialization(X),
215            "k_center_greedy" => self.k_center_greedy(X),
216            _ => self.farthest_first_initialization(X), // default
217        }
218    }
219}