Skip to main content

ruvector_coherence/
batch.rs

1//! Batched evaluation over multiple samples.
2
3use serde::{Deserialize, Serialize};
4
5use crate::metrics::delta_behavior;
6use crate::quality::quality_check;
7
8/// Aggregated results from evaluating a batch of baseline/gated output pairs.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct BatchResult {
11    pub mean_coherence_delta: f64,
12    pub std_coherence_delta: f64,
13    pub ci_95_lower: f64,
14    pub ci_95_upper: f64,
15    pub n_samples: usize,
16    pub pass_rate: f64,
17}
18
19/// Evaluates a batch of output pairs, producing mean/std/CI for coherence delta and pass rate.
20pub fn evaluate_batch(
21    baseline_outputs: &[Vec<f32>],
22    gated_outputs: &[Vec<f32>],
23    threshold: f64,
24) -> BatchResult {
25    let n = baseline_outputs.len().min(gated_outputs.len());
26    if n == 0 {
27        return BatchResult {
28            mean_coherence_delta: 0.0,
29            std_coherence_delta: 0.0,
30            ci_95_lower: 0.0,
31            ci_95_upper: 0.0,
32            n_samples: 0,
33            pass_rate: 0.0,
34        };
35    }
36
37    let mut deltas = Vec::with_capacity(n);
38    let mut passes = 0usize;
39    for i in 0..n {
40        deltas.push(delta_behavior(&baseline_outputs[i], &gated_outputs[i]).coherence_delta);
41        if quality_check(&baseline_outputs[i], &gated_outputs[i], threshold).passes_threshold {
42            passes += 1;
43        }
44    }
45
46    let mean = deltas.iter().sum::<f64>() / n as f64;
47    let var = if n > 1 {
48        deltas.iter().map(|d| (d - mean).powi(2)).sum::<f64>() / (n - 1) as f64
49    } else {
50        0.0
51    };
52    let std_dev = var.sqrt();
53    let margin = 1.96 * std_dev / (n as f64).sqrt();
54
55    BatchResult {
56        mean_coherence_delta: mean,
57        std_coherence_delta: std_dev,
58        ci_95_lower: mean - margin,
59        ci_95_upper: mean + margin,
60        n_samples: n,
61        pass_rate: passes as f64 / n as f64,
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn batch_empty() {
71        let r = evaluate_batch(&[], &[], 0.9);
72        assert_eq!(r.n_samples, 0);
73    }
74
75    #[test]
76    fn batch_identical() {
77        let bl = vec![vec![1.0, 2.0, 3.0]; 10];
78        let r = evaluate_batch(&bl, &bl.clone(), 0.9);
79        assert_eq!(r.n_samples, 10);
80        assert!(r.mean_coherence_delta.abs() < 1e-10);
81        assert!((r.pass_rate - 1.0).abs() < 1e-10);
82    }
83
84    #[test]
85    fn batch_ci_contains_mean() {
86        let bl = vec![
87            vec![1.0, 0.0],
88            vec![0.0, 1.0],
89            vec![1.0, 1.0],
90            vec![2.0, 3.0],
91        ];
92        let gt = vec![
93            vec![1.1, 0.1],
94            vec![0.1, 1.1],
95            vec![1.2, 0.9],
96            vec![2.1, 2.9],
97        ];
98        let r = evaluate_batch(&bl, &gt, 0.9);
99        assert!(r.ci_95_lower <= r.mean_coherence_delta);
100        assert!(r.ci_95_upper >= r.mean_coherence_delta);
101    }
102
103    #[test]
104    fn batch_pass_rate_partial() {
105        let bl = vec![vec![1.0, 0.0], vec![1.0, 0.0]];
106        let gt = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
107        let r = evaluate_batch(&bl, &gt, 0.5);
108        assert!((r.pass_rate - 0.5).abs() < 1e-10);
109    }
110
111    #[test]
112    fn batch_result_serializable() {
113        let r = BatchResult {
114            mean_coherence_delta: -0.05,
115            std_coherence_delta: 0.02,
116            ci_95_lower: -0.07,
117            ci_95_upper: -0.03,
118            n_samples: 100,
119            pass_rate: 0.95,
120        };
121        let d: BatchResult = serde_json::from_str(&serde_json::to_string(&r).unwrap()).unwrap();
122        assert_eq!(d.n_samples, 100);
123    }
124}