sklears_semi_supervised/
tri_training.rs

1//! Tri-Training implementation for semi-supervised learning
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use scirs2_core::random::Random;
5use sklears_core::{
6    error::{Result as SklResult, SklearsError},
7    traits::{Estimator, Fit, Predict, Untrained},
8    types::Float,
9};
10use std::collections::{HashMap, HashSet};
11
12/// Tri-Training classifier for semi-supervised learning
13///
14/// Tri-training uses three classifiers trained on different bootstrap samples
15/// of the labeled data. Two classifiers vote to label examples for the third.
16///
17/// # Parameters
18///
19/// * `max_iter` - Maximum number of iterations
20/// * `verbose` - Whether to print progress information
21/// * `theta` - Threshold for noise rate estimation
22///
23/// # Examples
24///
25/// ```
26/// use scirs2_core::array;
27/// use sklears_semi_supervised::TriTraining;
28/// use sklears_core::traits::{Predict, Fit};
29///
30///
31/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
32/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
33///
34/// let tt = TriTraining::new()
35///     .max_iter(10)
36///     .verbose(true);
37/// let fitted = tt.fit(&X.view(), &y.view()).unwrap();
38/// let predictions = fitted.predict(&X.view()).unwrap();
39/// ```
40#[derive(Debug, Clone)]
41pub struct TriTraining<S = Untrained> {
42    state: S,
43    max_iter: usize,
44    verbose: bool,
45    theta: f64,
46}
47
48impl TriTraining<Untrained> {
49    /// Create a new TriTraining instance
50    pub fn new() -> Self {
51        Self {
52            state: Untrained,
53            max_iter: 30,
54            verbose: false,
55            theta: 0.1,
56        }
57    }
58
59    /// Set the maximum number of iterations
60    pub fn max_iter(mut self, max_iter: usize) -> Self {
61        self.max_iter = max_iter;
62        self
63    }
64
65    /// Set verbosity
66    pub fn verbose(mut self, verbose: bool) -> Self {
67        self.verbose = verbose;
68        self
69    }
70
71    /// Set theta parameter for noise rate estimation
72    pub fn theta(mut self, theta: f64) -> Self {
73        self.theta = theta;
74        self
75    }
76
77    fn bootstrap_sample(
78        &self,
79        X: &Array2<f64>,
80        y: &Array1<i32>,
81        labeled_indices: &[usize],
82    ) -> (Array2<f64>, Array1<i32>) {
83        let n_labeled = labeled_indices.len();
84        let mut bootstrap_X = Array2::zeros((n_labeled, X.ncols()));
85        let mut bootstrap_y = Array1::zeros(n_labeled);
86
87        // Bootstrap sampling with replacement
88        let mut rng = Random::seed(42);
89        for i in 0..n_labeled {
90            let random_idx = rng.gen_range(0..n_labeled);
91            let idx = labeled_indices[random_idx];
92            bootstrap_X.row_mut(i).assign(&X.row(idx));
93            bootstrap_y[i] = y[idx];
94        }
95
96        (bootstrap_X, bootstrap_y)
97    }
98
99    fn simple_classifier_fit_predict(
100        &self,
101        X_train: &Array2<f64>,
102        y_train: &Array1<i32>,
103        X_test: &Array2<f64>,
104    ) -> Array1<i32> {
105        let n_test = X_test.nrows();
106        let mut predictions = Array1::zeros(n_test);
107
108        for i in 0..n_test {
109            // Simple k-NN classifier
110            let mut distances: Vec<(f64, i32)> = Vec::new();
111            for j in 0..X_train.nrows() {
112                let diff = &X_test.row(i) - &X_train.row(j);
113                let dist = diff.mapv(|x| x * x).sum().sqrt();
114                distances.push((dist, y_train[j]));
115            }
116
117            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
118
119            // Use k=5 nearest neighbors with majority vote
120            let k = distances.len().clamp(1, 5);
121            let mut class_votes: HashMap<i32, usize> = HashMap::new();
122
123            for &(_, label) in distances.iter().take(k) {
124                *class_votes.entry(label).or_insert(0) += 1;
125            }
126
127            let best_class = class_votes
128                .iter()
129                .max_by_key(|(_, &count)| count)
130                .map(|(&class, _)| class)
131                .unwrap_or(y_train[0]);
132
133            predictions[i] = best_class;
134        }
135
136        predictions
137    }
138
139    fn estimate_error_rate(
140        &self,
141        classifier_i: &Array2<f64>,
142        y_i: &Array1<i32>,
143        classifier_j: &Array2<f64>,
144        y_j: &Array1<i32>,
145        X_labeled: &Array2<f64>,
146        y_labeled: &Array1<i32>,
147    ) -> f64 {
148        let n_labeled = X_labeled.nrows();
149        let mut errors = 0;
150        let mut total = 0;
151
152        for k in 0..n_labeled {
153            let test_sample = X_labeled
154                .row(k)
155                .to_owned()
156                .insert_axis(scirs2_core::ndarray::Axis(0));
157            let pred_i = self.simple_classifier_fit_predict(classifier_i, y_i, &test_sample);
158            let pred_j = self.simple_classifier_fit_predict(classifier_j, y_j, &test_sample);
159
160            if pred_i[0] == pred_j[0] {
161                total += 1;
162                if pred_i[0] != y_labeled[k] {
163                    errors += 1;
164                }
165            }
166        }
167
168        if total > 0 {
169            errors as f64 / total as f64
170        } else {
171            1.0 // Conservative estimate if no agreement
172        }
173    }
174}
175
176impl Default for TriTraining<Untrained> {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl Estimator for TriTraining<Untrained> {
183    type Config = ();
184    type Error = SklearsError;
185    type Float = Float;
186
187    fn config(&self) -> &Self::Config {
188        &()
189    }
190}
191
192impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for TriTraining<Untrained> {
193    type Fitted = TriTraining<TriTrainingTrained>;
194
195    #[allow(non_snake_case)]
196    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
197        let X = X.to_owned();
198        let mut y = y.to_owned();
199
200        // Identify labeled and unlabeled samples
201        let mut labeled_indices: Vec<usize> = y
202            .iter()
203            .enumerate()
204            .filter(|(_, &label)| label != -1)
205            .map(|(i, _)| i)
206            .collect();
207
208        let mut unlabeled_indices: Vec<usize> = y
209            .iter()
210            .enumerate()
211            .filter(|(_, &label)| label == -1)
212            .map(|(i, _)| i)
213            .collect();
214
215        if labeled_indices.is_empty() {
216            return Err(SklearsError::InvalidInput(
217                "No labeled samples provided".to_string(),
218            ));
219        }
220
221        let mut classes = HashSet::new();
222        for &idx in &labeled_indices {
223            classes.insert(y[idx]);
224        }
225        let classes: Vec<i32> = classes.into_iter().collect();
226
227        // Initialize three classifiers with bootstrap samples
228        let mut classifiers: Vec<(Array2<f64>, Array1<i32>)> = Vec::new();
229        for _ in 0..3 {
230            let (bootstrap_X, bootstrap_y) = self.bootstrap_sample(&X, &y, &labeled_indices);
231            classifiers.push((bootstrap_X, bootstrap_y));
232        }
233
234        let mut e_prime = [0.5; 3]; // Previous error rates
235        let mut l_prime = [0; 3]; // Previous unlabeled set sizes
236
237        // Tri-training iterations
238        for iter in 0..self.max_iter {
239            let mut any_changes = false;
240
241            for i in 0..3 {
242                let j = (i + 1) % 3;
243                let k = (i + 2) % 3;
244
245                // Extract labeled data for current training set
246                let X_labeled_i: Vec<Vec<f64>> = labeled_indices
247                    .iter()
248                    .map(|&idx| X.row(idx).to_vec())
249                    .collect();
250                let y_labeled_i: Vec<i32> = labeled_indices.iter().map(|&idx| y[idx]).collect();
251
252                let X_labeled_array = Array2::from_shape_vec(
253                    (X_labeled_i.len(), X.ncols()),
254                    X_labeled_i.into_iter().flatten().collect(),
255                )
256                .map_err(|_| {
257                    SklearsError::InvalidInput("Failed to create labeled training data".to_string())
258                })?;
259
260                let y_labeled_array = Array1::from(y_labeled_i);
261
262                // Estimate error rate between classifiers j and k
263                let e_jk = self.estimate_error_rate(
264                    &classifiers[j].0,
265                    &classifiers[j].1,
266                    &classifiers[k].0,
267                    &classifiers[k].1,
268                    &X_labeled_array,
269                    &y_labeled_array,
270                );
271
272                if e_jk < e_prime[i] && e_jk < self.theta {
273                    // Get predictions from classifiers j and k on unlabeled data
274                    if !unlabeled_indices.is_empty() {
275                        let X_unlabeled: Vec<Vec<f64>> = unlabeled_indices
276                            .iter()
277                            .map(|&idx| X.row(idx).to_vec())
278                            .collect();
279
280                        let X_unlabeled_array = Array2::from_shape_vec(
281                            (X_unlabeled.len(), X.ncols()),
282                            X_unlabeled.into_iter().flatten().collect(),
283                        )
284                        .map_err(|_| {
285                            SklearsError::InvalidInput(
286                                "Failed to create unlabeled data".to_string(),
287                            )
288                        })?;
289
290                        let pred_j = self.simple_classifier_fit_predict(
291                            &classifiers[j].0,
292                            &classifiers[j].1,
293                            &X_unlabeled_array,
294                        );
295                        let pred_k = self.simple_classifier_fit_predict(
296                            &classifiers[k].0,
297                            &classifiers[k].1,
298                            &X_unlabeled_array,
299                        );
300
301                        // Find samples where j and k agree
302                        let mut new_labeled_for_i = Vec::new();
303                        for (idx, (&p_j, &p_k)) in pred_j.iter().zip(pred_k.iter()).enumerate() {
304                            if p_j == p_k {
305                                let original_idx = unlabeled_indices[idx];
306                                new_labeled_for_i.push((original_idx, p_j));
307                            }
308                        }
309
310                        if !new_labeled_for_i.is_empty() {
311                            // Add agreed-upon labels
312                            for (idx, label) in new_labeled_for_i {
313                                y[idx] = label;
314                                labeled_indices.push(idx);
315                                any_changes = true;
316                            }
317
318                            // Update unlabeled indices
319                            unlabeled_indices.retain(|&idx| y[idx] == -1);
320
321                            // Retrain classifier i with new data
322                            let (new_bootstrap_X, new_bootstrap_y) =
323                                self.bootstrap_sample(&X, &y, &labeled_indices);
324                            classifiers[i] = (new_bootstrap_X, new_bootstrap_y);
325
326                            e_prime[i] = e_jk;
327                            l_prime[i] = unlabeled_indices.len();
328                        }
329                    }
330                }
331            }
332
333            if !any_changes {
334                if self.verbose {
335                    println!("Iteration {}: No changes, stopping", iter + 1);
336                }
337                break;
338            }
339
340            if self.verbose {
341                let n_labeled = labeled_indices.len();
342                let n_unlabeled = unlabeled_indices.len();
343                println!(
344                    "Iteration {}: {} labeled, {} unlabeled",
345                    iter + 1,
346                    n_labeled,
347                    n_unlabeled
348                );
349            }
350
351            if unlabeled_indices.is_empty() {
352                if self.verbose {
353                    println!("All samples labeled, stopping");
354                }
355                break;
356            }
357        }
358
359        Ok(TriTraining {
360            state: TriTrainingTrained {
361                X_train: X.clone(),
362                y_train: y,
363                classes: Array1::from(classes),
364                classifiers,
365            },
366            max_iter: self.max_iter,
367            verbose: self.verbose,
368            theta: self.theta,
369        })
370    }
371}
372
373impl Predict<ArrayView2<'_, Float>, Array1<i32>> for TriTraining<TriTrainingTrained> {
374    #[allow(non_snake_case)]
375    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
376        let X = X.to_owned();
377        let n_test = X.nrows();
378        let mut predictions = Array1::zeros(n_test);
379
380        // Use ensemble prediction from all three classifiers
381        for i in 0..n_test {
382            let test_sample = X
383                .row(i)
384                .to_owned()
385                .insert_axis(scirs2_core::ndarray::Axis(0));
386            let mut votes: HashMap<i32, usize> = HashMap::new();
387
388            // Get prediction from each classifier
389            for (classifier_X, classifier_y) in &self.state.classifiers {
390                let pred = TriTraining::<TriTrainingTrained>::simple_classifier_fit_predict_static(
391                    classifier_X,
392                    classifier_y,
393                    &test_sample,
394                );
395                *votes.entry(pred[0]).or_insert(0) += 1;
396            }
397
398            // Majority vote
399            let best_class = votes
400                .iter()
401                .max_by_key(|(_, &count)| count)
402                .map(|(&class, _)| class)
403                .unwrap_or(self.state.classes[0]);
404
405            predictions[i] = best_class;
406        }
407
408        Ok(predictions)
409    }
410}
411
412impl TriTraining<TriTrainingTrained> {
413    /// Static version of the classifier for predictions
414    fn simple_classifier_fit_predict_static(
415        X_train: &Array2<f64>,
416        y_train: &Array1<i32>,
417        X_test: &Array2<f64>,
418    ) -> Array1<i32> {
419        let n_test = X_test.nrows();
420        let mut predictions = Array1::zeros(n_test);
421
422        for i in 0..n_test {
423            // Simple k-NN classifier
424            let mut distances: Vec<(f64, i32)> = Vec::new();
425            for j in 0..X_train.nrows() {
426                let diff = &X_test.row(i) - &X_train.row(j);
427                let dist = diff.mapv(|x| x * x).sum().sqrt();
428                distances.push((dist, y_train[j]));
429            }
430
431            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
432
433            // Use k=5 nearest neighbors with majority vote
434            let k = distances.len().clamp(1, 5);
435            let mut class_votes: HashMap<i32, usize> = HashMap::new();
436
437            for &(_, label) in distances.iter().take(k) {
438                *class_votes.entry(label).or_insert(0) += 1;
439            }
440
441            let best_class = class_votes
442                .iter()
443                .max_by_key(|(_, &count)| count)
444                .map(|(&class, _)| class)
445                .unwrap_or(y_train[0]);
446
447            predictions[i] = best_class;
448        }
449
450        predictions
451    }
452}
453
454/// Trained state for TriTraining
455#[derive(Debug, Clone)]
456pub struct TriTrainingTrained {
457    /// X_train
458    pub X_train: Array2<f64>,
459    /// y_train
460    pub y_train: Array1<i32>,
461    /// classes
462    pub classes: Array1<i32>,
463    /// classifiers
464    pub classifiers: Vec<(Array2<f64>, Array1<i32>)>,
465}