ruvector_coherence/
comparison.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ComparisonResult {
8 pub jaccard: f64,
9 pub edge_flips: usize,
10 pub baseline_edges: usize,
11 pub gated_edges: usize,
12 pub sparsity_ratio: f64,
13}
14
15pub fn jaccard_similarity(mask_a: &[bool], mask_b: &[bool]) -> f64 {
17 let n = mask_a.len().min(mask_b.len());
18 let (mut inter, mut union) = (0usize, 0usize);
19 for i in 0..n {
20 if mask_a[i] || mask_b[i] { union += 1; }
21 if mask_a[i] && mask_b[i] { inter += 1; }
22 }
23 union += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
24 if union == 0 { 1.0 } else { inter as f64 / union as f64 }
25}
26
27pub fn edge_flip_count(mask_a: &[bool], mask_b: &[bool]) -> usize {
29 let n = mask_a.len().min(mask_b.len());
30 let mut flips = (0..n).filter(|&i| mask_a[i] != mask_b[i]).count();
31 flips += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
32 flips
33}
34
35pub fn compare_attention_masks(baseline: &[bool], gated: &[bool]) -> ComparisonResult {
37 let baseline_edges = baseline.iter().filter(|&&v| v).count();
38 let gated_edges = gated.iter().filter(|&&v| v).count();
39 let total = baseline.len().max(gated.len());
40 let bl_sp = if total > 0 { 1.0 - baseline_edges as f64 / total as f64 } else { 1.0 };
41 let gt_sp = if total > 0 { 1.0 - gated_edges as f64 / total as f64 } else { 1.0 };
42 ComparisonResult {
43 jaccard: jaccard_similarity(baseline, gated),
44 edge_flips: edge_flip_count(baseline, gated),
45 baseline_edges,
46 gated_edges,
47 sparsity_ratio: if bl_sp > f64::EPSILON { gt_sp / bl_sp } else { gt_sp },
48 }
49}
50
51fn count_true_tail(mask: &[bool], from: usize) -> usize {
52 if mask.len() > from { mask[from..].iter().filter(|&&v| v).count() } else { 0 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58
59 #[test]
60 fn jaccard_cases() {
61 let m = vec![true, false, true, true];
62 assert!((jaccard_similarity(&m, &m) - 1.0).abs() < 1e-10);
63 assert!(jaccard_similarity(&[true, false], &[false, true]).abs() < 1e-10);
64 assert_eq!(jaccard_similarity(&[], &[]), 1.0);
65 let (a, b) = (vec![true, true, false, false], vec![true, false, true, false]);
67 assert!((jaccard_similarity(&a, &b) - 1.0 / 3.0).abs() < 1e-10);
68 }
69
70 #[test]
71 fn edge_flip_cases() {
72 assert_eq!(edge_flip_count(&[true, false], &[true, false]), 0);
73 assert_eq!(edge_flip_count(&[true, false, true], &[false, true, false]), 3);
74 assert_eq!(edge_flip_count(&[true, false], &[true, false, true, true]), 2);
75 }
76
77 #[test]
78 fn compare_masks() {
79 let bl = vec![true, true, false, false, true];
80 let gt = vec![true, false, false, true, true];
81 let r = compare_attention_masks(&bl, >);
82 assert_eq!(r.baseline_edges, 3);
83 assert_eq!(r.gated_edges, 3);
84 assert_eq!(r.edge_flips, 2);
85 assert!((r.jaccard - 0.5).abs() < 1e-10);
86 }
87}