Skip to main content

ruvector_coherence/
quality.rs

1//! Quality guardrails for attention mechanism output comparison.
2
3use serde::{Deserialize, Serialize};
4
5/// Result of a quality check comparing baseline and gated outputs.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct QualityResult {
8    pub cosine_sim: f64,
9    pub l2_dist: f64,
10    pub passes_threshold: bool,
11}
12
13/// Cosine similarity between two vectors. Returns `0.0` for zero-magnitude inputs.
14pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
15    let n = a.len().min(b.len());
16    let (mut dot, mut na, mut nb) = (0.0_f64, 0.0_f64, 0.0_f64);
17    for i in 0..n {
18        let (ai, bi) = (a[i] as f64, b[i] as f64);
19        dot += ai * bi;
20        na += ai * ai;
21        nb += bi * bi;
22    }
23    let denom = na.sqrt() * nb.sqrt();
24    if denom < f64::EPSILON { 0.0 } else { dot / denom }
25}
26
27/// Euclidean (L2) distance between two vectors.
28pub fn l2_distance(a: &[f32], b: &[f32]) -> f64 {
29    let n = a.len().min(b.len());
30    let mut s = 0.0_f64;
31    for i in 0..n {
32        let d = a[i] as f64 - b[i] as f64;
33        s += d * d;
34    }
35    if a.len() > n { s += a[n..].iter().map(|v| (*v as f64).powi(2)).sum::<f64>(); }
36    if b.len() > n { s += b[n..].iter().map(|v| (*v as f64).powi(2)).sum::<f64>(); }
37    s.sqrt()
38}
39
40/// Quality gate: passes when `cosine_similarity >= threshold`.
41pub fn quality_check(baseline_output: &[f32], gated_output: &[f32], threshold: f64) -> QualityResult {
42    let cosine_sim = cosine_similarity(baseline_output, gated_output);
43    let l2_dist = l2_distance(baseline_output, gated_output);
44    QualityResult { cosine_sim, l2_dist, passes_threshold: cosine_sim >= threshold }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[test]
52    fn cosine_cases() {
53        assert!((cosine_similarity(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]) - 1.0).abs() < 1e-10);
54        assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-10);
55        assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-10);
56        assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]), 0.0);
57    }
58
59    #[test]
60    fn l2_cases() {
61        assert!(l2_distance(&[1.0, 2.0], &[1.0, 2.0]) < 1e-10);
62        assert!((l2_distance(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-10);
63        assert!((l2_distance(&[1.0], &[1.0, 3.0]) - 3.0).abs() < 1e-10);
64    }
65
66    #[test]
67    fn quality_check_pass_and_fail() {
68        let r = quality_check(&[1.0, 2.0, 3.0], &[1.1, 2.1, 3.1], 0.99);
69        assert!(r.passes_threshold);
70        let r2 = quality_check(&[1.0, 0.0], &[0.0, 1.0], 0.5);
71        assert!(!r2.passes_threshold);
72    }
73
74    #[test]
75    fn quality_result_serializable() {
76        let r = QualityResult { cosine_sim: 0.95, l2_dist: 0.32, passes_threshold: true };
77        let j = serde_json::to_string(&r).unwrap();
78        let d: QualityResult = serde_json::from_str(&j).unwrap();
79        assert!((d.cosine_sim - 0.95).abs() < 1e-10);
80    }
81}