ruvector_coherence/
quality.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct QualityResult {
8 pub cosine_sim: f64,
9 pub l2_dist: f64,
10 pub passes_threshold: bool,
11}
12
13pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
15 let n = a.len().min(b.len());
16 let (mut dot, mut na, mut nb) = (0.0_f64, 0.0_f64, 0.0_f64);
17 for i in 0..n {
18 let (ai, bi) = (a[i] as f64, b[i] as f64);
19 dot += ai * bi;
20 na += ai * ai;
21 nb += bi * bi;
22 }
23 let denom = na.sqrt() * nb.sqrt();
24 if denom < f64::EPSILON { 0.0 } else { dot / denom }
25}
26
27pub fn l2_distance(a: &[f32], b: &[f32]) -> f64 {
29 let n = a.len().min(b.len());
30 let mut s = 0.0_f64;
31 for i in 0..n {
32 let d = a[i] as f64 - b[i] as f64;
33 s += d * d;
34 }
35 if a.len() > n { s += a[n..].iter().map(|v| (*v as f64).powi(2)).sum::<f64>(); }
36 if b.len() > n { s += b[n..].iter().map(|v| (*v as f64).powi(2)).sum::<f64>(); }
37 s.sqrt()
38}
39
40pub fn quality_check(baseline_output: &[f32], gated_output: &[f32], threshold: f64) -> QualityResult {
42 let cosine_sim = cosine_similarity(baseline_output, gated_output);
43 let l2_dist = l2_distance(baseline_output, gated_output);
44 QualityResult { cosine_sim, l2_dist, passes_threshold: cosine_sim >= threshold }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50
51 #[test]
52 fn cosine_cases() {
53 assert!((cosine_similarity(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0]) - 1.0).abs() < 1e-10);
54 assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-10);
55 assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-10);
56 assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 2.0]), 0.0);
57 }
58
59 #[test]
60 fn l2_cases() {
61 assert!(l2_distance(&[1.0, 2.0], &[1.0, 2.0]) < 1e-10);
62 assert!((l2_distance(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-10);
63 assert!((l2_distance(&[1.0], &[1.0, 3.0]) - 3.0).abs() < 1e-10);
64 }
65
66 #[test]
67 fn quality_check_pass_and_fail() {
68 let r = quality_check(&[1.0, 2.0, 3.0], &[1.1, 2.1, 3.1], 0.99);
69 assert!(r.passes_threshold);
70 let r2 = quality_check(&[1.0, 0.0], &[0.0, 1.0], 0.5);
71 assert!(!r2.passes_threshold);
72 }
73
74 #[test]
75 fn quality_result_serializable() {
76 let r = QualityResult { cosine_sim: 0.95, l2_dist: 0.32, passes_threshold: true };
77 let j = serde_json::to_string(&r).unwrap();
78 let d: QualityResult = serde_json::from_str(&j).unwrap();
79 assert!((d.cosine_sim - 0.95).abs() < 1e-10);
80 }
81}