ruvector_attention/topology/
coherence.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum CoherenceMetric {
11 BoundaryMass,
13 CutProxy,
15 Disagreement,
17 SimilarityVariance,
19}
20
21#[derive(Debug, Clone)]
23pub struct WindowCoherence {
24 pub score: f32,
26 pub metric_scores: Vec<f32>,
28 pub metrics: Vec<CoherenceMetric>,
30 pub window_size: usize,
32 pub is_stale: bool,
34 pub tokens_since_update: usize,
36}
37
38impl WindowCoherence {
39 pub fn compute(
41 keys: &[&[f32]],
42 k_neighbors: usize,
43 metrics: &[CoherenceMetric],
44 ) -> Self {
45 let n = keys.len();
46 if n < 2 {
47 return Self {
48 score: 1.0,
49 metric_scores: vec![1.0],
50 metrics: metrics.to_vec(),
51 window_size: n,
52 is_stale: false,
53 tokens_since_update: 0,
54 };
55 }
56
57 let knn_graph = Self::build_knn_graph(keys, k_neighbors);
59
60 let metric_scores: Vec<f32> = metrics
62 .iter()
63 .map(|m| Self::compute_metric(*m, keys, &knn_graph))
64 .collect();
65
66 let score = metric_scores.iter().sum::<f32>() / metric_scores.len() as f32;
68
69 Self {
70 score,
71 metric_scores,
72 metrics: metrics.to_vec(),
73 window_size: n,
74 is_stale: false,
75 tokens_since_update: 0,
76 }
77 }
78
79 pub fn mark_stale(&mut self) {
81 self.is_stale = true;
82 }
83
84 pub fn tick(&mut self) {
86 self.tokens_since_update += 1;
87 }
88
89 pub fn needs_update(&self, update_period: usize) -> bool {
91 self.is_stale || self.tokens_since_update >= update_period
92 }
93
94 fn build_knn_graph(keys: &[&[f32]], k: usize) -> Vec<Vec<usize>> {
97 let n = keys.len();
98 let k = k.min(n - 1);
99
100 keys.iter()
101 .enumerate()
102 .map(|(i, key)| {
103 let mut distances: Vec<(usize, f32)> = keys
104 .iter()
105 .enumerate()
106 .filter(|(j, _)| *j != i)
107 .map(|(j, k2)| (j, Self::squared_distance(key, k2)))
108 .collect();
109
110 distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
111
112 distances.iter().take(k).map(|(j, _)| *j).collect()
113 })
114 .collect()
115 }
116
117 #[inline]
119 fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
120 a.iter()
121 .zip(b.iter())
122 .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
123 .sum()
124 }
125
126 fn compute_metric(
128 metric: CoherenceMetric,
129 keys: &[&[f32]],
130 knn_graph: &[Vec<usize>],
131 ) -> f32 {
132 match metric {
133 CoherenceMetric::BoundaryMass => Self::boundary_mass(knn_graph),
134 CoherenceMetric::CutProxy => Self::cut_proxy(knn_graph),
135 CoherenceMetric::Disagreement => Self::disagreement(keys, knn_graph),
136 CoherenceMetric::SimilarityVariance => Self::similarity_variance(keys, knn_graph),
137 }
138 }
139
140 fn boundary_mass(knn_graph: &[Vec<usize>]) -> f32 {
143 if knn_graph.is_empty() {
144 return 1.0;
145 }
146
147 let n = knn_graph.len();
148 let mut internal_edges = 0;
149 let mut total_edges = 0;
150
151 for (i, neighbors) in knn_graph.iter().enumerate() {
152 for &j in neighbors {
153 total_edges += 1;
154 if (i as i32 - j as i32).unsigned_abs() as usize <= n / 4 {
156 internal_edges += 1;
157 }
158 }
159 }
160
161 if total_edges == 0 {
162 return 1.0;
163 }
164
165 internal_edges as f32 / total_edges as f32
166 }
167
168 fn cut_proxy(knn_graph: &[Vec<usize>]) -> f32 {
171 if knn_graph.is_empty() {
172 return 1.0;
173 }
174
175 let n = knn_graph.len();
176 let half = n / 2;
177
178 let mut crossing = 0;
180 let mut total = 0;
181
182 for (i, neighbors) in knn_graph.iter().enumerate() {
183 for &j in neighbors {
184 total += 1;
185 if (i < half) != (j < half) {
186 crossing += 1;
187 }
188 }
189 }
190
191 if total == 0 {
192 return 1.0;
193 }
194
195 1.0 - (crossing as f32 / total as f32)
197 }
198
199 fn disagreement(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
202 if knn_graph.is_empty() || keys.is_empty() {
203 return 1.0;
204 }
205
206 let mut total_variance = 0.0f32;
207 let mut count = 0;
208
209 for (i, neighbors) in knn_graph.iter().enumerate() {
210 if neighbors.is_empty() {
211 continue;
212 }
213
214 let sims: Vec<f32> = neighbors
216 .iter()
217 .map(|&j| Self::cosine_similarity(keys[i], keys[j]))
218 .collect();
219
220 let mean: f32 = sims.iter().sum::<f32>() / sims.len() as f32;
221 let variance: f32 = sims.iter().map(|s| (s - mean) * (s - mean)).sum::<f32>() / sims.len() as f32;
222
223 total_variance += variance;
224 count += 1;
225 }
226
227 if count == 0 {
228 return 1.0;
229 }
230
231 let avg_variance = total_variance / count as f32;
233 1.0 - avg_variance.min(1.0)
234 }
235
236 fn similarity_variance(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
238 if knn_graph.is_empty() || keys.is_empty() {
239 return 1.0;
240 }
241
242 let mut all_sims = Vec::new();
244 for (i, neighbors) in knn_graph.iter().enumerate() {
245 for &j in neighbors {
246 all_sims.push(Self::cosine_similarity(keys[i], keys[j]));
247 }
248 }
249
250 if all_sims.is_empty() {
251 return 1.0;
252 }
253
254 let mean: f32 = all_sims.iter().sum::<f32>() / all_sims.len() as f32;
255 let variance: f32 = all_sims.iter().map(|s| (s - mean) * (s - mean)).sum::<f32>() / all_sims.len() as f32;
256
257 let coherence = mean * (1.0 - variance.sqrt().min(1.0));
259 coherence.max(0.0).min(1.0)
260 }
261
262 #[inline]
264 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
265 let dot: f32 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
266 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
267 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
268
269 if norm_a < 1e-8 || norm_b < 1e-8 {
270 return 0.0;
271 }
272
273 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_coherence_computation() {
283 let keys: Vec<Vec<f32>> = (0..20)
284 .map(|i| vec![i as f32 * 0.1; 32])
285 .collect();
286 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
287
288 let coherence = WindowCoherence::compute(
289 &keys_refs,
290 5,
291 &[CoherenceMetric::BoundaryMass, CoherenceMetric::SimilarityVariance],
292 );
293
294 assert!(coherence.score >= 0.0 && coherence.score <= 1.0);
295 assert_eq!(coherence.window_size, 20);
296 }
297
298 #[test]
299 fn test_coherent_window() {
300 let keys: Vec<Vec<f32>> = (0..10)
302 .map(|_| vec![0.5f32; 16])
303 .collect();
304 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
305
306 let coherence = WindowCoherence::compute(
307 &keys_refs,
308 3,
309 &[CoherenceMetric::Disagreement],
310 );
311
312 assert!(coherence.score > 0.8);
314 }
315
316 #[test]
317 fn test_stale_tracking() {
318 let keys: Vec<Vec<f32>> = vec![vec![1.0; 8]; 5];
319 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
320
321 let mut coherence = WindowCoherence::compute(&keys_refs, 2, &[CoherenceMetric::BoundaryMass]);
322
323 assert!(!coherence.needs_update(4));
324
325 coherence.tick();
326 coherence.tick();
327 coherence.tick();
328 coherence.tick();
329
330 assert!(coherence.needs_update(4));
331 }
332}