ruvector_coherence/
comparison.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ComparisonResult {
8 pub jaccard: f64,
9 pub edge_flips: usize,
10 pub baseline_edges: usize,
11 pub gated_edges: usize,
12 pub sparsity_ratio: f64,
13}
14
15pub fn jaccard_similarity(mask_a: &[bool], mask_b: &[bool]) -> f64 {
17 let n = mask_a.len().min(mask_b.len());
18 let (mut inter, mut union) = (0usize, 0usize);
19 for i in 0..n {
20 if mask_a[i] || mask_b[i] {
21 union += 1;
22 }
23 if mask_a[i] && mask_b[i] {
24 inter += 1;
25 }
26 }
27 union += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
28 if union == 0 {
29 1.0
30 } else {
31 inter as f64 / union as f64
32 }
33}
34
35pub fn edge_flip_count(mask_a: &[bool], mask_b: &[bool]) -> usize {
37 let n = mask_a.len().min(mask_b.len());
38 let mut flips = (0..n).filter(|&i| mask_a[i] != mask_b[i]).count();
39 flips += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
40 flips
41}
42
43pub fn compare_attention_masks(baseline: &[bool], gated: &[bool]) -> ComparisonResult {
45 let baseline_edges = baseline.iter().filter(|&&v| v).count();
46 let gated_edges = gated.iter().filter(|&&v| v).count();
47 let total = baseline.len().max(gated.len());
48 let bl_sp = if total > 0 {
49 1.0 - baseline_edges as f64 / total as f64
50 } else {
51 1.0
52 };
53 let gt_sp = if total > 0 {
54 1.0 - gated_edges as f64 / total as f64
55 } else {
56 1.0
57 };
58 ComparisonResult {
59 jaccard: jaccard_similarity(baseline, gated),
60 edge_flips: edge_flip_count(baseline, gated),
61 baseline_edges,
62 gated_edges,
63 sparsity_ratio: if bl_sp > f64::EPSILON {
64 gt_sp / bl_sp
65 } else {
66 gt_sp
67 },
68 }
69}
70
71fn count_true_tail(mask: &[bool], from: usize) -> usize {
72 if mask.len() > from {
73 mask[from..].iter().filter(|&&v| v).count()
74 } else {
75 0
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn jaccard_cases() {
85 let m = vec![true, false, true, true];
86 assert!((jaccard_similarity(&m, &m) - 1.0).abs() < 1e-10);
87 assert!(jaccard_similarity(&[true, false], &[false, true]).abs() < 1e-10);
88 assert_eq!(jaccard_similarity(&[], &[]), 1.0);
89 let (a, b) = (
91 vec![true, true, false, false],
92 vec![true, false, true, false],
93 );
94 assert!((jaccard_similarity(&a, &b) - 1.0 / 3.0).abs() < 1e-10);
95 }
96
97 #[test]
98 fn edge_flip_cases() {
99 assert_eq!(edge_flip_count(&[true, false], &[true, false]), 0);
100 assert_eq!(
101 edge_flip_count(&[true, false, true], &[false, true, false]),
102 3
103 );
104 assert_eq!(
105 edge_flip_count(&[true, false], &[true, false, true, true]),
106 2
107 );
108 }
109
110 #[test]
111 fn compare_masks() {
112 let bl = vec![true, true, false, false, true];
113 let gt = vec![true, false, false, true, true];
114 let r = compare_attention_masks(&bl, >);
115 assert_eq!(r.baseline_edges, 3);
116 assert_eq!(r.gated_edges, 3);
117 assert_eq!(r.edge_flips, 2);
118 assert!((r.jaccard - 0.5).abs() < 1e-10);
119 }
120}