Skip to main content

ruvector_attention/topology/
coherence.rs

1//! Window Coherence Metrics
2//!
3//! Fast structural metrics for measuring attention window stability.
4//! These are permission signals, not similarity signals.
5
6use serde::{Deserialize, Serialize};
7
8/// Coherence metric type
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum CoherenceMetric {
11    /// k-NN graph boundary ratio
12    BoundaryMass,
13    /// Cut proxy score (edge cut estimate)
14    CutProxy,
15    /// Disagreement across neighbor labels
16    Disagreement,
17    /// Average neighbor similarity variance
18    SimilarityVariance,
19}
20
21/// Per-window coherence scores
22#[derive(Debug, Clone)]
23pub struct WindowCoherence {
24    /// Overall coherence score (0 = fragmented, 1 = coherent)
25    pub score: f32,
26    /// Individual metric scores
27    pub metric_scores: Vec<f32>,
28    /// Which metrics were used
29    pub metrics: Vec<CoherenceMetric>,
30    /// Number of keys in window
31    pub window_size: usize,
32    /// Whether this coherence is stale (needs update)
33    pub is_stale: bool,
34    /// Token count since last update
35    pub tokens_since_update: usize,
36}
37
38impl WindowCoherence {
39    /// Compute coherence from keys
40    pub fn compute(keys: &[&[f32]], k_neighbors: usize, metrics: &[CoherenceMetric]) -> Self {
41        let n = keys.len();
42        if n < 2 {
43            return Self {
44                score: 1.0,
45                metric_scores: vec![1.0],
46                metrics: metrics.to_vec(),
47                window_size: n,
48                is_stale: false,
49                tokens_since_update: 0,
50            };
51        }
52
53        // Build k-NN graph (fast approximate)
54        let knn_graph = Self::build_knn_graph(keys, k_neighbors);
55
56        // Compute each metric
57        let metric_scores: Vec<f32> = metrics
58            .iter()
59            .map(|m| Self::compute_metric(*m, keys, &knn_graph))
60            .collect();
61
62        // Average scores for overall coherence
63        let score = metric_scores.iter().sum::<f32>() / metric_scores.len() as f32;
64
65        Self {
66            score,
67            metric_scores,
68            metrics: metrics.to_vec(),
69            window_size: n,
70            is_stale: false,
71            tokens_since_update: 0,
72        }
73    }
74
75    /// Mark as stale (needs recomputation)
76    pub fn mark_stale(&mut self) {
77        self.is_stale = true;
78    }
79
80    /// Increment token counter
81    pub fn tick(&mut self) {
82        self.tokens_since_update += 1;
83    }
84
85    /// Check if update is needed based on period
86    pub fn needs_update(&self, update_period: usize) -> bool {
87        self.is_stale || self.tokens_since_update >= update_period
88    }
89
90    /// Build approximate k-NN graph
91    /// Returns [N × k] indices of nearest neighbors
92    fn build_knn_graph(keys: &[&[f32]], k: usize) -> Vec<Vec<usize>> {
93        let n = keys.len();
94        let k = k.min(n - 1);
95
96        keys.iter()
97            .enumerate()
98            .map(|(i, key)| {
99                let mut distances: Vec<(usize, f32)> = keys
100                    .iter()
101                    .enumerate()
102                    .filter(|(j, _)| *j != i)
103                    .map(|(j, k2)| (j, Self::squared_distance(key, k2)))
104                    .collect();
105
106                distances.sort_unstable_by(|a, b| {
107                    a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
108                });
109
110                distances.iter().take(k).map(|(j, _)| *j).collect()
111            })
112            .collect()
113    }
114
115    /// Squared Euclidean distance
116    #[inline]
117    fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
118        a.iter()
119            .zip(b.iter())
120            .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
121            .sum()
122    }
123
124    /// Compute specific metric
125    fn compute_metric(metric: CoherenceMetric, keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
126        match metric {
127            CoherenceMetric::BoundaryMass => Self::boundary_mass(knn_graph),
128            CoherenceMetric::CutProxy => Self::cut_proxy(knn_graph),
129            CoherenceMetric::Disagreement => Self::disagreement(keys, knn_graph),
130            CoherenceMetric::SimilarityVariance => Self::similarity_variance(keys, knn_graph),
131        }
132    }
133
134    /// Boundary mass: fraction of edges going to "far" neighbors
135    /// High coherence = most edges go to nearby neighbors
136    fn boundary_mass(knn_graph: &[Vec<usize>]) -> f32 {
137        if knn_graph.is_empty() {
138            return 1.0;
139        }
140
141        let n = knn_graph.len();
142        let mut internal_edges = 0;
143        let mut total_edges = 0;
144
145        for (i, neighbors) in knn_graph.iter().enumerate() {
146            for &j in neighbors {
147                total_edges += 1;
148                // "Internal" if neighbor is within n/4 positions
149                if (i as i32 - j as i32).unsigned_abs() as usize <= n / 4 {
150                    internal_edges += 1;
151                }
152            }
153        }
154
155        if total_edges == 0 {
156            return 1.0;
157        }
158
159        internal_edges as f32 / total_edges as f32
160    }
161
162    /// Cut proxy: estimate of graph cut cost
163    /// High coherence = low cut (well-connected)
164    fn cut_proxy(knn_graph: &[Vec<usize>]) -> f32 {
165        if knn_graph.is_empty() {
166            return 1.0;
167        }
168
169        let n = knn_graph.len();
170        let half = n / 2;
171
172        // Count edges crossing the midpoint
173        let mut crossing = 0;
174        let mut total = 0;
175
176        for (i, neighbors) in knn_graph.iter().enumerate() {
177            for &j in neighbors {
178                total += 1;
179                if (i < half) != (j < half) {
180                    crossing += 1;
181                }
182            }
183        }
184
185        if total == 0 {
186            return 1.0;
187        }
188
189        // Invert: high coherence = few crossings
190        1.0 - (crossing as f32 / total as f32)
191    }
192
193    /// Disagreement: variance in neighbor similarities
194    /// High coherence = neighbors have similar similarities
195    fn disagreement(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
196        if knn_graph.is_empty() || keys.is_empty() {
197            return 1.0;
198        }
199
200        let mut total_variance = 0.0f32;
201        let mut count = 0;
202
203        for (i, neighbors) in knn_graph.iter().enumerate() {
204            if neighbors.is_empty() {
205                continue;
206            }
207
208            // Similarities to neighbors
209            let sims: Vec<f32> = neighbors
210                .iter()
211                .map(|&j| Self::cosine_similarity(keys[i], keys[j]))
212                .collect();
213
214            let mean: f32 = sims.iter().sum::<f32>() / sims.len() as f32;
215            let variance: f32 =
216                sims.iter().map(|s| (s - mean) * (s - mean)).sum::<f32>() / sims.len() as f32;
217
218            total_variance += variance;
219            count += 1;
220        }
221
222        if count == 0 {
223            return 1.0;
224        }
225
226        // Low variance = high coherence
227        let avg_variance = total_variance / count as f32;
228        1.0 - avg_variance.min(1.0)
229    }
230
231    /// Similarity variance across window
232    fn similarity_variance(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
233        if knn_graph.is_empty() || keys.is_empty() {
234            return 1.0;
235        }
236
237        // Collect all neighbor similarities
238        let mut all_sims = Vec::new();
239        for (i, neighbors) in knn_graph.iter().enumerate() {
240            for &j in neighbors {
241                all_sims.push(Self::cosine_similarity(keys[i], keys[j]));
242            }
243        }
244
245        if all_sims.is_empty() {
246            return 1.0;
247        }
248
249        let mean: f32 = all_sims.iter().sum::<f32>() / all_sims.len() as f32;
250        let variance: f32 = all_sims
251            .iter()
252            .map(|s| (s - mean) * (s - mean))
253            .sum::<f32>()
254            / all_sims.len() as f32;
255
256        // Low variance + high mean = high coherence
257        let coherence = mean * (1.0 - variance.sqrt().min(1.0));
258        coherence.max(0.0).min(1.0)
259    }
260
261    /// Cosine similarity
262    #[inline]
263    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
264        let dot: f32 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
265        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
266        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
267
268        if norm_a < 1e-8 || norm_b < 1e-8 {
269            return 0.0;
270        }
271
272        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_coherence_computation() {
282        let keys: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.1; 32]).collect();
283        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
284
285        let coherence = WindowCoherence::compute(
286            &keys_refs,
287            5,
288            &[
289                CoherenceMetric::BoundaryMass,
290                CoherenceMetric::SimilarityVariance,
291            ],
292        );
293
294        assert!(coherence.score >= 0.0 && coherence.score <= 1.0);
295        assert_eq!(coherence.window_size, 20);
296    }
297
298    #[test]
299    fn test_coherent_window() {
300        // Highly similar keys = high coherence
301        let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.5f32; 16]).collect();
302        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
303
304        let coherence = WindowCoherence::compute(&keys_refs, 3, &[CoherenceMetric::Disagreement]);
305
306        // Should be very coherent
307        assert!(coherence.score > 0.8);
308    }
309
310    #[test]
311    fn test_stale_tracking() {
312        let keys: Vec<Vec<f32>> = vec![vec![1.0; 8]; 5];
313        let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
314
315        let mut coherence =
316            WindowCoherence::compute(&keys_refs, 2, &[CoherenceMetric::BoundaryMass]);
317
318        assert!(!coherence.needs_update(4));
319
320        coherence.tick();
321        coherence.tick();
322        coherence.tick();
323        coherence.tick();
324
325        assert!(coherence.needs_update(4));
326    }
327}