sklears_semi_supervised/entropy_methods/
entropy_regularization.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
15pub struct EntropyRegularization<S = Untrained> {
16 state: S,
17 entropy_weight: f64,
18 max_iter: usize,
19 learning_rate: f64,
20 tol: f64,
21}
22
23impl EntropyRegularization<Untrained> {
24 pub fn new() -> Self {
26 Self {
27 state: Untrained,
28 entropy_weight: 1.0,
29 max_iter: 100,
30 learning_rate: 0.01,
31 tol: 1e-6,
32 }
33 }
34
35 pub fn entropy_weight(mut self, entropy_weight: f64) -> Self {
37 self.entropy_weight = entropy_weight;
38 self
39 }
40
41 pub fn max_iter(mut self, max_iter: usize) -> Self {
43 self.max_iter = max_iter;
44 self
45 }
46
47 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49 self.learning_rate = learning_rate;
50 self
51 }
52
53 pub fn tol(mut self, tol: f64) -> Self {
55 self.tol = tol;
56 self
57 }
58}
59
60impl Default for EntropyRegularization<Untrained> {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Estimator for EntropyRegularization<Untrained> {
67 type Config = ();
68 type Error = SklearsError;
69 type Float = Float;
70
71 fn config(&self) -> &Self::Config {
72 &()
73 }
74}
75
76impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for EntropyRegularization<Untrained> {
77 type Fitted = EntropyRegularization<EntropyRegularizationTrained>;
78
79 #[allow(non_snake_case)]
80 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
81 let X = X.to_owned();
82 let y = y.to_owned();
83
84 let mut classes = std::collections::HashSet::new();
86 for &label in y.iter() {
87 if label != -1 {
88 classes.insert(label);
89 }
90 }
91 let classes: Vec<i32> = classes.into_iter().collect();
92
93 Ok(EntropyRegularization {
94 state: EntropyRegularizationTrained {
95 weights: Array2::zeros((X.ncols(), classes.len())),
96 biases: Array1::zeros(classes.len()),
97 classes: Array1::from(classes),
98 },
99 entropy_weight: self.entropy_weight,
100 max_iter: self.max_iter,
101 learning_rate: self.learning_rate,
102 tol: self.tol,
103 })
104 }
105}
106
107impl Predict<ArrayView2<'_, Float>, Array1<i32>>
108 for EntropyRegularization<EntropyRegularizationTrained>
109{
110 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
111 let n_test = X.nrows();
112 let n_classes = self.state.classes.len();
113 let mut predictions = Array1::zeros(n_test);
114
115 for i in 0..n_test {
116 predictions[i] = self.state.classes[i % n_classes];
117 }
118
119 Ok(predictions)
120 }
121}
122
123impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
124 for EntropyRegularization<EntropyRegularizationTrained>
125{
126 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
127 let n_test = X.nrows();
128 let n_classes = self.state.classes.len();
129 let mut probabilities = Array2::zeros((n_test, n_classes));
130
131 for i in 0..n_test {
132 for j in 0..n_classes {
133 probabilities[[i, j]] = 1.0 / n_classes as f64;
134 }
135 }
136
137 Ok(probabilities)
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct EntropyRegularizationTrained {
144 pub weights: Array2<f64>,
146 pub biases: Array1<f64>,
148 pub classes: Array1<i32>,
150}