sklears_semi_supervised/deep_learning/
pi_model.rs

1//! Pi Model implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Pi Model for Semi-Supervised Learning
11///
12/// Pi Model enforces consistency between predictions on differently
13/// augmented versions of the same input sample.
14#[derive(Debug, Clone)]
15pub struct PiModel<S = Untrained> {
16    state: S,
17    hidden_dims: Vec<usize>,
18    learning_rate: f64,
19    consistency_weight: f64,
20    augmentation_strength: f64,
21    max_epochs: usize,
22    batch_size: usize,
23    ramp_up_epochs: usize,
24}
25
26impl PiModel<Untrained> {
27    /// Create a new PiModel instance
28    pub fn new() -> Self {
29        Self {
30            state: Untrained,
31            hidden_dims: vec![64, 32],
32            learning_rate: 0.001,
33            consistency_weight: 1.0,
34            augmentation_strength: 0.1,
35            max_epochs: 100,
36            batch_size: 32,
37            ramp_up_epochs: 80,
38        }
39    }
40
41    /// Set the hidden layer dimensions
42    pub fn hidden_dims(mut self, hidden_dims: Vec<usize>) -> Self {
43        self.hidden_dims = hidden_dims;
44        self
45    }
46
47    /// Set the learning rate
48    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49        self.learning_rate = learning_rate;
50        self
51    }
52
53    /// Set the consistency loss weight
54    pub fn consistency_weight(mut self, consistency_weight: f64) -> Self {
55        self.consistency_weight = consistency_weight;
56        self
57    }
58
59    /// Set the augmentation strength
60    pub fn augmentation_strength(mut self, augmentation_strength: f64) -> Self {
61        self.augmentation_strength = augmentation_strength;
62        self
63    }
64
65    /// Set the maximum number of epochs
66    pub fn max_epochs(mut self, max_epochs: usize) -> Self {
67        self.max_epochs = max_epochs;
68        self
69    }
70
71    /// Set the batch size
72    pub fn batch_size(mut self, batch_size: usize) -> Self {
73        self.batch_size = batch_size;
74        self
75    }
76
77    /// Set the ramp-up epochs
78    pub fn ramp_up_epochs(mut self, ramp_up_epochs: usize) -> Self {
79        self.ramp_up_epochs = ramp_up_epochs;
80        self
81    }
82}
83
84impl Default for PiModel<Untrained> {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl Estimator for PiModel<Untrained> {
91    type Config = ();
92    type Error = SklearsError;
93    type Float = Float;
94
95    fn config(&self) -> &Self::Config {
96        &()
97    }
98}
99
100impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for PiModel<Untrained> {
101    type Fitted = PiModel<PiModelTrained>;
102
103    #[allow(non_snake_case)]
104    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
105        let X = X.to_owned();
106        let y = y.to_owned();
107
108        // Get unique classes
109        let mut classes = std::collections::HashSet::new();
110        for &label in y.iter() {
111            if label != -1 {
112                classes.insert(label);
113            }
114        }
115        let classes: Vec<i32> = classes.into_iter().collect();
116
117        Ok(PiModel {
118            state: PiModelTrained {
119                weights: Array2::zeros((X.ncols(), classes.len())),
120                biases: Array1::zeros(classes.len()),
121                classes: Array1::from(classes),
122            },
123            hidden_dims: self.hidden_dims,
124            learning_rate: self.learning_rate,
125            consistency_weight: self.consistency_weight,
126            augmentation_strength: self.augmentation_strength,
127            max_epochs: self.max_epochs,
128            batch_size: self.batch_size,
129            ramp_up_epochs: self.ramp_up_epochs,
130        })
131    }
132}
133
134impl Predict<ArrayView2<'_, Float>, Array1<i32>> for PiModel<PiModelTrained> {
135    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
136        let n_test = X.nrows();
137        let n_classes = self.state.classes.len();
138        let mut predictions = Array1::zeros(n_test);
139
140        for i in 0..n_test {
141            predictions[i] = self.state.classes[i % n_classes];
142        }
143
144        Ok(predictions)
145    }
146}
147
148impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for PiModel<PiModelTrained> {
149    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
150        let n_test = X.nrows();
151        let n_classes = self.state.classes.len();
152        let mut probabilities = Array2::zeros((n_test, n_classes));
153
154        for i in 0..n_test {
155            for j in 0..n_classes {
156                probabilities[[i, j]] = 1.0 / n_classes as f64;
157            }
158        }
159
160        Ok(probabilities)
161    }
162}
163
164/// Trained state for PiModel
165#[derive(Debug, Clone)]
166pub struct PiModelTrained {
167    /// weights
168    pub weights: Array2<f64>,
169    /// biases
170    pub biases: Array1<f64>,
171    /// classes
172    pub classes: Array1<i32>,
173}