ruvector_coherence/
batch.rs1use serde::{Deserialize, Serialize};
4
5use crate::metrics::delta_behavior;
6use crate::quality::quality_check;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct BatchResult {
11 pub mean_coherence_delta: f64,
12 pub std_coherence_delta: f64,
13 pub ci_95_lower: f64,
14 pub ci_95_upper: f64,
15 pub n_samples: usize,
16 pub pass_rate: f64,
17}
18
19pub fn evaluate_batch(
21 baseline_outputs: &[Vec<f32>],
22 gated_outputs: &[Vec<f32>],
23 threshold: f64,
24) -> BatchResult {
25 let n = baseline_outputs.len().min(gated_outputs.len());
26 if n == 0 {
27 return BatchResult {
28 mean_coherence_delta: 0.0,
29 std_coherence_delta: 0.0,
30 ci_95_lower: 0.0,
31 ci_95_upper: 0.0,
32 n_samples: 0,
33 pass_rate: 0.0,
34 };
35 }
36
37 let mut deltas = Vec::with_capacity(n);
38 let mut passes = 0usize;
39 for i in 0..n {
40 deltas.push(delta_behavior(&baseline_outputs[i], &gated_outputs[i]).coherence_delta);
41 if quality_check(&baseline_outputs[i], &gated_outputs[i], threshold).passes_threshold {
42 passes += 1;
43 }
44 }
45
46 let mean = deltas.iter().sum::<f64>() / n as f64;
47 let var = if n > 1 {
48 deltas.iter().map(|d| (d - mean).powi(2)).sum::<f64>() / (n - 1) as f64
49 } else {
50 0.0
51 };
52 let std_dev = var.sqrt();
53 let margin = 1.96 * std_dev / (n as f64).sqrt();
54
55 BatchResult {
56 mean_coherence_delta: mean,
57 std_coherence_delta: std_dev,
58 ci_95_lower: mean - margin,
59 ci_95_upper: mean + margin,
60 n_samples: n,
61 pass_rate: passes as f64 / n as f64,
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn batch_empty() {
71 let r = evaluate_batch(&[], &[], 0.9);
72 assert_eq!(r.n_samples, 0);
73 }
74
75 #[test]
76 fn batch_identical() {
77 let bl = vec![vec![1.0, 2.0, 3.0]; 10];
78 let r = evaluate_batch(&bl, &bl.clone(), 0.9);
79 assert_eq!(r.n_samples, 10);
80 assert!(r.mean_coherence_delta.abs() < 1e-10);
81 assert!((r.pass_rate - 1.0).abs() < 1e-10);
82 }
83
84 #[test]
85 fn batch_ci_contains_mean() {
86 let bl = vec![
87 vec![1.0, 0.0],
88 vec![0.0, 1.0],
89 vec![1.0, 1.0],
90 vec![2.0, 3.0],
91 ];
92 let gt = vec![
93 vec![1.1, 0.1],
94 vec![0.1, 1.1],
95 vec![1.2, 0.9],
96 vec![2.1, 2.9],
97 ];
98 let r = evaluate_batch(&bl, >, 0.9);
99 assert!(r.ci_95_lower <= r.mean_coherence_delta);
100 assert!(r.ci_95_upper >= r.mean_coherence_delta);
101 }
102
103 #[test]
104 fn batch_pass_rate_partial() {
105 let bl = vec![vec![1.0, 0.0], vec![1.0, 0.0]];
106 let gt = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
107 let r = evaluate_batch(&bl, >, 0.5);
108 assert!((r.pass_rate - 0.5).abs() < 1e-10);
109 }
110
111 #[test]
112 fn batch_result_serializable() {
113 let r = BatchResult {
114 mean_coherence_delta: -0.05,
115 std_coherence_delta: 0.02,
116 ci_95_lower: -0.07,
117 ci_95_upper: -0.03,
118 n_samples: 100,
119 pass_rate: 0.95,
120 };
121 let d: BatchResult = serde_json::from_str(&serde_json::to_string(&r).unwrap()).unwrap();
122 assert_eq!(d.n_samples, 100);
123 }
124}