sklears_semi_supervised/
active_learning.rs

1//! Active Learning Methods for Semi-Supervised Learning
2//!
3//! This module provides active learning algorithms that can be integrated with
4//! semi-supervised learning methods to intelligently select the most informative
5//! samples for labeling. These methods help optimize the labeling budget by
6//! focusing on the most uncertain or diverse samples.
7
8use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use sklears_core::error::{Result as SklResult, SklearsError};
10
11/// Uncertainty Sampling for Active Learning
12///
13/// Uncertainty sampling selects samples for labeling based on the model's
14/// uncertainty about their predictions. It supports multiple uncertainty
15/// measures including entropy, margin, and least confident sampling.
16///
17/// # Parameters
18///
19/// * `strategy` - Uncertainty sampling strategy ("entropy", "margin", "least_confident")
20/// * `n_samples` - Number of samples to select for labeling
21/// * `temperature` - Temperature scaling for probability calibration
22/// * `diversity_weight` - Weight for diversity-based selection
23/// * `batch_size` - Size of batches for efficient computation
24///
25/// # Examples
26///
27/// ```rust,ignore
28/// use sklears_semi_supervised::UncertaintySampling;
29///
30///
31/// let probas = array![
32///     [0.9, 0.1],
33///     [0.6, 0.4],
34///     [0.5, 0.5],
35///     [0.8, 0.2]
36/// ];
37///
38/// let us = UncertaintySampling::new()
39///     .strategy("entropy".to_string())
40///     .n_samples(2);
41/// let selected_indices = us.select_samples(&probas.view()).unwrap();
42/// ```
43#[derive(Debug, Clone)]
44pub struct UncertaintySampling {
45    strategy: String,
46    n_samples: usize,
47    temperature: f64,
48    diversity_weight: f64,
49    batch_size: usize,
50    random_tie_breaking: bool,
51}
52
53impl UncertaintySampling {
54    /// Create a new UncertaintySampling instance
55    pub fn new() -> Self {
56        Self {
57            strategy: "entropy".to_string(),
58            n_samples: 10,
59            temperature: 1.0,
60            diversity_weight: 0.0,
61            batch_size: 1000,
62            random_tie_breaking: true,
63        }
64    }
65
66    /// Set the uncertainty sampling strategy
67    pub fn strategy(mut self, strategy: String) -> Self {
68        self.strategy = strategy;
69        self
70    }
71
72    /// Set the number of samples to select
73    pub fn n_samples(mut self, n_samples: usize) -> Self {
74        self.n_samples = n_samples;
75        self
76    }
77
78    /// Set the temperature for probability calibration
79    pub fn temperature(mut self, temperature: f64) -> Self {
80        self.temperature = temperature;
81        self
82    }
83
84    /// Set the diversity weight
85    pub fn diversity_weight(mut self, weight: f64) -> Self {
86        self.diversity_weight = weight;
87        self
88    }
89
90    /// Set the batch size for computation
91    pub fn batch_size(mut self, batch_size: usize) -> Self {
92        self.batch_size = batch_size;
93        self
94    }
95
96    /// Enable/disable random tie breaking
97    pub fn random_tie_breaking(mut self, random: bool) -> Self {
98        self.random_tie_breaking = random;
99        self
100    }
101
102    /// Select samples based on uncertainty
103    pub fn select_samples(&self, probas: &ArrayView2<f64>) -> SklResult<Vec<usize>> {
104        let n_samples = probas.nrows();
105
106        // Validate strategy first
107        match self.strategy.as_str() {
108            "entropy" | "margin" | "least_confident" => {}
109            _ => {
110                return Err(SklearsError::InvalidInput(format!(
111                    "Unknown uncertainty strategy: {}",
112                    self.strategy
113                )))
114            }
115        }
116
117        if self.n_samples >= n_samples {
118            return Ok((0..n_samples).collect());
119        }
120
121        // Apply temperature scaling
122        let calibrated_probas = self.apply_temperature_scaling(probas);
123
124        // Compute uncertainty scores
125        let uncertainty_scores = match self.strategy.as_str() {
126            "entropy" => self.entropy_uncertainty(&calibrated_probas),
127            "margin" => self.margin_uncertainty(&calibrated_probas),
128            "least_confident" => self.least_confident_uncertainty(&calibrated_probas),
129            _ => unreachable!(), // Already validated above
130        }?;
131
132        // Select top uncertain samples
133        let mut indexed_scores: Vec<(usize, f64)> = uncertainty_scores
134            .iter()
135            .enumerate()
136            .map(|(i, &score)| (i, score))
137            .collect();
138
139        // Sort by uncertainty (descending)
140        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
141
142        // Apply diversity if weight > 0
143        let selected_indices = if self.diversity_weight > 0.0 {
144            self.diverse_selection(&indexed_scores, probas)?
145        } else {
146            indexed_scores
147                .iter()
148                .take(self.n_samples)
149                .map(|(idx, _)| *idx)
150                .collect()
151        };
152
153        Ok(selected_indices)
154    }
155
156    fn apply_temperature_scaling(&self, probas: &ArrayView2<f64>) -> Array2<f64> {
157        if (self.temperature - 1.0).abs() < 1e-10 {
158            return probas.to_owned();
159        }
160
161        let mut calibrated = Array2::zeros(probas.dim());
162        for (i, row) in probas.axis_iter(Axis(0)).enumerate() {
163            let scaled: Array1<f64> = row.mapv(|p| (p.ln() / self.temperature).exp());
164            let sum = scaled.sum();
165            if sum > 0.0 {
166                calibrated.row_mut(i).assign(&(scaled / sum));
167            } else {
168                calibrated.row_mut(i).assign(&row);
169            }
170        }
171        calibrated
172    }
173
174    fn entropy_uncertainty(&self, probas: &Array2<f64>) -> SklResult<Array1<f64>> {
175        let mut entropies = Array1::zeros(probas.nrows());
176
177        for (i, row) in probas.axis_iter(Axis(0)).enumerate() {
178            let mut entropy = 0.0;
179            for &p in row.iter() {
180                if p > 1e-15 {
181                    entropy -= p * p.ln();
182                }
183            }
184            entropies[i] = entropy;
185        }
186
187        Ok(entropies)
188    }
189
190    fn margin_uncertainty(&self, probas: &Array2<f64>) -> SklResult<Array1<f64>> {
191        let mut margins = Array1::zeros(probas.nrows());
192
193        for (i, row) in probas.axis_iter(Axis(0)).enumerate() {
194            let mut sorted_probs: Vec<f64> = row.iter().cloned().collect();
195            sorted_probs.sort_by(|a, b| b.partial_cmp(a).unwrap());
196
197            if sorted_probs.len() >= 2 {
198                margins[i] = -(sorted_probs[0] - sorted_probs[1]); // Negative for ascending order
199            } else {
200                margins[i] = -sorted_probs[0];
201            }
202        }
203
204        Ok(margins)
205    }
206
207    fn least_confident_uncertainty(&self, probas: &Array2<f64>) -> SklResult<Array1<f64>> {
208        let mut uncertainties = Array1::zeros(probas.nrows());
209
210        for (i, row) in probas.axis_iter(Axis(0)).enumerate() {
211            let max_prob = row.iter().fold(0.0f64, |a, &b| a.max(b));
212            uncertainties[i] = 1.0 - max_prob;
213        }
214
215        Ok(uncertainties)
216    }
217
218    fn diverse_selection(
219        &self,
220        indexed_scores: &[(usize, f64)],
221        probas: &ArrayView2<f64>,
222    ) -> SklResult<Vec<usize>> {
223        let mut selected = Vec::new();
224        let mut remaining: Vec<(usize, f64)> = indexed_scores.to_vec();
225
226        // Select first sample (most uncertain)
227        if let Some((first_idx, _)) = remaining.first().cloned() {
228            selected.push(first_idx);
229            remaining.retain(|(idx, _)| *idx != first_idx);
230        }
231
232        // Select remaining samples considering diversity
233        while selected.len() < self.n_samples && !remaining.is_empty() {
234            let mut best_idx = 0;
235            let mut best_score = f64::NEG_INFINITY;
236
237            for (candidate_idx, (sample_idx, uncertainty)) in remaining.iter().enumerate() {
238                // Compute diversity score (distance from already selected samples)
239                let mut min_distance = f64::INFINITY;
240                for &selected_idx in &selected {
241                    let distance =
242                        self.compute_distance(probas.row(*sample_idx), probas.row(selected_idx));
243                    min_distance = min_distance.min(distance);
244                }
245
246                // Combined score: uncertainty + diversity
247                let combined_score = (1.0 - self.diversity_weight) * uncertainty
248                    + self.diversity_weight * min_distance;
249
250                if combined_score > best_score {
251                    best_score = combined_score;
252                    best_idx = candidate_idx;
253                }
254            }
255
256            let (selected_sample_idx, _) = remaining.remove(best_idx);
257            selected.push(selected_sample_idx);
258        }
259
260        Ok(selected)
261    }
262
263    fn compute_distance(&self, prob1: ArrayView1<f64>, prob2: ArrayView1<f64>) -> f64 {
264        // Jensen-Shannon divergence
265        let m = (&prob1 + &prob2) / 2.0;
266        let kl1 = self.kl_divergence(&prob1, &m.view());
267        let kl2 = self.kl_divergence(&prob2, &m.view());
268        (kl1 + kl2) / 2.0
269    }
270
271    fn kl_divergence(&self, p: &ArrayView1<f64>, q: &ArrayView1<f64>) -> f64 {
272        let mut kl = 0.0;
273        for (pi, qi) in p.iter().zip(q.iter()) {
274            if *pi > 1e-15 && *qi > 1e-15 {
275                kl += pi * (pi / qi).ln();
276            }
277        }
278        kl
279    }
280}
281
282/// Query by Committee for Active Learning
283///
284/// Query by Committee selects samples where multiple models disagree the most.
285/// It maintains an ensemble of models and selects samples with the highest
286/// disagreement among committee members.
287///
288/// # Parameters
289///
290/// * `n_committee_members` - Number of models in the committee
291/// * `disagreement_measure` - Measure of disagreement ("vote_entropy", "kl_divergence")
292/// * `n_samples` - Number of samples to select
293/// * `diversity_weight` - Weight for diversity in selection
294///
295/// # Examples
296///
297/// ```rust,ignore
298/// use sklears_semi_supervised::QueryByCommittee;
299///
300///
301/// let committee_probas = vec![
302///     array![[0.8, 0.2], [0.6, 0.4], [0.3, 0.7]],
303///     array![[0.7, 0.3], [0.5, 0.5], [0.4, 0.6]],
304///     array![[0.9, 0.1], [0.4, 0.6], [0.2, 0.8]],
305/// ];
306///
307/// let qbc = QueryByCommittee::new()
308///     .n_committee_members(3)
309///     .disagreement_measure("vote_entropy".to_string())
310///     .n_samples(2);
311/// let selected = qbc.select_samples(&committee_probas).unwrap();
312/// ```
313#[derive(Debug, Clone)]
314pub struct QueryByCommittee {
315    n_committee_members: usize,
316    disagreement_measure: String,
317    n_samples: usize,
318    diversity_weight: f64,
319    normalize_disagreement: bool,
320}
321
322impl QueryByCommittee {
323    /// Create a new QueryByCommittee instance
324    pub fn new() -> Self {
325        Self {
326            n_committee_members: 3,
327            disagreement_measure: "vote_entropy".to_string(),
328            n_samples: 10,
329            diversity_weight: 0.0,
330            normalize_disagreement: true,
331        }
332    }
333
334    /// Set the number of committee members
335    pub fn n_committee_members(mut self, n_members: usize) -> Self {
336        self.n_committee_members = n_members;
337        self
338    }
339
340    /// Set the disagreement measure
341    pub fn disagreement_measure(mut self, measure: String) -> Self {
342        self.disagreement_measure = measure;
343        self
344    }
345
346    /// Set the number of samples to select
347    pub fn n_samples(mut self, n_samples: usize) -> Self {
348        self.n_samples = n_samples;
349        self
350    }
351
352    /// Set the diversity weight
353    pub fn diversity_weight(mut self, weight: f64) -> Self {
354        self.diversity_weight = weight;
355        self
356    }
357
358    /// Enable/disable disagreement normalization
359    pub fn normalize_disagreement(mut self, normalize: bool) -> Self {
360        self.normalize_disagreement = normalize;
361        self
362    }
363
364    /// Select samples based on committee disagreement
365    pub fn select_samples(&self, committee_probas: &[Array2<f64>]) -> SklResult<Vec<usize>> {
366        if committee_probas.is_empty() {
367            return Err(SklearsError::InvalidInput(
368                "Empty committee provided".to_string(),
369            ));
370        }
371
372        let n_samples = committee_probas[0].nrows();
373        let n_classes = committee_probas[0].ncols();
374
375        // Validate committee dimensions
376        for (i, probas) in committee_probas.iter().enumerate() {
377            if probas.dim() != (n_samples, n_classes) {
378                return Err(SklearsError::InvalidInput(format!(
379                    "Committee member {} has incompatible dimensions",
380                    i
381                )));
382            }
383        }
384
385        if self.n_samples >= n_samples {
386            return Ok((0..n_samples).collect());
387        }
388
389        // Compute disagreement scores
390        let disagreement_scores = match self.disagreement_measure.as_str() {
391            "vote_entropy" => self.vote_entropy_disagreement(committee_probas)?,
392            "kl_divergence" => self.kl_divergence_disagreement(committee_probas)?,
393            "variance" => self.variance_disagreement(committee_probas)?,
394            _ => {
395                return Err(SklearsError::InvalidInput(format!(
396                    "Unknown disagreement measure: {}",
397                    self.disagreement_measure
398                )))
399            }
400        };
401
402        // Normalize disagreement scores if requested
403        let normalized_scores = if self.normalize_disagreement {
404            self.normalize_scores(&disagreement_scores)
405        } else {
406            disagreement_scores
407        };
408
409        // Select samples with highest disagreement
410        let mut indexed_scores: Vec<(usize, f64)> = normalized_scores
411            .iter()
412            .enumerate()
413            .map(|(i, &score)| (i, score))
414            .collect();
415
416        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
417
418        let selected_indices: Vec<usize> = indexed_scores
419            .iter()
420            .take(self.n_samples)
421            .map(|(idx, _)| *idx)
422            .collect();
423
424        Ok(selected_indices)
425    }
426
427    fn vote_entropy_disagreement(
428        &self,
429        committee_probas: &[Array2<f64>],
430    ) -> SklResult<Array1<f64>> {
431        let n_samples = committee_probas[0].nrows();
432        let n_classes = committee_probas[0].ncols();
433        let mut disagreements = Array1::zeros(n_samples);
434
435        for sample_idx in 0..n_samples {
436            // Get predictions from all committee members
437            let mut class_votes = Array1::zeros(n_classes);
438
439            for committee_probas in committee_probas.iter() {
440                let sample_probas = committee_probas.row(sample_idx);
441                let predicted_class = sample_probas
442                    .iter()
443                    .enumerate()
444                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
445                    .unwrap()
446                    .0;
447                class_votes[predicted_class] += 1.0;
448            }
449
450            // Normalize vote counts to probabilities
451            let total_votes: f64 = class_votes.sum();
452            if total_votes > 0.0 {
453                class_votes /= total_votes;
454            }
455
456            // Compute entropy of vote distribution
457            let mut entropy = 0.0;
458            for &vote_prob in class_votes.iter() {
459                if vote_prob > 1e-15 {
460                    entropy -= vote_prob * vote_prob.ln();
461                }
462            }
463            disagreements[sample_idx] = entropy;
464        }
465
466        Ok(disagreements)
467    }
468
469    fn kl_divergence_disagreement(
470        &self,
471        committee_probas: &[Array2<f64>],
472    ) -> SklResult<Array1<f64>> {
473        let n_samples = committee_probas[0].nrows();
474        let mut disagreements = Array1::zeros(n_samples);
475
476        for sample_idx in 0..n_samples {
477            let mut total_disagreement = 0.0;
478            let mut pair_count = 0;
479
480            // Compute pairwise KL divergences
481            for i in 0..committee_probas.len() {
482                for j in (i + 1)..committee_probas.len() {
483                    let p1 = committee_probas[i].row(sample_idx);
484                    let p2 = committee_probas[j].row(sample_idx);
485
486                    let kl_div = self.symmetric_kl_divergence(&p1, &p2);
487                    total_disagreement += kl_div;
488                    pair_count += 1;
489                }
490            }
491
492            if pair_count > 0 {
493                disagreements[sample_idx] = total_disagreement / pair_count as f64;
494            }
495        }
496
497        Ok(disagreements)
498    }
499
500    fn variance_disagreement(&self, committee_probas: &[Array2<f64>]) -> SklResult<Array1<f64>> {
501        let n_samples = committee_probas[0].nrows();
502        let n_classes = committee_probas[0].ncols();
503        let mut disagreements = Array1::zeros(n_samples);
504
505        for sample_idx in 0..n_samples {
506            let mut total_variance = 0.0;
507
508            for class_idx in 0..n_classes {
509                // Collect probabilities for this class from all committee members
510                let class_probs: Vec<f64> = committee_probas
511                    .iter()
512                    .map(|probas| probas[[sample_idx, class_idx]])
513                    .collect();
514
515                // Compute variance
516                let mean = class_probs.iter().sum::<f64>() / class_probs.len() as f64;
517                let variance = class_probs.iter().map(|&p| (p - mean).powi(2)).sum::<f64>()
518                    / class_probs.len() as f64;
519
520                total_variance += variance;
521            }
522
523            disagreements[sample_idx] = total_variance;
524        }
525
526        Ok(disagreements)
527    }
528
529    fn symmetric_kl_divergence(&self, p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
530        let kl1 = self.kl_divergence(p1, p2);
531        let kl2 = self.kl_divergence(p2, p1);
532        (kl1 + kl2) / 2.0
533    }
534
535    fn kl_divergence(&self, p: &ArrayView1<f64>, q: &ArrayView1<f64>) -> f64 {
536        let mut kl = 0.0;
537        for (pi, qi) in p.iter().zip(q.iter()) {
538            if *pi > 1e-15 && *qi > 1e-15 {
539                kl += pi * (pi / qi).ln();
540            }
541        }
542        kl
543    }
544
545    fn normalize_scores(&self, scores: &Array1<f64>) -> Array1<f64> {
546        let min_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
547        let max_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
548
549        if (max_score - min_score).abs() < 1e-15 {
550            Array1::from_elem(scores.len(), 0.5)
551        } else {
552            scores.mapv(|x| (x - min_score) / (max_score - min_score))
553        }
554    }
555}
556
557impl Default for UncertaintySampling {
558    fn default() -> Self {
559        Self::new()
560    }
561}
562
563impl Default for QueryByCommittee {
564    fn default() -> Self {
565        Self::new()
566    }
567}
568
569/// Expected Model Change for Active Learning
570///
571/// Expected Model Change selects samples that are expected to cause the largest
572/// change in the model parameters when added to the training set. This strategy
573/// looks ahead to predict which samples would be most informative for model updating.
574///
575/// # Parameters
576///
577/// * `n_samples` - Number of samples to select for labeling
578/// * `approximation_method` - Method for approximating model change ("gradient_norm", "fisher_information", "parameter_variance")
579/// * `learning_rate` - Learning rate for gradient approximation
580/// * `epsilon` - Small value for numerical stability
581/// * `normalize_scores` - Whether to normalize change scores
582///
583/// # Examples
584///
585/// ```rust,ignore
586/// use sklears_semi_supervised::ExpectedModelChange;
587///
588///
589/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
590/// let gradients = array![[0.1, 0.2], [0.3, 0.1], [0.05, 0.4]];
591///
592/// let emc = ExpectedModelChange::new()
593///     .approximation_method("gradient_norm".to_string())
594///     .n_samples(2);
595/// let selected = emc.select_samples(&X.view(), &gradients.view()).unwrap();
596/// ```
597#[derive(Debug, Clone)]
598pub struct ExpectedModelChange {
599    n_samples: usize,
600    approximation_method: String,
601    learning_rate: f64,
602    epsilon: f64,
603    normalize_scores: bool,
604    diversity_weight: f64,
605    batch_size: usize,
606}
607
608impl ExpectedModelChange {
609    /// Create a new ExpectedModelChange instance
610    pub fn new() -> Self {
611        Self {
612            n_samples: 10,
613            approximation_method: "gradient_norm".to_string(),
614            learning_rate: 0.01,
615            epsilon: 1e-8,
616            normalize_scores: true,
617            diversity_weight: 0.0,
618            batch_size: 1000,
619        }
620    }
621
622    /// Set the number of samples to select
623    pub fn n_samples(mut self, n_samples: usize) -> Self {
624        self.n_samples = n_samples;
625        self
626    }
627
628    /// Set the approximation method for model change
629    pub fn approximation_method(mut self, method: String) -> Self {
630        self.approximation_method = method;
631        self
632    }
633
634    /// Set the learning rate for gradient approximation
635    pub fn learning_rate(mut self, lr: f64) -> Self {
636        self.learning_rate = lr;
637        self
638    }
639
640    /// Set epsilon for numerical stability
641    pub fn epsilon(mut self, epsilon: f64) -> Self {
642        self.epsilon = epsilon;
643        self
644    }
645
646    /// Set whether to normalize scores
647    pub fn normalize_scores(mut self, normalize: bool) -> Self {
648        self.normalize_scores = normalize;
649        self
650    }
651
652    /// Set diversity weight for selection
653    pub fn diversity_weight(mut self, weight: f64) -> Self {
654        self.diversity_weight = weight;
655        self
656    }
657
658    /// Set batch size for computation
659    pub fn batch_size(mut self, batch_size: usize) -> Self {
660        self.batch_size = batch_size;
661        self
662    }
663
664    /// Select samples based on expected model change
665    pub fn select_samples(
666        &self,
667        X: &ArrayView2<f64>,
668        gradients: &ArrayView2<f64>,
669    ) -> SklResult<Vec<usize>> {
670        let n_samples = X.nrows();
671
672        if gradients.nrows() != n_samples {
673            return Err(SklearsError::InvalidInput(
674                "Number of gradients must match number of samples".to_string(),
675            ));
676        }
677
678        if self.n_samples >= n_samples {
679            return Ok((0..n_samples).collect());
680        }
681
682        // Validate approximation method
683        match self.approximation_method.as_str() {
684            "gradient_norm" | "fisher_information" | "parameter_variance" => {}
685            _ => {
686                return Err(SklearsError::InvalidInput(format!(
687                    "Unknown approximation method: {}",
688                    self.approximation_method
689                )))
690            }
691        }
692
693        // Compute expected model change scores
694        let change_scores = match self.approximation_method.as_str() {
695            "gradient_norm" => self.gradient_norm_scores(gradients)?,
696            "fisher_information" => self.fisher_information_scores(X, gradients)?,
697            "parameter_variance" => self.parameter_variance_scores(gradients)?,
698            _ => unreachable!(),
699        };
700
701        // Normalize scores if requested
702        let final_scores = if self.normalize_scores {
703            self.normalize_change_scores(&change_scores)
704        } else {
705            change_scores
706        };
707
708        // Select samples with highest expected change
709        let mut indexed_scores: Vec<(usize, f64)> = final_scores
710            .iter()
711            .enumerate()
712            .map(|(i, &score)| (i, score))
713            .collect();
714
715        // Sort by expected change (descending)
716        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
717
718        // Apply diversity if weight > 0
719        let selected_indices = if self.diversity_weight > 0.0 {
720            self.diverse_model_change_selection(&indexed_scores, X)?
721        } else {
722            indexed_scores
723                .into_iter()
724                .take(self.n_samples)
725                .map(|(idx, _)| idx)
726                .collect()
727        };
728
729        Ok(selected_indices)
730    }
731
732    fn gradient_norm_scores(&self, gradients: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
733        let n_samples = gradients.nrows();
734        let mut scores = Array1::zeros(n_samples);
735
736        for i in 0..n_samples {
737            let gradient = gradients.row(i);
738            // L2 norm of gradient (expected parameter change)
739            let norm = gradient.iter().map(|&x| x * x).sum::<f64>().sqrt();
740            scores[i] = norm * self.learning_rate;
741        }
742
743        Ok(scores)
744    }
745
746    fn fisher_information_scores(
747        &self,
748        X: &ArrayView2<f64>,
749        gradients: &ArrayView2<f64>,
750    ) -> SklResult<Array1<f64>> {
751        let n_samples = X.nrows();
752        let mut scores = Array1::zeros(n_samples);
753
754        for i in 0..n_samples {
755            let gradient = gradients.row(i);
756            let features = X.row(i);
757
758            // Approximate Fisher Information as outer product of gradients
759            // weighted by feature magnitudes
760            let feature_weight = features.iter().map(|&x| x * x).sum::<f64>().sqrt() + self.epsilon;
761            let gradient_magnitude = gradient.iter().map(|&x| x * x).sum::<f64>();
762
763            scores[i] = gradient_magnitude * feature_weight * self.learning_rate;
764        }
765
766        Ok(scores)
767    }
768
769    fn parameter_variance_scores(&self, gradients: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
770        let n_samples = gradients.nrows();
771        let n_features = gradients.ncols();
772        let mut scores = Array1::zeros(n_samples);
773
774        // Compute mean gradient across samples
775        let mean_gradient = gradients.mean_axis(Axis(0)).unwrap();
776
777        for i in 0..n_samples {
778            let gradient = gradients.row(i);
779
780            // Compute variance from mean gradient
781            let mut variance = 0.0;
782            for j in 0..n_features {
783                let diff = gradient[j] - mean_gradient[j];
784                variance += diff * diff;
785            }
786
787            variance /= n_features as f64;
788            scores[i] = variance * self.learning_rate;
789        }
790
791        Ok(scores)
792    }
793
794    fn normalize_change_scores(&self, scores: &Array1<f64>) -> Array1<f64> {
795        let min_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
796        let max_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
797
798        if (max_score - min_score).abs() < self.epsilon {
799            Array1::from_elem(scores.len(), 0.5)
800        } else {
801            scores.mapv(|x| (x - min_score) / (max_score - min_score))
802        }
803    }
804
805    fn diverse_model_change_selection(
806        &self,
807        indexed_scores: &[(usize, f64)],
808        X: &ArrayView2<f64>,
809    ) -> SklResult<Vec<usize>> {
810        let mut selected = Vec::new();
811        let mut remaining: Vec<(usize, f64)> = indexed_scores.to_vec();
812
813        // Select first sample with highest change score
814        if let Some((first_idx, _)) = remaining.first() {
815            selected.push(*first_idx);
816            remaining.remove(0);
817        }
818
819        // Select remaining samples balancing change and diversity
820        while selected.len() < self.n_samples && !remaining.is_empty() {
821            let mut best_score = f64::NEG_INFINITY;
822            let mut best_idx = 0;
823
824            for (candidate_idx, (sample_idx, change_score)) in remaining.iter().enumerate() {
825                // Compute minimum distance to already selected samples
826                let mut min_distance = f64::INFINITY;
827                for &selected_idx in &selected {
828                    let dist = self.euclidean_distance(X.row(*sample_idx), X.row(selected_idx));
829                    min_distance = min_distance.min(dist);
830                }
831
832                // Combined score: model change + diversity
833                let combined_score = (1.0 - self.diversity_weight) * change_score
834                    + self.diversity_weight * min_distance;
835
836                if combined_score > best_score {
837                    best_score = combined_score;
838                    best_idx = candidate_idx;
839                }
840            }
841
842            let (selected_sample_idx, _) = remaining.remove(best_idx);
843            selected.push(selected_sample_idx);
844        }
845
846        Ok(selected)
847    }
848
849    fn euclidean_distance(&self, x1: ArrayView1<f64>, x2: ArrayView1<f64>) -> f64 {
850        x1.iter()
851            .zip(x2.iter())
852            .map(|(&a, &b)| (a - b).powi(2))
853            .sum::<f64>()
854            .sqrt()
855    }
856}
857
858impl Default for ExpectedModelChange {
859    fn default() -> Self {
860        Self::new()
861    }
862}
863
864/// Information Density for Active Learning
865///
866/// Information Density methods combine uncertainty measures with density-based
867/// sample selection. These methods prefer samples that are both uncertain and
868/// located in dense regions of the feature space, as such samples are more
869/// representative and likely to improve model performance.
870///
871/// # Parameters
872///
873/// * `uncertainty_measure` - Uncertainty measure to use ("entropy", "margin", "least_confident")
874/// * `density_measure` - Density measure to use ("knn_density", "gaussian_density", "cosine_similarity")
875/// * `n_samples` - Number of samples to select for labeling
876/// * `density_weight` - Weight for density component (0.0 = pure uncertainty, 1.0 = pure density)
877/// * `bandwidth` - Bandwidth parameter for density estimation
878/// * `k_neighbors` - Number of neighbors for k-NN density estimation
879///
880/// # Examples
881///
882/// ```rust,ignore
883/// use sklears_semi_supervised::InformationDensity;
884///
885///
886/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
887/// let probas = array![[0.9, 0.1], [0.6, 0.4], [0.5, 0.5], [0.8, 0.2]];
888///
889/// let id = InformationDensity::new()
890///     .uncertainty_measure("entropy".to_string())
891///     .density_measure("knn_density".to_string())
892///     .density_weight(0.5)
893///     .n_samples(2);
894/// let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
895/// ```
896#[derive(Debug, Clone)]
897pub struct InformationDensity {
898    uncertainty_measure: String,
899    density_measure: String,
900    n_samples: usize,
901    density_weight: f64,
902    bandwidth: f64,
903    k_neighbors: usize,
904    temperature: f64,
905    normalize_scores: bool,
906}
907
908impl InformationDensity {
909    /// Create a new InformationDensity instance
910    pub fn new() -> Self {
911        Self {
912            uncertainty_measure: "entropy".to_string(),
913            density_measure: "knn_density".to_string(),
914            n_samples: 10,
915            density_weight: 0.5,
916            bandwidth: 1.0,
917            k_neighbors: 5,
918            temperature: 1.0,
919            normalize_scores: true,
920        }
921    }
922
923    /// Set the uncertainty measure
924    pub fn uncertainty_measure(mut self, measure: String) -> Self {
925        self.uncertainty_measure = measure;
926        self
927    }
928
929    /// Set the density measure
930    pub fn density_measure(mut self, measure: String) -> Self {
931        self.density_measure = measure;
932        self
933    }
934
935    /// Set the number of samples to select
936    pub fn n_samples(mut self, n_samples: usize) -> Self {
937        self.n_samples = n_samples;
938        self
939    }
940
941    /// Set the density weight
942    pub fn density_weight(mut self, weight: f64) -> Self {
943        self.density_weight = weight;
944        self
945    }
946
947    /// Set the bandwidth for density estimation
948    pub fn bandwidth(mut self, bandwidth: f64) -> Self {
949        self.bandwidth = bandwidth;
950        self
951    }
952
953    /// Set the number of neighbors for k-NN density
954    pub fn k_neighbors(mut self, k: usize) -> Self {
955        self.k_neighbors = k;
956        self
957    }
958
959    /// Set temperature for probability calibration
960    pub fn temperature(mut self, temperature: f64) -> Self {
961        self.temperature = temperature;
962        self
963    }
964
965    /// Set whether to normalize scores
966    pub fn normalize_scores(mut self, normalize: bool) -> Self {
967        self.normalize_scores = normalize;
968        self
969    }
970
971    /// Select samples based on information density
972    pub fn select_samples(
973        &self,
974        X: &ArrayView2<f64>,
975        probas: &ArrayView2<f64>,
976    ) -> SklResult<Vec<usize>> {
977        let n_samples = X.nrows();
978
979        if probas.nrows() != n_samples {
980            return Err(SklearsError::InvalidInput(
981                "Number of probabilities must match number of samples".to_string(),
982            ));
983        }
984
985        if self.n_samples >= n_samples {
986            return Ok((0..n_samples).collect());
987        }
988
989        // Validate measures
990        match self.uncertainty_measure.as_str() {
991            "entropy" | "margin" | "least_confident" => {}
992            _ => {
993                return Err(SklearsError::InvalidInput(format!(
994                    "Unknown uncertainty measure: {}",
995                    self.uncertainty_measure
996                )))
997            }
998        }
999
1000        match self.density_measure.as_str() {
1001            "knn_density" | "gaussian_density" | "cosine_similarity" => {}
1002            _ => {
1003                return Err(SklearsError::InvalidInput(format!(
1004                    "Unknown density measure: {}",
1005                    self.density_measure
1006                )))
1007            }
1008        }
1009
1010        // Apply temperature scaling to probabilities
1011        let calibrated_probas = self.apply_temperature_scaling(probas);
1012
1013        // Compute uncertainty scores
1014        let uncertainty_scores = match self.uncertainty_measure.as_str() {
1015            "entropy" => self.entropy_uncertainty(&calibrated_probas)?,
1016            "margin" => self.margin_uncertainty(&calibrated_probas)?,
1017            "least_confident" => self.least_confident_uncertainty(&calibrated_probas)?,
1018            _ => unreachable!(),
1019        };
1020
1021        // Compute density scores
1022        let density_scores = match self.density_measure.as_str() {
1023            "knn_density" => self.knn_density_scores(X)?,
1024            "gaussian_density" => self.gaussian_density_scores(X)?,
1025            "cosine_similarity" => self.cosine_similarity_scores(X)?,
1026            _ => unreachable!(),
1027        };
1028
1029        // Normalize scores if requested
1030        let normalized_uncertainty = if self.normalize_scores {
1031            self.normalize_array(&uncertainty_scores)
1032        } else {
1033            uncertainty_scores
1034        };
1035
1036        let normalized_density = if self.normalize_scores {
1037            self.normalize_array(&density_scores)
1038        } else {
1039            density_scores
1040        };
1041
1042        // Combine uncertainty and density scores
1043        let mut combined_scores = Array1::zeros(n_samples);
1044        for i in 0..n_samples {
1045            combined_scores[i] = (1.0 - self.density_weight) * normalized_uncertainty[i]
1046                + self.density_weight * normalized_density[i];
1047        }
1048
1049        // Select samples with highest combined scores
1050        let mut indexed_scores: Vec<(usize, f64)> = combined_scores
1051            .iter()
1052            .enumerate()
1053            .map(|(i, &score)| (i, score))
1054            .collect();
1055
1056        // Sort by combined score (descending)
1057        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1058
1059        let selected_indices: Vec<usize> = indexed_scores
1060            .into_iter()
1061            .take(self.n_samples)
1062            .map(|(idx, _)| idx)
1063            .collect();
1064
1065        Ok(selected_indices)
1066    }
1067
1068    fn apply_temperature_scaling(&self, probas: &ArrayView2<f64>) -> Array2<f64> {
1069        if (self.temperature - 1.0).abs() < 1e-10 {
1070            return probas.to_owned();
1071        }
1072
1073        let mut calibrated = Array2::zeros(probas.dim());
1074        for i in 0..probas.nrows() {
1075            let row = probas.row(i);
1076            // Apply temperature scaling: p_i = exp(logit_i / T) / sum(exp(logit_j / T))
1077            let max_prob = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1078            let logits: Vec<f64> = row.iter().map(|&p| (p / max_prob).ln()).collect();
1079
1080            let scaled_logits: Vec<f64> = logits.iter().map(|&l| l / self.temperature).collect();
1081            let max_logit = scaled_logits
1082                .iter()
1083                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1084
1085            let exp_sum: f64 = scaled_logits.iter().map(|&l| (l - max_logit).exp()).sum();
1086
1087            for j in 0..probas.ncols() {
1088                calibrated[[i, j]] = (scaled_logits[j] - max_logit).exp() / exp_sum;
1089            }
1090        }
1091
1092        calibrated
1093    }
1094
1095    fn entropy_uncertainty(&self, probas: &Array2<f64>) -> SklResult<Array1<f64>> {
1096        let n_samples = probas.nrows();
1097        let mut entropies = Array1::zeros(n_samples);
1098
1099        for i in 0..n_samples {
1100            let mut entropy = 0.0;
1101            for &p in probas.row(i).iter() {
1102                if p > 1e-15 {
1103                    entropy -= p * p.ln();
1104                }
1105            }
1106            entropies[i] = entropy;
1107        }
1108
1109        Ok(entropies)
1110    }
1111
1112    fn margin_uncertainty(&self, probas: &Array2<f64>) -> SklResult<Array1<f64>> {
1113        let n_samples = probas.nrows();
1114        let mut margins = Array1::zeros(n_samples);
1115
1116        for i in 0..n_samples {
1117            let row = probas.row(i);
1118            let mut sorted_probs: Vec<f64> = row.iter().cloned().collect();
1119            sorted_probs.sort_by(|a, b| b.partial_cmp(a).unwrap());
1120
1121            let margin = if sorted_probs.len() >= 2 {
1122                sorted_probs[0] - sorted_probs[1] // largest - second largest
1123            } else {
1124                sorted_probs[0]
1125            };
1126            margins[i] = -margin; // Negative so higher uncertainty = higher score
1127        }
1128
1129        Ok(margins)
1130    }
1131
1132    fn least_confident_uncertainty(&self, probas: &Array2<f64>) -> SklResult<Array1<f64>> {
1133        let n_samples = probas.nrows();
1134        let mut uncertainties = Array1::zeros(n_samples);
1135
1136        for i in 0..n_samples {
1137            let max_prob = probas
1138                .row(i)
1139                .iter()
1140                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1141            uncertainties[i] = 1.0 - max_prob;
1142        }
1143
1144        Ok(uncertainties)
1145    }
1146
1147    fn knn_density_scores(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
1148        let n_samples = X.nrows();
1149        let mut density_scores = Array1::zeros(n_samples);
1150
1151        for i in 0..n_samples {
1152            let sample_i = X.row(i);
1153            let mut distances = Vec::new();
1154
1155            // Compute distances to all other samples
1156            for j in 0..n_samples {
1157                if i != j {
1158                    let sample_j = X.row(j);
1159                    let distance = self.euclidean_distance(sample_i, sample_j);
1160                    distances.push(distance);
1161                }
1162            }
1163
1164            // Sort distances and take k-th nearest neighbor distance
1165            distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
1166            let k_distance = if self.k_neighbors <= distances.len() {
1167                distances[self.k_neighbors - 1]
1168            } else {
1169                distances.last().cloned().unwrap_or(1.0)
1170            };
1171
1172            // Higher density = smaller distance to k-th neighbor
1173            density_scores[i] = 1.0 / (k_distance + 1e-8);
1174        }
1175
1176        Ok(density_scores)
1177    }
1178
1179    fn gaussian_density_scores(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
1180        let n_samples = X.nrows();
1181        let mut density_scores = Array1::zeros(n_samples);
1182
1183        for i in 0..n_samples {
1184            let sample_i = X.row(i);
1185            let mut density = 0.0;
1186
1187            // Compute Gaussian kernel density
1188            for j in 0..n_samples {
1189                if i != j {
1190                    let sample_j = X.row(j);
1191                    let distance_sq = self.euclidean_distance_squared(sample_i, sample_j);
1192                    density += (-distance_sq / (2.0 * self.bandwidth * self.bandwidth)).exp();
1193                }
1194            }
1195
1196            density_scores[i] = density / (n_samples - 1) as f64;
1197        }
1198
1199        Ok(density_scores)
1200    }
1201
1202    fn cosine_similarity_scores(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
1203        let n_samples = X.nrows();
1204        let mut similarity_scores = Array1::zeros(n_samples);
1205
1206        for i in 0..n_samples {
1207            let sample_i = X.row(i);
1208            let mut total_similarity = 0.0;
1209
1210            for j in 0..n_samples {
1211                if i != j {
1212                    let sample_j = X.row(j);
1213                    let similarity = self.cosine_similarity(sample_i, sample_j);
1214                    total_similarity += similarity;
1215                }
1216            }
1217
1218            similarity_scores[i] = total_similarity / (n_samples - 1) as f64;
1219        }
1220
1221        Ok(similarity_scores)
1222    }
1223
1224    fn normalize_array(&self, array: &Array1<f64>) -> Array1<f64> {
1225        let min_val = array.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1226        let max_val = array.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1227
1228        if (max_val - min_val).abs() < 1e-15 {
1229            Array1::from_elem(array.len(), 0.5)
1230        } else {
1231            array.mapv(|x| (x - min_val) / (max_val - min_val))
1232        }
1233    }
1234
1235    fn euclidean_distance(&self, x1: ArrayView1<f64>, x2: ArrayView1<f64>) -> f64 {
1236        x1.iter()
1237            .zip(x2.iter())
1238            .map(|(&a, &b)| (a - b).powi(2))
1239            .sum::<f64>()
1240            .sqrt()
1241    }
1242
1243    fn euclidean_distance_squared(&self, x1: ArrayView1<f64>, x2: ArrayView1<f64>) -> f64 {
1244        x1.iter()
1245            .zip(x2.iter())
1246            .map(|(&a, &b)| (a - b).powi(2))
1247            .sum::<f64>()
1248    }
1249
1250    fn cosine_similarity(&self, x1: ArrayView1<f64>, x2: ArrayView1<f64>) -> f64 {
1251        let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(&a, &b)| a * b).sum();
1252        let norm_x1 = x1.iter().map(|&x| x * x).sum::<f64>().sqrt();
1253        let norm_x2 = x2.iter().map(|&x| x * x).sum::<f64>().sqrt();
1254
1255        if norm_x1 < 1e-15 || norm_x2 < 1e-15 {
1256            0.0
1257        } else {
1258            dot_product / (norm_x1 * norm_x2)
1259        }
1260    }
1261}
1262
1263impl Default for InformationDensity {
1264    fn default() -> Self {
1265        Self::new()
1266    }
1267}
1268
1269#[allow(non_snake_case)]
1270#[cfg(test)]
1271mod tests {
1272    use super::*;
1273    use scirs2_core::array;
1274
1275    #[test]
1276    fn test_uncertainty_sampling_entropy() {
1277        let probas = array![
1278            [0.9, 0.1], // Low entropy (certain)
1279            [0.6, 0.4], // Medium entropy
1280            [0.5, 0.5], // High entropy (uncertain)
1281            [0.8, 0.2], // Low entropy
1282        ];
1283
1284        let us = UncertaintySampling::new()
1285            .strategy("entropy".to_string())
1286            .n_samples(2);
1287        let selected = us.select_samples(&probas.view()).unwrap();
1288
1289        assert_eq!(selected.len(), 2);
1290        // Most uncertain sample should be selected first (index 2: [0.5, 0.5])
1291        assert!(selected.contains(&2));
1292    }
1293
1294    #[test]
1295    fn test_uncertainty_sampling_margin() {
1296        let probas = array![
1297            [0.9, 0.1], // Large margin
1298            [0.6, 0.4], // Small margin
1299            [0.5, 0.5], // No margin (most uncertain)
1300            [0.8, 0.2], // Medium margin
1301        ];
1302
1303        let us = UncertaintySampling::new()
1304            .strategy("margin".to_string())
1305            .n_samples(2);
1306        let selected = us.select_samples(&probas.view()).unwrap();
1307
1308        assert_eq!(selected.len(), 2);
1309        // Sample with smallest margin should be selected (index 2)
1310        assert!(selected.contains(&2));
1311    }
1312
1313    #[test]
1314    fn test_uncertainty_sampling_least_confident() {
1315        let probas = array![
1316            [0.9, 0.1], // High confidence
1317            [0.6, 0.4], // Medium confidence
1318            [0.5, 0.5], // Least confident
1319            [0.8, 0.2], // High confidence
1320        ];
1321
1322        let us = UncertaintySampling::new()
1323            .strategy("least_confident".to_string())
1324            .n_samples(2);
1325        let selected = us.select_samples(&probas.view()).unwrap();
1326
1327        assert_eq!(selected.len(), 2);
1328        // Least confident sample should be selected (index 2)
1329        assert!(selected.contains(&2));
1330    }
1331
1332    #[test]
1333    fn test_uncertainty_sampling_temperature_scaling() {
1334        let us = UncertaintySampling::new().temperature(2.0);
1335        let probas = array![[0.8, 0.2], [0.6, 0.4]];
1336
1337        let calibrated = us.apply_temperature_scaling(&probas.view());
1338
1339        // Temperature scaling should make probabilities less extreme
1340        assert!(calibrated[[0, 0]] < 0.8);
1341        assert!(calibrated[[0, 1]] > 0.2);
1342
1343        // Check that probabilities still sum to 1
1344        for i in 0..calibrated.nrows() {
1345            let sum: f64 = calibrated.row(i).sum();
1346            assert!((sum - 1.0).abs() < 1e-10);
1347        }
1348    }
1349
1350    #[test]
1351    fn test_query_by_committee_vote_entropy() {
1352        let committee_probas = vec![
1353            array![[0.8, 0.2], [0.6, 0.4], [0.3, 0.7]], // Member 1
1354            array![[0.7, 0.3], [0.5, 0.5], [0.4, 0.6]], // Member 2
1355            array![[0.9, 0.1], [0.4, 0.6], [0.2, 0.8]], // Member 3
1356        ];
1357
1358        let qbc = QueryByCommittee::new()
1359            .disagreement_measure("vote_entropy".to_string())
1360            .n_samples(2);
1361        let selected = qbc.select_samples(&committee_probas).unwrap();
1362
1363        assert_eq!(selected.len(), 2);
1364        // Should select samples where committee members disagree most
1365    }
1366
1367    #[test]
1368    fn test_query_by_committee_kl_divergence() {
1369        let committee_probas = vec![
1370            array![[0.8, 0.2], [0.6, 0.4]],
1371            array![[0.2, 0.8], [0.4, 0.6]], // Very different from first
1372        ];
1373
1374        let qbc = QueryByCommittee::new()
1375            .disagreement_measure("kl_divergence".to_string())
1376            .n_samples(1);
1377        let selected = qbc.select_samples(&committee_probas).unwrap();
1378
1379        assert_eq!(selected.len(), 1);
1380    }
1381
1382    #[test]
1383    fn test_query_by_committee_variance() {
1384        let committee_probas = vec![
1385            array![[0.9, 0.1], [0.5, 0.5]],
1386            array![[0.8, 0.2], [0.6, 0.4]],
1387            array![[0.7, 0.3], [0.4, 0.6]],
1388        ];
1389
1390        let qbc = QueryByCommittee::new()
1391            .disagreement_measure("variance".to_string())
1392            .n_samples(1);
1393        let selected = qbc.select_samples(&committee_probas).unwrap();
1394
1395        assert_eq!(selected.len(), 1);
1396    }
1397
1398    #[test]
1399    fn test_uncertainty_sampling_diversity() {
1400        let probas = array![
1401            [0.5, 0.5], // Uncertain
1402            [0.4, 0.6], // Uncertain and different
1403            [0.6, 0.4], // Uncertain and different
1404            [0.9, 0.1], // Certain
1405        ];
1406
1407        let us = UncertaintySampling::new()
1408            .strategy("entropy".to_string())
1409            .diversity_weight(0.5)
1410            .n_samples(2);
1411        let selected = us.select_samples(&probas.view()).unwrap();
1412
1413        assert_eq!(selected.len(), 2);
1414        // Should select diverse uncertain samples
1415    }
1416
1417    #[test]
1418    fn test_uncertainty_sampling_edge_cases() {
1419        // Test with n_samples >= total samples
1420        let probas = array![[0.7, 0.3], [0.6, 0.4]];
1421        let us = UncertaintySampling::new().n_samples(5);
1422        let selected = us.select_samples(&probas.view()).unwrap();
1423        assert_eq!(selected.len(), 2);
1424        assert_eq!(selected, vec![0, 1]);
1425
1426        // Test with invalid strategy
1427        let us_invalid = UncertaintySampling::new().strategy("invalid".to_string());
1428        let result = us_invalid.select_samples(&probas.view());
1429        assert!(result.is_err());
1430    }
1431
1432    #[test]
1433    fn test_query_by_committee_edge_cases() {
1434        // Test with empty committee
1435        let qbc = QueryByCommittee::new();
1436        let result = qbc.select_samples(&[]);
1437        assert!(result.is_err());
1438
1439        // Test with mismatched dimensions
1440        let committee_probas = vec![
1441            array![[0.8, 0.2], [0.6, 0.4]],
1442            array![[0.7, 0.3]], // Different number of samples
1443        ];
1444        let result = qbc.select_samples(&committee_probas);
1445        assert!(result.is_err());
1446    }
1447
1448    #[test]
1449    fn test_kl_divergence_computation() {
1450        let us = UncertaintySampling::new();
1451        let p1 = array![0.8, 0.2];
1452        let p2 = array![0.6, 0.4];
1453
1454        let kl = us.kl_divergence(&p1.view(), &p2.view());
1455        assert!(kl > 0.0);
1456
1457        // KL divergence with self should be 0
1458        let kl_self = us.kl_divergence(&p1.view(), &p1.view());
1459        assert!(kl_self.abs() < 1e-10);
1460    }
1461
1462    #[test]
1463    fn test_score_normalization() {
1464        let qbc = QueryByCommittee::new();
1465        let scores = array![1.0, 5.0, 3.0, 2.0];
1466
1467        let normalized = qbc.normalize_scores(&scores);
1468
1469        // Check range [0, 1]
1470        for &score in normalized.iter() {
1471            assert!(score >= 0.0 && score <= 1.0);
1472        }
1473
1474        // Check min and max
1475        let min_normalized = normalized.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1476        let max_normalized = normalized.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1477        assert!((min_normalized - 0.0).abs() < 1e-10);
1478        assert!((max_normalized - 1.0).abs() < 1e-10);
1479    }
1480
1481    #[test]
1482    #[allow(non_snake_case)]
1483    fn test_expected_model_change_gradient_norm() {
1484        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1485        let gradients = array![
1486            [0.1, 0.2],  // Small gradient
1487            [0.8, 0.6],  // Large gradient
1488            [0.05, 0.1], // Very small gradient
1489            [0.4, 0.3],  // Medium gradient
1490        ];
1491
1492        let emc = ExpectedModelChange::new()
1493            .approximation_method("gradient_norm".to_string())
1494            .n_samples(2);
1495        let selected = emc.select_samples(&X.view(), &gradients.view()).unwrap();
1496
1497        assert_eq!(selected.len(), 2);
1498        // Sample with largest gradient should be selected first (index 1)
1499        assert!(selected.contains(&1));
1500    }
1501
1502    #[test]
1503    #[allow(non_snake_case)]
1504    fn test_expected_model_change_fisher_information() {
1505        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1506        let gradients = array![[0.1, 0.2], [0.3, 0.1], [0.05, 0.4], [0.2, 0.3],];
1507
1508        let emc = ExpectedModelChange::new()
1509            .approximation_method("fisher_information".to_string())
1510            .n_samples(2);
1511        let selected = emc.select_samples(&X.view(), &gradients.view()).unwrap();
1512
1513        assert_eq!(selected.len(), 2);
1514        // Check that valid indices are selected
1515        for &idx in &selected {
1516            assert!(idx < 4);
1517        }
1518    }
1519
1520    #[test]
1521    #[allow(non_snake_case)]
1522    fn test_expected_model_change_parameter_variance() {
1523        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1524        let gradients = array![
1525            [0.1, 0.2],
1526            [0.8, 0.1], // Different from mean
1527            [0.1, 0.2],
1528            [0.1, 0.9], // Very different from mean
1529        ];
1530
1531        let emc = ExpectedModelChange::new()
1532            .approximation_method("parameter_variance".to_string())
1533            .n_samples(2);
1534        let selected = emc.select_samples(&X.view(), &gradients.view()).unwrap();
1535
1536        assert_eq!(selected.len(), 2);
1537        // Samples with high variance should be selected
1538        assert!(selected.contains(&3)); // High variance in second component
1539    }
1540
1541    #[test]
1542    #[allow(non_snake_case)]
1543    fn test_expected_model_change_diversity() {
1544        let X = array![
1545            [1.0, 1.0],
1546            [1.1, 1.1], // Close to first sample
1547            [5.0, 5.0], // Far from first samples
1548            [5.1, 5.1], // Close to third sample
1549        ];
1550        let gradients = array![
1551            [0.5, 0.5], // High gradient
1552            [0.4, 0.4], // High gradient, close to first
1553            [0.3, 0.3], // Medium gradient, far from others
1554            [0.2, 0.2], // Low gradient
1555        ];
1556
1557        let emc = ExpectedModelChange::new()
1558            .approximation_method("gradient_norm".to_string())
1559            .diversity_weight(0.5)
1560            .n_samples(2);
1561        let selected = emc.select_samples(&X.view(), &gradients.view()).unwrap();
1562
1563        assert_eq!(selected.len(), 2);
1564        // Should balance gradient magnitude with diversity
1565        // First sample (highest gradient) should be selected
1566        assert!(selected.contains(&0));
1567    }
1568
1569    #[test]
1570    #[allow(non_snake_case)]
1571    fn test_expected_model_change_edge_cases() {
1572        let X = array![[1.0, 2.0], [2.0, 3.0]];
1573        let gradients = array![[0.1, 0.2], [0.3, 0.1]];
1574
1575        // Test with n_samples >= total samples
1576        let emc = ExpectedModelChange::new().n_samples(5);
1577        let selected = emc.select_samples(&X.view(), &gradients.view()).unwrap();
1578        assert_eq!(selected.len(), 2);
1579        assert_eq!(selected, vec![0, 1]);
1580
1581        // Test with mismatched dimensions
1582        let bad_gradients = array![[0.1, 0.2]]; // Only one gradient for two samples
1583        let result = emc.select_samples(&X.view(), &bad_gradients.view());
1584        assert!(result.is_err());
1585
1586        // Test with invalid approximation method
1587        let emc_invalid = ExpectedModelChange::new()
1588            .approximation_method("invalid".to_string())
1589            .n_samples(1); // Ensure validation happens
1590        let result = emc_invalid.select_samples(&X.view(), &gradients.view());
1591        assert!(result.is_err());
1592    }
1593
1594    #[test]
1595    #[allow(non_snake_case)]
1596    fn test_expected_model_change_normalization() {
1597        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1598        let gradients = array![[1.0, 0.0], [10.0, 0.0], [5.0, 0.0]];
1599
1600        let emc_normalized = ExpectedModelChange::new()
1601            .normalize_scores(true)
1602            .n_samples(2);
1603        let selected_norm = emc_normalized
1604            .select_samples(&X.view(), &gradients.view())
1605            .unwrap();
1606
1607        let emc_unnormalized = ExpectedModelChange::new()
1608            .normalize_scores(false)
1609            .n_samples(2);
1610        let selected_unnorm = emc_unnormalized
1611            .select_samples(&X.view(), &gradients.view())
1612            .unwrap();
1613
1614        // Both should select same samples (highest gradients)
1615        assert_eq!(selected_norm.len(), 2);
1616        assert_eq!(selected_unnorm.len(), 2);
1617        assert!(selected_norm.contains(&1)); // Highest gradient
1618        assert!(selected_unnorm.contains(&1)); // Highest gradient
1619    }
1620
1621    #[test]
1622    fn test_euclidean_distance_computation() {
1623        let emc = ExpectedModelChange::new();
1624        let x1 = array![1.0, 2.0];
1625        let x2 = array![4.0, 6.0];
1626
1627        let distance = emc.euclidean_distance(x1.view(), x2.view());
1628        let expected = ((4.0_f64 - 1.0).powi(2) + (6.0_f64 - 2.0).powi(2)).sqrt();
1629        assert!((distance - expected).abs() < 1e-10);
1630
1631        // Distance to self should be 0
1632        let distance_self = emc.euclidean_distance(x1.view(), x1.view());
1633        assert!(distance_self.abs() < 1e-10);
1634    }
1635
1636    #[test]
1637    #[allow(non_snake_case)]
1638    fn test_information_density_knn() {
1639        let X = array![
1640            [1.0, 1.0], // Clustered samples
1641            [1.1, 1.1],
1642            [1.2, 1.2],
1643            [10.0, 10.0], // Isolated sample
1644        ];
1645        let probas = array![
1646            [0.5, 0.5], // High uncertainty
1647            [0.6, 0.4], // Medium uncertainty
1648            [0.7, 0.3], // Lower uncertainty
1649            [0.9, 0.1], // Low uncertainty
1650        ];
1651
1652        let id = InformationDensity::new()
1653            .uncertainty_measure("entropy".to_string())
1654            .density_measure("knn_density".to_string())
1655            .density_weight(0.5)
1656            .k_neighbors(2)
1657            .n_samples(2);
1658        let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
1659
1660        assert_eq!(selected.len(), 2);
1661        // Should prefer uncertain samples in dense regions
1662        // Clustered uncertain samples should be preferred over isolated uncertain ones
1663    }
1664
1665    #[test]
1666    #[allow(non_snake_case)]
1667    fn test_information_density_gaussian() {
1668        let X = array![[1.0, 1.0], [1.5, 1.5], [2.0, 2.0], [10.0, 10.0],];
1669        let probas = array![
1670            [0.5, 0.5], // High uncertainty
1671            [0.6, 0.4],
1672            [0.7, 0.3],
1673            [0.5, 0.5], // High uncertainty but isolated
1674        ];
1675
1676        let id = InformationDensity::new()
1677            .uncertainty_measure("entropy".to_string())
1678            .density_measure("gaussian_density".to_string())
1679            .density_weight(0.7)
1680            .bandwidth(1.0)
1681            .n_samples(2);
1682        let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
1683
1684        assert_eq!(selected.len(), 2);
1685        // Should favor dense regions with high weight on density
1686    }
1687
1688    #[test]
1689    #[allow(non_snake_case)]
1690    fn test_information_density_cosine_similarity() {
1691        let X = array![
1692            [1.0, 0.0], // Orthogonal vectors
1693            [0.0, 1.0],
1694            [1.0, 1.0], // Similar to first two
1695            [0.5, 0.5],
1696        ];
1697        let probas = array![
1698            [0.5, 0.5], // High uncertainty
1699            [0.5, 0.5], // High uncertainty
1700            [0.6, 0.4],
1701            [0.7, 0.3],
1702        ];
1703
1704        let id = InformationDensity::new()
1705            .uncertainty_measure("entropy".to_string())
1706            .density_measure("cosine_similarity".to_string())
1707            .density_weight(0.3)
1708            .n_samples(2);
1709        let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
1710
1711        assert_eq!(selected.len(), 2);
1712        // Should select based on both uncertainty and cosine similarity
1713    }
1714
1715    #[test]
1716    #[allow(non_snake_case)]
1717    fn test_information_density_margin_uncertainty() {
1718        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1719        let probas = array![
1720            [0.9, 0.1],   // Low margin (certain)
1721            [0.6, 0.4],   // High margin (uncertain)
1722            [0.55, 0.45], // Highest margin (most uncertain)
1723            [0.8, 0.2],   // Low margin
1724        ];
1725
1726        let id = InformationDensity::new()
1727            .uncertainty_measure("margin".to_string())
1728            .density_measure("knn_density".to_string())
1729            .density_weight(0.0) // Pure uncertainty
1730            .n_samples(2);
1731        let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
1732
1733        assert_eq!(selected.len(), 2);
1734        // Should select samples with smallest margins (most uncertain)
1735        assert!(selected.contains(&2)); // Smallest margin
1736    }
1737
1738    #[test]
1739    #[allow(non_snake_case)]
1740    fn test_information_density_least_confident() {
1741        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1742        let probas = array![
1743            [0.9, 0.1],   // Very confident
1744            [0.6, 0.4],   // Less confident
1745            [0.55, 0.45], // Least confident
1746            [0.8, 0.2],   // Confident
1747        ];
1748
1749        let id = InformationDensity::new()
1750            .uncertainty_measure("least_confident".to_string())
1751            .density_measure("knn_density".to_string())
1752            .density_weight(0.0) // Pure uncertainty
1753            .n_samples(2);
1754        let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
1755
1756        assert_eq!(selected.len(), 2);
1757        // Should select least confident samples
1758        assert!(selected.contains(&2)); // Least confident
1759        assert!(selected.contains(&1)); // Second least confident
1760    }
1761
1762    #[test]
1763    #[allow(non_snake_case)]
1764    fn test_information_density_temperature_scaling() {
1765        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1766        let probas = array![[0.8, 0.2], [0.6, 0.4], [0.7, 0.3],];
1767
1768        // Test with different temperatures
1769        let id_low_temp = InformationDensity::new()
1770            .temperature(0.5) // Sharper probabilities
1771            .density_weight(0.0)
1772            .n_samples(2);
1773        let selected_low = id_low_temp
1774            .select_samples(&X.view(), &probas.view())
1775            .unwrap();
1776
1777        let id_high_temp = InformationDensity::new()
1778            .temperature(2.0) // Smoother probabilities
1779            .density_weight(0.0)
1780            .n_samples(2);
1781        let selected_high = id_high_temp
1782            .select_samples(&X.view(), &probas.view())
1783            .unwrap();
1784
1785        assert_eq!(selected_low.len(), 2);
1786        assert_eq!(selected_high.len(), 2);
1787    }
1788
1789    #[test]
1790    #[allow(non_snake_case)]
1791    fn test_information_density_edge_cases() {
1792        let X = array![[1.0, 2.0], [2.0, 3.0]];
1793        let probas = array![[0.6, 0.4], [0.7, 0.3]];
1794
1795        // Test with n_samples >= total samples
1796        let id = InformationDensity::new().n_samples(5);
1797        let selected = id.select_samples(&X.view(), &probas.view()).unwrap();
1798        assert_eq!(selected.len(), 2);
1799        assert_eq!(selected, vec![0, 1]);
1800
1801        // Test with mismatched dimensions
1802        let bad_probas = array![[0.6, 0.4]]; // Only one probability for two samples
1803        let result = id.select_samples(&X.view(), &bad_probas.view());
1804        assert!(result.is_err());
1805
1806        // Test with invalid uncertainty measure
1807        let id_invalid = InformationDensity::new()
1808            .uncertainty_measure("invalid".to_string())
1809            .n_samples(1); // Ensure validation happens
1810        let result = id_invalid.select_samples(&X.view(), &probas.view());
1811        assert!(result.is_err());
1812
1813        // Test with invalid density measure
1814        let id_invalid = InformationDensity::new()
1815            .density_measure("invalid".to_string())
1816            .n_samples(1); // Ensure validation happens
1817        let result = id_invalid.select_samples(&X.view(), &probas.view());
1818        assert!(result.is_err());
1819    }
1820
1821    #[test]
1822    #[allow(non_snake_case)]
1823    fn test_information_density_pure_modes() {
1824        let X = array![
1825            [1.0, 1.0], // Dense cluster
1826            [1.1, 1.1],
1827            [10.0, 10.0], // Isolated
1828            [11.0, 11.0], // Another cluster
1829        ];
1830        let probas = array![
1831            [0.5, 0.5], // High uncertainty
1832            [0.9, 0.1], // Low uncertainty
1833            [0.5, 0.5], // High uncertainty
1834            [0.8, 0.2], // Low uncertainty
1835        ];
1836
1837        // Pure uncertainty (density_weight = 0)
1838        let id_uncertainty = InformationDensity::new().density_weight(0.0).n_samples(2);
1839        let selected_unc = id_uncertainty
1840            .select_samples(&X.view(), &probas.view())
1841            .unwrap();
1842        assert!(selected_unc.contains(&0)); // High uncertainty
1843        assert!(selected_unc.contains(&2)); // High uncertainty
1844
1845        // Pure density (density_weight = 1)
1846        let id_density = InformationDensity::new()
1847            .density_weight(1.0)
1848            .k_neighbors(1)
1849            .n_samples(2);
1850        let selected_den = id_density
1851            .select_samples(&X.view(), &probas.view())
1852            .unwrap();
1853        // Should prefer samples in denser regions
1854        assert_eq!(selected_den.len(), 2);
1855    }
1856
1857    #[test]
1858    fn test_cosine_similarity_computation() {
1859        let id = InformationDensity::new();
1860
1861        // Test orthogonal vectors
1862        let x1 = array![1.0, 0.0];
1863        let x2 = array![0.0, 1.0];
1864        let sim = id.cosine_similarity(x1.view(), x2.view());
1865        assert!((sim - 0.0).abs() < 1e-10);
1866
1867        // Test identical vectors
1868        let x3 = array![1.0, 1.0];
1869        let x4 = array![1.0, 1.0];
1870        let sim_identical = id.cosine_similarity(x3.view(), x4.view());
1871        assert!((sim_identical - 1.0).abs() < 1e-10);
1872
1873        // Test opposite vectors
1874        let x5 = array![1.0, 0.0];
1875        let x6 = array![-1.0, 0.0];
1876        let sim_opposite = id.cosine_similarity(x5.view(), x6.view());
1877        assert!((sim_opposite - (-1.0)).abs() < 1e-10);
1878    }
1879
1880    #[test]
1881    fn test_array_normalization() {
1882        let id = InformationDensity::new();
1883        let array = array![1.0, 5.0, 3.0, 2.0];
1884
1885        let normalized = id.normalize_array(&array);
1886
1887        // Check range [0, 1]
1888        for &value in normalized.iter() {
1889            assert!(value >= 0.0 && value <= 1.0);
1890        }
1891
1892        // Check min and max
1893        let min_norm = normalized.iter().fold(f64::INFINITY, |a, &b| a.min(b));
1894        let max_norm = normalized.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1895        assert!((min_norm - 0.0).abs() < 1e-10);
1896        assert!((max_norm - 1.0).abs() < 1e-10);
1897
1898        // Test with constant array
1899        let constant = array![5.0, 5.0, 5.0];
1900        let normalized_constant = id.normalize_array(&constant);
1901        for &value in normalized_constant.iter() {
1902            assert!((value - 0.5).abs() < 1e-10);
1903        }
1904    }
1905}