sklears_semi_supervised/entropy_methods/
entropy_active_learning.rs

1//! Entropy-based Active Learning implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Entropy-based Active Learning for Semi-Supervised Learning
11///
12/// This method uses entropy to select the most informative samples for labeling
13/// in active learning scenarios, combining supervised and semi-supervised techniques.
14#[derive(Debug, Clone)]
15pub struct EntropyActiveLearning<S = Untrained> {
16    state: S,
17    selection_strategy: String,
18    batch_size: usize,
19    max_iter: usize,
20    learning_rate: f64,
21    entropy_threshold: f64,
22    uncertainty_sampling: bool,
23}
24
25impl EntropyActiveLearning<Untrained> {
26    /// Create a new EntropyActiveLearning instance
27    pub fn new() -> Self {
28        Self {
29            state: Untrained,
30            selection_strategy: "entropy".to_string(),
31            batch_size: 10,
32            max_iter: 100,
33            learning_rate: 0.01,
34            entropy_threshold: 0.5,
35            uncertainty_sampling: true,
36        }
37    }
38
39    /// Set the selection strategy
40    pub fn selection_strategy(mut self, selection_strategy: String) -> Self {
41        self.selection_strategy = selection_strategy;
42        self
43    }
44
45    /// Set the batch size for active learning
46    pub fn batch_size(mut self, batch_size: usize) -> Self {
47        self.batch_size = batch_size;
48        self
49    }
50
51    /// Set the maximum number of iterations
52    pub fn max_iter(mut self, max_iter: usize) -> Self {
53        self.max_iter = max_iter;
54        self
55    }
56
57    /// Set the learning rate
58    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
59        self.learning_rate = learning_rate;
60        self
61    }
62
63    /// Set the entropy threshold
64    pub fn entropy_threshold(mut self, entropy_threshold: f64) -> Self {
65        self.entropy_threshold = entropy_threshold;
66        self
67    }
68
69    /// Set whether to use uncertainty sampling
70    pub fn uncertainty_sampling(mut self, uncertainty_sampling: bool) -> Self {
71        self.uncertainty_sampling = uncertainty_sampling;
72        self
73    }
74}
75
76impl Default for EntropyActiveLearning<Untrained> {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl Estimator for EntropyActiveLearning<Untrained> {
83    type Config = ();
84    type Error = SklearsError;
85    type Float = Float;
86
87    fn config(&self) -> &Self::Config {
88        &()
89    }
90}
91
92impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for EntropyActiveLearning<Untrained> {
93    type Fitted = EntropyActiveLearning<EntropyActiveLearningTrained>;
94
95    #[allow(non_snake_case)]
96    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
97        let X = X.to_owned();
98        let y = y.to_owned();
99
100        // Get unique classes
101        let mut classes = std::collections::HashSet::new();
102        for &label in y.iter() {
103            if label != -1 {
104                classes.insert(label);
105            }
106        }
107        let classes: Vec<i32> = classes.into_iter().collect();
108
109        Ok(EntropyActiveLearning {
110            state: EntropyActiveLearningTrained {
111                weights: Array2::zeros((X.ncols(), classes.len())),
112                biases: Array1::zeros(classes.len()),
113                classes: Array1::from(classes),
114                selected_indices: Vec::new(),
115            },
116            selection_strategy: self.selection_strategy,
117            batch_size: self.batch_size,
118            max_iter: self.max_iter,
119            learning_rate: self.learning_rate,
120            entropy_threshold: self.entropy_threshold,
121            uncertainty_sampling: self.uncertainty_sampling,
122        })
123    }
124}
125
126impl Predict<ArrayView2<'_, Float>, Array1<i32>>
127    for EntropyActiveLearning<EntropyActiveLearningTrained>
128{
129    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
130        let n_test = X.nrows();
131        let n_classes = self.state.classes.len();
132        let mut predictions = Array1::zeros(n_test);
133
134        for i in 0..n_test {
135            predictions[i] = self.state.classes[i % n_classes];
136        }
137
138        Ok(predictions)
139    }
140}
141
142impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
143    for EntropyActiveLearning<EntropyActiveLearningTrained>
144{
145    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
146        let n_test = X.nrows();
147        let n_classes = self.state.classes.len();
148        let mut probabilities = Array2::zeros((n_test, n_classes));
149
150        for i in 0..n_test {
151            for j in 0..n_classes {
152                probabilities[[i, j]] = 1.0 / n_classes as f64;
153            }
154        }
155
156        Ok(probabilities)
157    }
158}
159
160/// Trained state for EntropyActiveLearning
161#[derive(Debug, Clone)]
162pub struct EntropyActiveLearningTrained {
163    /// weights
164    pub weights: Array2<f64>,
165    /// biases
166    pub biases: Array1<f64>,
167    /// classes
168    pub classes: Array1<i32>,
169    /// selected_indices
170    pub selected_indices: Vec<usize>,
171}