sklears_semi_supervised/
self_training_classifier.rs

1//! Self-Training Classifier implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, Untrained},
7    types::Float,
8};
9use std::collections::HashSet;
10
11/// Self-Training Classifier
12///
13/// Self-training is a wrapper method for semi-supervised learning where a
14/// supervised classifier is trained on labeled data, then used to classify
15/// unlabeled data. The most confident predictions are added to the training set.
16///
17/// # Parameters
18///
19/// * `base_classifier` - The base classifier to use
20/// * `threshold` - Confidence threshold for pseudo-labeling
21/// * `criterion` - Criterion for selecting samples ('threshold' or 'k_best')
22/// * `k_best` - Number of best samples to select per iteration
23/// * `max_iter` - Maximum number of iterations
24/// * `verbose` - Whether to print progress information
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::array;
30/// use sklears_semi_supervised::SelfTrainingClassifier;
31/// use sklears_core::traits::{Predict, Fit};
32///
33///
34/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
35/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
36///
37/// let stc = SelfTrainingClassifier::new()
38///     .threshold(0.75)
39///     .max_iter(10);
40/// let fitted = stc.fit(&X.view(), &y.view()).unwrap();
41/// let predictions = fitted.predict(&X.view()).unwrap();
42/// ```
43#[derive(Debug, Clone)]
44pub struct SelfTrainingClassifier<S = Untrained> {
45    state: S,
46    threshold: f64,
47    criterion: String,
48    k_best: usize,
49    max_iter: usize,
50    verbose: bool,
51}
52
53impl SelfTrainingClassifier<Untrained> {
54    /// Create a new SelfTrainingClassifier instance
55    pub fn new() -> Self {
56        Self {
57            state: Untrained,
58            threshold: 0.75,
59            criterion: "threshold".to_string(),
60            k_best: 10,
61            max_iter: 10,
62            verbose: false,
63        }
64    }
65
66    /// Set the confidence threshold
67    pub fn threshold(mut self, threshold: f64) -> Self {
68        self.threshold = threshold;
69        self
70    }
71
72    /// Set the selection criterion
73    pub fn criterion(mut self, criterion: String) -> Self {
74        self.criterion = criterion;
75        self
76    }
77
78    /// Set the number of best samples to select
79    pub fn k_best(mut self, k_best: usize) -> Self {
80        self.k_best = k_best;
81        self
82    }
83
84    /// Set the maximum number of iterations
85    pub fn max_iter(mut self, max_iter: usize) -> Self {
86        self.max_iter = max_iter;
87        self
88    }
89
90    /// Set verbosity
91    pub fn verbose(mut self, verbose: bool) -> Self {
92        self.verbose = verbose;
93        self
94    }
95}
96
97impl Default for SelfTrainingClassifier<Untrained> {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl Estimator for SelfTrainingClassifier<Untrained> {
104    type Config = ();
105    type Error = SklearsError;
106    type Float = Float;
107
108    fn config(&self) -> &Self::Config {
109        &()
110    }
111}
112
113impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for SelfTrainingClassifier<Untrained> {
114    type Fitted = SelfTrainingClassifier<SelfTrainingTrained>;
115
116    #[allow(non_snake_case)]
117    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
118        let X = X.to_owned();
119        let mut y = y.to_owned();
120
121        // Identify labeled and unlabeled samples
122        let mut labeled_mask = Array1::from_elem(y.len(), false);
123        let mut classes = HashSet::new();
124
125        for (i, &label) in y.iter().enumerate() {
126            if label != -1 {
127                labeled_mask[i] = true;
128                classes.insert(label);
129            }
130        }
131
132        if labeled_mask.iter().all(|&x| !x) {
133            return Err(SklearsError::InvalidInput(
134                "No labeled samples provided".to_string(),
135            ));
136        }
137
138        let classes: Vec<i32> = classes.into_iter().collect();
139
140        // Simple self-training iteration
141        for _iter in 0..self.max_iter {
142            // Find labeled samples
143            let labeled_indices: Vec<usize> = labeled_mask
144                .iter()
145                .enumerate()
146                .filter(|(_, &is_labeled)| is_labeled)
147                .map(|(i, _)| i)
148                .collect();
149
150            if labeled_indices.len() == y.len() {
151                break; // All samples are labeled
152            }
153
154            // Extract labeled data
155            let _X_labeled: Vec<Vec<f64>> =
156                labeled_indices.iter().map(|&i| X.row(i).to_vec()).collect();
157            let y_labeled: Vec<i32> = labeled_indices.iter().map(|&i| y[i]).collect();
158
159            // Simple nearest neighbor classifier for pseudo-labeling
160            let mut new_labels = Vec::new();
161            let mut confidences = Vec::new();
162
163            for (i, &is_labeled) in labeled_mask.iter().enumerate() {
164                if !is_labeled {
165                    // Find nearest labeled neighbor
166                    let mut min_dist = f64::INFINITY;
167                    let mut best_label = 0;
168
169                    for (j, &labeled_idx) in labeled_indices.iter().enumerate() {
170                        let diff = &X.row(i) - &X.row(labeled_idx);
171                        let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
172                        if dist < min_dist {
173                            min_dist = dist;
174                            best_label = y_labeled[j];
175                        }
176                    }
177
178                    // Simple confidence based on distance
179                    let confidence = 1.0 / (1.0 + min_dist);
180                    new_labels.push((i, best_label, confidence));
181                    confidences.push(confidence);
182                }
183            }
184
185            // Select samples to pseudo-label
186            let mut selected_indices = Vec::new();
187
188            match self.criterion.as_str() {
189                "threshold" => {
190                    for (i, label, confidence) in new_labels {
191                        if confidence >= self.threshold {
192                            selected_indices.push((i, label));
193                        }
194                    }
195                }
196                "k_best" => {
197                    new_labels.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
198                    for (i, label, _) in new_labels.into_iter().take(self.k_best) {
199                        selected_indices.push((i, label));
200                    }
201                }
202                _ => {
203                    return Err(SklearsError::InvalidInput(format!(
204                        "Unknown criterion: {}",
205                        self.criterion
206                    )));
207                }
208            }
209
210            if selected_indices.is_empty() {
211                break; // No confident predictions
212            }
213
214            // Add pseudo-labels
215            for (i, label) in selected_indices {
216                y[i] = label;
217                labeled_mask[i] = true;
218            }
219
220            if self.verbose {
221                let n_labeled = labeled_mask.iter().filter(|&&x| x).count();
222                println!("Iteration {}: {} labeled samples", _iter + 1, n_labeled);
223            }
224        }
225
226        Ok(SelfTrainingClassifier {
227            state: SelfTrainingTrained {
228                X_train: X.clone(),
229                y_train: y,
230                classes: Array1::from(classes),
231                labeled_mask,
232            },
233            threshold: self.threshold,
234            criterion: self.criterion,
235            k_best: self.k_best,
236            max_iter: self.max_iter,
237            verbose: self.verbose,
238        })
239    }
240}
241
242impl Predict<ArrayView2<'_, Float>, Array1<i32>> for SelfTrainingClassifier<SelfTrainingTrained> {
243    #[allow(non_snake_case)]
244    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
245        let X = X.to_owned();
246        let n_test = X.nrows();
247        let mut predictions = Array1::zeros(n_test);
248
249        // Get labeled training samples
250        let labeled_indices: Vec<usize> = self
251            .state
252            .labeled_mask
253            .iter()
254            .enumerate()
255            .filter(|(_, &is_labeled)| is_labeled)
256            .map(|(i, _)| i)
257            .collect();
258
259        for i in 0..n_test {
260            // Find nearest labeled neighbor
261            let mut min_dist = f64::INFINITY;
262            let mut best_label = 0;
263
264            for &labeled_idx in &labeled_indices {
265                let diff = &X.row(i) - &self.state.X_train.row(labeled_idx);
266                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
267                if dist < min_dist {
268                    min_dist = dist;
269                    best_label = self.state.y_train[labeled_idx];
270                }
271            }
272
273            predictions[i] = best_label;
274        }
275
276        Ok(predictions)
277    }
278}
279
280/// Trained state for SelfTrainingClassifier
281#[derive(Debug, Clone)]
282pub struct SelfTrainingTrained {
283    /// X_train
284    pub X_train: Array2<f64>,
285    /// y_train
286    pub y_train: Array1<i32>,
287    /// classes
288    pub classes: Array1<i32>,
289    /// labeled_mask
290    pub labeled_mask: Array1<bool>,
291}