sklears_semi_supervised/batch_active_learning/
gradient_embedding.rs

1//! Gradient Embedding Methods implementation for batch active learning
2
3use super::{BatchActiveLearningError, *};
4
5/// Gradient Embedding Methods for active learning
6///
7/// This method uses gradient information from the model to select informative batches.
8/// It considers both the gradients with respect to model parameters and embedding
9/// representations to identify samples that would provide maximum learning benefit.
10#[derive(Debug, Clone)]
11pub struct GradientEmbeddingMethods {
12    /// batch_size
13    pub batch_size: usize,
14    /// embedding_dim
15    pub embedding_dim: usize,
16    /// gradient_method
17    pub gradient_method: String,
18    /// similarity_threshold
19    pub similarity_threshold: f64,
20    /// learning_rate
21    pub learning_rate: f64,
22    /// max_iter
23    pub max_iter: usize,
24    /// random_state
25    pub random_state: Option<u64>,
26}
27
28impl Default for GradientEmbeddingMethods {
29    fn default() -> Self {
30        Self {
31            batch_size: 10,
32            embedding_dim: 128,
33            gradient_method: "gradnorm".to_string(),
34            similarity_threshold: 0.8,
35            learning_rate: 0.001,
36            max_iter: 100,
37            random_state: None,
38        }
39    }
40}
41
42impl GradientEmbeddingMethods {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
48        if batch_size == 0 {
49            return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
50        }
51        self.batch_size = batch_size;
52        Ok(self)
53    }
54
55    pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
56        self.embedding_dim = embedding_dim;
57        self
58    }
59
60    pub fn gradient_method(mut self, gradient_method: String) -> Self {
61        self.gradient_method = gradient_method;
62        self
63    }
64
65    pub fn similarity_threshold(mut self, similarity_threshold: f64) -> Self {
66        self.similarity_threshold = similarity_threshold;
67        self
68    }
69
70    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
71        self.learning_rate = learning_rate;
72        self
73    }
74
75    pub fn max_iter(mut self, max_iter: usize) -> Self {
76        self.max_iter = max_iter;
77        self
78    }
79
80    pub fn random_state(mut self, random_state: u64) -> Self {
81        self.random_state = Some(random_state);
82        self
83    }
84
85    pub fn query(
86        &self,
87        X: &ArrayView2<f64>,
88        probabilities: &ArrayView2<f64>,
89    ) -> Result<Vec<usize>> {
90        let n_samples = X.dim().0;
91
92        if n_samples < self.batch_size {
93            return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
94        }
95
96        // Simple placeholder implementation - select samples with highest uncertainty
97        let mut uncertainty_scores = Vec::new();
98        for i in 0..n_samples {
99            let mut entropy = 0.0;
100            for prob in probabilities.row(i) {
101                if *prob > 0.0 {
102                    entropy -= prob * prob.ln();
103                }
104            }
105            uncertainty_scores.push((i, entropy));
106        }
107
108        // Sort by uncertainty and select top samples
109        uncertainty_scores
110            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111
112        let selected: Vec<usize> = uncertainty_scores
113            .into_iter()
114            .take(self.batch_size)
115            .map(|(idx, _)| idx)
116            .collect();
117
118        Ok(selected)
119    }
120}