Skip to main content

ruvector_coherence/
comparison.rs

1//! Side-by-side comparison utilities for attention masks.
2
3use serde::{Deserialize, Serialize};
4
5/// Result of comparing two attention masks.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ComparisonResult {
8    pub jaccard: f64,
9    pub edge_flips: usize,
10    pub baseline_edges: usize,
11    pub gated_edges: usize,
12    pub sparsity_ratio: f64,
13}
14
15/// Jaccard similarity: `|A & B| / |A | B|`. Returns `1.0` for two empty masks.
16pub fn jaccard_similarity(mask_a: &[bool], mask_b: &[bool]) -> f64 {
17    let n = mask_a.len().min(mask_b.len());
18    let (mut inter, mut union) = (0usize, 0usize);
19    for i in 0..n {
20        if mask_a[i] || mask_b[i] { union += 1; }
21        if mask_a[i] && mask_b[i] { inter += 1; }
22    }
23    union += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
24    if union == 0 { 1.0 } else { inter as f64 / union as f64 }
25}
26
27/// Counts positions where the two masks disagree.
28pub fn edge_flip_count(mask_a: &[bool], mask_b: &[bool]) -> usize {
29    let n = mask_a.len().min(mask_b.len());
30    let mut flips = (0..n).filter(|&i| mask_a[i] != mask_b[i]).count();
31    flips += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
32    flips
33}
34
35/// Full comparison of two attention masks.
36pub fn compare_attention_masks(baseline: &[bool], gated: &[bool]) -> ComparisonResult {
37    let baseline_edges = baseline.iter().filter(|&&v| v).count();
38    let gated_edges = gated.iter().filter(|&&v| v).count();
39    let total = baseline.len().max(gated.len());
40    let bl_sp = if total > 0 { 1.0 - baseline_edges as f64 / total as f64 } else { 1.0 };
41    let gt_sp = if total > 0 { 1.0 - gated_edges as f64 / total as f64 } else { 1.0 };
42    ComparisonResult {
43        jaccard: jaccard_similarity(baseline, gated),
44        edge_flips: edge_flip_count(baseline, gated),
45        baseline_edges,
46        gated_edges,
47        sparsity_ratio: if bl_sp > f64::EPSILON { gt_sp / bl_sp } else { gt_sp },
48    }
49}
50
51fn count_true_tail(mask: &[bool], from: usize) -> usize {
52    if mask.len() > from { mask[from..].iter().filter(|&&v| v).count() } else { 0 }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58
59    #[test]
60    fn jaccard_cases() {
61        let m = vec![true, false, true, true];
62        assert!((jaccard_similarity(&m, &m) - 1.0).abs() < 1e-10);
63        assert!(jaccard_similarity(&[true, false], &[false, true]).abs() < 1e-10);
64        assert_eq!(jaccard_similarity(&[], &[]), 1.0);
65        // partial: intersection=1, union=3
66        let (a, b) = (vec![true, true, false, false], vec![true, false, true, false]);
67        assert!((jaccard_similarity(&a, &b) - 1.0 / 3.0).abs() < 1e-10);
68    }
69
70    #[test]
71    fn edge_flip_cases() {
72        assert_eq!(edge_flip_count(&[true, false], &[true, false]), 0);
73        assert_eq!(edge_flip_count(&[true, false, true], &[false, true, false]), 3);
74        assert_eq!(edge_flip_count(&[true, false], &[true, false, true, true]), 2);
75    }
76
77    #[test]
78    fn compare_masks() {
79        let bl = vec![true, true, false, false, true];
80        let gt = vec![true, false, false, true, true];
81        let r = compare_attention_masks(&bl, &gt);
82        assert_eq!(r.baseline_edges, 3);
83        assert_eq!(r.gated_edges, 3);
84        assert_eq!(r.edge_flips, 2);
85        assert!((r.jaccard - 0.5).abs() < 1e-10);
86    }
87}