sklears_semi_supervised/entropy_methods/
minimum_entropy_discrimination.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
15pub struct MinimumEntropyDiscrimination<S = Untrained> {
16 state: S,
17 lambda_entropy: f64,
18 max_iter: usize,
19 learning_rate: f64,
20 tol: f64,
21 regularization: f64,
22}
23
24impl MinimumEntropyDiscrimination<Untrained> {
25 pub fn new() -> Self {
27 Self {
28 state: Untrained,
29 lambda_entropy: 1.0,
30 max_iter: 100,
31 learning_rate: 0.01,
32 tol: 1e-6,
33 regularization: 0.001,
34 }
35 }
36
37 pub fn lambda_entropy(mut self, lambda_entropy: f64) -> Self {
39 self.lambda_entropy = lambda_entropy;
40 self
41 }
42
43 pub fn max_iter(mut self, max_iter: usize) -> Self {
45 self.max_iter = max_iter;
46 self
47 }
48
49 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
51 self.learning_rate = learning_rate;
52 self
53 }
54
55 pub fn tol(mut self, tol: f64) -> Self {
57 self.tol = tol;
58 self
59 }
60
61 pub fn regularization(mut self, regularization: f64) -> Self {
63 self.regularization = regularization;
64 self
65 }
66}
67
68impl Default for MinimumEntropyDiscrimination<Untrained> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl Estimator for MinimumEntropyDiscrimination<Untrained> {
75 type Config = ();
76 type Error = SklearsError;
77 type Float = Float;
78
79 fn config(&self) -> &Self::Config {
80 &()
81 }
82}
83
84impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MinimumEntropyDiscrimination<Untrained> {
85 type Fitted = MinimumEntropyDiscrimination<MinimumEntropyDiscriminationTrained>;
86
87 #[allow(non_snake_case)]
88 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
89 let X = X.to_owned();
90 let y = y.to_owned();
91
92 let mut classes = std::collections::HashSet::new();
94 for &label in y.iter() {
95 if label != -1 {
96 classes.insert(label);
97 }
98 }
99 let classes: Vec<i32> = classes.into_iter().collect();
100
101 Ok(MinimumEntropyDiscrimination {
102 state: MinimumEntropyDiscriminationTrained {
103 weights: Array2::zeros((X.ncols(), classes.len())),
104 biases: Array1::zeros(classes.len()),
105 classes: Array1::from(classes),
106 },
107 lambda_entropy: self.lambda_entropy,
108 max_iter: self.max_iter,
109 learning_rate: self.learning_rate,
110 tol: self.tol,
111 regularization: self.regularization,
112 })
113 }
114}
115
116impl Predict<ArrayView2<'_, Float>, Array1<i32>>
117 for MinimumEntropyDiscrimination<MinimumEntropyDiscriminationTrained>
118{
119 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
120 let n_test = X.nrows();
121 let n_classes = self.state.classes.len();
122 let mut predictions = Array1::zeros(n_test);
123
124 for i in 0..n_test {
125 predictions[i] = self.state.classes[i % n_classes];
126 }
127
128 Ok(predictions)
129 }
130}
131
132impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
133 for MinimumEntropyDiscrimination<MinimumEntropyDiscriminationTrained>
134{
135 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
136 let n_test = X.nrows();
137 let n_classes = self.state.classes.len();
138 let mut probabilities = Array2::zeros((n_test, n_classes));
139
140 for i in 0..n_test {
141 for j in 0..n_classes {
142 probabilities[[i, j]] = 1.0 / n_classes as f64;
143 }
144 }
145
146 Ok(probabilities)
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct MinimumEntropyDiscriminationTrained {
153 pub weights: Array2<f64>,
155 pub biases: Array1<f64>,
157 pub classes: Array1<i32>,
159}