Skip to main content

scirs2_metrics/active_learning/
mod.rs

1//! Active Learning Metrics
2//!
3//! This module provides metrics and selection strategies for active learning:
4//!
5//! - **Uncertainty sampling**: margin, entropy, least confidence
6//! - **Query-by-committee**: vote entropy, KL disagreement
7//! - **Expected model change**: gradient magnitude proxy
8//! - **Core-set selection**: greedy farthest-first traversal for diversity
9//! - **Batch-mode selection**: diversity-weighted batch selection
10//! - **Candidate ranking**: top-n selection by score
11
12use crate::error::{MetricsError, Result};
13
14// ─────────────────────────────────────────────────────────────────────────────
15// Configuration
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// Configuration for active learning experiments.
19#[non_exhaustive]
20#[derive(Debug, Clone)]
21pub struct ActiveLearningConfig {
22    /// Number of committee members for query-by-committee methods.
23    pub n_committee: usize,
24    /// Number of candidate samples to consider.
25    pub n_candidates: usize,
26}
27
28impl Default for ActiveLearningConfig {
29    fn default() -> Self {
30        Self {
31            n_committee: 5,
32            n_candidates: 100,
33        }
34    }
35}
36
37/// Type of uncertainty scoring used for active learning.
38#[non_exhaustive]
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum UncertaintyScore {
41    /// Margin sampling: 1 - (p_max - p_second_max)
42    MarginSampling,
43    /// Entropy sampling: H(p) = -sum(p_i * log(p_i))
44    EntropySampling,
45    /// Least confidence: 1 - p_max
46    LeastConfidence,
47    /// Query by committee disagreement
48    QueryByCommittee,
49    /// Expected model change (gradient magnitude proxy)
50    ExpectedModelChange,
51    /// Core-set diversity selection (farthest-first traversal)
52    CoreSet,
53}
54
55// ─────────────────────────────────────────────────────────────────────────────
56// Uncertainty Sampling — batch (vector-of-vectors) API
57// ─────────────────────────────────────────────────────────────────────────────
58
59/// Margin sampling over multiple candidates.
60///
61/// For each candidate, computes `1 - (p_max - p_second_max)`.
62/// A smaller margin means more uncertainty; returned score is higher when
63/// the model is more uncertain.
64///
65/// Each inner `Vec<f64>` must have at least 2 class probabilities.
66pub fn margin_sampling(probs: &[Vec<f64>]) -> Result<Vec<f64>> {
67    if probs.is_empty() {
68        return Err(MetricsError::InvalidInput(
69            "probs must not be empty".to_string(),
70        ));
71    }
72    probs
73        .iter()
74        .enumerate()
75        .map(|(i, p)| {
76            if p.len() < 2 {
77                return Err(MetricsError::InvalidInput(format!(
78                    "sample {i}: margin sampling requires at least 2 class probabilities"
79                )));
80            }
81            let mut sorted = p.clone();
82            sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
83            let margin = sorted[0] - sorted[1];
84            Ok(1.0 - margin)
85        })
86        .collect()
87}
88
89/// Entropy sampling over multiple candidates.
90///
91/// For each candidate, computes Shannon entropy `H(p) = -sum(p_i * ln(p_i))`.
92/// Higher entropy indicates more uncertainty (uniform distribution maximises it).
93pub fn entropy_sampling(probs: &[Vec<f64>]) -> Result<Vec<f64>> {
94    if probs.is_empty() {
95        return Err(MetricsError::InvalidInput(
96            "probs must not be empty".to_string(),
97        ));
98    }
99    probs
100        .iter()
101        .enumerate()
102        .map(|(i, p)| {
103            if p.is_empty() {
104                return Err(MetricsError::InvalidInput(format!(
105                    "sample {i}: probabilities must not be empty"
106                )));
107            }
108            let h: f64 = p
109                .iter()
110                .filter(|&&pi| pi > 0.0)
111                .map(|&pi| -pi * pi.ln())
112                .sum();
113            Ok(h)
114        })
115        .collect()
116}
117
118/// Least confidence over multiple candidates.
119///
120/// For each candidate, computes `1 - max(p_i)`.
121/// Lower maximum probability indicates more uncertainty.
122pub fn least_confidence(probs: &[Vec<f64>]) -> Result<Vec<f64>> {
123    if probs.is_empty() {
124        return Err(MetricsError::InvalidInput(
125            "probs must not be empty".to_string(),
126        ));
127    }
128    probs
129        .iter()
130        .enumerate()
131        .map(|(i, p)| {
132            if p.is_empty() {
133                return Err(MetricsError::InvalidInput(format!(
134                    "sample {i}: probabilities must not be empty"
135                )));
136            }
137            let p_max = p.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
138            Ok(1.0 - p_max)
139        })
140        .collect()
141}
142
143// ─────────────────────────────────────────────────────────────────────────────
144// Single-sample convenience functions (backwards-compatible)
145// ─────────────────────────────────────────────────────────────────────────────
146
147/// Margin sampling score for a single sample: `1 - (p_max - p_second_max)`.
148pub fn margin_sampling_score(probabilities: &[f64]) -> Result<f64> {
149    if probabilities.len() < 2 {
150        return Err(MetricsError::InvalidInput(
151            "margin sampling requires at least 2 class probabilities".to_string(),
152        ));
153    }
154    let mut sorted = probabilities.to_vec();
155    sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
156    let margin = sorted[0] - sorted[1];
157    Ok(1.0 - margin)
158}
159
160/// Entropy-based uncertainty for a single sample: `H(p) = -sum(p_i * ln(p_i))`.
161pub fn entropy_uncertainty(probabilities: &[f64]) -> Result<f64> {
162    if probabilities.is_empty() {
163        return Err(MetricsError::InvalidInput(
164            "probabilities must not be empty".to_string(),
165        ));
166    }
167    let h = probabilities
168        .iter()
169        .filter(|&&p| p > 0.0)
170        .map(|&p| -p * p.ln())
171        .sum::<f64>();
172    Ok(h)
173}
174
175/// Least confidence for a single sample: `1 - max(p_i)`.
176pub fn least_confidence_score(probabilities: &[f64]) -> Result<f64> {
177    if probabilities.is_empty() {
178        return Err(MetricsError::InvalidInput(
179            "probabilities must not be empty".to_string(),
180        ));
181    }
182    let p_max = probabilities
183        .iter()
184        .cloned()
185        .fold(f64::NEG_INFINITY, f64::max);
186    Ok(1.0 - p_max)
187}
188
189// ─────────────────────────────────────────────────────────────────────────────
190// Query-by-Committee
191// ─────────────────────────────────────────────────────────────────────────────
192
193/// Validate that a committee is non-empty and all members have the same
194/// number of classes.
195fn check_committee(committee_probs: &[Vec<f64>]) -> Result<usize> {
196    if committee_probs.is_empty() {
197        return Err(MetricsError::InvalidInput(
198            "committee must have at least one member".to_string(),
199        ));
200    }
201    let n_classes = committee_probs[0].len();
202    if n_classes == 0 {
203        return Err(MetricsError::InvalidInput(
204            "each committee member must supply at least one class probability".to_string(),
205        ));
206    }
207    for (i, member) in committee_probs.iter().enumerate() {
208        if member.len() != n_classes {
209            return Err(MetricsError::DimensionMismatch(format!(
210                "committee member {i} has {} classes, expected {n_classes}",
211                member.len()
212            )));
213        }
214    }
215    Ok(n_classes)
216}
217
218/// Query-by-committee disagreement via vote entropy (single sample).
219///
220/// Each committee member "votes" for the class with the highest probability.
221/// Returns the entropy of the resulting vote distribution.
222pub fn vote_entropy(committee_probs: &[Vec<f64>]) -> Result<f64> {
223    let n_classes = check_committee(committee_probs)?;
224    let n_members = committee_probs.len() as f64;
225
226    let mut votes = vec![0usize; n_classes];
227    for member in committee_probs {
228        let winner = member
229            .iter()
230            .enumerate()
231            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
232            .map(|(i, _)| i)
233            .unwrap_or(0);
234        votes[winner] += 1;
235    }
236
237    let h = votes
238        .iter()
239        .filter(|&&v| v > 0)
240        .map(|&v| {
241            let frac = v as f64 / n_members;
242            -frac * frac.ln()
243        })
244        .sum::<f64>();
245    Ok(h)
246}
247
248/// Query-by-committee disagreement via average KL divergence from consensus (single sample).
249pub fn qbc_kl_disagreement(committee_probs: &[Vec<f64>]) -> Result<f64> {
250    let n_classes = check_committee(committee_probs)?;
251    let n_members = committee_probs.len() as f64;
252
253    let mut consensus = vec![0.0_f64; n_classes];
254    for member in committee_probs {
255        for (c, &p) in consensus.iter_mut().zip(member) {
256            *c += p;
257        }
258    }
259    for c in &mut consensus {
260        *c /= n_members;
261    }
262
263    let mut total_kl = 0.0_f64;
264    for member in committee_probs {
265        let kl: f64 = member
266            .iter()
267            .zip(&consensus)
268            .map(|(&pi, &mi)| {
269                if pi <= 0.0 {
270                    0.0
271                } else if mi <= 0.0 {
272                    f64::INFINITY
273                } else {
274                    pi * (pi / mi).ln()
275                }
276            })
277            .sum();
278        if kl.is_infinite() {
279            return Err(MetricsError::CalculationError(
280                "KL divergence is infinite in committee disagreement".to_string(),
281            ));
282        }
283        total_kl += kl;
284    }
285    Ok(total_kl / n_members)
286}
287
288/// Query-by-committee for multiple candidates.
289///
290/// `committee_probs[m][s]` is committee member `m`'s probability vector for sample `s`.
291/// Returns a disagreement score per sample (vote entropy across committee members).
292pub fn query_by_committee(committee_probs: &[Vec<Vec<f64>>]) -> Result<Vec<f64>> {
293    if committee_probs.is_empty() {
294        return Err(MetricsError::InvalidInput(
295            "committee_probs must have at least one member".to_string(),
296        ));
297    }
298    let n_members = committee_probs.len();
299    let n_samples = committee_probs[0].len();
300
301    // Validate dimensions
302    for (m, member) in committee_probs.iter().enumerate() {
303        if member.len() != n_samples {
304            return Err(MetricsError::DimensionMismatch(format!(
305                "committee member {m} has {} samples, expected {n_samples}",
306                member.len()
307            )));
308        }
309    }
310
311    let mut scores = Vec::with_capacity(n_samples);
312    for s in 0..n_samples {
313        // Gather this sample's predictions from all committee members
314        let sample_probs: Vec<Vec<f64>> = (0..n_members)
315            .map(|m| committee_probs[m][s].clone())
316            .collect();
317        let ve = vote_entropy(&sample_probs)?;
318        scores.push(ve);
319    }
320    Ok(scores)
321}
322
323// ─────────────────────────────────────────────────────────────────────────────
324// Expected Model Change
325// ─────────────────────────────────────────────────────────────────────────────
326
327/// Expected model change: uses gradient norm as a proxy for informativeness.
328///
329/// `gradients[i]` is the gradient vector (or gradient magnitude proxy) for candidate `i`.
330/// Returns `||gradient_i||_2` for each candidate.
331pub fn expected_model_change(gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
332    if gradients.is_empty() {
333        return Err(MetricsError::InvalidInput(
334            "gradients must not be empty".to_string(),
335        ));
336    }
337    gradients
338        .iter()
339        .enumerate()
340        .map(|(i, g)| {
341            if g.is_empty() {
342                return Err(MetricsError::InvalidInput(format!(
343                    "sample {i} has empty gradient vector"
344                )));
345            }
346            let norm = g.iter().map(|&v| v * v).sum::<f64>().sqrt();
347            Ok(norm)
348        })
349        .collect()
350}
351
352/// Expected gradient magnitude proxy (probability-based): `||p - y_one_hot||_2`.
353///
354/// Approximated as the Euclidean distance from the predicted probability
355/// vector to the one-hot encoding of the predicted class (argmax).
356/// Returns one magnitude value per sample.
357pub fn expected_gradient_magnitude(probabilities: &[Vec<f64>]) -> Result<Vec<f64>> {
358    if probabilities.is_empty() {
359        return Err(MetricsError::InvalidInput(
360            "probabilities must not be empty".to_string(),
361        ));
362    }
363    probabilities
364        .iter()
365        .enumerate()
366        .map(|(i, p)| {
367            if p.is_empty() {
368                return Err(MetricsError::InvalidInput(format!(
369                    "sample {i} has empty probability vector"
370                )));
371            }
372            let argmax = p
373                .iter()
374                .enumerate()
375                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
376                .map(|(j, _)| j)
377                .unwrap_or(0);
378            let mag = p
379                .iter()
380                .enumerate()
381                .map(|(j, &pj)| {
382                    let one_hot = if j == argmax { 1.0 } else { 0.0 };
383                    (pj - one_hot).powi(2)
384                })
385                .sum::<f64>()
386                .sqrt();
387            Ok(mag)
388        })
389        .collect()
390}
391
392// ─────────────────────────────────────────────────────────────────────────────
393// Core-Set Selection
394// ─────────────────────────────────────────────────────────────────────────────
395
396/// Euclidean distance between two feature vectors.
397fn euclidean_dist(a: &[f64], b: &[f64]) -> f64 {
398    a.iter()
399        .zip(b)
400        .map(|(x, y)| (x - y).powi(2))
401        .sum::<f64>()
402        .sqrt()
403}
404
405/// Core-set selection via greedy farthest-first traversal.
406///
407/// Given `embeddings` for all candidates and `selected` (indices of already-labeled
408/// points), selects `n_select` new points that maximise minimum distance to the
409/// already-selected set plus newly-chosen points.
410///
411/// If `selected` is empty, the first point (index 0) is used as the seed.
412pub fn core_set_selection(
413    embeddings: &[Vec<f64>],
414    selected: &[usize],
415    n_select: usize,
416) -> Result<Vec<usize>> {
417    if embeddings.is_empty() {
418        return Err(MetricsError::InvalidInput(
419            "embeddings must not be empty".to_string(),
420        ));
421    }
422    if n_select == 0 {
423        return Ok(vec![]);
424    }
425    let n = embeddings.len();
426    if n_select > n {
427        return Err(MetricsError::InvalidInput(format!(
428            "n_select ({n_select}) exceeds number of points ({n})"
429        )));
430    }
431
432    // Build initial set of centres
433    let mut centres: Vec<usize> = selected.to_vec();
434    // Mark already-selected as used
435    let mut used = vec![false; n];
436    for &idx in &centres {
437        if idx < n {
438            used[idx] = true;
439        }
440    }
441
442    // If no centres, seed with index 0
443    if centres.is_empty() {
444        centres.push(0);
445        used[0] = true;
446    }
447
448    // Compute initial min-dist from each point to nearest centre
449    let mut min_dists: Vec<f64> = (0..n)
450        .map(|i| {
451            if used[i] {
452                return 0.0;
453            }
454            centres
455                .iter()
456                .map(|&c| {
457                    if c < n {
458                        euclidean_dist(&embeddings[i], &embeddings[c])
459                    } else {
460                        f64::INFINITY
461                    }
462                })
463                .fold(f64::INFINITY, f64::min)
464        })
465        .collect();
466
467    let mut new_selected = Vec::with_capacity(n_select);
468
469    while new_selected.len() < n_select {
470        // Pick the point with the largest min-dist (farthest from all centres)
471        let next = min_dists
472            .iter()
473            .enumerate()
474            .filter(|(i, _)| !used[*i])
475            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
476            .map(|(i, _)| i);
477
478        match next {
479            Some(idx) => {
480                new_selected.push(idx);
481                used[idx] = true;
482                // Update min_dists
483                let new_centre = &embeddings[idx];
484                for (i, md) in min_dists.iter_mut().enumerate() {
485                    if !used[i] {
486                        let d = euclidean_dist(&embeddings[i], new_centre);
487                        if d < *md {
488                            *md = d;
489                        }
490                    }
491                }
492            }
493            None => break, // No more candidates
494        }
495    }
496
497    Ok(new_selected)
498}
499
500/// Greedy k-center core-set selection (legacy API).
501///
502/// Iteratively selects the point that is farthest from the current set of
503/// selected centres, maximising minimum coverage.
504///
505/// Returns `k` indices into `features`.
506pub fn greedy_k_center(
507    features: &[Vec<f64>],
508    k: usize,
509    seed_idx: Option<usize>,
510) -> Result<Vec<usize>> {
511    if features.is_empty() {
512        return Err(MetricsError::InvalidInput(
513            "features must not be empty".to_string(),
514        ));
515    }
516    if k == 0 {
517        return Err(MetricsError::InvalidInput(
518            "k must be at least 1".to_string(),
519        ));
520    }
521    if k > features.len() {
522        return Err(MetricsError::InvalidInput(format!(
523            "k ({k}) exceeds number of points ({})",
524            features.len()
525        )));
526    }
527
528    let n = features.len();
529    let first = seed_idx.unwrap_or(0).min(n - 1);
530
531    let mut selected = vec![first];
532    let mut min_dists: Vec<f64> = features
533        .iter()
534        .map(|f| euclidean_dist(f, &features[first]))
535        .collect();
536
537    while selected.len() < k {
538        let next = min_dists
539            .iter()
540            .enumerate()
541            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
542            .map(|(i, _)| i)
543            .unwrap_or(0);
544        selected.push(next);
545        let new_centre = &features[next];
546        for (i, md) in min_dists.iter_mut().enumerate() {
547            let d = euclidean_dist(&features[i], new_centre);
548            if d < *md {
549                *md = d;
550            }
551        }
552    }
553
554    Ok(selected)
555}
556
557// ─────────────────────────────────────────────────────────────────────────────
558// Ranking & Selection Utilities
559// ─────────────────────────────────────────────────────────────────────────────
560
561/// Rank candidates by score (highest first) and return the top `n_select` indices.
562///
563/// If `n_select >= scores.len()`, returns all indices sorted by descending score.
564pub fn rank_candidates(scores: &[f64], n_select: usize) -> Vec<usize> {
565    let mut indices: Vec<usize> = (0..scores.len()).collect();
566    indices.sort_by(|&a, &b| {
567        scores[b]
568            .partial_cmp(&scores[a])
569            .unwrap_or(std::cmp::Ordering::Equal)
570    });
571    indices.truncate(n_select);
572    indices
573}
574
575/// Rank candidates by active learning score (highest score -> highest priority).
576///
577/// Returns all indices sorted from most to least uncertain.
578pub fn rank_by_uncertainty(scores: &[f64]) -> Vec<usize> {
579    let mut indices: Vec<usize> = (0..scores.len()).collect();
580    indices.sort_by(|&a, &b| {
581        scores[b]
582            .partial_cmp(&scores[a])
583            .unwrap_or(std::cmp::Ordering::Equal)
584    });
585    indices
586}
587
588// ─────────────────────────────────────────────────────────────────────────────
589// Batch-Mode Selection
590// ─────────────────────────────────────────────────────────────────────────────
591
592/// Strategy for batch-mode active learning selection.
593#[non_exhaustive]
594#[derive(Debug, Clone, Copy, PartialEq)]
595pub enum BatchSelectionMethod {
596    /// Select by entropy uncertainty score.
597    Entropy,
598    /// Select by margin sampling score.
599    MarginSampling,
600    /// Select by greedy k-center core-set.
601    CoreSet,
602}
603
604/// Configuration for batch-mode active learning selection.
605#[non_exhaustive]
606#[derive(Debug, Clone)]
607pub struct BatchSelectionConfig {
608    /// Number of samples to select.
609    pub n_select: usize,
610    /// Weight on diversity vs pure uncertainty in `[0, 1]`.
611    /// `0.0` = pure uncertainty, `1.0` = pure diversity (CoreSet).
612    pub diversity_weight: f64,
613    /// Selection method.
614    pub method: BatchSelectionMethod,
615}
616
617impl Default for BatchSelectionConfig {
618    fn default() -> Self {
619        Self {
620            n_select: 10,
621            diversity_weight: 0.5,
622            method: BatchSelectionMethod::Entropy,
623        }
624    }
625}
626
627/// Batch selection that balances uncertainty with diversity.
628///
629/// When `diversity_weight == 0.0`, this is equivalent to pure uncertainty ranking.
630/// When `diversity_weight == 1.0`, this is pure core-set (diversity) selection.
631/// Values in between produce a hybrid: candidates are scored by
632/// `(1 - diversity_weight) * normalized_uncertainty + diversity_weight * normalized_distance`.
633///
634/// `scores` is an uncertainty score per candidate (higher = more uncertain).
635/// `embeddings` is a feature vector per candidate (for computing distances).
636/// `n_select` is how many candidates to choose.
637pub fn batch_selection(
638    scores: &[f64],
639    embeddings: &[Vec<f64>],
640    n_select: usize,
641    diversity_weight: f64,
642) -> Result<Vec<usize>> {
643    if scores.is_empty() || embeddings.is_empty() {
644        return Err(MetricsError::InvalidInput(
645            "scores and embeddings must not be empty".to_string(),
646        ));
647    }
648    if scores.len() != embeddings.len() {
649        return Err(MetricsError::DimensionMismatch(format!(
650            "scores len {} != embeddings len {}",
651            scores.len(),
652            embeddings.len()
653        )));
654    }
655
656    let n = scores.len();
657    let k = n_select.min(n);
658
659    if k == 0 {
660        return Ok(vec![]);
661    }
662
663    // Pure uncertainty
664    let dw = diversity_weight.clamp(0.0, 1.0);
665    if dw < 1e-12 {
666        return Ok(rank_candidates(scores, k));
667    }
668
669    // Pure diversity
670    if (dw - 1.0).abs() < 1e-12 {
671        return core_set_selection(embeddings, &[], k);
672    }
673
674    // Hybrid: greedy selection balancing uncertainty + diversity
675    // Normalize uncertainty scores to [0, 1]
676    let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
677    let min_score = scores.iter().cloned().fold(f64::INFINITY, f64::min);
678    let score_range = max_score - min_score;
679    let norm_scores: Vec<f64> = if score_range > 1e-15 {
680        scores
681            .iter()
682            .map(|&s| (s - min_score) / score_range)
683            .collect()
684    } else {
685        vec![0.5; n]
686    };
687
688    let mut selected = Vec::with_capacity(k);
689    let mut used = vec![false; n];
690
691    // Seed with highest-uncertainty point
692    let seed = norm_scores
693        .iter()
694        .enumerate()
695        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
696        .map(|(i, _)| i)
697        .unwrap_or(0);
698    selected.push(seed);
699    used[seed] = true;
700
701    // Min distance from each point to selected set
702    let mut min_dists: Vec<f64> = (0..n)
703        .map(|i| {
704            if i == seed {
705                0.0
706            } else {
707                euclidean_dist(&embeddings[i], &embeddings[seed])
708            }
709        })
710        .collect();
711
712    while selected.len() < k {
713        // Normalize min_dists to [0, 1]
714        let max_dist = min_dists
715            .iter()
716            .enumerate()
717            .filter(|(i, _)| !used[*i])
718            .map(|(_, &d)| d)
719            .fold(f64::NEG_INFINITY, f64::max);
720        let min_dist_val = min_dists
721            .iter()
722            .enumerate()
723            .filter(|(i, _)| !used[*i])
724            .map(|(_, &d)| d)
725            .fold(f64::INFINITY, f64::min);
726        let dist_range = max_dist - min_dist_val;
727
728        // Combined score: (1-dw)*uncertainty + dw*diversity
729        let mut best_idx = 0;
730        let mut best_combined = f64::NEG_INFINITY;
731
732        for i in 0..n {
733            if used[i] {
734                continue;
735            }
736            let norm_dist = if dist_range > 1e-15 {
737                (min_dists[i] - min_dist_val) / dist_range
738            } else {
739                0.5
740            };
741            let combined = (1.0 - dw) * norm_scores[i] + dw * norm_dist;
742            if combined > best_combined {
743                best_combined = combined;
744                best_idx = i;
745            }
746        }
747
748        selected.push(best_idx);
749        used[best_idx] = true;
750
751        // Update min_dists
752        let new_centre = &embeddings[best_idx];
753        for (i, md) in min_dists.iter_mut().enumerate() {
754            if !used[i] {
755                let d = euclidean_dist(&embeddings[i], new_centre);
756                if d < *md {
757                    *md = d;
758                }
759            }
760        }
761    }
762
763    Ok(selected)
764}
765
766/// Select a batch of samples for labeling (legacy API).
767///
768/// Combines an uncertainty score with optional diversity (greedy spacing).
769pub fn batch_select(
770    features: &[Vec<f64>],
771    probabilities: &[Vec<f64>],
772    config: &BatchSelectionConfig,
773) -> Result<Vec<usize>> {
774    if features.is_empty() || probabilities.is_empty() {
775        return Err(MetricsError::InvalidInput(
776            "features and probabilities must not be empty".to_string(),
777        ));
778    }
779    if features.len() != probabilities.len() {
780        return Err(MetricsError::DimensionMismatch(format!(
781            "features len {} != probabilities len {}",
782            features.len(),
783            probabilities.len()
784        )));
785    }
786    let n = features.len();
787    let k = config.n_select.min(n);
788
789    match config.method {
790        BatchSelectionMethod::CoreSet => greedy_k_center(features, k, None),
791        BatchSelectionMethod::Entropy => {
792            let scores: Vec<f64> = probabilities
793                .iter()
794                .map(|p| entropy_uncertainty(p))
795                .collect::<Result<Vec<_>>>()?;
796            let ranked = rank_by_uncertainty(&scores);
797            Ok(ranked.into_iter().take(k).collect())
798        }
799        BatchSelectionMethod::MarginSampling => {
800            let scores: Vec<f64> = probabilities
801                .iter()
802                .map(|p| margin_sampling_score(p))
803                .collect::<Result<Vec<_>>>()?;
804            let ranked = rank_by_uncertainty(&scores);
805            Ok(ranked.into_iter().take(k).collect())
806        }
807    }
808}
809
810// ─────────────────────────────────────────────────────────────────────────────
811// Tests
812// ─────────────────────────────────────────────────────────────────────────────
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    // --- Margin sampling ---
819
820    #[test]
821    fn test_margin_sampling_uniform_score_zero() {
822        // Uniform probs: p_max == p_second => margin=0 => score=1.0 (most uncertain)
823        let probs = vec![vec![0.25, 0.25, 0.25, 0.25], vec![0.5, 0.5]];
824        let scores = margin_sampling(&probs).expect("should succeed");
825        assert!(
826            (scores[0] - 1.0).abs() < 1e-12,
827            "uniform 4-class: score should be 1.0, got {}",
828            scores[0]
829        );
830        assert!(
831            (scores[1] - 1.0).abs() < 1e-12,
832            "uniform 2-class: score should be 1.0, got {}",
833            scores[1]
834        );
835    }
836
837    #[test]
838    fn test_margin_sampling_peaked_close_to_one() {
839        // Peaked probs: p_max far from p_second => large margin => score close to 0
840        let probs = vec![vec![0.99, 0.01]];
841        let scores = margin_sampling(&probs).expect("should succeed");
842        assert!(
843            scores[0] < 0.05,
844            "peaked should have low uncertainty, got {}",
845            scores[0]
846        );
847    }
848
849    // --- Entropy sampling ---
850
851    #[test]
852    fn test_entropy_uniform_has_max() {
853        let n = 4;
854        let p = 1.0 / n as f64;
855        let probs = vec![vec![p; n]];
856        let scores = entropy_sampling(&probs).expect("should succeed");
857        let expected = (n as f64).ln();
858        assert!(
859            (scores[0] - expected).abs() < 1e-10,
860            "expected {expected}, got {}",
861            scores[0]
862        );
863    }
864
865    #[test]
866    fn test_entropy_point_mass_zero() {
867        let probs = vec![vec![1.0, 0.0, 0.0]];
868        let scores = entropy_sampling(&probs).expect("should succeed");
869        assert!(
870            scores[0].abs() < 1e-12,
871            "point mass entropy should be 0, got {}",
872            scores[0]
873        );
874    }
875
876    // --- Least confidence ---
877
878    #[test]
879    fn test_least_confidence_confident_low_score() {
880        let probs = vec![vec![0.95, 0.03, 0.02]];
881        let scores = least_confidence(&probs).expect("should succeed");
882        assert!(
883            scores[0] < 0.1,
884            "confident prediction should have low LC, got {}",
885            scores[0]
886        );
887    }
888
889    #[test]
890    fn test_least_confidence_uncertain_high_score() {
891        let probs = vec![vec![0.34, 0.33, 0.33]];
892        let scores = least_confidence(&probs).expect("should succeed");
893        assert!(
894            scores[0] > 0.5,
895            "uncertain prediction should have high LC, got {}",
896            scores[0]
897        );
898    }
899
900    // --- Query by committee ---
901
902    #[test]
903    fn test_qbc_unanimous_low_disagreement() {
904        // All committee members agree on class 0
905        let committee = vec![
906            vec![vec![0.9, 0.1], vec![0.8, 0.2]],     // member 0: 2 samples
907            vec![vec![0.85, 0.15], vec![0.75, 0.25]], // member 1
908            vec![vec![0.95, 0.05], vec![0.7, 0.3]],   // member 2
909        ];
910        let scores = query_by_committee(&committee).expect("should succeed");
911        // All members predict class 0 for sample 0 => vote entropy = 0
912        assert!(
913            scores[0].abs() < 1e-12,
914            "unanimous committee: disagreement should be 0, got {}",
915            scores[0]
916        );
917    }
918
919    #[test]
920    fn test_qbc_disagreeing_positive() {
921        // Committee members disagree
922        let committee = vec![
923            vec![vec![0.9, 0.1]], // predicts class 0
924            vec![vec![0.1, 0.9]], // predicts class 1
925        ];
926        let scores = query_by_committee(&committee).expect("should succeed");
927        assert!(
928            scores[0] > 0.0,
929            "disagreeing committee should have positive score, got {}",
930            scores[0]
931        );
932    }
933
934    // --- Expected model change ---
935
936    #[test]
937    fn test_expected_model_change_norm() {
938        let gradients = vec![
939            vec![3.0, 4.0],           // norm = 5.0
940            vec![0.0, 0.0],           // norm = 0.0
941            vec![1.0, 1.0, 1.0, 1.0], // norm = 2.0
942        ];
943        let scores = expected_model_change(&gradients).expect("should succeed");
944        assert!((scores[0] - 5.0).abs() < 1e-12);
945        assert!(scores[1].abs() < 1e-12);
946        assert!((scores[2] - 2.0).abs() < 1e-12);
947    }
948
949    // --- Core-set selection ---
950
951    #[test]
952    fn test_core_set_points_well_spread() {
953        // Points at 0, 10, 20, ..., 90 on a line
954        let embeddings: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64 * 10.0]).collect();
955        let selected = core_set_selection(&embeddings, &[], 3).expect("should succeed");
956        assert_eq!(selected.len(), 3);
957        // With seed at 0, second should be 9 (farthest), third should be 4 or 5
958        // Verify they're spread out: no two selected points within 15 of each other
959        for i in 0..selected.len() {
960            for j in (i + 1)..selected.len() {
961                let d = euclidean_dist(&embeddings[selected[i]], &embeddings[selected[j]]);
962                assert!(d >= 10.0, "selected points should be spread: dist={d}");
963            }
964        }
965    }
966
967    #[test]
968    fn test_core_set_with_existing_selected() {
969        let embeddings: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64]).collect();
970        let already_selected = vec![0, 9]; // endpoints
971        let new = core_set_selection(&embeddings, &already_selected, 1).expect("should succeed");
972        assert_eq!(new.len(), 1);
973        // The farthest from {0, 9} is either 4 or 5 (midpoint)
974        assert!(
975            new[0] >= 3 && new[0] <= 6,
976            "midpoint expected, got {}",
977            new[0]
978        );
979    }
980
981    // --- Rank candidates ---
982
983    #[test]
984    fn test_rank_candidates_top_n() {
985        let scores = vec![0.1, 0.9, 0.5, 0.3, 0.7];
986        let top3 = rank_candidates(&scores, 3);
987        assert_eq!(top3.len(), 3);
988        assert_eq!(top3[0], 1); // highest score
989        assert_eq!(top3[1], 4); // second highest
990        assert_eq!(top3[2], 2); // third highest
991    }
992
993    // --- Batch selection ---
994
995    #[test]
996    fn test_batch_selection_diversity_zero_matches_uncertainty() {
997        let scores = vec![0.1, 0.9, 0.5, 0.3, 0.7];
998        let embeddings: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64]).collect();
999
1000        let pure_unc = rank_candidates(&scores, 3);
1001        let batch = batch_selection(&scores, &embeddings, 3, 0.0).expect("should succeed");
1002        assert_eq!(
1003            batch, pure_unc,
1004            "diversity_weight=0 should match pure uncertainty ranking"
1005        );
1006    }
1007
1008    #[test]
1009    fn test_batch_selection_returns_correct_count() {
1010        let scores = vec![0.5; 20];
1011        let embeddings: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64, 0.0]).collect();
1012        let selected = batch_selection(&scores, &embeddings, 7, 0.5).expect("should succeed");
1013        assert_eq!(selected.len(), 7);
1014    }
1015
1016    #[test]
1017    fn test_batch_selection_respects_n_select_legacy() {
1018        let features: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64]).collect();
1019        let probs: Vec<Vec<f64>> = (0..20)
1020            .map(|i| {
1021                let p = i as f64 / 20.0;
1022                vec![p, 1.0 - p]
1023            })
1024            .collect();
1025        let cfg = BatchSelectionConfig {
1026            n_select: 7,
1027            ..Default::default()
1028        };
1029        let selected = batch_select(&features, &probs, &cfg).expect("should succeed");
1030        assert_eq!(selected.len(), 7, "should select exactly 7 samples");
1031    }
1032
1033    // --- Single-sample backwards compat ---
1034
1035    #[test]
1036    fn test_margin_sampling_score_compat() {
1037        let p = vec![0.25, 0.25, 0.25, 0.25];
1038        let s = margin_sampling_score(&p).expect("should succeed");
1039        assert!((s - 1.0).abs() < 1e-12);
1040    }
1041
1042    #[test]
1043    fn test_vote_entropy_unanimous_zero() {
1044        let committee = vec![vec![0.9, 0.1], vec![0.8, 0.2], vec![0.95, 0.05]];
1045        let ve = vote_entropy(&committee).expect("should succeed");
1046        assert!(
1047            ve.abs() < 1e-12,
1048            "unanimous vote should give entropy=0, got {ve}"
1049        );
1050    }
1051
1052    #[test]
1053    fn test_expected_gradient_magnitude_shape() {
1054        let probs = vec![vec![0.7, 0.2, 0.1], vec![0.3, 0.4, 0.3]];
1055        let mags = expected_gradient_magnitude(&probs).expect("should succeed");
1056        assert_eq!(mags.len(), 2);
1057        for m in &mags {
1058            assert!(*m >= 0.0, "magnitude must be non-negative, got {m}");
1059        }
1060    }
1061
1062    #[test]
1063    fn test_k_center_returns_k_points() {
1064        let features: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64, 0.0]).collect();
1065        let selected = greedy_k_center(&features, 5, None).expect("should succeed");
1066        assert_eq!(selected.len(), 5);
1067    }
1068
1069    #[test]
1070    fn test_default_config() {
1071        let cfg = ActiveLearningConfig::default();
1072        assert_eq!(cfg.n_committee, 5);
1073        assert_eq!(cfg.n_candidates, 100);
1074    }
1075}