1use crate::error::{MetricsError, Result};
13
14#[non_exhaustive]
20#[derive(Debug, Clone)]
21pub struct ActiveLearningConfig {
22 pub n_committee: usize,
24 pub n_candidates: usize,
26}
27
28impl Default for ActiveLearningConfig {
29 fn default() -> Self {
30 Self {
31 n_committee: 5,
32 n_candidates: 100,
33 }
34 }
35}
36
37#[non_exhaustive]
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum UncertaintyScore {
41 MarginSampling,
43 EntropySampling,
45 LeastConfidence,
47 QueryByCommittee,
49 ExpectedModelChange,
51 CoreSet,
53}
54
55pub fn margin_sampling(probs: &[Vec<f64>]) -> Result<Vec<f64>> {
67 if probs.is_empty() {
68 return Err(MetricsError::InvalidInput(
69 "probs must not be empty".to_string(),
70 ));
71 }
72 probs
73 .iter()
74 .enumerate()
75 .map(|(i, p)| {
76 if p.len() < 2 {
77 return Err(MetricsError::InvalidInput(format!(
78 "sample {i}: margin sampling requires at least 2 class probabilities"
79 )));
80 }
81 let mut sorted = p.clone();
82 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
83 let margin = sorted[0] - sorted[1];
84 Ok(1.0 - margin)
85 })
86 .collect()
87}
88
89pub fn entropy_sampling(probs: &[Vec<f64>]) -> Result<Vec<f64>> {
94 if probs.is_empty() {
95 return Err(MetricsError::InvalidInput(
96 "probs must not be empty".to_string(),
97 ));
98 }
99 probs
100 .iter()
101 .enumerate()
102 .map(|(i, p)| {
103 if p.is_empty() {
104 return Err(MetricsError::InvalidInput(format!(
105 "sample {i}: probabilities must not be empty"
106 )));
107 }
108 let h: f64 = p
109 .iter()
110 .filter(|&&pi| pi > 0.0)
111 .map(|&pi| -pi * pi.ln())
112 .sum();
113 Ok(h)
114 })
115 .collect()
116}
117
118pub fn least_confidence(probs: &[Vec<f64>]) -> Result<Vec<f64>> {
123 if probs.is_empty() {
124 return Err(MetricsError::InvalidInput(
125 "probs must not be empty".to_string(),
126 ));
127 }
128 probs
129 .iter()
130 .enumerate()
131 .map(|(i, p)| {
132 if p.is_empty() {
133 return Err(MetricsError::InvalidInput(format!(
134 "sample {i}: probabilities must not be empty"
135 )));
136 }
137 let p_max = p.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
138 Ok(1.0 - p_max)
139 })
140 .collect()
141}
142
143pub fn margin_sampling_score(probabilities: &[f64]) -> Result<f64> {
149 if probabilities.len() < 2 {
150 return Err(MetricsError::InvalidInput(
151 "margin sampling requires at least 2 class probabilities".to_string(),
152 ));
153 }
154 let mut sorted = probabilities.to_vec();
155 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
156 let margin = sorted[0] - sorted[1];
157 Ok(1.0 - margin)
158}
159
160pub fn entropy_uncertainty(probabilities: &[f64]) -> Result<f64> {
162 if probabilities.is_empty() {
163 return Err(MetricsError::InvalidInput(
164 "probabilities must not be empty".to_string(),
165 ));
166 }
167 let h = probabilities
168 .iter()
169 .filter(|&&p| p > 0.0)
170 .map(|&p| -p * p.ln())
171 .sum::<f64>();
172 Ok(h)
173}
174
175pub fn least_confidence_score(probabilities: &[f64]) -> Result<f64> {
177 if probabilities.is_empty() {
178 return Err(MetricsError::InvalidInput(
179 "probabilities must not be empty".to_string(),
180 ));
181 }
182 let p_max = probabilities
183 .iter()
184 .cloned()
185 .fold(f64::NEG_INFINITY, f64::max);
186 Ok(1.0 - p_max)
187}
188
189fn check_committee(committee_probs: &[Vec<f64>]) -> Result<usize> {
196 if committee_probs.is_empty() {
197 return Err(MetricsError::InvalidInput(
198 "committee must have at least one member".to_string(),
199 ));
200 }
201 let n_classes = committee_probs[0].len();
202 if n_classes == 0 {
203 return Err(MetricsError::InvalidInput(
204 "each committee member must supply at least one class probability".to_string(),
205 ));
206 }
207 for (i, member) in committee_probs.iter().enumerate() {
208 if member.len() != n_classes {
209 return Err(MetricsError::DimensionMismatch(format!(
210 "committee member {i} has {} classes, expected {n_classes}",
211 member.len()
212 )));
213 }
214 }
215 Ok(n_classes)
216}
217
218pub fn vote_entropy(committee_probs: &[Vec<f64>]) -> Result<f64> {
223 let n_classes = check_committee(committee_probs)?;
224 let n_members = committee_probs.len() as f64;
225
226 let mut votes = vec![0usize; n_classes];
227 for member in committee_probs {
228 let winner = member
229 .iter()
230 .enumerate()
231 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
232 .map(|(i, _)| i)
233 .unwrap_or(0);
234 votes[winner] += 1;
235 }
236
237 let h = votes
238 .iter()
239 .filter(|&&v| v > 0)
240 .map(|&v| {
241 let frac = v as f64 / n_members;
242 -frac * frac.ln()
243 })
244 .sum::<f64>();
245 Ok(h)
246}
247
248pub fn qbc_kl_disagreement(committee_probs: &[Vec<f64>]) -> Result<f64> {
250 let n_classes = check_committee(committee_probs)?;
251 let n_members = committee_probs.len() as f64;
252
253 let mut consensus = vec![0.0_f64; n_classes];
254 for member in committee_probs {
255 for (c, &p) in consensus.iter_mut().zip(member) {
256 *c += p;
257 }
258 }
259 for c in &mut consensus {
260 *c /= n_members;
261 }
262
263 let mut total_kl = 0.0_f64;
264 for member in committee_probs {
265 let kl: f64 = member
266 .iter()
267 .zip(&consensus)
268 .map(|(&pi, &mi)| {
269 if pi <= 0.0 {
270 0.0
271 } else if mi <= 0.0 {
272 f64::INFINITY
273 } else {
274 pi * (pi / mi).ln()
275 }
276 })
277 .sum();
278 if kl.is_infinite() {
279 return Err(MetricsError::CalculationError(
280 "KL divergence is infinite in committee disagreement".to_string(),
281 ));
282 }
283 total_kl += kl;
284 }
285 Ok(total_kl / n_members)
286}
287
288pub fn query_by_committee(committee_probs: &[Vec<Vec<f64>>]) -> Result<Vec<f64>> {
293 if committee_probs.is_empty() {
294 return Err(MetricsError::InvalidInput(
295 "committee_probs must have at least one member".to_string(),
296 ));
297 }
298 let n_members = committee_probs.len();
299 let n_samples = committee_probs[0].len();
300
301 for (m, member) in committee_probs.iter().enumerate() {
303 if member.len() != n_samples {
304 return Err(MetricsError::DimensionMismatch(format!(
305 "committee member {m} has {} samples, expected {n_samples}",
306 member.len()
307 )));
308 }
309 }
310
311 let mut scores = Vec::with_capacity(n_samples);
312 for s in 0..n_samples {
313 let sample_probs: Vec<Vec<f64>> = (0..n_members)
315 .map(|m| committee_probs[m][s].clone())
316 .collect();
317 let ve = vote_entropy(&sample_probs)?;
318 scores.push(ve);
319 }
320 Ok(scores)
321}
322
323pub fn expected_model_change(gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
332 if gradients.is_empty() {
333 return Err(MetricsError::InvalidInput(
334 "gradients must not be empty".to_string(),
335 ));
336 }
337 gradients
338 .iter()
339 .enumerate()
340 .map(|(i, g)| {
341 if g.is_empty() {
342 return Err(MetricsError::InvalidInput(format!(
343 "sample {i} has empty gradient vector"
344 )));
345 }
346 let norm = g.iter().map(|&v| v * v).sum::<f64>().sqrt();
347 Ok(norm)
348 })
349 .collect()
350}
351
352pub fn expected_gradient_magnitude(probabilities: &[Vec<f64>]) -> Result<Vec<f64>> {
358 if probabilities.is_empty() {
359 return Err(MetricsError::InvalidInput(
360 "probabilities must not be empty".to_string(),
361 ));
362 }
363 probabilities
364 .iter()
365 .enumerate()
366 .map(|(i, p)| {
367 if p.is_empty() {
368 return Err(MetricsError::InvalidInput(format!(
369 "sample {i} has empty probability vector"
370 )));
371 }
372 let argmax = p
373 .iter()
374 .enumerate()
375 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
376 .map(|(j, _)| j)
377 .unwrap_or(0);
378 let mag = p
379 .iter()
380 .enumerate()
381 .map(|(j, &pj)| {
382 let one_hot = if j == argmax { 1.0 } else { 0.0 };
383 (pj - one_hot).powi(2)
384 })
385 .sum::<f64>()
386 .sqrt();
387 Ok(mag)
388 })
389 .collect()
390}
391
392fn euclidean_dist(a: &[f64], b: &[f64]) -> f64 {
398 a.iter()
399 .zip(b)
400 .map(|(x, y)| (x - y).powi(2))
401 .sum::<f64>()
402 .sqrt()
403}
404
405pub fn core_set_selection(
413 embeddings: &[Vec<f64>],
414 selected: &[usize],
415 n_select: usize,
416) -> Result<Vec<usize>> {
417 if embeddings.is_empty() {
418 return Err(MetricsError::InvalidInput(
419 "embeddings must not be empty".to_string(),
420 ));
421 }
422 if n_select == 0 {
423 return Ok(vec![]);
424 }
425 let n = embeddings.len();
426 if n_select > n {
427 return Err(MetricsError::InvalidInput(format!(
428 "n_select ({n_select}) exceeds number of points ({n})"
429 )));
430 }
431
432 let mut centres: Vec<usize> = selected.to_vec();
434 let mut used = vec![false; n];
436 for &idx in ¢res {
437 if idx < n {
438 used[idx] = true;
439 }
440 }
441
442 if centres.is_empty() {
444 centres.push(0);
445 used[0] = true;
446 }
447
448 let mut min_dists: Vec<f64> = (0..n)
450 .map(|i| {
451 if used[i] {
452 return 0.0;
453 }
454 centres
455 .iter()
456 .map(|&c| {
457 if c < n {
458 euclidean_dist(&embeddings[i], &embeddings[c])
459 } else {
460 f64::INFINITY
461 }
462 })
463 .fold(f64::INFINITY, f64::min)
464 })
465 .collect();
466
467 let mut new_selected = Vec::with_capacity(n_select);
468
469 while new_selected.len() < n_select {
470 let next = min_dists
472 .iter()
473 .enumerate()
474 .filter(|(i, _)| !used[*i])
475 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
476 .map(|(i, _)| i);
477
478 match next {
479 Some(idx) => {
480 new_selected.push(idx);
481 used[idx] = true;
482 let new_centre = &embeddings[idx];
484 for (i, md) in min_dists.iter_mut().enumerate() {
485 if !used[i] {
486 let d = euclidean_dist(&embeddings[i], new_centre);
487 if d < *md {
488 *md = d;
489 }
490 }
491 }
492 }
493 None => break, }
495 }
496
497 Ok(new_selected)
498}
499
500pub fn greedy_k_center(
507 features: &[Vec<f64>],
508 k: usize,
509 seed_idx: Option<usize>,
510) -> Result<Vec<usize>> {
511 if features.is_empty() {
512 return Err(MetricsError::InvalidInput(
513 "features must not be empty".to_string(),
514 ));
515 }
516 if k == 0 {
517 return Err(MetricsError::InvalidInput(
518 "k must be at least 1".to_string(),
519 ));
520 }
521 if k > features.len() {
522 return Err(MetricsError::InvalidInput(format!(
523 "k ({k}) exceeds number of points ({})",
524 features.len()
525 )));
526 }
527
528 let n = features.len();
529 let first = seed_idx.unwrap_or(0).min(n - 1);
530
531 let mut selected = vec![first];
532 let mut min_dists: Vec<f64> = features
533 .iter()
534 .map(|f| euclidean_dist(f, &features[first]))
535 .collect();
536
537 while selected.len() < k {
538 let next = min_dists
539 .iter()
540 .enumerate()
541 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
542 .map(|(i, _)| i)
543 .unwrap_or(0);
544 selected.push(next);
545 let new_centre = &features[next];
546 for (i, md) in min_dists.iter_mut().enumerate() {
547 let d = euclidean_dist(&features[i], new_centre);
548 if d < *md {
549 *md = d;
550 }
551 }
552 }
553
554 Ok(selected)
555}
556
557pub fn rank_candidates(scores: &[f64], n_select: usize) -> Vec<usize> {
565 let mut indices: Vec<usize> = (0..scores.len()).collect();
566 indices.sort_by(|&a, &b| {
567 scores[b]
568 .partial_cmp(&scores[a])
569 .unwrap_or(std::cmp::Ordering::Equal)
570 });
571 indices.truncate(n_select);
572 indices
573}
574
575pub fn rank_by_uncertainty(scores: &[f64]) -> Vec<usize> {
579 let mut indices: Vec<usize> = (0..scores.len()).collect();
580 indices.sort_by(|&a, &b| {
581 scores[b]
582 .partial_cmp(&scores[a])
583 .unwrap_or(std::cmp::Ordering::Equal)
584 });
585 indices
586}
587
588#[non_exhaustive]
594#[derive(Debug, Clone, Copy, PartialEq)]
595pub enum BatchSelectionMethod {
596 Entropy,
598 MarginSampling,
600 CoreSet,
602}
603
604#[non_exhaustive]
606#[derive(Debug, Clone)]
607pub struct BatchSelectionConfig {
608 pub n_select: usize,
610 pub diversity_weight: f64,
613 pub method: BatchSelectionMethod,
615}
616
617impl Default for BatchSelectionConfig {
618 fn default() -> Self {
619 Self {
620 n_select: 10,
621 diversity_weight: 0.5,
622 method: BatchSelectionMethod::Entropy,
623 }
624 }
625}
626
627pub fn batch_selection(
638 scores: &[f64],
639 embeddings: &[Vec<f64>],
640 n_select: usize,
641 diversity_weight: f64,
642) -> Result<Vec<usize>> {
643 if scores.is_empty() || embeddings.is_empty() {
644 return Err(MetricsError::InvalidInput(
645 "scores and embeddings must not be empty".to_string(),
646 ));
647 }
648 if scores.len() != embeddings.len() {
649 return Err(MetricsError::DimensionMismatch(format!(
650 "scores len {} != embeddings len {}",
651 scores.len(),
652 embeddings.len()
653 )));
654 }
655
656 let n = scores.len();
657 let k = n_select.min(n);
658
659 if k == 0 {
660 return Ok(vec![]);
661 }
662
663 let dw = diversity_weight.clamp(0.0, 1.0);
665 if dw < 1e-12 {
666 return Ok(rank_candidates(scores, k));
667 }
668
669 if (dw - 1.0).abs() < 1e-12 {
671 return core_set_selection(embeddings, &[], k);
672 }
673
674 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
677 let min_score = scores.iter().cloned().fold(f64::INFINITY, f64::min);
678 let score_range = max_score - min_score;
679 let norm_scores: Vec<f64> = if score_range > 1e-15 {
680 scores
681 .iter()
682 .map(|&s| (s - min_score) / score_range)
683 .collect()
684 } else {
685 vec![0.5; n]
686 };
687
688 let mut selected = Vec::with_capacity(k);
689 let mut used = vec![false; n];
690
691 let seed = norm_scores
693 .iter()
694 .enumerate()
695 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
696 .map(|(i, _)| i)
697 .unwrap_or(0);
698 selected.push(seed);
699 used[seed] = true;
700
701 let mut min_dists: Vec<f64> = (0..n)
703 .map(|i| {
704 if i == seed {
705 0.0
706 } else {
707 euclidean_dist(&embeddings[i], &embeddings[seed])
708 }
709 })
710 .collect();
711
712 while selected.len() < k {
713 let max_dist = min_dists
715 .iter()
716 .enumerate()
717 .filter(|(i, _)| !used[*i])
718 .map(|(_, &d)| d)
719 .fold(f64::NEG_INFINITY, f64::max);
720 let min_dist_val = min_dists
721 .iter()
722 .enumerate()
723 .filter(|(i, _)| !used[*i])
724 .map(|(_, &d)| d)
725 .fold(f64::INFINITY, f64::min);
726 let dist_range = max_dist - min_dist_val;
727
728 let mut best_idx = 0;
730 let mut best_combined = f64::NEG_INFINITY;
731
732 for i in 0..n {
733 if used[i] {
734 continue;
735 }
736 let norm_dist = if dist_range > 1e-15 {
737 (min_dists[i] - min_dist_val) / dist_range
738 } else {
739 0.5
740 };
741 let combined = (1.0 - dw) * norm_scores[i] + dw * norm_dist;
742 if combined > best_combined {
743 best_combined = combined;
744 best_idx = i;
745 }
746 }
747
748 selected.push(best_idx);
749 used[best_idx] = true;
750
751 let new_centre = &embeddings[best_idx];
753 for (i, md) in min_dists.iter_mut().enumerate() {
754 if !used[i] {
755 let d = euclidean_dist(&embeddings[i], new_centre);
756 if d < *md {
757 *md = d;
758 }
759 }
760 }
761 }
762
763 Ok(selected)
764}
765
766pub fn batch_select(
770 features: &[Vec<f64>],
771 probabilities: &[Vec<f64>],
772 config: &BatchSelectionConfig,
773) -> Result<Vec<usize>> {
774 if features.is_empty() || probabilities.is_empty() {
775 return Err(MetricsError::InvalidInput(
776 "features and probabilities must not be empty".to_string(),
777 ));
778 }
779 if features.len() != probabilities.len() {
780 return Err(MetricsError::DimensionMismatch(format!(
781 "features len {} != probabilities len {}",
782 features.len(),
783 probabilities.len()
784 )));
785 }
786 let n = features.len();
787 let k = config.n_select.min(n);
788
789 match config.method {
790 BatchSelectionMethod::CoreSet => greedy_k_center(features, k, None),
791 BatchSelectionMethod::Entropy => {
792 let scores: Vec<f64> = probabilities
793 .iter()
794 .map(|p| entropy_uncertainty(p))
795 .collect::<Result<Vec<_>>>()?;
796 let ranked = rank_by_uncertainty(&scores);
797 Ok(ranked.into_iter().take(k).collect())
798 }
799 BatchSelectionMethod::MarginSampling => {
800 let scores: Vec<f64> = probabilities
801 .iter()
802 .map(|p| margin_sampling_score(p))
803 .collect::<Result<Vec<_>>>()?;
804 let ranked = rank_by_uncertainty(&scores);
805 Ok(ranked.into_iter().take(k).collect())
806 }
807 }
808}
809
810#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
821 fn test_margin_sampling_uniform_score_zero() {
822 let probs = vec![vec![0.25, 0.25, 0.25, 0.25], vec![0.5, 0.5]];
824 let scores = margin_sampling(&probs).expect("should succeed");
825 assert!(
826 (scores[0] - 1.0).abs() < 1e-12,
827 "uniform 4-class: score should be 1.0, got {}",
828 scores[0]
829 );
830 assert!(
831 (scores[1] - 1.0).abs() < 1e-12,
832 "uniform 2-class: score should be 1.0, got {}",
833 scores[1]
834 );
835 }
836
837 #[test]
838 fn test_margin_sampling_peaked_close_to_one() {
839 let probs = vec![vec![0.99, 0.01]];
841 let scores = margin_sampling(&probs).expect("should succeed");
842 assert!(
843 scores[0] < 0.05,
844 "peaked should have low uncertainty, got {}",
845 scores[0]
846 );
847 }
848
849 #[test]
852 fn test_entropy_uniform_has_max() {
853 let n = 4;
854 let p = 1.0 / n as f64;
855 let probs = vec![vec![p; n]];
856 let scores = entropy_sampling(&probs).expect("should succeed");
857 let expected = (n as f64).ln();
858 assert!(
859 (scores[0] - expected).abs() < 1e-10,
860 "expected {expected}, got {}",
861 scores[0]
862 );
863 }
864
865 #[test]
866 fn test_entropy_point_mass_zero() {
867 let probs = vec![vec![1.0, 0.0, 0.0]];
868 let scores = entropy_sampling(&probs).expect("should succeed");
869 assert!(
870 scores[0].abs() < 1e-12,
871 "point mass entropy should be 0, got {}",
872 scores[0]
873 );
874 }
875
876 #[test]
879 fn test_least_confidence_confident_low_score() {
880 let probs = vec![vec![0.95, 0.03, 0.02]];
881 let scores = least_confidence(&probs).expect("should succeed");
882 assert!(
883 scores[0] < 0.1,
884 "confident prediction should have low LC, got {}",
885 scores[0]
886 );
887 }
888
889 #[test]
890 fn test_least_confidence_uncertain_high_score() {
891 let probs = vec![vec![0.34, 0.33, 0.33]];
892 let scores = least_confidence(&probs).expect("should succeed");
893 assert!(
894 scores[0] > 0.5,
895 "uncertain prediction should have high LC, got {}",
896 scores[0]
897 );
898 }
899
900 #[test]
903 fn test_qbc_unanimous_low_disagreement() {
904 let committee = vec![
906 vec![vec![0.9, 0.1], vec![0.8, 0.2]], vec![vec![0.85, 0.15], vec![0.75, 0.25]], vec![vec![0.95, 0.05], vec![0.7, 0.3]], ];
910 let scores = query_by_committee(&committee).expect("should succeed");
911 assert!(
913 scores[0].abs() < 1e-12,
914 "unanimous committee: disagreement should be 0, got {}",
915 scores[0]
916 );
917 }
918
919 #[test]
920 fn test_qbc_disagreeing_positive() {
921 let committee = vec![
923 vec![vec![0.9, 0.1]], vec![vec![0.1, 0.9]], ];
926 let scores = query_by_committee(&committee).expect("should succeed");
927 assert!(
928 scores[0] > 0.0,
929 "disagreeing committee should have positive score, got {}",
930 scores[0]
931 );
932 }
933
934 #[test]
937 fn test_expected_model_change_norm() {
938 let gradients = vec![
939 vec![3.0, 4.0], vec![0.0, 0.0], vec![1.0, 1.0, 1.0, 1.0], ];
943 let scores = expected_model_change(&gradients).expect("should succeed");
944 assert!((scores[0] - 5.0).abs() < 1e-12);
945 assert!(scores[1].abs() < 1e-12);
946 assert!((scores[2] - 2.0).abs() < 1e-12);
947 }
948
949 #[test]
952 fn test_core_set_points_well_spread() {
953 let embeddings: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64 * 10.0]).collect();
955 let selected = core_set_selection(&embeddings, &[], 3).expect("should succeed");
956 assert_eq!(selected.len(), 3);
957 for i in 0..selected.len() {
960 for j in (i + 1)..selected.len() {
961 let d = euclidean_dist(&embeddings[selected[i]], &embeddings[selected[j]]);
962 assert!(d >= 10.0, "selected points should be spread: dist={d}");
963 }
964 }
965 }
966
967 #[test]
968 fn test_core_set_with_existing_selected() {
969 let embeddings: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64]).collect();
970 let already_selected = vec![0, 9]; let new = core_set_selection(&embeddings, &already_selected, 1).expect("should succeed");
972 assert_eq!(new.len(), 1);
973 assert!(
975 new[0] >= 3 && new[0] <= 6,
976 "midpoint expected, got {}",
977 new[0]
978 );
979 }
980
981 #[test]
984 fn test_rank_candidates_top_n() {
985 let scores = vec![0.1, 0.9, 0.5, 0.3, 0.7];
986 let top3 = rank_candidates(&scores, 3);
987 assert_eq!(top3.len(), 3);
988 assert_eq!(top3[0], 1); assert_eq!(top3[1], 4); assert_eq!(top3[2], 2); }
992
993 #[test]
996 fn test_batch_selection_diversity_zero_matches_uncertainty() {
997 let scores = vec![0.1, 0.9, 0.5, 0.3, 0.7];
998 let embeddings: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64]).collect();
999
1000 let pure_unc = rank_candidates(&scores, 3);
1001 let batch = batch_selection(&scores, &embeddings, 3, 0.0).expect("should succeed");
1002 assert_eq!(
1003 batch, pure_unc,
1004 "diversity_weight=0 should match pure uncertainty ranking"
1005 );
1006 }
1007
1008 #[test]
1009 fn test_batch_selection_returns_correct_count() {
1010 let scores = vec![0.5; 20];
1011 let embeddings: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64, 0.0]).collect();
1012 let selected = batch_selection(&scores, &embeddings, 7, 0.5).expect("should succeed");
1013 assert_eq!(selected.len(), 7);
1014 }
1015
1016 #[test]
1017 fn test_batch_selection_respects_n_select_legacy() {
1018 let features: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64]).collect();
1019 let probs: Vec<Vec<f64>> = (0..20)
1020 .map(|i| {
1021 let p = i as f64 / 20.0;
1022 vec![p, 1.0 - p]
1023 })
1024 .collect();
1025 let cfg = BatchSelectionConfig {
1026 n_select: 7,
1027 ..Default::default()
1028 };
1029 let selected = batch_select(&features, &probs, &cfg).expect("should succeed");
1030 assert_eq!(selected.len(), 7, "should select exactly 7 samples");
1031 }
1032
1033 #[test]
1036 fn test_margin_sampling_score_compat() {
1037 let p = vec![0.25, 0.25, 0.25, 0.25];
1038 let s = margin_sampling_score(&p).expect("should succeed");
1039 assert!((s - 1.0).abs() < 1e-12);
1040 }
1041
1042 #[test]
1043 fn test_vote_entropy_unanimous_zero() {
1044 let committee = vec![vec![0.9, 0.1], vec![0.8, 0.2], vec![0.95, 0.05]];
1045 let ve = vote_entropy(&committee).expect("should succeed");
1046 assert!(
1047 ve.abs() < 1e-12,
1048 "unanimous vote should give entropy=0, got {ve}"
1049 );
1050 }
1051
1052 #[test]
1053 fn test_expected_gradient_magnitude_shape() {
1054 let probs = vec![vec![0.7, 0.2, 0.1], vec![0.3, 0.4, 0.3]];
1055 let mags = expected_gradient_magnitude(&probs).expect("should succeed");
1056 assert_eq!(mags.len(), 2);
1057 for m in &mags {
1058 assert!(*m >= 0.0, "magnitude must be non-negative, got {m}");
1059 }
1060 }
1061
1062 #[test]
1063 fn test_k_center_returns_k_points() {
1064 let features: Vec<Vec<f64>> = (0..20).map(|i| vec![i as f64, 0.0]).collect();
1065 let selected = greedy_k_center(&features, 5, None).expect("should succeed");
1066 assert_eq!(selected.len(), 5);
1067 }
1068
1069 #[test]
1070 fn test_default_config() {
1071 let cfg = ActiveLearningConfig::default();
1072 assert_eq!(cfg.n_committee, 5);
1073 assert_eq!(cfg.n_candidates, 100);
1074 }
1075}