Skip to main content

ruvector_coherence/
batch.rs

1//! Batched evaluation over multiple samples.
2
3use serde::{Deserialize, Serialize};
4
5use crate::metrics::delta_behavior;
6use crate::quality::quality_check;
7
8/// Aggregated results from evaluating a batch of baseline/gated output pairs.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct BatchResult {
11    pub mean_coherence_delta: f64,
12    pub std_coherence_delta: f64,
13    pub ci_95_lower: f64,
14    pub ci_95_upper: f64,
15    pub n_samples: usize,
16    pub pass_rate: f64,
17}
18
19/// Evaluates a batch of output pairs, producing mean/std/CI for coherence delta and pass rate.
20pub fn evaluate_batch(
21    baseline_outputs: &[Vec<f32>],
22    gated_outputs: &[Vec<f32>],
23    threshold: f64,
24) -> BatchResult {
25    let n = baseline_outputs.len().min(gated_outputs.len());
26    if n == 0 {
27        return BatchResult {
28            mean_coherence_delta: 0.0, std_coherence_delta: 0.0,
29            ci_95_lower: 0.0, ci_95_upper: 0.0, n_samples: 0, pass_rate: 0.0,
30        };
31    }
32
33    let mut deltas = Vec::with_capacity(n);
34    let mut passes = 0usize;
35    for i in 0..n {
36        deltas.push(delta_behavior(&baseline_outputs[i], &gated_outputs[i]).coherence_delta);
37        if quality_check(&baseline_outputs[i], &gated_outputs[i], threshold).passes_threshold {
38            passes += 1;
39        }
40    }
41
42    let mean = deltas.iter().sum::<f64>() / n as f64;
43    let var = if n > 1 {
44        deltas.iter().map(|d| (d - mean).powi(2)).sum::<f64>() / (n - 1) as f64
45    } else { 0.0 };
46    let std_dev = var.sqrt();
47    let margin = 1.96 * std_dev / (n as f64).sqrt();
48
49    BatchResult {
50        mean_coherence_delta: mean, std_coherence_delta: std_dev,
51        ci_95_lower: mean - margin, ci_95_upper: mean + margin,
52        n_samples: n, pass_rate: passes as f64 / n as f64,
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[test]
61    fn batch_empty() {
62        let r = evaluate_batch(&[], &[], 0.9);
63        assert_eq!(r.n_samples, 0);
64    }
65
66    #[test]
67    fn batch_identical() {
68        let bl = vec![vec![1.0, 2.0, 3.0]; 10];
69        let r = evaluate_batch(&bl, &bl.clone(), 0.9);
70        assert_eq!(r.n_samples, 10);
71        assert!(r.mean_coherence_delta.abs() < 1e-10);
72        assert!((r.pass_rate - 1.0).abs() < 1e-10);
73    }
74
75    #[test]
76    fn batch_ci_contains_mean() {
77        let bl = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0], vec![2.0, 3.0]];
78        let gt = vec![vec![1.1, 0.1], vec![0.1, 1.1], vec![1.2, 0.9], vec![2.1, 2.9]];
79        let r = evaluate_batch(&bl, &gt, 0.9);
80        assert!(r.ci_95_lower <= r.mean_coherence_delta);
81        assert!(r.ci_95_upper >= r.mean_coherence_delta);
82    }
83
84    #[test]
85    fn batch_pass_rate_partial() {
86        let bl = vec![vec![1.0, 0.0], vec![1.0, 0.0]];
87        let gt = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
88        let r = evaluate_batch(&bl, &gt, 0.5);
89        assert!((r.pass_rate - 0.5).abs() < 1e-10);
90    }
91
92    #[test]
93    fn batch_result_serializable() {
94        let r = BatchResult {
95            mean_coherence_delta: -0.05, std_coherence_delta: 0.02,
96            ci_95_lower: -0.07, ci_95_upper: -0.03, n_samples: 100, pass_rate: 0.95,
97        };
98        let d: BatchResult = serde_json::from_str(&serde_json::to_string(&r).unwrap()).unwrap();
99        assert_eq!(d.n_samples, 100);
100    }
101}