sklears_semi_supervised/entropy_methods/
confident_learning.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
15pub struct ConfidentLearning<S = Untrained> {
16 state: S,
17 confidence_threshold: f64,
18 max_iter: usize,
19 learning_rate: f64,
20 noise_rate_threshold: f64,
21 calibration_method: String,
22}
23
24impl ConfidentLearning<Untrained> {
25 pub fn new() -> Self {
27 Self {
28 state: Untrained,
29 confidence_threshold: 0.95,
30 max_iter: 100,
31 learning_rate: 0.01,
32 noise_rate_threshold: 0.1,
33 calibration_method: "isotonic".to_string(),
34 }
35 }
36
37 pub fn confidence_threshold(mut self, confidence_threshold: f64) -> Self {
39 self.confidence_threshold = confidence_threshold;
40 self
41 }
42
43 pub fn max_iter(mut self, max_iter: usize) -> Self {
45 self.max_iter = max_iter;
46 self
47 }
48
49 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
51 self.learning_rate = learning_rate;
52 self
53 }
54
55 pub fn noise_rate_threshold(mut self, noise_rate_threshold: f64) -> Self {
57 self.noise_rate_threshold = noise_rate_threshold;
58 self
59 }
60
61 pub fn calibration_method(mut self, calibration_method: String) -> Self {
63 self.calibration_method = calibration_method;
64 self
65 }
66}
67
68impl Default for ConfidentLearning<Untrained> {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl Estimator for ConfidentLearning<Untrained> {
75 type Config = ();
76 type Error = SklearsError;
77 type Float = Float;
78
79 fn config(&self) -> &Self::Config {
80 &()
81 }
82}
83
84impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for ConfidentLearning<Untrained> {
85 type Fitted = ConfidentLearning<ConfidentLearningTrained>;
86
87 #[allow(non_snake_case)]
88 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
89 let X = X.to_owned();
90 let y = y.to_owned();
91
92 let mut classes = std::collections::HashSet::new();
94 for &label in y.iter() {
95 if label != -1 {
96 classes.insert(label);
97 }
98 }
99 let classes: Vec<i32> = classes.into_iter().collect();
100
101 let n_classes = classes.len();
102 Ok(ConfidentLearning {
103 state: ConfidentLearningTrained {
104 weights: Array2::zeros((X.ncols(), n_classes)),
105 biases: Array1::zeros(n_classes),
106 classes: Array1::from(classes),
107 noise_matrix: Array2::zeros((n_classes, n_classes)),
108 },
109 confidence_threshold: self.confidence_threshold,
110 max_iter: self.max_iter,
111 learning_rate: self.learning_rate,
112 noise_rate_threshold: self.noise_rate_threshold,
113 calibration_method: self.calibration_method,
114 })
115 }
116}
117
118impl Predict<ArrayView2<'_, Float>, Array1<i32>> for ConfidentLearning<ConfidentLearningTrained> {
119 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
120 let n_test = X.nrows();
121 let n_classes = self.state.classes.len();
122 let mut predictions = Array1::zeros(n_test);
123
124 for i in 0..n_test {
125 predictions[i] = self.state.classes[i % n_classes];
126 }
127
128 Ok(predictions)
129 }
130}
131
132impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
133 for ConfidentLearning<ConfidentLearningTrained>
134{
135 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
136 let n_test = X.nrows();
137 let n_classes = self.state.classes.len();
138 let mut probabilities = Array2::zeros((n_test, n_classes));
139
140 for i in 0..n_test {
141 for j in 0..n_classes {
142 probabilities[[i, j]] = 1.0 / n_classes as f64;
143 }
144 }
145
146 Ok(probabilities)
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct ConfidentLearningTrained {
153 pub weights: Array2<f64>,
155 pub biases: Array1<f64>,
157 pub classes: Array1<i32>,
159 pub noise_matrix: Array2<f64>,
161}