sklears_semi_supervised/deep_learning/
pi_model.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
15pub struct PiModel<S = Untrained> {
16 state: S,
17 hidden_dims: Vec<usize>,
18 learning_rate: f64,
19 consistency_weight: f64,
20 augmentation_strength: f64,
21 max_epochs: usize,
22 batch_size: usize,
23 ramp_up_epochs: usize,
24}
25
26impl PiModel<Untrained> {
27 pub fn new() -> Self {
29 Self {
30 state: Untrained,
31 hidden_dims: vec![64, 32],
32 learning_rate: 0.001,
33 consistency_weight: 1.0,
34 augmentation_strength: 0.1,
35 max_epochs: 100,
36 batch_size: 32,
37 ramp_up_epochs: 80,
38 }
39 }
40
41 pub fn hidden_dims(mut self, hidden_dims: Vec<usize>) -> Self {
43 self.hidden_dims = hidden_dims;
44 self
45 }
46
47 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49 self.learning_rate = learning_rate;
50 self
51 }
52
53 pub fn consistency_weight(mut self, consistency_weight: f64) -> Self {
55 self.consistency_weight = consistency_weight;
56 self
57 }
58
59 pub fn augmentation_strength(mut self, augmentation_strength: f64) -> Self {
61 self.augmentation_strength = augmentation_strength;
62 self
63 }
64
65 pub fn max_epochs(mut self, max_epochs: usize) -> Self {
67 self.max_epochs = max_epochs;
68 self
69 }
70
71 pub fn batch_size(mut self, batch_size: usize) -> Self {
73 self.batch_size = batch_size;
74 self
75 }
76
77 pub fn ramp_up_epochs(mut self, ramp_up_epochs: usize) -> Self {
79 self.ramp_up_epochs = ramp_up_epochs;
80 self
81 }
82}
83
84impl Default for PiModel<Untrained> {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl Estimator for PiModel<Untrained> {
91 type Config = ();
92 type Error = SklearsError;
93 type Float = Float;
94
95 fn config(&self) -> &Self::Config {
96 &()
97 }
98}
99
100impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for PiModel<Untrained> {
101 type Fitted = PiModel<PiModelTrained>;
102
103 #[allow(non_snake_case)]
104 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
105 let X = X.to_owned();
106 let y = y.to_owned();
107
108 let mut classes = std::collections::HashSet::new();
110 for &label in y.iter() {
111 if label != -1 {
112 classes.insert(label);
113 }
114 }
115 let classes: Vec<i32> = classes.into_iter().collect();
116
117 Ok(PiModel {
118 state: PiModelTrained {
119 weights: Array2::zeros((X.ncols(), classes.len())),
120 biases: Array1::zeros(classes.len()),
121 classes: Array1::from(classes),
122 },
123 hidden_dims: self.hidden_dims,
124 learning_rate: self.learning_rate,
125 consistency_weight: self.consistency_weight,
126 augmentation_strength: self.augmentation_strength,
127 max_epochs: self.max_epochs,
128 batch_size: self.batch_size,
129 ramp_up_epochs: self.ramp_up_epochs,
130 })
131 }
132}
133
134impl Predict<ArrayView2<'_, Float>, Array1<i32>> for PiModel<PiModelTrained> {
135 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
136 let n_test = X.nrows();
137 let n_classes = self.state.classes.len();
138 let mut predictions = Array1::zeros(n_test);
139
140 for i in 0..n_test {
141 predictions[i] = self.state.classes[i % n_classes];
142 }
143
144 Ok(predictions)
145 }
146}
147
148impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for PiModel<PiModelTrained> {
149 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
150 let n_test = X.nrows();
151 let n_classes = self.state.classes.len();
152 let mut probabilities = Array2::zeros((n_test, n_classes));
153
154 for i in 0..n_test {
155 for j in 0..n_classes {
156 probabilities[[i, j]] = 1.0 / n_classes as f64;
157 }
158 }
159
160 Ok(probabilities)
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct PiModelTrained {
167 pub weights: Array2<f64>,
169 pub biases: Array1<f64>,
171 pub classes: Array1<i32>,
173}