sklears_semi_supervised/batch_active_learning/
gradient_embedding.rs1use super::{BatchActiveLearningError, *};
4
5#[derive(Debug, Clone)]
11pub struct GradientEmbeddingMethods {
12 pub batch_size: usize,
14 pub embedding_dim: usize,
16 pub gradient_method: String,
18 pub similarity_threshold: f64,
20 pub learning_rate: f64,
22 pub max_iter: usize,
24 pub random_state: Option<u64>,
26}
27
28impl Default for GradientEmbeddingMethods {
29 fn default() -> Self {
30 Self {
31 batch_size: 10,
32 embedding_dim: 128,
33 gradient_method: "gradnorm".to_string(),
34 similarity_threshold: 0.8,
35 learning_rate: 0.001,
36 max_iter: 100,
37 random_state: None,
38 }
39 }
40}
41
42impl GradientEmbeddingMethods {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn batch_size(mut self, batch_size: usize) -> Result<Self> {
48 if batch_size == 0 {
49 return Err(BatchActiveLearningError::InvalidBatchSize(batch_size).into());
50 }
51 self.batch_size = batch_size;
52 Ok(self)
53 }
54
55 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
56 self.embedding_dim = embedding_dim;
57 self
58 }
59
60 pub fn gradient_method(mut self, gradient_method: String) -> Self {
61 self.gradient_method = gradient_method;
62 self
63 }
64
65 pub fn similarity_threshold(mut self, similarity_threshold: f64) -> Self {
66 self.similarity_threshold = similarity_threshold;
67 self
68 }
69
70 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
71 self.learning_rate = learning_rate;
72 self
73 }
74
75 pub fn max_iter(mut self, max_iter: usize) -> Self {
76 self.max_iter = max_iter;
77 self
78 }
79
80 pub fn random_state(mut self, random_state: u64) -> Self {
81 self.random_state = Some(random_state);
82 self
83 }
84
85 pub fn query(
86 &self,
87 X: &ArrayView2<f64>,
88 probabilities: &ArrayView2<f64>,
89 ) -> Result<Vec<usize>> {
90 let n_samples = X.dim().0;
91
92 if n_samples < self.batch_size {
93 return Err(BatchActiveLearningError::InsufficientUnlabeledSamples.into());
94 }
95
96 let mut uncertainty_scores = Vec::new();
98 for i in 0..n_samples {
99 let mut entropy = 0.0;
100 for prob in probabilities.row(i) {
101 if *prob > 0.0 {
102 entropy -= prob * prob.ln();
103 }
104 }
105 uncertainty_scores.push((i, entropy));
106 }
107
108 uncertainty_scores
110 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111
112 let selected: Vec<usize> = uncertainty_scores
113 .into_iter()
114 .take(self.batch_size)
115 .map(|(idx, _)| idx)
116 .collect();
117
118 Ok(selected)
119 }
120}