ruvector_coherence/
quality.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct QualityResult {
8 pub cosine_sim: f64,
9 pub l2_dist: f64,
10 pub passes_threshold: bool,
11}
12
13pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
15 let n = a.len().min(b.len());
16 let (mut dot, mut na, mut nb) = (0.0_f64, 0.0_f64, 0.0_f64);
17 for i in 0..n {
18 let (ai, bi) = (a[i] as f64, b[i] as f64);
19 dot += ai * bi;
20 na += ai * ai;
21 nb += bi * bi;
22 }
23 let denom = na.sqrt() * nb.sqrt();
24 if denom < f64::EPSILON {
25 0.0
26 } else {
27 dot / denom
28 }
29}
30
31pub fn l2_distance(a: &[f32], b: &[f32]) -> f64 {
33 let n = a.len().min(b.len());
34 let mut s = 0.0_f64;
35 for i in 0..n {
36 let d = a[i] as f64 - b[i] as f64;
37 s += d * d;
38 }
39 if a.len() > n {
40 s += a[n..].iter().map(|v| (*v as f64).powi(2)).sum::<f64>();
41 }
42 if b.len() > n {
43 s += b[n..].iter().map(|v| (*v as f64).powi(2)).sum::<f64>();
44 }
45 s.sqrt()
46}
47
48pub fn quality_check(
50 baseline_output: &[f32],
51 gated_output: &[f32],
52 threshold: f64,
53) -> QualityResult {
54 let cosine_sim = cosine_similarity(baseline_output, gated_output);
55 let l2_dist = l2_distance(baseline_output, gated_output);
56 QualityResult {
57 cosine_sim,
58 l2_dist,
59 passes_threshold: cosine_sim >= threshold,
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn cosine_cases() {
69 assert!((cosine_similarity(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]) - 1.0).abs() < 1e-10);
70 assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-10);
71 assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-10);
72 assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]), 0.0);
73 }
74
75 #[test]
76 fn l2_cases() {
77 assert!(l2_distance(&[1.0, 2.0], &[1.0, 2.0]) < 1e-10);
78 assert!((l2_distance(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-10);
79 assert!((l2_distance(&[1.0], &[1.0, 3.0]) - 3.0).abs() < 1e-10);
80 }
81
82 #[test]
83 fn quality_check_pass_and_fail() {
84 let r = quality_check(&[1.0, 2.0, 3.0], &[1.1, 2.1, 3.1], 0.99);
85 assert!(r.passes_threshold);
86 let r2 = quality_check(&[1.0, 0.0], &[0.0, 1.0], 0.5);
87 assert!(!r2.passes_threshold);
88 }
89
90 #[test]
91 fn quality_result_serializable() {
92 let r = QualityResult {
93 cosine_sim: 0.95,
94 l2_dist: 0.32,
95 passes_threshold: true,
96 };
97 let j = serde_json::to_string(&r).unwrap();
98 let d: QualityResult = serde_json::from_str(&j).unwrap();
99 assert!((d.cosine_sim - 0.95).abs() < 1e-10);
100 }
101}