ruvector_attention/topology/
coherence.rs

1//! Window Coherence Metrics
2//!
3//! Fast structural metrics for measuring attention window stability.
4//! These are permission signals, not similarity signals.
5
6use serde::{Deserialize, Serialize};
7
8/// Coherence metric type
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum CoherenceMetric {
11    /// k-NN graph boundary ratio
12    BoundaryMass,
13    /// Cut proxy score (edge cut estimate)
14    CutProxy,
15    /// Disagreement across neighbor labels
16    Disagreement,
17    /// Average neighbor similarity variance
18    SimilarityVariance,
19}
20
21/// Per-window coherence scores
22#[derive(Debug, Clone)]
23pub struct WindowCoherence {
24    /// Overall coherence score (0 = fragmented, 1 = coherent)
25    pub score: f32,
26    /// Individual metric scores
27    pub metric_scores: Vec<f32>,
28    /// Which metrics were used
29    pub metrics: Vec<CoherenceMetric>,
30    /// Number of keys in window
31    pub window_size: usize,
32    /// Whether this coherence is stale (needs update)
33    pub is_stale: bool,
34    /// Token count since last update
35    pub tokens_since_update: usize,
36}
37
38impl WindowCoherence {
39    /// Compute coherence from keys
40    pub fn compute(
41        keys: &[&[f32]],
42        k_neighbors: usize,
43        metrics: &[CoherenceMetric],
44    ) -> Self {
45        let n = keys.len();
46        if n < 2 {
47            return Self {
48                score: 1.0,
49                metric_scores: vec![1.0],
50                metrics: metrics.to_vec(),
51                window_size: n,
52                is_stale: false,
53                tokens_since_update: 0,
54            };
55        }
56
57        // Build k-NN graph (fast approximate)
58        let knn_graph = Self::build_knn_graph(keys, k_neighbors);
59
60        // Compute each metric
61        let metric_scores: Vec<f32> = metrics
62            .iter()
63            .map(|m| Self::compute_metric(*m, keys, &knn_graph))
64            .collect();
65
66        // Average scores for overall coherence
67        let score = metric_scores.iter().sum::<f32>() / metric_scores.len() as f32;
68
69        Self {
70            score,
71            metric_scores,
72            metrics: metrics.to_vec(),
73            window_size: n,
74            is_stale: false,
75            tokens_since_update: 0,
76        }
77    }
78
79    /// Mark as stale (needs recomputation)
80    pub fn mark_stale(&mut self) {
81        self.is_stale = true;
82    }
83
84    /// Increment token counter
85    pub fn tick(&mut self) {
86        self.tokens_since_update += 1;
87    }
88
89    /// Check if update is needed based on period
90    pub fn needs_update(&self, update_period: usize) -> bool {
91        self.is_stale || self.tokens_since_update >= update_period
92    }
93
94    /// Build approximate k-NN graph
95    /// Returns [N × k] indices of nearest neighbors
96    fn build_knn_graph(keys: &[&[f32]], k: usize) -> Vec<Vec<usize>> {
97        let n = keys.len();
98        let k = k.min(n - 1);
99
100        keys.iter()
101            .enumerate()
102            .map(|(i, key)| {
103                let mut distances: Vec<(usize, f32)> = keys
104                    .iter()
105                    .enumerate()
106                    .filter(|(j, _)| *j != i)
107                    .map(|(j, k2)| (j, Self::squared_distance(key, k2)))
108                    .collect();
109
110                distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
111
112                distances.iter().take(k).map(|(j, _)| *j).collect()
113            })
114            .collect()
115    }
116
117    /// Squared Euclidean distance
118    #[inline]
119    fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
120        a.iter()
121            .zip(b.iter())
122            .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
123            .sum()
124    }
125
126    /// Compute specific metric
127    fn compute_metric(
128        metric: CoherenceMetric,
129        keys: &[&[f32]],
130        knn_graph: &[Vec<usize>],
131    ) -> f32 {
132        match metric {
133            CoherenceMetric::BoundaryMass => Self::boundary_mass(knn_graph),
134            CoherenceMetric::CutProxy => Self::cut_proxy(knn_graph),
135            CoherenceMetric::Disagreement => Self::disagreement(keys, knn_graph),
136            CoherenceMetric::SimilarityVariance => Self::similarity_variance(keys, knn_graph),
137        }
138    }
139
140    /// Boundary mass: fraction of edges going to "far" neighbors
141    /// High coherence = most edges go to nearby neighbors
142    fn boundary_mass(knn_graph: &[Vec<usize>]) -> f32 {
143        if knn_graph.is_empty() {
144            return 1.0;
145        }
146
147        let n = knn_graph.len();
148        let mut internal_edges = 0;
149        let mut total_edges = 0;
150
151        for (i, neighbors) in knn_graph.iter().enumerate() {
152            for &j in neighbors {
153                total_edges += 1;
154                // "Internal" if neighbor is within n/4 positions
155                if (i as i32 - j as i32).unsigned_abs() as usize <= n / 4 {
156                    internal_edges += 1;
157                }
158            }
159        }
160
161        if total_edges == 0 {
162            return 1.0;
163        }
164
165        internal_edges as f32 / total_edges as f32
166    }
167
168    /// Cut proxy: estimate of graph cut cost
169    /// High coherence = low cut (well-connected)
170    fn cut_proxy(knn_graph: &[Vec<usize>]) -> f32 {
171        if knn_graph.is_empty() {
172            return 1.0;
173        }
174
175        let n = knn_graph.len();
176        let half = n / 2;
177
178        // Count edges crossing the midpoint
179        let mut crossing = 0;
180        let mut total = 0;
181
182        for (i, neighbors) in knn_graph.iter().enumerate() {
183            for &j in neighbors {
184                total += 1;
185                if (i < half) != (j < half) {
186                    crossing += 1;
187                }
188            }
189        }
190
191        if total == 0 {
192            return 1.0;
193        }
194
195        // Invert: high coherence = few crossings
196        1.0 - (crossing as f32 / total as f32)
197    }
198
199    /// Disagreement: variance in neighbor similarities
200    /// High coherence = neighbors have similar similarities
201    fn disagreement(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
202        if knn_graph.is_empty() || keys.is_empty() {
203            return 1.0;
204        }
205
206        let mut total_variance = 0.0f32;
207        let mut count = 0;
208
209        for (i, neighbors) in knn_graph.iter().enumerate() {
210            if neighbors.is_empty() {
211                continue;
212            }
213
214            // Similarities to neighbors
215            let sims: Vec<f32> = neighbors
216                .iter()
217                .map(|&j| Self::cosine_similarity(keys[i], keys[j]))
218                .collect();
219
220            let mean: f32 = sims.iter().sum::<f32>() / sims.len() as f32;
221            let variance: f32 = sims.iter().map(|s| (s - mean) * (s - mean)).sum::<f32>() / sims.len() as f32;
222
223            total_variance += variance;
224            count += 1;
225        }
226
227        if count == 0 {
228            return 1.0;
229        }
230
231        // Low variance = high coherence
232        let avg_variance = total_variance / count as f32;
233        1.0 - avg_variance.min(1.0)
234    }
235
236    /// Similarity variance across window
237    fn similarity_variance(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
238        if knn_graph.is_empty() || keys.is_empty() {
239            return 1.0;
240        }
241
242        // Collect all neighbor similarities
243        let mut all_sims = Vec::new();
244        for (i, neighbors) in knn_graph.iter().enumerate() {
245            for &j in neighbors {
246                all_sims.push(Self::cosine_similarity(keys[i], keys[j]));
247            }
248        }
249
250        if all_sims.is_empty() {
251            return 1.0;
252        }
253
254        let mean: f32 = all_sims.iter().sum::<f32>() / all_sims.len() as f32;
255        let variance: f32 = all_sims.iter().map(|s| (s - mean) * (s - mean)).sum::<f32>() / all_sims.len() as f32;
256
257        // Low variance + high mean = high coherence
258        let coherence = mean * (1.0 - variance.sqrt().min(1.0));
259        coherence.max(0.0).min(1.0)
260    }
261
262    /// Cosine similarity
263    #[inline]
264    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
265        let dot: f32 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
266        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
267        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
268
269        if norm_a < 1e-8 || norm_b < 1e-8 {
270            return 0.0;
271        }
272
273        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_coherence_computation() {
283        let keys: Vec<Vec<f32>> = (0..20)
284            .map(|i| vec![i as f32 * 0.1; 32])
285            .collect();
286        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
287
288        let coherence = WindowCoherence::compute(
289            &keys_refs,
290            5,
291            &[CoherenceMetric::BoundaryMass, CoherenceMetric::SimilarityVariance],
292        );
293
294        assert!(coherence.score >= 0.0 && coherence.score <= 1.0);
295        assert_eq!(coherence.window_size, 20);
296    }
297
298    #[test]
299    fn test_coherent_window() {
300        // Highly similar keys = high coherence
301        let keys: Vec<Vec<f32>> = (0..10)
302            .map(|_| vec![0.5f32; 16])
303            .collect();
304        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
305
306        let coherence = WindowCoherence::compute(
307            &keys_refs,
308            3,
309            &[CoherenceMetric::Disagreement],
310        );
311
312        // Should be very coherent
313        assert!(coherence.score > 0.8);
314    }
315
316    #[test]
317    fn test_stale_tracking() {
318        let keys: Vec<Vec<f32>> = vec![vec![1.0; 8]; 5];
319        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
320
321        let mut coherence = WindowCoherence::compute(&keys_refs, 2, &[CoherenceMetric::BoundaryMass]);
322
323        assert!(!coherence.needs_update(4));
324
325        coherence.tick();
326        coherence.tick();
327        coherence.tick();
328        coherence.tick();
329
330        assert!(coherence.needs_update(4));
331    }
332}