Skip to main content

ruvector_coherence/
comparison.rs

1//! Side-by-side comparison utilities for attention masks.
2
3use serde::{Deserialize, Serialize};
4
5/// Result of comparing two attention masks.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ComparisonResult {
8    pub jaccard: f64,
9    pub edge_flips: usize,
10    pub baseline_edges: usize,
11    pub gated_edges: usize,
12    pub sparsity_ratio: f64,
13}
14
15/// Jaccard similarity: `|A & B| / |A | B|`. Returns `1.0` for two empty masks.
16pub fn jaccard_similarity(mask_a: &[bool], mask_b: &[bool]) -> f64 {
17    let n = mask_a.len().min(mask_b.len());
18    let (mut inter, mut union) = (0usize, 0usize);
19    for i in 0..n {
20        if mask_a[i] || mask_b[i] {
21            union += 1;
22        }
23        if mask_a[i] && mask_b[i] {
24            inter += 1;
25        }
26    }
27    union += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
28    if union == 0 {
29        1.0
30    } else {
31        inter as f64 / union as f64
32    }
33}
34
35/// Counts positions where the two masks disagree.
36pub fn edge_flip_count(mask_a: &[bool], mask_b: &[bool]) -> usize {
37    let n = mask_a.len().min(mask_b.len());
38    let mut flips = (0..n).filter(|&i| mask_a[i] != mask_b[i]).count();
39    flips += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
40    flips
41}
42
43/// Full comparison of two attention masks.
44pub fn compare_attention_masks(baseline: &[bool], gated: &[bool]) -> ComparisonResult {
45    let baseline_edges = baseline.iter().filter(|&&v| v).count();
46    let gated_edges = gated.iter().filter(|&&v| v).count();
47    let total = baseline.len().max(gated.len());
48    let bl_sp = if total > 0 {
49        1.0 - baseline_edges as f64 / total as f64
50    } else {
51        1.0
52    };
53    let gt_sp = if total > 0 {
54        1.0 - gated_edges as f64 / total as f64
55    } else {
56        1.0
57    };
58    ComparisonResult {
59        jaccard: jaccard_similarity(baseline, gated),
60        edge_flips: edge_flip_count(baseline, gated),
61        baseline_edges,
62        gated_edges,
63        sparsity_ratio: if bl_sp > f64::EPSILON {
64            gt_sp / bl_sp
65        } else {
66            gt_sp
67        },
68    }
69}
70
71fn count_true_tail(mask: &[bool], from: usize) -> usize {
72    if mask.len() > from {
73        mask[from..].iter().filter(|&&v| v).count()
74    } else {
75        0
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn jaccard_cases() {
85        let m = vec![true, false, true, true];
86        assert!((jaccard_similarity(&m, &m) - 1.0).abs() < 1e-10);
87        assert!(jaccard_similarity(&[true, false], &[false, true]).abs() < 1e-10);
88        assert_eq!(jaccard_similarity(&[], &[]), 1.0);
89        // partial: intersection=1, union=3
90        let (a, b) = (
91            vec![true, true, false, false],
92            vec![true, false, true, false],
93        );
94        assert!((jaccard_similarity(&a, &b) - 1.0 / 3.0).abs() < 1e-10);
95    }
96
97    #[test]
98    fn edge_flip_cases() {
99        assert_eq!(edge_flip_count(&[true, false], &[true, false]), 0);
100        assert_eq!(
101            edge_flip_count(&[true, false, true], &[false, true, false]),
102            3
103        );
104        assert_eq!(
105            edge_flip_count(&[true, false], &[true, false, true, true]),
106            2
107        );
108    }
109
110    #[test]
111    fn compare_masks() {
112        let bl = vec![true, true, false, false, true];
113        let gt = vec![true, false, false, true, true];
114        let r = compare_attention_masks(&bl, &gt);
115        assert_eq!(r.baseline_edges, 3);
116        assert_eq!(r.gated_edges, 3);
117        assert_eq!(r.edge_flips, 2);
118        assert!((r.jaccard - 0.5).abs() < 1e-10);
119    }
120}