ruvector_core/advanced/
tda.rs

1//! # Topological Data Analysis (TDA)
2//!
3//! Basic topological analysis for embedding quality assessment.
4//! Detects mode collapse, degeneracy, and topological structure.
5
6use crate::error::{Result, RuvectorError};
7use ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Topological analyzer for embeddings
12pub struct TopologicalAnalyzer {
13    /// k for k-nearest neighbors graph
14    k_neighbors: usize,
15    /// Distance threshold for edge creation
16    epsilon: f32,
17}
18
19impl TopologicalAnalyzer {
20    /// Create a new topological analyzer
21    pub fn new(k_neighbors: usize, epsilon: f32) -> Self {
22        Self {
23            k_neighbors,
24            epsilon,
25        }
26    }
27
28    /// Analyze embedding quality
29    pub fn analyze(&self, embeddings: &[Vec<f32>]) -> Result<EmbeddingQuality> {
30        if embeddings.is_empty() {
31            return Err(RuvectorError::InvalidInput("Empty embeddings".into()));
32        }
33
34        let n = embeddings.len();
35        let dim = embeddings[0].len();
36
37        // Build k-NN graph
38        let graph = self.build_knn_graph(embeddings);
39
40        // Compute topological features
41        let connected_components = self.count_connected_components(&graph, n);
42        let clustering_coefficient = self.compute_clustering_coefficient(&graph);
43        let degree_stats = self.compute_degree_statistics(&graph, n);
44
45        // Detect mode collapse
46        let mode_collapse_score = self.detect_mode_collapse(embeddings);
47
48        // Compute embedding spread
49        let spread = self.compute_spread(embeddings);
50
51        // Detect degeneracy (vectors collapsing to a lower-dimensional manifold)
52        let degeneracy_score = self.detect_degeneracy(embeddings);
53
54        // Compute persistence features (simplified)
55        let persistence_score = self.compute_persistence(&graph, embeddings);
56
57        // Overall quality score (0-1, higher is better)
58        let quality_score = self.compute_quality_score(
59            mode_collapse_score,
60            degeneracy_score,
61            connected_components,
62            clustering_coefficient,
63            spread,
64        );
65
66        Ok(EmbeddingQuality {
67            dimensions: dim,
68            num_vectors: n,
69            connected_components,
70            clustering_coefficient,
71            avg_degree: degree_stats.0,
72            degree_std: degree_stats.1,
73            mode_collapse_score,
74            degeneracy_score,
75            spread,
76            persistence_score,
77            quality_score,
78        })
79    }
80
81    fn build_knn_graph(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<usize>> {
82        let n = embeddings.len();
83        let mut graph = vec![Vec::new(); n];
84
85        for i in 0..n {
86            let mut distances: Vec<(usize, f32)> = (0..n)
87                .filter(|&j| i != j)
88                .map(|j| {
89                    let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
90                    (j, dist)
91                })
92                .collect();
93
94            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
95
96            // Add k nearest neighbors
97            for (j, dist) in distances.iter().take(self.k_neighbors) {
98                if *dist <= self.epsilon {
99                    graph[i].push(*j);
100                }
101            }
102        }
103
104        graph
105    }
106
107    fn count_connected_components(&self, graph: &[Vec<usize>], n: usize) -> usize {
108        let mut visited = vec![false; n];
109        let mut components = 0;
110
111        for i in 0..n {
112            if !visited[i] {
113                components += 1;
114                self.dfs(i, graph, &mut visited);
115            }
116        }
117
118        components
119    }
120
121    fn dfs(&self, node: usize, graph: &[Vec<usize>], visited: &mut [bool]) {
122        visited[node] = true;
123        for &neighbor in &graph[node] {
124            if !visited[neighbor] {
125                self.dfs(neighbor, graph, visited);
126            }
127        }
128    }
129
130    fn compute_clustering_coefficient(&self, graph: &[Vec<usize>]) -> f32 {
131        let mut total_coeff = 0.0;
132        let mut count = 0;
133
134        for neighbors in graph {
135            if neighbors.len() < 2 {
136                continue;
137            }
138
139            let k = neighbors.len();
140            let mut triangles = 0;
141
142            // Count triangles
143            for i in 0..k {
144                for j in i + 1..k {
145                    let ni = neighbors[i];
146                    let nj = neighbors[j];
147
148                    if graph[ni].contains(&nj) {
149                        triangles += 1;
150                    }
151                }
152            }
153
154            let possible_triangles = k * (k - 1) / 2;
155            if possible_triangles > 0 {
156                total_coeff += triangles as f32 / possible_triangles as f32;
157                count += 1;
158            }
159        }
160
161        if count > 0 {
162            total_coeff / count as f32
163        } else {
164            0.0
165        }
166    }
167
168    fn compute_degree_statistics(&self, graph: &[Vec<usize>], n: usize) -> (f32, f32) {
169        let degrees: Vec<f32> = graph
170            .iter()
171            .map(|neighbors| neighbors.len() as f32)
172            .collect();
173
174        let avg = degrees.iter().sum::<f32>() / n as f32;
175        let variance = degrees.iter().map(|&d| (d - avg).powi(2)).sum::<f32>() / n as f32;
176        let std = variance.sqrt();
177
178        (avg, std)
179    }
180
181    fn detect_mode_collapse(&self, embeddings: &[Vec<f32>]) -> f32 {
182        // Compute pairwise distances
183        let n = embeddings.len();
184        let mut distances = Vec::new();
185
186        for i in 0..n {
187            for j in i + 1..n {
188                let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
189                distances.push(dist);
190            }
191        }
192
193        if distances.is_empty() {
194            return 0.0;
195        }
196
197        // Compute coefficient of variation
198        let mean = distances.iter().sum::<f32>() / distances.len() as f32;
199        let variance =
200            distances.iter().map(|&d| (d - mean).powi(2)).sum::<f32>() / distances.len() as f32;
201        let std = variance.sqrt();
202
203        // High CV indicates good separation, low CV indicates collapse
204        let cv = if mean > 0.0 { std / mean } else { 0.0 };
205
206        // Normalize to 0-1, where 0 is collapsed, 1 is good
207        (cv * 2.0).min(1.0)
208    }
209
210    fn compute_spread(&self, embeddings: &[Vec<f32>]) -> f32 {
211        if embeddings.is_empty() {
212            return 0.0;
213        }
214
215        let dim = embeddings[0].len();
216
217        // Compute mean
218        let mut mean = vec![0.0; dim];
219        for emb in embeddings {
220            for (i, &val) in emb.iter().enumerate() {
221                mean[i] += val;
222            }
223        }
224        for val in mean.iter_mut() {
225            *val /= embeddings.len() as f32;
226        }
227
228        // Compute average distance from mean
229        let mut total_dist = 0.0;
230        for emb in embeddings {
231            let dist = euclidean_distance(emb, &mean);
232            total_dist += dist;
233        }
234
235        total_dist / embeddings.len() as f32
236    }
237
238    fn detect_degeneracy(&self, embeddings: &[Vec<f32>]) -> f32 {
239        if embeddings.is_empty() || embeddings[0].is_empty() {
240            return 1.0; // Fully degenerate
241        }
242
243        let n = embeddings.len();
244        let dim = embeddings[0].len();
245
246        if n < dim {
247            return 0.0; // Cannot determine
248        }
249
250        // Compute covariance matrix
251        let cov = self.compute_covariance_matrix(embeddings);
252
253        // Estimate rank by counting significant singular values
254        let singular_values = self.approximate_singular_values(&cov);
255
256        let significant = singular_values.iter().filter(|&&sv| sv > 1e-6).count();
257
258        // Degeneracy score: 0 = full rank, 1 = rank-1 (collapsed)
259        1.0 - (significant as f32 / dim as f32)
260    }
261
262    fn compute_covariance_matrix(&self, embeddings: &[Vec<f32>]) -> Array2<f32> {
263        let n = embeddings.len();
264        let dim = embeddings[0].len();
265
266        // Compute mean
267        let mut mean = vec![0.0; dim];
268        for emb in embeddings {
269            for (i, &val) in emb.iter().enumerate() {
270                mean[i] += val;
271            }
272        }
273        for val in mean.iter_mut() {
274            *val /= n as f32;
275        }
276
277        // Compute covariance
278        let mut cov = Array2::zeros((dim, dim));
279        for emb in embeddings {
280            for i in 0..dim {
281                for j in 0..dim {
282                    cov[[i, j]] += (emb[i] - mean[i]) * (emb[j] - mean[j]);
283                }
284            }
285        }
286
287        cov.mapv(|x| x / (n - 1) as f32);
288        cov
289    }
290
291    fn approximate_singular_values(&self, matrix: &Array2<f32>) -> Vec<f32> {
292        // Power iteration for largest singular values (simplified)
293        let dim = matrix.nrows();
294        let mut values = Vec::new();
295
296        // Just return diagonal for approximation
297        for i in 0..dim {
298            values.push(matrix[[i, i]].abs());
299        }
300
301        values.sort_by(|a, b| b.partial_cmp(a).unwrap());
302        values
303    }
304
305    fn compute_persistence(&self, _graph: &[Vec<usize>], embeddings: &[Vec<f32>]) -> f32 {
306        // Simplified persistence: measure how graph structure changes with distance threshold
307        let scales = vec![0.1, 0.5, 1.0, 2.0, 5.0];
308        let mut component_counts = Vec::new();
309
310        for &scale in &scales {
311            let scaled_analyzer = TopologicalAnalyzer::new(self.k_neighbors, scale);
312            let scaled_graph = scaled_analyzer.build_knn_graph(embeddings);
313            let components =
314                scaled_analyzer.count_connected_components(&scaled_graph, embeddings.len());
315            component_counts.push(components);
316        }
317
318        // Persistence is the variation in component count across scales
319        let max_components = *component_counts.iter().max().unwrap_or(&1);
320        let min_components = *component_counts.iter().min().unwrap_or(&1);
321
322        (max_components - min_components) as f32 / max_components as f32
323    }
324
325    fn compute_quality_score(
326        &self,
327        mode_collapse: f32,
328        degeneracy: f32,
329        components: usize,
330        clustering: f32,
331        spread: f32,
332    ) -> f32 {
333        // Weighted combination of metrics
334        let collapse_score = mode_collapse; // Higher is better
335        let degeneracy_score = 1.0 - degeneracy; // Lower degeneracy is better
336        let component_score = if components == 1 { 1.0 } else { 0.5 }; // Single component is good
337        let clustering_score = clustering; // Higher clustering is good
338        let spread_score = (spread / 10.0).min(1.0); // Reasonable spread
339
340        (collapse_score * 0.3
341            + degeneracy_score * 0.3
342            + component_score * 0.2
343            + clustering_score * 0.1
344            + spread_score * 0.1)
345            .clamp(0.0, 1.0)
346    }
347}
348
349/// Embedding quality metrics
350#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct EmbeddingQuality {
352    /// Embedding dimensions
353    pub dimensions: usize,
354    /// Number of vectors
355    pub num_vectors: usize,
356    /// Number of connected components
357    pub connected_components: usize,
358    /// Clustering coefficient (0-1)
359    pub clustering_coefficient: f32,
360    /// Average node degree
361    pub avg_degree: f32,
362    /// Degree standard deviation
363    pub degree_std: f32,
364    /// Mode collapse score (0=collapsed, 1=good)
365    pub mode_collapse_score: f32,
366    /// Degeneracy score (0=full rank, 1=degenerate)
367    pub degeneracy_score: f32,
368    /// Average spread from centroid
369    pub spread: f32,
370    /// Topological persistence score
371    pub persistence_score: f32,
372    /// Overall quality (0-1, higher is better)
373    pub quality_score: f32,
374}
375
376impl EmbeddingQuality {
377    /// Check if embeddings show signs of mode collapse
378    pub fn has_mode_collapse(&self) -> bool {
379        self.mode_collapse_score < 0.3
380    }
381
382    /// Check if embeddings are degenerate
383    pub fn is_degenerate(&self) -> bool {
384        self.degeneracy_score > 0.7
385    }
386
387    /// Check if embeddings are well-structured
388    pub fn is_good_quality(&self) -> bool {
389        self.quality_score > 0.7
390    }
391
392    /// Get quality assessment
393    pub fn assessment(&self) -> &str {
394        if self.quality_score > 0.8 {
395            "Excellent"
396        } else if self.quality_score > 0.6 {
397            "Good"
398        } else if self.quality_score > 0.4 {
399            "Fair"
400        } else {
401            "Poor"
402        }
403    }
404}
405
406fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
407    a.iter()
408        .zip(b.iter())
409        .map(|(x, y)| (x - y).powi(2))
410        .sum::<f32>()
411        .sqrt()
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_embedding_analysis() {
420        let analyzer = TopologicalAnalyzer::new(3, 5.0);
421
422        // Create well-separated embeddings
423        let embeddings = vec![
424            vec![0.0, 0.0],
425            vec![0.1, 0.1],
426            vec![0.2, 0.2],
427            vec![5.0, 5.0],
428            vec![5.1, 5.1],
429        ];
430
431        let quality = analyzer.analyze(&embeddings).unwrap();
432
433        assert_eq!(quality.dimensions, 2);
434        assert_eq!(quality.num_vectors, 5);
435        assert!(quality.quality_score > 0.0);
436    }
437
438    #[test]
439    fn test_mode_collapse_detection() {
440        let analyzer = TopologicalAnalyzer::new(2, 10.0);
441
442        // Well-separated embeddings (high CV should give high score)
443        let good = vec![vec![0.0, 0.0], vec![5.0, 5.0], vec![10.0, 10.0]];
444        let score_good = analyzer.detect_mode_collapse(&good);
445
446        // Collapsed embeddings (all identical, CV = 0)
447        let collapsed = vec![vec![1.0, 1.0], vec![1.0, 1.0], vec![1.0, 1.0]];
448        let score_collapsed = analyzer.detect_mode_collapse(&collapsed);
449
450        // Identical vectors should have score 0 (distances all same = CV 0)
451        assert_eq!(score_collapsed, 0.0);
452
453        // Well-separated should have higher score
454        assert!(score_good > score_collapsed);
455    }
456
457    #[test]
458    fn test_connected_components() {
459        let analyzer = TopologicalAnalyzer::new(1, 1.0);
460
461        // Two separate clusters
462        let embeddings = vec![
463            vec![0.0, 0.0],
464            vec![0.5, 0.5],
465            vec![10.0, 10.0],
466            vec![10.5, 10.5],
467        ];
468
469        let graph = analyzer.build_knn_graph(&embeddings);
470        let components = analyzer.count_connected_components(&graph, embeddings.len());
471
472        assert!(components >= 2); // Should have at least 2 components
473    }
474
475    #[test]
476    fn test_quality_assessment() {
477        let quality = EmbeddingQuality {
478            dimensions: 128,
479            num_vectors: 1000,
480            connected_components: 1,
481            clustering_coefficient: 0.6,
482            avg_degree: 5.0,
483            degree_std: 1.2,
484            mode_collapse_score: 0.8,
485            degeneracy_score: 0.2,
486            spread: 3.5,
487            persistence_score: 0.4,
488            quality_score: 0.75,
489        };
490
491        assert!(!quality.has_mode_collapse());
492        assert!(!quality.is_degenerate());
493        assert!(quality.is_good_quality());
494        assert_eq!(quality.assessment(), "Good");
495    }
496}