sklears_semi_supervised/
self_training_classifier.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, Untrained},
7 types::Float,
8};
9use std::collections::HashSet;
10
11#[derive(Debug, Clone)]
44pub struct SelfTrainingClassifier<S = Untrained> {
45 state: S,
46 threshold: f64,
47 criterion: String,
48 k_best: usize,
49 max_iter: usize,
50 verbose: bool,
51}
52
53impl SelfTrainingClassifier<Untrained> {
54 pub fn new() -> Self {
56 Self {
57 state: Untrained,
58 threshold: 0.75,
59 criterion: "threshold".to_string(),
60 k_best: 10,
61 max_iter: 10,
62 verbose: false,
63 }
64 }
65
66 pub fn threshold(mut self, threshold: f64) -> Self {
68 self.threshold = threshold;
69 self
70 }
71
72 pub fn criterion(mut self, criterion: String) -> Self {
74 self.criterion = criterion;
75 self
76 }
77
78 pub fn k_best(mut self, k_best: usize) -> Self {
80 self.k_best = k_best;
81 self
82 }
83
84 pub fn max_iter(mut self, max_iter: usize) -> Self {
86 self.max_iter = max_iter;
87 self
88 }
89
90 pub fn verbose(mut self, verbose: bool) -> Self {
92 self.verbose = verbose;
93 self
94 }
95}
96
97impl Default for SelfTrainingClassifier<Untrained> {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl Estimator for SelfTrainingClassifier<Untrained> {
104 type Config = ();
105 type Error = SklearsError;
106 type Float = Float;
107
108 fn config(&self) -> &Self::Config {
109 &()
110 }
111}
112
113impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for SelfTrainingClassifier<Untrained> {
114 type Fitted = SelfTrainingClassifier<SelfTrainingTrained>;
115
116 #[allow(non_snake_case)]
117 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
118 let X = X.to_owned();
119 let mut y = y.to_owned();
120
121 let mut labeled_mask = Array1::from_elem(y.len(), false);
123 let mut classes = HashSet::new();
124
125 for (i, &label) in y.iter().enumerate() {
126 if label != -1 {
127 labeled_mask[i] = true;
128 classes.insert(label);
129 }
130 }
131
132 if labeled_mask.iter().all(|&x| !x) {
133 return Err(SklearsError::InvalidInput(
134 "No labeled samples provided".to_string(),
135 ));
136 }
137
138 let classes: Vec<i32> = classes.into_iter().collect();
139
140 for _iter in 0..self.max_iter {
142 let labeled_indices: Vec<usize> = labeled_mask
144 .iter()
145 .enumerate()
146 .filter(|(_, &is_labeled)| is_labeled)
147 .map(|(i, _)| i)
148 .collect();
149
150 if labeled_indices.len() == y.len() {
151 break; }
153
154 let _X_labeled: Vec<Vec<f64>> =
156 labeled_indices.iter().map(|&i| X.row(i).to_vec()).collect();
157 let y_labeled: Vec<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
158
159 let mut new_labels = Vec::new();
161 let mut confidences = Vec::new();
162
163 for (i, &is_labeled) in labeled_mask.iter().enumerate() {
164 if !is_labeled {
165 let mut min_dist = f64::INFINITY;
167 let mut best_label = 0;
168
169 for (j, &labeled_idx) in labeled_indices.iter().enumerate() {
170 let diff = &X.row(i) - &X.row(labeled_idx);
171 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
172 if dist < min_dist {
173 min_dist = dist;
174 best_label = y_labeled[j];
175 }
176 }
177
178 let confidence = 1.0 / (1.0 + min_dist);
180 new_labels.push((i, best_label, confidence));
181 confidences.push(confidence);
182 }
183 }
184
185 let mut selected_indices = Vec::new();
187
188 match self.criterion.as_str() {
189 "threshold" => {
190 for (i, label, confidence) in new_labels {
191 if confidence >= self.threshold {
192 selected_indices.push((i, label));
193 }
194 }
195 }
196 "k_best" => {
197 new_labels.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
198 for (i, label, _) in new_labels.into_iter().take(self.k_best) {
199 selected_indices.push((i, label));
200 }
201 }
202 _ => {
203 return Err(SklearsError::InvalidInput(format!(
204 "Unknown criterion: {}",
205 self.criterion
206 )));
207 }
208 }
209
210 if selected_indices.is_empty() {
211 break; }
213
214 for (i, label) in selected_indices {
216 y[i] = label;
217 labeled_mask[i] = true;
218 }
219
220 if self.verbose {
221 let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
222 println!("Iteration {}: {} labeled samples", _iter + 1, n_labeled);
223 }
224 }
225
226 Ok(SelfTrainingClassifier {
227 state: SelfTrainingTrained {
228 X_train: X.clone(),
229 y_train: y,
230 classes: Array1::from(classes),
231 labeled_mask,
232 },
233 threshold: self.threshold,
234 criterion: self.criterion,
235 k_best: self.k_best,
236 max_iter: self.max_iter,
237 verbose: self.verbose,
238 })
239 }
240}
241
242impl Predict<ArrayView2<'_, Float>, Array1<i32>> for SelfTrainingClassifier<SelfTrainingTrained> {
243 #[allow(non_snake_case)]
244 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
245 let X = X.to_owned();
246 let n_test = X.nrows();
247 let mut predictions = Array1::zeros(n_test);
248
249 let labeled_indices: Vec<usize> = self
251 .state
252 .labeled_mask
253 .iter()
254 .enumerate()
255 .filter(|(_, &is_labeled)| is_labeled)
256 .map(|(i, _)| i)
257 .collect();
258
259 for i in 0..n_test {
260 let mut min_dist = f64::INFINITY;
262 let mut best_label = 0;
263
264 for &labeled_idx in &labeled_indices {
265 let diff = &X.row(i) - &self.state.X_train.row(labeled_idx);
266 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
267 if dist < min_dist {
268 min_dist = dist;
269 best_label = self.state.y_train[labeled_idx];
270 }
271 }
272
273 predictions[i] = best_label;
274 }
275
276 Ok(predictions)
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct SelfTrainingTrained {
283 pub X_train: Array2<f64>,
285 pub y_train: Array1<i32>,
287 pub classes: Array1<i32>,
289 pub labeled_mask: Array1<bool>,
291}