sklears_semi_supervised/deep_learning/
consistency_training.rs

1//! Consistency Training implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Consistency Training for Semi-Supervised Learning
11///
12/// Consistency training uses data augmentation and encourages the model
13/// to produce consistent predictions for augmented versions of the same input.
14#[derive(Debug, Clone)]
15pub struct ConsistencyTraining<S = Untrained> {
16    state: S,
17    hidden_dims: Vec<usize>,
18    learning_rate: f64,
19    consistency_weight: f64,
20    augmentation_strength: f64,
21    max_epochs: usize,
22    batch_size: usize,
23    tol: f64,
24}
25
26impl ConsistencyTraining<Untrained> {
27    /// Create a new ConsistencyTraining instance
28    pub fn new() -> Self {
29        Self {
30            state: Untrained,
31            hidden_dims: vec![64, 32],
32            learning_rate: 0.001,
33            consistency_weight: 1.0,
34            augmentation_strength: 0.1,
35            max_epochs: 100,
36            batch_size: 32,
37            tol: 1e-6,
38        }
39    }
40
41    /// Set the hidden layer dimensions
42    pub fn hidden_dims(mut self, hidden_dims: Vec<usize>) -> Self {
43        self.hidden_dims = hidden_dims;
44        self
45    }
46
47    /// Set the learning rate
48    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49        self.learning_rate = learning_rate;
50        self
51    }
52
53    /// Set the consistency loss weight
54    pub fn consistency_weight(mut self, consistency_weight: f64) -> Self {
55        self.consistency_weight = consistency_weight;
56        self
57    }
58
59    /// Set the data augmentation strength
60    pub fn augmentation_strength(mut self, augmentation_strength: f64) -> Self {
61        self.augmentation_strength = augmentation_strength;
62        self
63    }
64
65    /// Set the maximum number of epochs
66    pub fn max_epochs(mut self, max_epochs: usize) -> Self {
67        self.max_epochs = max_epochs;
68        self
69    }
70
71    /// Set the batch size
72    pub fn batch_size(mut self, batch_size: usize) -> Self {
73        self.batch_size = batch_size;
74        self
75    }
76
77    /// Set the convergence tolerance
78    pub fn tol(mut self, tol: f64) -> Self {
79        self.tol = tol;
80        self
81    }
82}
83
84impl Default for ConsistencyTraining<Untrained> {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Estimator for ConsistencyTraining<Untrained> {
91    type Config = ();
92    type Error = SklearsError;
93    type Float = Float;
94
95    fn config(&self) -> &Self::Config {
96        &()
97    }
98}
99
100impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for ConsistencyTraining<Untrained> {
101    type Fitted = ConsistencyTraining<ConsistencyTrainingTrained>;
102
103    #[allow(non_snake_case)]
104    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
105        let X = X.to_owned();
106        let y = y.to_owned();
107
108        // Get unique classes
109        let mut classes = std::collections::HashSet::new();
110        for &label in y.iter() {
111            if label != -1 {
112                classes.insert(label);
113            }
114        }
115        let classes: Vec<i32> = classes.into_iter().collect();
116
117        Ok(ConsistencyTraining {
118            state: ConsistencyTrainingTrained {
119                weights: Array2::zeros((X.ncols(), classes.len())),
120                biases: Array1::zeros(classes.len()),
121                classes: Array1::from(classes),
122            },
123            hidden_dims: self.hidden_dims,
124            learning_rate: self.learning_rate,
125            consistency_weight: self.consistency_weight,
126            augmentation_strength: self.augmentation_strength,
127            max_epochs: self.max_epochs,
128            batch_size: self.batch_size,
129            tol: self.tol,
130        })
131    }
132}
133
134impl Predict<ArrayView2<'_, Float>, Array1<i32>>
135    for ConsistencyTraining<ConsistencyTrainingTrained>
136{
137    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
138        let n_test = X.nrows();
139        let n_classes = self.state.classes.len();
140        let mut predictions = Array1::zeros(n_test);
141
142        for i in 0..n_test {
143            predictions[i] = self.state.classes[i % n_classes];
144        }
145
146        Ok(predictions)
147    }
148}
149
150impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
151    for ConsistencyTraining<ConsistencyTrainingTrained>
152{
153    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
154        let n_test = X.nrows();
155        let n_classes = self.state.classes.len();
156        let mut probabilities = Array2::zeros((n_test, n_classes));
157
158        for i in 0..n_test {
159            for j in 0..n_classes {
160                probabilities[[i, j]] = 1.0 / n_classes as f64;
161            }
162        }
163
164        Ok(probabilities)
165    }
166}
167
168/// Trained state for ConsistencyTraining
169#[derive(Debug, Clone)]
170pub struct ConsistencyTrainingTrained {
171    /// weights
172    pub weights: Array2<f64>,
173    /// biases
174    pub biases: Array1<f64>,
175    /// classes
176    pub classes: Array1<i32>,
177}