ruvector_coherence/
batch.rs1use serde::{Deserialize, Serialize};
4
5use crate::metrics::delta_behavior;
6use crate::quality::quality_check;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct BatchResult {
11 pub mean_coherence_delta: f64,
12 pub std_coherence_delta: f64,
13 pub ci_95_lower: f64,
14 pub ci_95_upper: f64,
15 pub n_samples: usize,
16 pub pass_rate: f64,
17}
18
19pub fn evaluate_batch(
21 baseline_outputs: &[Vec<f32>],
22 gated_outputs: &[Vec<f32>],
23 threshold: f64,
24) -> BatchResult {
25 let n = baseline_outputs.len().min(gated_outputs.len());
26 if n == 0 {
27 return BatchResult {
28 mean_coherence_delta: 0.0, std_coherence_delta: 0.0,
29 ci_95_lower: 0.0, ci_95_upper: 0.0, n_samples: 0, pass_rate: 0.0,
30 };
31 }
32
33 let mut deltas = Vec::with_capacity(n);
34 let mut passes = 0usize;
35 for i in 0..n {
36 deltas.push(delta_behavior(&baseline_outputs[i], &gated_outputs[i]).coherence_delta);
37 if quality_check(&baseline_outputs[i], &gated_outputs[i], threshold).passes_threshold {
38 passes += 1;
39 }
40 }
41
42 let mean = deltas.iter().sum::<f64>() / n as f64;
43 let var = if n > 1 {
44 deltas.iter().map(|d| (d - mean).powi(2)).sum::<f64>() / (n - 1) as f64
45 } else { 0.0 };
46 let std_dev = var.sqrt();
47 let margin = 1.96 * std_dev / (n as f64).sqrt();
48
49 BatchResult {
50 mean_coherence_delta: mean, std_coherence_delta: std_dev,
51 ci_95_lower: mean - margin, ci_95_upper: mean + margin,
52 n_samples: n, pass_rate: passes as f64 / n as f64,
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn batch_empty() {
62 let r = evaluate_batch(&[], &[], 0.9);
63 assert_eq!(r.n_samples, 0);
64 }
65
66 #[test]
67 fn batch_identical() {
68 let bl = vec![vec![1.0, 2.0, 3.0]; 10];
69 let r = evaluate_batch(&bl, &bl.clone(), 0.9);
70 assert_eq!(r.n_samples, 10);
71 assert!(r.mean_coherence_delta.abs() < 1e-10);
72 assert!((r.pass_rate - 1.0).abs() < 1e-10);
73 }
74
75 #[test]
76 fn batch_ci_contains_mean() {
77 let bl = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0], vec![2.0, 3.0]];
78 let gt = vec![vec![1.1, 0.1], vec![0.1, 1.1], vec![1.2, 0.9], vec![2.1, 2.9]];
79 let r = evaluate_batch(&bl, >, 0.9);
80 assert!(r.ci_95_lower <= r.mean_coherence_delta);
81 assert!(r.ci_95_upper >= r.mean_coherence_delta);
82 }
83
84 #[test]
85 fn batch_pass_rate_partial() {
86 let bl = vec![vec![1.0, 0.0], vec![1.0, 0.0]];
87 let gt = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
88 let r = evaluate_batch(&bl, >, 0.5);
89 assert!((r.pass_rate - 0.5).abs() < 1e-10);
90 }
91
92 #[test]
93 fn batch_result_serializable() {
94 let r = BatchResult {
95 mean_coherence_delta: -0.05, std_coherence_delta: 0.02,
96 ci_95_lower: -0.07, ci_95_upper: -0.03, n_samples: 100, pass_rate: 0.95,
97 };
98 let d: BatchResult = serde_json::from_str(&serde_json::to_string(&r).unwrap()).unwrap();
99 assert_eq!(d.n_samples, 100);
100 }
101}