sklears_semi_supervised/entropy_methods/
entropy_active_learning.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
15pub struct EntropyActiveLearning<S = Untrained> {
16 state: S,
17 selection_strategy: String,
18 batch_size: usize,
19 max_iter: usize,
20 learning_rate: f64,
21 entropy_threshold: f64,
22 uncertainty_sampling: bool,
23}
24
25impl EntropyActiveLearning<Untrained> {
26 pub fn new() -> Self {
28 Self {
29 state: Untrained,
30 selection_strategy: "entropy".to_string(),
31 batch_size: 10,
32 max_iter: 100,
33 learning_rate: 0.01,
34 entropy_threshold: 0.5,
35 uncertainty_sampling: true,
36 }
37 }
38
39 pub fn selection_strategy(mut self, selection_strategy: String) -> Self {
41 self.selection_strategy = selection_strategy;
42 self
43 }
44
45 pub fn batch_size(mut self, batch_size: usize) -> Self {
47 self.batch_size = batch_size;
48 self
49 }
50
51 pub fn max_iter(mut self, max_iter: usize) -> Self {
53 self.max_iter = max_iter;
54 self
55 }
56
57 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
59 self.learning_rate = learning_rate;
60 self
61 }
62
63 pub fn entropy_threshold(mut self, entropy_threshold: f64) -> Self {
65 self.entropy_threshold = entropy_threshold;
66 self
67 }
68
69 pub fn uncertainty_sampling(mut self, uncertainty_sampling: bool) -> Self {
71 self.uncertainty_sampling = uncertainty_sampling;
72 self
73 }
74}
75
76impl Default for EntropyActiveLearning<Untrained> {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl Estimator for EntropyActiveLearning<Untrained> {
83 type Config = ();
84 type Error = SklearsError;
85 type Float = Float;
86
87 fn config(&self) -> &Self::Config {
88 &()
89 }
90}
91
92impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for EntropyActiveLearning<Untrained> {
93 type Fitted = EntropyActiveLearning<EntropyActiveLearningTrained>;
94
95 #[allow(non_snake_case)]
96 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
97 let X = X.to_owned();
98 let y = y.to_owned();
99
100 let mut classes = std::collections::HashSet::new();
102 for &label in y.iter() {
103 if label != -1 {
104 classes.insert(label);
105 }
106 }
107 let classes: Vec<i32> = classes.into_iter().collect();
108
109 Ok(EntropyActiveLearning {
110 state: EntropyActiveLearningTrained {
111 weights: Array2::zeros((X.ncols(), classes.len())),
112 biases: Array1::zeros(classes.len()),
113 classes: Array1::from(classes),
114 selected_indices: Vec::new(),
115 },
116 selection_strategy: self.selection_strategy,
117 batch_size: self.batch_size,
118 max_iter: self.max_iter,
119 learning_rate: self.learning_rate,
120 entropy_threshold: self.entropy_threshold,
121 uncertainty_sampling: self.uncertainty_sampling,
122 })
123 }
124}
125
126impl Predict<ArrayView2<'_, Float>, Array1<i32>>
127 for EntropyActiveLearning<EntropyActiveLearningTrained>
128{
129 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
130 let n_test = X.nrows();
131 let n_classes = self.state.classes.len();
132 let mut predictions = Array1::zeros(n_test);
133
134 for i in 0..n_test {
135 predictions[i] = self.state.classes[i % n_classes];
136 }
137
138 Ok(predictions)
139 }
140}
141
142impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
143 for EntropyActiveLearning<EntropyActiveLearningTrained>
144{
145 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
146 let n_test = X.nrows();
147 let n_classes = self.state.classes.len();
148 let mut probabilities = Array2::zeros((n_test, n_classes));
149
150 for i in 0..n_test {
151 for j in 0..n_classes {
152 probabilities[[i, j]] = 1.0 / n_classes as f64;
153 }
154 }
155
156 Ok(probabilities)
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct EntropyActiveLearningTrained {
163 pub weights: Array2<f64>,
165 pub biases: Array1<f64>,
167 pub classes: Array1<i32>,
169 pub selected_indices: Vec<usize>,
171}