sklears_clustering/
sparse_matrix.rs

1//! Sparse matrix representations for large-scale clustering
2//!
3//! This module provides sparse matrix data structures and algorithms optimized
4//! for clustering applications where many distances are zero or above a threshold.
5//! This is particularly useful for high-dimensional data or when using distance
6//! thresholds to create neighborhood graphs.
7
8use std::collections::{HashMap, VecDeque};
9
10use sklears_core::{
11    error::{Result, SklearsError},
12    types::{Array2, Float},
13};
14
15use crate::simd_distances::{simd_distance, SimdDistanceMetric};
16
17/// Sparse matrix entry with row, column, and value
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub struct SparseEntry {
20    pub row: usize,
21    pub col: usize,
22    pub value: Float,
23}
24
25/// Compressed Sparse Row (CSR) matrix for efficient sparse distance storage
26#[derive(Debug, Clone)]
27pub struct SparseDistanceMatrix {
28    /// Values of non-zero entries
29    values: Vec<Float>,
30    /// Column indices for each value
31    col_indices: Vec<usize>,
32    /// Row pointers (cumulative count of entries per row)
33    row_pointers: Vec<usize>,
34    /// Matrix dimensions
35    n_rows: usize,
36    n_cols: usize,
37    /// Number of non-zero entries
38    nnz: usize,
39    /// Whether the matrix is symmetric
40    symmetric: bool,
41}
42
43/// Configuration for sparse matrix creation
44#[derive(Debug, Clone)]
45pub struct SparseMatrixConfig {
46    /// Distance threshold - distances above this are not stored
47    pub distance_threshold: Float,
48    /// Distance metric to use
49    pub metric: SimdDistanceMetric,
50    /// Initial capacity hint for sparse entries
51    pub initial_capacity: Option<usize>,
52    /// Whether to enforce symmetry
53    pub symmetric: bool,
54    /// Sparsity threshold (fraction of entries that must be zero to use sparse)
55    pub sparsity_threshold: Float,
56}
57
58impl Default for SparseMatrixConfig {
59    fn default() -> Self {
60        Self {
61            distance_threshold: Float::INFINITY,
62            metric: SimdDistanceMetric::Euclidean,
63            initial_capacity: None,
64            symmetric: true,
65            sparsity_threshold: 0.5, // 50% sparse to justify sparse storage
66        }
67    }
68}
69
70impl SparseDistanceMatrix {
71    /// Create a new empty sparse distance matrix
72    ///
73    /// # Arguments
74    /// * `n_rows` - Number of rows
75    /// * `n_cols` - Number of columns
76    /// * `symmetric` - Whether the matrix should be symmetric
77    pub fn new(n_rows: usize, n_cols: usize, symmetric: bool) -> Self {
78        let row_pointers = vec![0; n_rows + 1];
79
80        Self {
81            values: Vec::new(),
82            col_indices: Vec::new(),
83            row_pointers,
84            n_rows,
85            n_cols,
86            nnz: 0,
87            symmetric,
88        }
89    }
90
91    /// Create sparse distance matrix from dense data
92    ///
93    /// # Arguments
94    /// * `data` - Input data matrix (n_samples × n_features)
95    /// * `config` - Configuration for sparse matrix creation
96    pub fn from_data(data: &Array2<Float>, config: SparseMatrixConfig) -> Result<Self> {
97        let n_samples = data.nrows();
98
99        // First pass: count non-zero entries to estimate sparsity
100        let entry_count = 0;
101        let total_entries = if config.symmetric {
102            (n_samples * (n_samples - 1)) / 2
103        } else {
104            n_samples * n_samples
105        };
106
107        // Sample a subset to estimate sparsity
108        let sample_size = (n_samples / 10).max(10).min(100);
109        let mut sampled_entries = 0;
110        let mut sampled_nonzero = 0;
111
112        let step_size = (n_samples / sample_size).max(1);
113        for i in (0..n_samples).step_by(step_size) {
114            let j_start = if config.symmetric { i + 1 } else { 0 };
115            for j in (j_start..n_samples).step_by(step_size) {
116                if i != j {
117                    let row_i = data.row(i);
118                    let row_j = data.row(j);
119                    let distance = simd_distance(&row_i, &row_j, config.metric).map_err(|e| {
120                        SklearsError::NumericalError(format!(
121                            "SIMD distance computation failed: {}",
122                            e
123                        ))
124                    })?;
125
126                    sampled_entries += 1;
127                    if distance <= config.distance_threshold && distance > 0.0 {
128                        sampled_nonzero += 1;
129                    }
130                }
131            }
132        }
133
134        // Estimate sparsity
135        let estimated_sparsity = if sampled_entries > 0 {
136            1.0 - (sampled_nonzero as Float / sampled_entries as Float)
137        } else {
138            0.0
139        };
140
141        // Check if sparse representation is beneficial
142        if estimated_sparsity < config.sparsity_threshold {
143            return Err(SklearsError::InvalidInput(format!(
144                "Data is not sparse enough ({:.2}% sparse) for sparse representation. Consider using dense matrix.",
145                estimated_sparsity * 100.0
146            )));
147        }
148
149        // Create sparse matrix
150        let mut sparse_matrix = Self::new(n_samples, n_samples, config.symmetric);
151
152        // Pre-allocate based on estimated non-zero count
153        let estimated_nnz = (total_entries as Float * (1.0 - estimated_sparsity)) as usize;
154        let capacity = config.initial_capacity.unwrap_or(estimated_nnz);
155        sparse_matrix.values.reserve(capacity);
156        sparse_matrix.col_indices.reserve(capacity);
157
158        // Build sparse matrix row by row
159        let mut entries_buffer = Vec::new();
160
161        for i in 0..n_samples {
162            let j_start = if config.symmetric { i + 1 } else { 0 };
163
164            entries_buffer.clear();
165
166            for j in j_start..n_samples {
167                if i != j {
168                    let row_i = data.row(i);
169                    let row_j = data.row(j);
170                    let distance = simd_distance(&row_i, &row_j, config.metric).map_err(|e| {
171                        SklearsError::NumericalError(format!(
172                            "SIMD distance computation failed: {}",
173                            e
174                        ))
175                    })?;
176
177                    if distance <= config.distance_threshold && distance > 0.0 {
178                        entries_buffer.push((j, distance));
179                    }
180                }
181            }
182
183            // Sort entries by column index for CSR format
184            entries_buffer.sort_by_key(|&(col, _)| col);
185
186            // Add entries to sparse matrix
187            for (col, value) in entries_buffer.iter() {
188                sparse_matrix.values.push(*value);
189                sparse_matrix.col_indices.push(*col);
190                sparse_matrix.nnz += 1;
191            }
192
193            sparse_matrix.row_pointers[i + 1] = sparse_matrix.nnz;
194
195            // Log progress for large datasets
196            if (i + 1) % 1000 == 0 || i == n_samples - 1 {
197                let progress = (i + 1) as f64 / n_samples as f64 * 100.0;
198                eprintln!("Building sparse matrix: {:.1}% complete", progress);
199            }
200        }
201
202        eprintln!(
203            "Created sparse matrix: {} non-zero entries out of {} total ({:.2}% sparse)",
204            sparse_matrix.nnz,
205            total_entries,
206            (1.0 - sparse_matrix.nnz as f64 / total_entries as f64) * 100.0
207        );
208
209        Ok(sparse_matrix)
210    }
211
212    /// Get the value at position (row, col)
213    pub fn get(&self, row: usize, col: usize) -> Float {
214        if row >= self.n_rows || col >= self.n_cols {
215            return 0.0;
216        }
217
218        if row == col {
219            return 0.0; // Diagonal is always zero for distance matrices
220        }
221
222        // Handle symmetry
223        let (search_row, search_col) = if self.symmetric && row > col {
224            (col, row)
225        } else {
226            (row, col)
227        };
228
229        // If symmetric and searching in lower triangle, return 0
230        if self.symmetric && search_row > search_col {
231            return 0.0;
232        }
233
234        // Binary search in the row
235        let start = self.row_pointers[search_row];
236        let end = self.row_pointers[search_row + 1];
237
238        match self.col_indices[start..end].binary_search(&search_col) {
239            Ok(idx) => self.values[start + idx],
240            Err(_) => 0.0,
241        }
242    }
243
244    /// Set the value at position (row, col)
245    ///
246    /// Note: This is an expensive operation for CSR matrices as it may require
247    /// rebuilding the entire structure. Use during construction only.
248    pub fn set(&mut self, row: usize, col: usize, value: Float) -> Result<()> {
249        if row >= self.n_rows || col >= self.n_cols {
250            return Err(SklearsError::InvalidInput(format!(
251                "Index ({}, {}) out of bounds for matrix {}×{}",
252                row, col, self.n_rows, self.n_cols
253            )));
254        }
255
256        if row == col && value != 0.0 {
257            return Err(SklearsError::InvalidInput(
258                "Cannot set non-zero diagonal element in distance matrix".to_string(),
259            ));
260        }
261
262        // For now, we'll rebuild the matrix with the new value
263        // This is inefficient but correct
264        self.set_and_rebuild(row, col, value)
265    }
266
267    /// Set value and rebuild matrix structure (expensive operation)
268    fn set_and_rebuild(&mut self, row: usize, col: usize, value: Float) -> Result<()> {
269        // Convert to COO format, modify, and convert back
270        let mut entries = Vec::new();
271
272        // Extract existing entries
273        for r in 0..self.n_rows {
274            let start = self.row_pointers[r];
275            let end = self.row_pointers[r + 1];
276
277            for idx in start..end {
278                let c = self.col_indices[idx];
279                let v = self.values[idx];
280
281                if !(r == row && c == col) {
282                    entries.push((r, c, v));
283                }
284            }
285        }
286
287        // Add new entry if non-zero
288        if value != 0.0 {
289            entries.push((row, col, value));
290        }
291
292        // Handle symmetry
293        if self.symmetric && value != 0.0 && row != col {
294            entries.push((col, row, value));
295        }
296
297        // Sort entries by (row, col)
298        entries.sort_by_key(|&(r, c, _)| (r, c));
299
300        // Rebuild CSR structure
301        self.values.clear();
302        self.col_indices.clear();
303        self.row_pointers.fill(0);
304        self.nnz = 0;
305
306        let mut current_row = 0;
307        for (r, c, v) in entries {
308            // Update row pointers
309            while current_row <= r {
310                self.row_pointers[current_row] = self.nnz;
311                current_row += 1;
312            }
313
314            self.values.push(v);
315            self.col_indices.push(c);
316            self.nnz += 1;
317        }
318
319        // Fill remaining row pointers
320        while current_row <= self.n_rows {
321            self.row_pointers[current_row] = self.nnz;
322            current_row += 1;
323        }
324
325        Ok(())
326    }
327
328    /// Get all non-zero entries in a row
329    pub fn row_entries(&self, row: usize) -> Vec<(usize, Float)> {
330        if row >= self.n_rows {
331            return Vec::new();
332        }
333
334        let start = self.row_pointers[row];
335        let end = self.row_pointers[row + 1];
336
337        let mut entries = Vec::new();
338        for idx in start..end {
339            entries.push((self.col_indices[idx], self.values[idx]));
340        }
341
342        entries
343    }
344
345    /// Get k-nearest neighbors for a specific row
346    ///
347    /// # Arguments
348    /// * `row` - Row index to find neighbors for
349    /// * `k` - Number of nearest neighbors to find
350    ///
351    /// # Returns
352    /// Vector of (column_index, distance) pairs sorted by distance
353    pub fn k_nearest_neighbors(&self, row: usize, k: usize) -> Vec<(usize, Float)> {
354        if row >= self.n_rows {
355            return Vec::new();
356        }
357
358        let mut neighbors = self.row_entries(row);
359
360        // For symmetric matrices, also check the column (avoiding duplicates)
361        if self.symmetric {
362            for other_row in 0..self.n_rows {
363                if other_row != row {
364                    let value = self.get(other_row, row);
365                    if value > 0.0 && !neighbors.iter().any(|(idx, _)| *idx == other_row) {
366                        neighbors.push((other_row, value));
367                    }
368                }
369            }
370        }
371
372        // Sort by distance and take k nearest
373        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
374        neighbors.truncate(k);
375
376        neighbors
377    }
378
379    /// Get all neighbors within a specific radius
380    pub fn neighbors_within_radius(&self, row: usize, radius: Float) -> Vec<(usize, Float)> {
381        if row >= self.n_rows {
382            return Vec::new();
383        }
384
385        let mut neighbors = Vec::new();
386
387        // Get entries from the row
388        for (col, distance) in self.row_entries(row) {
389            if distance <= radius {
390                neighbors.push((col, distance));
391            }
392        }
393
394        // For symmetric matrices, also check the column (avoiding duplicates)
395        if self.symmetric {
396            for other_row in 0..self.n_rows {
397                if other_row != row {
398                    let distance = self.get(other_row, row);
399                    if distance > 0.0
400                        && distance <= radius
401                        && !neighbors.iter().any(|(idx, _)| *idx == other_row)
402                    {
403                        neighbors.push((other_row, distance));
404                    }
405                }
406            }
407        }
408
409        // Sort by distance
410        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
411
412        neighbors
413    }
414
415    /// Convert to dense matrix (for small matrices or debugging)
416    pub fn to_dense(&self) -> Array2<Float> {
417        let mut dense = Array2::zeros((self.n_rows, self.n_cols));
418
419        for row in 0..self.n_rows {
420            let start = self.row_pointers[row];
421            let end = self.row_pointers[row + 1];
422
423            for idx in start..end {
424                let col = self.col_indices[idx];
425                let value = self.values[idx];
426                dense[[row, col]] = value;
427
428                // Handle symmetry
429                if self.symmetric && row != col {
430                    dense[[col, row]] = value;
431                }
432            }
433        }
434
435        dense
436    }
437
438    /// Get matrix statistics
439    pub fn stats(&self) -> SparseMatrixStats {
440        let total_entries = self.n_rows * self.n_cols;
441        let sparsity = 1.0 - (self.nnz as f64 / total_entries as f64);
442
443        let dense_memory = total_entries * std::mem::size_of::<Float>();
444        let sparse_memory = self.values.len() * std::mem::size_of::<Float>()
445            + self.col_indices.len() * std::mem::size_of::<usize>()
446            + self.row_pointers.len() * std::mem::size_of::<usize>();
447
448        let memory_savings = 1.0 - (sparse_memory as f64 / dense_memory as f64);
449
450        SparseMatrixStats {
451            n_rows: self.n_rows,
452            n_cols: self.n_cols,
453            nnz: self.nnz,
454            total_entries,
455            sparsity,
456            dense_memory_bytes: dense_memory,
457            sparse_memory_bytes: sparse_memory,
458            memory_savings,
459            symmetric: self.symmetric,
460        }
461    }
462
463    /// Matrix dimensions
464    pub fn shape(&self) -> (usize, usize) {
465        (self.n_rows, self.n_cols)
466    }
467
468    /// Number of non-zero entries
469    pub fn nnz(&self) -> usize {
470        self.nnz
471    }
472
473    /// Check if matrix is symmetric
474    pub fn is_symmetric(&self) -> bool {
475        self.symmetric
476    }
477
478    /// Advanced: Get connected components using union-find algorithm
479    /// Returns vector where each element is the component ID for that vertex
480    pub fn connected_components(&self) -> Vec<usize> {
481        let mut parent: Vec<usize> = (0..self.n_rows).collect();
482        let mut rank = vec![0; self.n_rows];
483
484        // Union-Find find with path compression
485        fn find(parent: &mut [usize], x: usize) -> usize {
486            if parent[x] != x {
487                parent[x] = find(parent, parent[x]);
488            }
489            parent[x]
490        }
491
492        // Union-Find union by rank
493        fn union(parent: &mut [usize], rank: &mut [usize], x: usize, y: usize) {
494            let root_x = find(parent, x);
495            let root_y = find(parent, y);
496
497            if root_x != root_y {
498                if rank[root_x] < rank[root_y] {
499                    parent[root_x] = root_y;
500                } else if rank[root_x] > rank[root_y] {
501                    parent[root_y] = root_x;
502                } else {
503                    parent[root_y] = root_x;
504                    rank[root_x] += 1;
505                }
506            }
507        }
508
509        // Process all edges
510        for i in 0..self.n_rows {
511            let start = self.row_pointers[i];
512            let end = self.row_pointers[i + 1];
513
514            for idx in start..end {
515                let j = self.col_indices[idx];
516                if i < j {
517                    // Process each edge only once
518                    union(&mut parent, &mut rank, i, j);
519                }
520            }
521        }
522
523        // Assign component IDs
524        let mut component_id = HashMap::new();
525        let mut next_id = 0;
526        let mut result = vec![0; self.n_rows];
527
528        for i in 0..self.n_rows {
529            let root = find(&mut parent, i);
530            let id = *component_id.entry(root).or_insert_with(|| {
531                let id = next_id;
532                next_id += 1;
533                id
534            });
535            result[i] = id;
536        }
537
538        result
539    }
540
541    /// Advanced: Compute graph diameter (longest shortest path)
542    /// Uses BFS from multiple starting points for efficiency
543    pub fn graph_diameter(&self) -> Option<usize> {
544        if self.n_rows == 0 {
545            return None;
546        }
547
548        let mut max_distance = 0;
549        let n_samples = (self.n_rows as f64).sqrt() as usize + 1;
550
551        // Sample starting vertices for diameter approximation
552        for start in (0..self.n_rows).step_by(self.n_rows / n_samples.max(1)) {
553            let distances = self.bfs_distances(start);
554            // Find max distance that is not usize::MAX (unreachable)
555            if let Some(max_dist) = distances.iter().filter(|&&d| d != usize::MAX).max() {
556                max_distance = max_distance.max(*max_dist);
557            }
558        }
559
560        Some(max_distance)
561    }
562
563    /// BFS to compute distances from a source vertex
564    fn bfs_distances(&self, source: usize) -> Vec<usize> {
565        let mut distances = vec![usize::MAX; self.n_rows];
566        let mut queue = VecDeque::new();
567
568        distances[source] = 0;
569        queue.push_back(source);
570
571        while let Some(vertex) = queue.pop_front() {
572            let start = self.row_pointers[vertex];
573            let end = self.row_pointers[vertex + 1];
574
575            for idx in start..end {
576                let neighbor = self.col_indices[idx];
577                if distances[neighbor] == usize::MAX {
578                    distances[neighbor] = distances[vertex] + 1;
579                    queue.push_back(neighbor);
580                }
581            }
582        }
583
584        distances
585    }
586
587    /// Advanced: Compute clustering coefficient for a vertex
588    /// Measures how connected a vertex's neighbors are to each other
589    pub fn clustering_coefficient(&self, vertex: usize) -> Float {
590        if vertex >= self.n_rows {
591            return 0.0;
592        }
593
594        let neighbors = self.row_entries(vertex);
595        let degree = neighbors.len();
596
597        if degree < 2 {
598            return 0.0;
599        }
600
601        let mut triangles = 0;
602
603        // Count triangles: edges between neighbors
604        for i in 0..neighbors.len() {
605            for j in (i + 1)..neighbors.len() {
606                let neighbor1 = neighbors[i].0;
607                let neighbor2 = neighbors[j].0;
608
609                // Check if neighbor1 and neighbor2 are connected
610                if self.get(neighbor1, neighbor2) > 0.0 {
611                    triangles += 1;
612                }
613            }
614        }
615
616        // Clustering coefficient = 2 * triangles / (degree * (degree - 1))
617        2.0 * triangles as Float / (degree * (degree - 1)) as Float
618    }
619
620    /// Advanced: Compute average clustering coefficient for the entire graph
621    pub fn average_clustering_coefficient(&self) -> Float {
622        let mut total_coefficient = 0.0;
623        let mut valid_vertices = 0;
624
625        for vertex in 0..self.n_rows {
626            let coefficient = self.clustering_coefficient(vertex);
627            if coefficient.is_finite() {
628                total_coefficient += coefficient;
629                valid_vertices += 1;
630            }
631        }
632
633        if valid_vertices > 0 {
634            total_coefficient / valid_vertices as Float
635        } else {
636            0.0
637        }
638    }
639
640    /// Advanced: Approximate betweenness centrality using sampling
641    /// Measures how often a vertex lies on shortest paths between other vertices
642    pub fn approximate_betweenness_centrality(&self, sample_size: usize) -> Vec<Float> {
643        let mut centrality = vec![0.0; self.n_rows];
644        let sample_vertices: Vec<usize> = (0..self.n_rows)
645            .step_by((self.n_rows / sample_size.max(1)).max(1))
646            .collect();
647
648        for &source in &sample_vertices {
649            let (distances, predecessors) = self.single_source_shortest_paths(source);
650
651            // Count paths through each vertex
652            let mut path_counts = vec![0.0; self.n_rows];
653            let mut dependency = vec![0.0; self.n_rows];
654
655            // Initialize path counts
656            for i in 0..self.n_rows {
657                if distances[i] != usize::MAX {
658                    path_counts[i] = 1.0;
659                }
660            }
661
662            // Process vertices in order of decreasing distance
663            let mut vertices_by_distance: Vec<usize> = (0..self.n_rows).collect();
664            vertices_by_distance.sort_by_key(|&v| std::cmp::Reverse(distances[v]));
665
666            for &vertex in &vertices_by_distance {
667                if distances[vertex] == usize::MAX {
668                    continue;
669                }
670
671                for &pred in &predecessors[vertex] {
672                    let contrib =
673                        path_counts[pred] * (1.0 + dependency[vertex]) / path_counts[vertex];
674                    dependency[pred] += contrib;
675                }
676
677                if vertex != source {
678                    centrality[vertex] += dependency[vertex];
679                }
680            }
681        }
682
683        // Normalize by sample size
684        let scale = sample_vertices.len() as Float;
685        for centrality_val in &mut centrality {
686            *centrality_val /= scale;
687        }
688
689        centrality
690    }
691
692    /// Single-source shortest paths with predecessor tracking
693    fn single_source_shortest_paths(&self, source: usize) -> (Vec<usize>, Vec<Vec<usize>>) {
694        let mut distances = vec![usize::MAX; self.n_rows];
695        let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); self.n_rows];
696        let mut queue = VecDeque::new();
697
698        distances[source] = 0;
699        queue.push_back(source);
700
701        while let Some(vertex) = queue.pop_front() {
702            let start = self.row_pointers[vertex];
703            let end = self.row_pointers[vertex + 1];
704
705            for idx in start..end {
706                let neighbor = self.col_indices[idx];
707                let new_dist = distances[vertex] + 1;
708
709                if new_dist < distances[neighbor] {
710                    distances[neighbor] = new_dist;
711                    predecessors[neighbor].clear();
712                    predecessors[neighbor].push(vertex);
713                    queue.push_back(neighbor);
714                } else if new_dist == distances[neighbor] {
715                    predecessors[neighbor].push(vertex);
716                }
717            }
718        }
719
720        (distances, predecessors)
721    }
722}
723
724/// Statistics for sparse matrix
725#[derive(Debug, Clone)]
726pub struct SparseMatrixStats {
727    pub n_rows: usize,
728    pub n_cols: usize,
729    pub nnz: usize,
730    pub total_entries: usize,
731    pub sparsity: f64,
732    pub dense_memory_bytes: usize,
733    pub sparse_memory_bytes: usize,
734    pub memory_savings: f64,
735    pub symmetric: bool,
736}
737
738impl std::fmt::Display for SparseMatrixStats {
739    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
740        write!(
741            f,
742            "SparseMatrix {{ {}×{}, {:.2}% sparse, {} nnz, {:.1}% memory savings }}",
743            self.n_rows,
744            self.n_cols,
745            self.sparsity * 100.0,
746            self.nnz,
747            self.memory_savings * 100.0
748        )
749    }
750}
751
752/// Sparse neighborhood graph for clustering algorithms
753pub struct SparseNeighborhoodGraph {
754    /// Sparse adjacency matrix
755    adjacency: SparseDistanceMatrix,
756    /// Vertex degrees
757    degrees: Vec<usize>,
758}
759
760impl SparseNeighborhoodGraph {
761    /// Create neighborhood graph from sparse distance matrix
762    pub fn from_sparse_matrix(matrix: SparseDistanceMatrix) -> Self {
763        let n_vertices = matrix.n_rows;
764        let mut degrees = vec![0; n_vertices];
765
766        // Calculate degrees
767        for i in 0..n_vertices {
768            degrees[i] = matrix.row_entries(i).len();
769        }
770
771        Self {
772            adjacency: matrix,
773            degrees,
774        }
775    }
776
777    /// Get neighbors of a vertex
778    pub fn neighbors(&self, vertex: usize) -> Vec<(usize, Float)> {
779        self.adjacency.row_entries(vertex)
780    }
781
782    /// Get degree of a vertex
783    pub fn degree(&self, vertex: usize) -> usize {
784        self.degrees.get(vertex).copied().unwrap_or(0)
785    }
786
787    /// Get all vertices
788    pub fn vertices(&self) -> std::ops::Range<usize> {
789        0..self.adjacency.n_rows
790    }
791
792    /// Number of vertices
793    pub fn n_vertices(&self) -> usize {
794        self.adjacency.n_rows
795    }
796
797    /// Number of edges
798    pub fn n_edges(&self) -> usize {
799        if self.adjacency.symmetric {
800            self.adjacency.nnz / 2
801        } else {
802            self.adjacency.nnz
803        }
804    }
805
806    /// Graph statistics
807    pub fn graph_stats(&self) -> GraphStats {
808        let total_degree: usize = self.degrees.iter().sum();
809        let avg_degree = total_degree as f64 / self.n_vertices() as f64;
810        let max_degree = *self.degrees.iter().max().unwrap_or(&0);
811        let min_degree = *self.degrees.iter().min().unwrap_or(&0);
812
813        GraphStats {
814            n_vertices: self.n_vertices(),
815            n_edges: self.n_edges(),
816            avg_degree,
817            max_degree,
818            min_degree,
819            matrix_stats: self.adjacency.stats(),
820        }
821    }
822}
823
824/// Statistics for sparse neighborhood graph
825#[derive(Debug, Clone)]
826pub struct GraphStats {
827    pub n_vertices: usize,
828    pub n_edges: usize,
829    pub avg_degree: f64,
830    pub max_degree: usize,
831    pub min_degree: usize,
832    pub matrix_stats: SparseMatrixStats,
833}
834
835impl std::fmt::Display for GraphStats {
836    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
837        write!(
838            f,
839            "Graph {{ {} vertices, {} edges, avg degree: {:.1}, degree range: {}-{} }}",
840            self.n_vertices, self.n_edges, self.avg_degree, self.min_degree, self.max_degree
841        )
842    }
843}
844
845#[allow(non_snake_case)]
846#[cfg(test)]
847mod tests {
848    use super::*;
849    use scirs2_core::ndarray::array;
850
851    #[test]
852    fn test_sparse_matrix_creation() {
853        let matrix = SparseDistanceMatrix::new(3, 3, true);
854
855        assert_eq!(matrix.shape(), (3, 3));
856        assert_eq!(matrix.nnz(), 0);
857        assert!(matrix.is_symmetric());
858
859        // All entries should be zero initially
860        for i in 0..3 {
861            for j in 0..3 {
862                assert_eq!(matrix.get(i, j), 0.0);
863            }
864        }
865    }
866
867    #[test]
868    fn test_sparse_matrix_set_get() {
869        let mut matrix = SparseDistanceMatrix::new(3, 3, true);
870
871        // Set some values
872        matrix.set(0, 1, 2.5).unwrap();
873        matrix.set(1, 2, 3.0).unwrap();
874
875        // Check values
876        assert_eq!(matrix.get(0, 1), 2.5);
877        assert_eq!(matrix.get(1, 0), 2.5); // Symmetric
878        assert_eq!(matrix.get(1, 2), 3.0);
879        assert_eq!(matrix.get(2, 1), 3.0); // Symmetric
880        assert_eq!(matrix.get(0, 2), 0.0); // Not set
881
882        // Diagonal should remain zero
883        for i in 0..3 {
884            assert_eq!(matrix.get(i, i), 0.0);
885        }
886    }
887
888    #[test]
889    fn test_sparse_matrix_from_data() {
890        // Create data with clear clusters (sparse distances)
891        let data = array![
892            [0.0, 0.0],
893            [0.1, 0.0],  // Close to first point
894            [10.0, 0.0], // Far from others
895            [10.1, 0.0], // Close to third point
896        ];
897
898        let config = SparseMatrixConfig {
899            distance_threshold: 1.0, // Only store distances <= 1.0
900            sparsity_threshold: 0.3, // Allow 30% sparse
901            ..Default::default()
902        };
903
904        let sparse_matrix = SparseDistanceMatrix::from_data(&data, config).unwrap();
905
906        // Should have captured close pairs only
907        assert!(sparse_matrix.get(0, 1) > 0.0); // Points 0 and 1 are close
908        assert!(sparse_matrix.get(2, 3) > 0.0); // Points 2 and 3 are close
909        assert_eq!(sparse_matrix.get(0, 2), 0.0); // Points 0 and 2 are far
910        assert_eq!(sparse_matrix.get(1, 3), 0.0); // Points 1 and 3 are far
911    }
912
913    #[test]
914    fn test_k_nearest_neighbors() {
915        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
916
917        // Set up distances: point 0 connected to 1 (distance 1.0) and 2 (distance 2.0)
918        matrix.set(0, 1, 1.0).unwrap();
919        matrix.set(0, 2, 2.0).unwrap();
920        matrix.set(0, 3, 5.0).unwrap();
921
922        let neighbors = matrix.k_nearest_neighbors(0, 2);
923
924        assert_eq!(neighbors.len(), 2);
925        assert_eq!(neighbors[0], (1, 1.0)); // Closest
926        assert_eq!(neighbors[1], (2, 2.0)); // Second closest
927    }
928
929    #[test]
930    fn test_neighbors_within_radius() {
931        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
932
933        matrix.set(0, 1, 1.0).unwrap();
934        matrix.set(0, 2, 2.0).unwrap();
935        matrix.set(0, 3, 5.0).unwrap();
936
937        let neighbors = matrix.neighbors_within_radius(0, 2.5);
938
939        assert_eq!(neighbors.len(), 2);
940        assert!(neighbors.iter().any(|&(idx, _)| idx == 1));
941        assert!(neighbors.iter().any(|&(idx, _)| idx == 2));
942        assert!(!neighbors.iter().any(|&(idx, _)| idx == 3)); // Distance 5.0 > 2.5
943    }
944
945    #[test]
946    fn test_sparse_matrix_stats() {
947        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
948
949        matrix.set(0, 1, 1.0).unwrap();
950        matrix.set(1, 2, 2.0).unwrap();
951
952        let stats = matrix.stats();
953
954        assert_eq!(stats.n_rows, 4);
955        assert_eq!(stats.n_cols, 4);
956        assert_eq!(stats.nnz, 4); // 2 entries + 2 symmetric entries
957        assert!(stats.sparsity > 0.5); // Should be quite sparse
958        assert!(stats.memory_savings > 0.0);
959    }
960
961    #[test]
962    fn test_to_dense_conversion() {
963        let mut matrix = SparseDistanceMatrix::new(3, 3, true);
964
965        matrix.set(0, 1, 1.5).unwrap();
966        matrix.set(1, 2, 2.5).unwrap();
967
968        let dense = matrix.to_dense();
969
970        assert_eq!(dense.shape(), &[3, 3]);
971        assert_eq!(dense[[0, 1]], 1.5);
972        assert_eq!(dense[[1, 0]], 1.5); // Symmetric
973        assert_eq!(dense[[1, 2]], 2.5);
974        assert_eq!(dense[[2, 1]], 2.5); // Symmetric
975        assert_eq!(dense[[0, 2]], 0.0); // Not connected
976
977        // Diagonal should be zero
978        for i in 0..3 {
979            assert_eq!(dense[[i, i]], 0.0);
980        }
981    }
982
983    #[test]
984    fn test_neighborhood_graph() {
985        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
986
987        matrix.set(0, 1, 1.0).unwrap();
988        matrix.set(1, 2, 1.5).unwrap();
989        matrix.set(2, 3, 2.0).unwrap();
990
991        let graph = SparseNeighborhoodGraph::from_sparse_matrix(matrix);
992
993        assert_eq!(graph.n_vertices(), 4);
994        assert_eq!(graph.n_edges(), 3); // 3 undirected edges
995
996        // Check degrees
997        assert_eq!(graph.degree(0), 1); // Connected to 1
998        assert_eq!(graph.degree(1), 2); // Connected to 0 and 2
999        assert_eq!(graph.degree(2), 2); // Connected to 1 and 3
1000        assert_eq!(graph.degree(3), 1); // Connected to 2
1001
1002        // Check neighbors
1003        let neighbors_1 = graph.neighbors(1);
1004        assert_eq!(neighbors_1.len(), 2);
1005    }
1006
1007    #[test]
1008    fn test_graph_stats() {
1009        let mut matrix = SparseDistanceMatrix::new(3, 3, true);
1010
1011        matrix.set(0, 1, 1.0).unwrap();
1012        matrix.set(1, 2, 1.5).unwrap();
1013
1014        let graph = SparseNeighborhoodGraph::from_sparse_matrix(matrix);
1015        let stats = graph.graph_stats();
1016
1017        assert_eq!(stats.n_vertices, 3);
1018        assert_eq!(stats.n_edges, 2);
1019        assert!((stats.avg_degree - 4.0 / 3.0).abs() < 1e-10); // Total degree 4, 3 vertices
1020        assert_eq!(stats.max_degree, 2); // Vertex 1 has degree 2
1021        assert_eq!(stats.min_degree, 1); // Vertices 0 and 2 have degree 1
1022    }
1023
1024    #[test]
1025    fn test_connected_components() {
1026        let mut matrix = SparseDistanceMatrix::new(5, 5, true);
1027
1028        // Create two disconnected components: {0,1,2} and {3,4}
1029        matrix.set(0, 1, 1.0).unwrap();
1030        matrix.set(1, 2, 1.0).unwrap();
1031        matrix.set(3, 4, 1.0).unwrap();
1032
1033        let components = matrix.connected_components();
1034
1035        // Vertices 0, 1, 2 should be in the same component
1036        assert_eq!(components[0], components[1]);
1037        assert_eq!(components[1], components[2]);
1038
1039        // Vertices 3, 4 should be in the same component (different from above)
1040        assert_eq!(components[3], components[4]);
1041        assert_ne!(components[0], components[3]);
1042
1043        // Each component should have at least one member
1044        let unique_components: std::collections::HashSet<_> = components.iter().collect();
1045        assert!(unique_components.len() >= 2);
1046    }
1047
1048    #[test]
1049    fn test_graph_diameter() {
1050        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
1051
1052        // Create a linear graph: 0-1-2-3
1053        matrix.set(0, 1, 1.0).unwrap();
1054        matrix.set(1, 2, 1.0).unwrap();
1055        matrix.set(2, 3, 1.0).unwrap();
1056
1057        let diameter = matrix.graph_diameter();
1058        assert_eq!(diameter, Some(3)); // Distance from 0 to 3
1059    }
1060
1061    #[test]
1062    fn test_clustering_coefficient() {
1063        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
1064
1065        // Create a triangle plus one: 0-1-2-0, and 1-3
1066        matrix.set(0, 1, 1.0).unwrap();
1067        matrix.set(1, 2, 1.0).unwrap();
1068        matrix.set(2, 0, 1.0).unwrap(); // Triangle complete
1069        matrix.set(1, 3, 1.0).unwrap();
1070
1071        // Vertex 0: neighbors are {1, 2}, they are connected → coefficient = 1.0
1072        let coeff_0 = matrix.clustering_coefficient(0);
1073        assert!((coeff_0 - 1.0).abs() < 1e-10);
1074
1075        // Vertex 1: neighbors are {0, 2, 3}, 0-2 connected, others not → coefficient = 1/3
1076        let coeff_1 = matrix.clustering_coefficient(1);
1077        assert!((coeff_1 - 1.0 / 3.0).abs() < 1e-10);
1078
1079        // Vertex 3: only one neighbor → coefficient = 0.0
1080        let coeff_3 = matrix.clustering_coefficient(3);
1081        assert_eq!(coeff_3, 0.0);
1082    }
1083
1084    #[test]
1085    fn test_average_clustering_coefficient() {
1086        let mut matrix = SparseDistanceMatrix::new(3, 3, true);
1087
1088        // Create a complete triangle
1089        matrix.set(0, 1, 1.0).unwrap();
1090        matrix.set(1, 2, 1.0).unwrap();
1091        matrix.set(2, 0, 1.0).unwrap();
1092
1093        let avg_coeff = matrix.average_clustering_coefficient();
1094        // In a complete triangle, all vertices have clustering coefficient 1.0
1095        assert!((avg_coeff - 1.0).abs() < 1e-10);
1096    }
1097
1098    #[test]
1099    fn test_betweenness_centrality() {
1100        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
1101
1102        // Create a linear graph: 0-1-2-3
1103        // Vertices 1 and 2 should have higher betweenness centrality
1104        matrix.set(0, 1, 1.0).unwrap();
1105        matrix.set(1, 2, 1.0).unwrap();
1106        matrix.set(2, 3, 1.0).unwrap();
1107
1108        let centrality = matrix.approximate_betweenness_centrality(4);
1109
1110        // End vertices (0, 3) should have lower centrality than middle vertices (1, 2)
1111        assert!(centrality[1] > centrality[0]);
1112        assert!(centrality[2] > centrality[3]);
1113        assert!((centrality[1] - centrality[2]).abs() < 0.1); // Should be similar
1114    }
1115
1116    #[test]
1117    fn test_bfs_distances() {
1118        let mut matrix = SparseDistanceMatrix::new(4, 4, true);
1119
1120        // Create a simple path: 0-1-2-3
1121        matrix.set(0, 1, 1.0).unwrap();
1122        matrix.set(1, 2, 1.0).unwrap();
1123        matrix.set(2, 3, 1.0).unwrap();
1124
1125        let distances = matrix.bfs_distances(0);
1126
1127        assert_eq!(distances[0], 0);
1128        assert_eq!(distances[1], 1);
1129        assert_eq!(distances[2], 2);
1130        assert_eq!(distances[3], 3);
1131    }
1132
1133    #[test]
1134    fn test_advanced_algorithms_empty_graph() {
1135        let matrix = SparseDistanceMatrix::new(3, 3, true);
1136
1137        // Test empty graph behavior
1138        let components = matrix.connected_components();
1139        assert_eq!(components.len(), 3);
1140        // Each vertex should be its own component
1141        assert_ne!(components[0], components[1]);
1142        assert_ne!(components[1], components[2]);
1143
1144        let diameter = matrix.graph_diameter();
1145        assert_eq!(diameter, Some(0));
1146
1147        let avg_coeff = matrix.average_clustering_coefficient();
1148        assert_eq!(avg_coeff, 0.0);
1149    }
1150}