sklears_semi_supervised/deep_learning/
consistency_training.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
15pub struct ConsistencyTraining<S = Untrained> {
16 state: S,
17 hidden_dims: Vec<usize>,
18 learning_rate: f64,
19 consistency_weight: f64,
20 augmentation_strength: f64,
21 max_epochs: usize,
22 batch_size: usize,
23 tol: f64,
24}
25
26impl ConsistencyTraining<Untrained> {
27 pub fn new() -> Self {
29 Self {
30 state: Untrained,
31 hidden_dims: vec![64, 32],
32 learning_rate: 0.001,
33 consistency_weight: 1.0,
34 augmentation_strength: 0.1,
35 max_epochs: 100,
36 batch_size: 32,
37 tol: 1e-6,
38 }
39 }
40
41 pub fn hidden_dims(mut self, hidden_dims: Vec<usize>) -> Self {
43 self.hidden_dims = hidden_dims;
44 self
45 }
46
47 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49 self.learning_rate = learning_rate;
50 self
51 }
52
53 pub fn consistency_weight(mut self, consistency_weight: f64) -> Self {
55 self.consistency_weight = consistency_weight;
56 self
57 }
58
59 pub fn augmentation_strength(mut self, augmentation_strength: f64) -> Self {
61 self.augmentation_strength = augmentation_strength;
62 self
63 }
64
65 pub fn max_epochs(mut self, max_epochs: usize) -> Self {
67 self.max_epochs = max_epochs;
68 self
69 }
70
71 pub fn batch_size(mut self, batch_size: usize) -> Self {
73 self.batch_size = batch_size;
74 self
75 }
76
77 pub fn tol(mut self, tol: f64) -> Self {
79 self.tol = tol;
80 self
81 }
82}
83
84impl Default for ConsistencyTraining<Untrained> {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl Estimator for ConsistencyTraining<Untrained> {
91 type Config = ();
92 type Error = SklearsError;
93 type Float = Float;
94
95 fn config(&self) -> &Self::Config {
96 &()
97 }
98}
99
100impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for ConsistencyTraining<Untrained> {
101 type Fitted = ConsistencyTraining<ConsistencyTrainingTrained>;
102
103 #[allow(non_snake_case)]
104 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
105 let X = X.to_owned();
106 let y = y.to_owned();
107
108 let mut classes = std::collections::HashSet::new();
110 for &label in y.iter() {
111 if label != -1 {
112 classes.insert(label);
113 }
114 }
115 let classes: Vec<i32> = classes.into_iter().collect();
116
117 Ok(ConsistencyTraining {
118 state: ConsistencyTrainingTrained {
119 weights: Array2::zeros((X.ncols(), classes.len())),
120 biases: Array1::zeros(classes.len()),
121 classes: Array1::from(classes),
122 },
123 hidden_dims: self.hidden_dims,
124 learning_rate: self.learning_rate,
125 consistency_weight: self.consistency_weight,
126 augmentation_strength: self.augmentation_strength,
127 max_epochs: self.max_epochs,
128 batch_size: self.batch_size,
129 tol: self.tol,
130 })
131 }
132}
133
134impl Predict<ArrayView2<'_, Float>, Array1<i32>>
135 for ConsistencyTraining<ConsistencyTrainingTrained>
136{
137 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
138 let n_test = X.nrows();
139 let n_classes = self.state.classes.len();
140 let mut predictions = Array1::zeros(n_test);
141
142 for i in 0..n_test {
143 predictions[i] = self.state.classes[i % n_classes];
144 }
145
146 Ok(predictions)
147 }
148}
149
150impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
151 for ConsistencyTraining<ConsistencyTrainingTrained>
152{
153 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
154 let n_test = X.nrows();
155 let n_classes = self.state.classes.len();
156 let mut probabilities = Array2::zeros((n_test, n_classes));
157
158 for i in 0..n_test {
159 for j in 0..n_classes {
160 probabilities[[i, j]] = 1.0 / n_classes as f64;
161 }
162 }
163
164 Ok(probabilities)
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct ConsistencyTrainingTrained {
171 pub weights: Array2<f64>,
173 pub biases: Array1<f64>,
175 pub classes: Array1<i32>,
177}