ruvector_attention/topology/
coherence.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum CoherenceMetric {
11 BoundaryMass,
13 CutProxy,
15 Disagreement,
17 SimilarityVariance,
19}
20
21#[derive(Debug, Clone)]
23pub struct WindowCoherence {
24 pub score: f32,
26 pub metric_scores: Vec<f32>,
28 pub metrics: Vec<CoherenceMetric>,
30 pub window_size: usize,
32 pub is_stale: bool,
34 pub tokens_since_update: usize,
36}
37
38impl WindowCoherence {
39 pub fn compute(keys: &[&[f32]], k_neighbors: usize, metrics: &[CoherenceMetric]) -> Self {
41 let n = keys.len();
42 if n < 2 {
43 return Self {
44 score: 1.0,
45 metric_scores: vec![1.0],
46 metrics: metrics.to_vec(),
47 window_size: n,
48 is_stale: false,
49 tokens_since_update: 0,
50 };
51 }
52
53 let knn_graph = Self::build_knn_graph(keys, k_neighbors);
55
56 let metric_scores: Vec<f32> = metrics
58 .iter()
59 .map(|m| Self::compute_metric(*m, keys, &knn_graph))
60 .collect();
61
62 let score = metric_scores.iter().sum::<f32>() / metric_scores.len() as f32;
64
65 Self {
66 score,
67 metric_scores,
68 metrics: metrics.to_vec(),
69 window_size: n,
70 is_stale: false,
71 tokens_since_update: 0,
72 }
73 }
74
75 pub fn mark_stale(&mut self) {
77 self.is_stale = true;
78 }
79
80 pub fn tick(&mut self) {
82 self.tokens_since_update += 1;
83 }
84
85 pub fn needs_update(&self, update_period: usize) -> bool {
87 self.is_stale || self.tokens_since_update >= update_period
88 }
89
90 fn build_knn_graph(keys: &[&[f32]], k: usize) -> Vec<Vec<usize>> {
93 let n = keys.len();
94 let k = k.min(n - 1);
95
96 keys.iter()
97 .enumerate()
98 .map(|(i, key)| {
99 let mut distances: Vec<(usize, f32)> = keys
100 .iter()
101 .enumerate()
102 .filter(|(j, _)| *j != i)
103 .map(|(j, k2)| (j, Self::squared_distance(key, k2)))
104 .collect();
105
106 distances.sort_unstable_by(|a, b| {
107 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
108 });
109
110 distances.iter().take(k).map(|(j, _)| *j).collect()
111 })
112 .collect()
113 }
114
115 #[inline]
117 fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
118 a.iter()
119 .zip(b.iter())
120 .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
121 .sum()
122 }
123
124 fn compute_metric(metric: CoherenceMetric, keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
126 match metric {
127 CoherenceMetric::BoundaryMass => Self::boundary_mass(knn_graph),
128 CoherenceMetric::CutProxy => Self::cut_proxy(knn_graph),
129 CoherenceMetric::Disagreement => Self::disagreement(keys, knn_graph),
130 CoherenceMetric::SimilarityVariance => Self::similarity_variance(keys, knn_graph),
131 }
132 }
133
134 fn boundary_mass(knn_graph: &[Vec<usize>]) -> f32 {
137 if knn_graph.is_empty() {
138 return 1.0;
139 }
140
141 let n = knn_graph.len();
142 let mut internal_edges = 0;
143 let mut total_edges = 0;
144
145 for (i, neighbors) in knn_graph.iter().enumerate() {
146 for &j in neighbors {
147 total_edges += 1;
148 if (i as i32 - j as i32).unsigned_abs() as usize <= n / 4 {
150 internal_edges += 1;
151 }
152 }
153 }
154
155 if total_edges == 0 {
156 return 1.0;
157 }
158
159 internal_edges as f32 / total_edges as f32
160 }
161
162 fn cut_proxy(knn_graph: &[Vec<usize>]) -> f32 {
165 if knn_graph.is_empty() {
166 return 1.0;
167 }
168
169 let n = knn_graph.len();
170 let half = n / 2;
171
172 let mut crossing = 0;
174 let mut total = 0;
175
176 for (i, neighbors) in knn_graph.iter().enumerate() {
177 for &j in neighbors {
178 total += 1;
179 if (i < half) != (j < half) {
180 crossing += 1;
181 }
182 }
183 }
184
185 if total == 0 {
186 return 1.0;
187 }
188
189 1.0 - (crossing as f32 / total as f32)
191 }
192
193 fn disagreement(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
196 if knn_graph.is_empty() || keys.is_empty() {
197 return 1.0;
198 }
199
200 let mut total_variance = 0.0f32;
201 let mut count = 0;
202
203 for (i, neighbors) in knn_graph.iter().enumerate() {
204 if neighbors.is_empty() {
205 continue;
206 }
207
208 let sims: Vec<f32> = neighbors
210 .iter()
211 .map(|&j| Self::cosine_similarity(keys[i], keys[j]))
212 .collect();
213
214 let mean: f32 = sims.iter().sum::<f32>() / sims.len() as f32;
215 let variance: f32 =
216 sims.iter().map(|s| (s - mean) * (s - mean)).sum::<f32>() / sims.len() as f32;
217
218 total_variance += variance;
219 count += 1;
220 }
221
222 if count == 0 {
223 return 1.0;
224 }
225
226 let avg_variance = total_variance / count as f32;
228 1.0 - avg_variance.min(1.0)
229 }
230
231 fn similarity_variance(keys: &[&[f32]], knn_graph: &[Vec<usize>]) -> f32 {
233 if knn_graph.is_empty() || keys.is_empty() {
234 return 1.0;
235 }
236
237 let mut all_sims = Vec::new();
239 for (i, neighbors) in knn_graph.iter().enumerate() {
240 for &j in neighbors {
241 all_sims.push(Self::cosine_similarity(keys[i], keys[j]));
242 }
243 }
244
245 if all_sims.is_empty() {
246 return 1.0;
247 }
248
249 let mean: f32 = all_sims.iter().sum::<f32>() / all_sims.len() as f32;
250 let variance: f32 = all_sims
251 .iter()
252 .map(|s| (s - mean) * (s - mean))
253 .sum::<f32>()
254 / all_sims.len() as f32;
255
256 let coherence = mean * (1.0 - variance.sqrt().min(1.0));
258 coherence.max(0.0).min(1.0)
259 }
260
261 #[inline]
263 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
264 let dot: f32 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
265 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
266 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
267
268 if norm_a < 1e-8 || norm_b < 1e-8 {
269 return 0.0;
270 }
271
272 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_coherence_computation() {
282 let keys: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.1; 32]).collect();
283 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
284
285 let coherence = WindowCoherence::compute(
286 &keys_refs,
287 5,
288 &[
289 CoherenceMetric::BoundaryMass,
290 CoherenceMetric::SimilarityVariance,
291 ],
292 );
293
294 assert!(coherence.score >= 0.0 && coherence.score <= 1.0);
295 assert_eq!(coherence.window_size, 20);
296 }
297
298 #[test]
299 fn test_coherent_window() {
300 let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![0.5f32; 16]).collect();
302 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
303
304 let coherence = WindowCoherence::compute(&keys_refs, 3, &[CoherenceMetric::Disagreement]);
305
306 assert!(coherence.score > 0.8);
308 }
309
310 #[test]
311 fn test_stale_tracking() {
312 let keys: Vec<Vec<f32>> = vec![vec![1.0; 8]; 5];
313 let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
314
315 let mut coherence =
316 WindowCoherence::compute(&keys_refs, 2, &[CoherenceMetric::BoundaryMass]);
317
318 assert!(!coherence.needs_update(4));
319
320 coherence.tick();
321 coherence.tick();
322 coherence.tick();
323 coherence.tick();
324
325 assert!(coherence.needs_update(4));
326 }
327}