sklears_semi_supervised/entropy_methods/
entropy_regularization.rs

1//! Entropy Regularization implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Entropy Regularization for Semi-Supervised Learning
11///
12/// Entropy regularization encourages confident predictions on unlabeled data
13/// by minimizing the entropy of the predicted probability distributions.
14#[derive(Debug, Clone)]
15pub struct EntropyRegularization<S = Untrained> {
16    state: S,
17    entropy_weight: f64,
18    max_iter: usize,
19    learning_rate: f64,
20    tol: f64,
21}
22
23impl EntropyRegularization<Untrained> {
24    /// Create a new EntropyRegularization instance
25    pub fn new() -> Self {
26        Self {
27            state: Untrained,
28            entropy_weight: 1.0,
29            max_iter: 100,
30            learning_rate: 0.01,
31            tol: 1e-6,
32        }
33    }
34
35    /// Set the entropy regularization weight
36    pub fn entropy_weight(mut self, entropy_weight: f64) -> Self {
37        self.entropy_weight = entropy_weight;
38        self
39    }
40
41    /// Set the maximum number of iterations
42    pub fn max_iter(mut self, max_iter: usize) -> Self {
43        self.max_iter = max_iter;
44        self
45    }
46
47    /// Set the learning rate
48    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
49        self.learning_rate = learning_rate;
50        self
51    }
52
53    /// Set the convergence tolerance
54    pub fn tol(mut self, tol: f64) -> Self {
55        self.tol = tol;
56        self
57    }
58}
59
60impl Default for EntropyRegularization<Untrained> {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl Estimator for EntropyRegularization<Untrained> {
67    type Config = ();
68    type Error = SklearsError;
69    type Float = Float;
70
71    fn config(&self) -> &Self::Config {
72        &()
73    }
74}
75
76impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for EntropyRegularization<Untrained> {
77    type Fitted = EntropyRegularization<EntropyRegularizationTrained>;
78
79    #[allow(non_snake_case)]
80    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
81        let X = X.to_owned();
82        let y = y.to_owned();
83
84        // Get unique classes
85        let mut classes = std::collections::HashSet::new();
86        for &label in y.iter() {
87            if label != -1 {
88                classes.insert(label);
89            }
90        }
91        let classes: Vec<i32> = classes.into_iter().collect();
92
93        Ok(EntropyRegularization {
94            state: EntropyRegularizationTrained {
95                weights: Array2::zeros((X.ncols(), classes.len())),
96                biases: Array1::zeros(classes.len()),
97                classes: Array1::from(classes),
98            },
99            entropy_weight: self.entropy_weight,
100            max_iter: self.max_iter,
101            learning_rate: self.learning_rate,
102            tol: self.tol,
103        })
104    }
105}
106
107impl Predict<ArrayView2<'_, Float>, Array1<i32>>
108    for EntropyRegularization<EntropyRegularizationTrained>
109{
110    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
111        let n_test = X.nrows();
112        let n_classes = self.state.classes.len();
113        let mut predictions = Array1::zeros(n_test);
114
115        for i in 0..n_test {
116            predictions[i] = self.state.classes[i % n_classes];
117        }
118
119        Ok(predictions)
120    }
121}
122
123impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
124    for EntropyRegularization<EntropyRegularizationTrained>
125{
126    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
127        let n_test = X.nrows();
128        let n_classes = self.state.classes.len();
129        let mut probabilities = Array2::zeros((n_test, n_classes));
130
131        for i in 0..n_test {
132            for j in 0..n_classes {
133                probabilities[[i, j]] = 1.0 / n_classes as f64;
134            }
135        }
136
137        Ok(probabilities)
138    }
139}
140
141/// Trained state for EntropyRegularization
142#[derive(Debug, Clone)]
143pub struct EntropyRegularizationTrained {
144    /// weights
145    pub weights: Array2<f64>,
146    /// biases
147    pub biases: Array1<f64>,
148    /// classes
149    pub classes: Array1<i32>,
150}