Skip to main content

scirs2_cluster/
sparse.rs

1//! Sparse distance matrix support for large datasets
2//!
3//! This module provides efficient representations and algorithms for working
4//! with sparse distance matrices, particularly useful for high-dimensional
5//! data where most pairwise distances are zero or very large.
6
7use scirs2_core::ndarray::{Array2, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{ClusteringError, Result};
13use crate::hierarchy::{LinkageMethod, Metric};
14
15/// Sparse distance matrix using coordinate format (COO)
16///
17/// Stores only non-zero distances to save memory for sparse datasets.
18#[derive(Debug, Clone)]
19pub struct SparseDistanceMatrix<F: Float> {
20    /// Row indices of non-zero distances
21    rows: Vec<usize>,
22    /// Column indices of non-zero distances  
23    cols: Vec<usize>,
24    /// Non-zero distance values
25    data: Vec<F>,
26    /// Number of samples (matrix dimension)
27    n_samples: usize,
28    /// Default value for unspecified distances (typically 0.0 or infinity)
29    default_value: F,
30}
31
32impl<F: Float + FromPrimitive> SparseDistanceMatrix<F> {
33    /// Create a new sparse distance matrix
34    pub fn new(n_samples: usize, default_value: F) -> Self {
35        Self {
36            rows: Vec::new(),
37            cols: Vec::new(),
38            data: Vec::new(),
39            n_samples,
40            default_value,
41        }
42    }
43
44    /// Create a sparse distance matrix from a dense matrix, keeping only values above threshold
45    pub fn from_dense(dense: ArrayView2<F>, threshold: F) -> Self {
46        let n_samples = dense.shape()[0];
47        let mut rows = Vec::new();
48        let mut cols = Vec::new();
49        let mut data = Vec::new();
50
51        for i in 0..n_samples {
52            for j in (i + 1)..n_samples {
53                let distance = dense[[i, j]];
54                if distance > threshold {
55                    rows.push(i);
56                    cols.push(j);
57                    data.push(distance);
58                }
59            }
60        }
61
62        Self {
63            rows,
64            cols,
65            data,
66            n_samples,
67            default_value: F::zero(),
68        }
69    }
70
71    /// Add a distance entry to the sparse matrix
72    pub fn add_distance(&mut self, i: usize, j: usize, distance: F) -> Result<()> {
73        if i >= self.n_samples || j >= self.n_samples {
74            return Err(ClusteringError::InvalidInput("Index out of bounds".into()));
75        }
76
77        // Ensure i < j for upper triangular storage
78        let (row, col) = if i < j { (i, j) } else { (j, i) };
79
80        // Check if this edge already exists
81        for idx in 0..self.rows.len() {
82            if self.rows[idx] == row && self.cols[idx] == col {
83                // Update existing entry with the shorter distance
84                if distance < self.data[idx] {
85                    self.data[idx] = distance;
86                }
87                return Ok(());
88            }
89        }
90
91        // Add new entry
92        self.rows.push(row);
93        self.cols.push(col);
94        self.data.push(distance);
95
96        Ok(())
97    }
98
99    /// Get the distance between two points
100    pub fn get_distance(&self, i: usize, j: usize) -> F {
101        if i == j {
102            return F::zero();
103        }
104
105        let (row, col) = if i < j { (i, j) } else { (j, i) };
106
107        // Linear search through stored values (could be optimized with sorted storage)
108        for idx in 0..self.rows.len() {
109            if self.rows[idx] == row && self.cols[idx] == col {
110                return self.data[idx];
111            }
112        }
113
114        self.default_value
115    }
116
117    /// Get all neighbors within a given distance threshold
118    pub fn neighbors_within_distance(&self, point: usize, maxdistance: F) -> Vec<(usize, F)> {
119        let mut neighbors = Vec::new();
120
121        // Check all stored distances involving this point
122        for idx in 0..self.rows.len() {
123            let (neighbor, distance) = if self.rows[idx] == point {
124                (self.cols[idx], self.data[idx])
125            } else if self.cols[idx] == point {
126                (self.rows[idx], self.data[idx])
127            } else {
128                continue;
129            };
130
131            if distance <= maxdistance {
132                neighbors.push((neighbor, distance));
133            }
134        }
135
136        neighbors
137    }
138
139    /// Get the k nearest neighbors for a point
140    pub fn k_nearest_neighbors(&self, point: usize, k: usize) -> Vec<(usize, F)> {
141        let mut all_neighbors = Vec::new();
142
143        // Collect all neighbors
144        for idx in 0..self.rows.len() {
145            let (neighbor, distance) = if self.rows[idx] == point {
146                (self.cols[idx], self.data[idx])
147            } else if self.cols[idx] == point {
148                (self.rows[idx], self.data[idx])
149            } else {
150                continue;
151            };
152
153            all_neighbors.push((neighbor, distance));
154        }
155
156        // Sort by distance and take k nearest
157        all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
158        all_neighbors.truncate(k);
159
160        all_neighbors
161    }
162
163    /// Convert to a dense distance matrix (use with caution for large matrices)
164    pub fn to_dense(&self) -> Array2<F> {
165        let mut dense = Array2::from_elem((self.n_samples, self.n_samples), self.default_value);
166
167        // Set diagonal to zero
168        for i in 0..self.n_samples {
169            dense[[i, i]] = F::zero();
170        }
171
172        // Fill in stored distances (both upper and lower triangular)
173        for idx in 0..self.rows.len() {
174            let i = self.rows[idx];
175            let j = self.cols[idx];
176            let distance = self.data[idx];
177
178            dense[[i, j]] = distance;
179            dense[[j, i]] = distance;
180        }
181
182        dense
183    }
184
185    /// Get the number of non-zero entries
186    pub fn nnz(&self) -> usize {
187        self.data.len()
188    }
189
190    /// Get the sparsity ratio (fraction of zeros)
191    pub fn sparsity(&self) -> f64 {
192        let total_entries = self.n_samples * (self.n_samples - 1) / 2;
193        1.0 - (self.nnz() as f64 / total_entries as f64)
194    }
195
196    /// Get the number of samples
197    pub fn n_samples(&self) -> usize {
198        self.n_samples
199    }
200}
201
202/// Sparse hierarchical clustering using minimal spanning tree approach
203///
204/// This algorithm is particularly efficient for sparse distance matrices
205/// where most distances are infinite or very large.
206pub struct SparseHierarchicalClustering<F: Float> {
207    sparse_matrix: SparseDistanceMatrix<F>,
208    linkage_method: LinkageMethod,
209}
210
211impl<F: Float + FromPrimitive + Debug + PartialOrd> SparseHierarchicalClustering<F> {
212    /// Create a new sparse hierarchical clustering instance
213    pub fn new(sparse_matrix: SparseDistanceMatrix<F>, linkage_method: LinkageMethod) -> Self {
214        Self {
215            sparse_matrix,
216            linkage_method,
217        }
218    }
219
220    /// Perform hierarchical clustering using Prim's algorithm for MST
221    pub fn fit(&self) -> Result<Array2<F>> {
222        let n_samples = self.sparse_matrix.n_samples();
223
224        if n_samples < 2 {
225            return Err(ClusteringError::InvalidInput(
226                "Need at least 2 samples for clustering".into(),
227            ));
228        }
229
230        // Build a minimal spanning tree
231        let mst_edges = self.minimum_spanning_tree()?;
232
233        // Convert MST to linkage matrix based on the chosen method
234        self.mst_to_linkage(mst_edges)
235    }
236
237    /// Build minimal spanning tree using Prim's algorithm
238    fn minimum_spanning_tree(&self) -> Result<Vec<(usize, usize, F)>> {
239        let n_samples = self.sparse_matrix.n_samples();
240        let mut mst_edges = Vec::new();
241        let mut visited = vec![false; n_samples];
242        let mut min_edge: HashMap<usize, (usize, F)> = HashMap::new();
243
244        // Start with vertex 0
245        visited[0] = true;
246
247        // Initialize edges from vertex 0
248        for neighbor_idx in 0..self.sparse_matrix.rows.len() {
249            let (i, j) = (
250                self.sparse_matrix.rows[neighbor_idx],
251                self.sparse_matrix.cols[neighbor_idx],
252            );
253            let distance = self.sparse_matrix.data[neighbor_idx];
254
255            if i == 0 && !visited[j] {
256                min_edge.insert(j, (i, distance));
257            } else if j == 0 && !visited[i] {
258                min_edge.insert(i, (j, distance));
259            }
260        }
261
262        // Build MST one edge at a time
263        for _ in 1..n_samples {
264            // Find minimum edge to unvisited vertex
265            let mut min_dist = F::infinity();
266            let mut min_vertex = 0;
267            let mut min_parent = 0;
268
269            for (&vertex, &(parent, distance)) in &min_edge {
270                if !visited[vertex] && distance < min_dist {
271                    min_dist = distance;
272                    min_vertex = vertex;
273                    min_parent = parent;
274                }
275            }
276
277            if min_dist == F::infinity() {
278                // Disconnected graph - use default distance
279                min_dist = self.sparse_matrix.default_value;
280            }
281
282            // Add edge to MST
283            mst_edges.push((min_parent, min_vertex, min_dist));
284            visited[min_vertex] = true;
285
286            // Update edges from the newly added vertex
287            for neighbor_idx in 0..self.sparse_matrix.rows.len() {
288                let (i, j) = (
289                    self.sparse_matrix.rows[neighbor_idx],
290                    self.sparse_matrix.cols[neighbor_idx],
291                );
292                let distance = self.sparse_matrix.data[neighbor_idx];
293
294                let (from_vertex, to_vertex) = if i == min_vertex && !visited[j] {
295                    (i, j)
296                } else if j == min_vertex && !visited[i] {
297                    (j, i)
298                } else {
299                    continue;
300                };
301
302                // Update minimum edge to to_vertex if this is better
303                match min_edge.get(&to_vertex) {
304                    Some(&(_, current_dist)) if distance < current_dist => {
305                        min_edge.insert(to_vertex, (from_vertex, distance));
306                    }
307                    None => {
308                        min_edge.insert(to_vertex, (from_vertex, distance));
309                    }
310                    _ => {}
311                }
312            }
313        }
314
315        Ok(mst_edges)
316    }
317
318    /// Convert MST edges to linkage matrix format
319    fn mst_to_linkage(&self, mut mst_edges: Vec<(usize, usize, F)>) -> Result<Array2<F>> {
320        let n_samples = self.sparse_matrix.n_samples();
321
322        // Sort _edges by distance for single linkage, or process in MST order
323        match self.linkage_method {
324            LinkageMethod::Single => {
325                // For single linkage, MST _edges directly give the dendrogram
326                mst_edges.sort_by(|a, b| a.2.partial_cmp(&b.2).expect("Operation failed"));
327            }
328            _ => {
329                // For other methods, we would need more complex processing
330                // For now, treat as single linkage
331            }
332        }
333
334        let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
335        let mut cluster_map: HashMap<usize, usize> = HashMap::new();
336        let mut next_cluster_id = n_samples;
337
338        // Initialize cluster map (each point is its own cluster)
339        for i in 0..n_samples {
340            cluster_map.insert(i, i);
341        }
342
343        for (step, (i, j, distance)) in mst_edges.iter().enumerate() {
344            let cluster_i = cluster_map[i];
345            let cluster_j = cluster_map[j];
346
347            // Record merge in linkage matrix
348            linkage_matrix[[step, 0]] = F::from(cluster_i).expect("Failed to convert to float");
349            linkage_matrix[[step, 1]] = F::from(cluster_j).expect("Failed to convert to float");
350            linkage_matrix[[step, 2]] = *distance;
351            linkage_matrix[[step, 3]] = F::from(2).expect("Failed to convert constant to float"); // Size starts at 2, would need to track actual sizes
352
353            // Update cluster mapping
354            cluster_map.insert(*i, next_cluster_id);
355            cluster_map.insert(*j, next_cluster_id);
356            next_cluster_id += 1;
357        }
358
359        Ok(linkage_matrix)
360    }
361}
362
363/// Build a sparse k-nearest neighbor graph from dense data
364#[allow(dead_code)]
365pub fn sparse_knn_graph<F>(
366    data: ArrayView2<F>,
367    k: usize,
368    metric: Metric,
369) -> Result<SparseDistanceMatrix<F>>
370where
371    F: Float + FromPrimitive + Debug,
372{
373    let n_samples = data.shape()[0];
374    let n_features = data.shape()[1];
375
376    if k >= n_samples {
377        return Err(ClusteringError::InvalidInput(
378            "k must be less than number of samples".into(),
379        ));
380    }
381
382    let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
383
384    // For each point, find its k nearest neighbors
385    for i in 0..n_samples {
386        let mut distances: Vec<(usize, F)> = Vec::new();
387
388        // Calculate distances to all other points
389        for j in 0..n_samples {
390            if i == j {
391                continue;
392            }
393
394            let dist = match metric {
395                Metric::Euclidean => {
396                    let mut sum = F::zero();
397                    for k in 0..n_features {
398                        let diff = data[[i, k]] - data[[j, k]];
399                        sum = sum + diff * diff;
400                    }
401                    sum.sqrt()
402                }
403                Metric::Manhattan => {
404                    let mut sum = F::zero();
405                    for k in 0..n_features {
406                        let diff = (data[[i, k]] - data[[j, k]]).abs();
407                        sum = sum + diff;
408                    }
409                    sum
410                }
411                Metric::Chebyshev => {
412                    let mut max_diff = F::zero();
413                    for k in 0..n_features {
414                        let diff = (data[[i, k]] - data[[j, k]]).abs();
415                        if diff > max_diff {
416                            max_diff = diff;
417                        }
418                    }
419                    max_diff
420                }
421                Metric::Cosine => {
422                    let mut dot = F::zero();
423                    let mut norm_i = F::zero();
424                    let mut norm_j = F::zero();
425                    for k in 0..n_features {
426                        let vi = data[[i, k]];
427                        let vj = data[[j, k]];
428                        dot = dot + vi * vj;
429                        norm_i = norm_i + vi * vi;
430                        norm_j = norm_j + vj * vj;
431                    }
432                    let norm_prod = (norm_i * norm_j).sqrt();
433                    if norm_prod
434                        < F::from_f64(1e-10).ok_or_else(|| {
435                            ClusteringError::InvalidInput("float conversion failed".into())
436                        })?
437                    {
438                        F::one()
439                    } else {
440                        F::one() - dot / norm_prod
441                    }
442                }
443                Metric::Correlation => {
444                    let n_f = F::from_usize(n_features).ok_or_else(|| {
445                        ClusteringError::InvalidInput("float conversion failed".into())
446                    })?;
447                    let mut mean_i = F::zero();
448                    let mut mean_j = F::zero();
449                    for k in 0..n_features {
450                        mean_i = mean_i + data[[i, k]];
451                        mean_j = mean_j + data[[j, k]];
452                    }
453                    mean_i = mean_i / n_f;
454                    mean_j = mean_j / n_f;
455
456                    let mut numerator = F::zero();
457                    let mut denom_i = F::zero();
458                    let mut denom_j = F::zero();
459                    for k in 0..n_features {
460                        let di = data[[i, k]] - mean_i;
461                        let dj = data[[j, k]] - mean_j;
462                        numerator = numerator + di * dj;
463                        denom_i = denom_i + di * di;
464                        denom_j = denom_j + dj * dj;
465                    }
466                    let denom = (denom_i * denom_j).sqrt();
467                    if denom
468                        < F::from_f64(1e-10).ok_or_else(|| {
469                            ClusteringError::InvalidInput("float conversion failed".into())
470                        })?
471                    {
472                        F::zero()
473                    } else {
474                        F::one() - numerator / denom
475                    }
476                }
477            };
478
479            distances.push((j, dist));
480        }
481
482        // Sort by distance and keep k nearest
483        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
484        distances.truncate(k);
485
486        // Add to sparse matrix
487        for (neighbor, distance) in distances {
488            sparse_matrix.add_distance(i, neighbor, distance)?;
489        }
490    }
491
492    Ok(sparse_matrix)
493}
494
495/// Build a sparse epsilon-neighborhood graph from dense data
496#[allow(dead_code)]
497pub fn sparse_epsilon_graph<F>(
498    data: ArrayView2<F>,
499    epsilon: F,
500    metric: Metric,
501) -> Result<SparseDistanceMatrix<F>>
502where
503    F: Float + FromPrimitive + Debug,
504{
505    let n_samples = data.shape()[0];
506    let n_features = data.shape()[1];
507
508    let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
509
510    // For each pair of points, check if within epsilon distance
511    for i in 0..n_samples {
512        for j in (i + 1)..n_samples {
513            let dist = match metric {
514                Metric::Euclidean => {
515                    let mut sum = F::zero();
516                    for k in 0..n_features {
517                        let diff = data[[i, k]] - data[[j, k]];
518                        sum = sum + diff * diff;
519                    }
520                    sum.sqrt()
521                }
522                Metric::Manhattan => {
523                    let mut sum = F::zero();
524                    for k in 0..n_features {
525                        let diff = (data[[i, k]] - data[[j, k]]).abs();
526                        sum = sum + diff;
527                    }
528                    sum
529                }
530                Metric::Chebyshev => {
531                    let mut max_diff = F::zero();
532                    for k in 0..n_features {
533                        let diff = (data[[i, k]] - data[[j, k]]).abs();
534                        if diff > max_diff {
535                            max_diff = diff;
536                        }
537                    }
538                    max_diff
539                }
540                Metric::Cosine => {
541                    let mut dot = F::zero();
542                    let mut norm_i = F::zero();
543                    let mut norm_j = F::zero();
544                    for k in 0..n_features {
545                        let vi = data[[i, k]];
546                        let vj = data[[j, k]];
547                        dot = dot + vi * vj;
548                        norm_i = norm_i + vi * vi;
549                        norm_j = norm_j + vj * vj;
550                    }
551                    let norm_prod = (norm_i * norm_j).sqrt();
552                    if norm_prod
553                        < F::from_f64(1e-10).ok_or_else(|| {
554                            ClusteringError::InvalidInput("float conversion failed".into())
555                        })?
556                    {
557                        F::one()
558                    } else {
559                        F::one() - dot / norm_prod
560                    }
561                }
562                Metric::Correlation => {
563                    let n_f = F::from_usize(n_features).ok_or_else(|| {
564                        ClusteringError::InvalidInput("float conversion failed".into())
565                    })?;
566                    let mut mean_i = F::zero();
567                    let mut mean_j = F::zero();
568                    for k in 0..n_features {
569                        mean_i = mean_i + data[[i, k]];
570                        mean_j = mean_j + data[[j, k]];
571                    }
572                    mean_i = mean_i / n_f;
573                    mean_j = mean_j / n_f;
574
575                    let mut numerator = F::zero();
576                    let mut denom_i = F::zero();
577                    let mut denom_j = F::zero();
578                    for k in 0..n_features {
579                        let di = data[[i, k]] - mean_i;
580                        let dj = data[[j, k]] - mean_j;
581                        numerator = numerator + di * dj;
582                        denom_i = denom_i + di * di;
583                        denom_j = denom_j + dj * dj;
584                    }
585                    let denom = (denom_i * denom_j).sqrt();
586                    if denom
587                        < F::from_f64(1e-10).ok_or_else(|| {
588                            ClusteringError::InvalidInput("float conversion failed".into())
589                        })?
590                    {
591                        F::zero()
592                    } else {
593                        F::one() - numerator / denom
594                    }
595                }
596            };
597
598            if dist <= epsilon {
599                sparse_matrix.add_distance(i, j, dist)?;
600            }
601        }
602    }
603
604    Ok(sparse_matrix)
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610    use scirs2_core::ndarray::Array2;
611
612    #[test]
613    fn test_sparse_distance_matrix_creation() {
614        let sparse_matrix = SparseDistanceMatrix::<f64>::new(5, 0.0);
615        assert_eq!(sparse_matrix.n_samples(), 5);
616        assert_eq!(sparse_matrix.nnz(), 0);
617        assert_eq!(sparse_matrix.sparsity(), 1.0);
618    }
619
620    #[test]
621    fn test_sparse_distance_matrix_add_distance() {
622        let mut sparse_matrix = SparseDistanceMatrix::new(3, 0.0);
623
624        sparse_matrix
625            .add_distance(0, 1, 2.0)
626            .expect("Operation failed");
627        sparse_matrix
628            .add_distance(1, 2, 3.0)
629            .expect("Operation failed");
630
631        assert_eq!(sparse_matrix.get_distance(0, 1), 2.0);
632        assert_eq!(sparse_matrix.get_distance(1, 0), 2.0); // Symmetric
633        assert_eq!(sparse_matrix.get_distance(1, 2), 3.0);
634        assert_eq!(sparse_matrix.get_distance(0, 2), 0.0); // Default value
635        assert_eq!(sparse_matrix.nnz(), 2);
636    }
637
638    #[test]
639    fn test_sparse_from_dense() {
640        let dense =
641            Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 5.0, 1.0, 0.0, 2.0, 5.0, 2.0, 0.0])
642                .expect("Operation failed");
643
644        let sparse = SparseDistanceMatrix::from_dense(dense.view(), 1.5);
645
646        // Should include distances > 1.5: (0,2)=5.0 and (1,2)=2.0
647        assert_eq!(sparse.nnz(), 2);
648        assert_eq!(sparse.get_distance(0, 2), 5.0);
649        assert_eq!(sparse.get_distance(1, 2), 2.0);
650        assert_eq!(sparse.get_distance(0, 1), 0.0); // Below threshold
651    }
652
653    #[test]
654    fn test_neighbors_within_distance() {
655        let mut sparse_matrix = SparseDistanceMatrix::new(4, f64::INFINITY);
656
657        sparse_matrix
658            .add_distance(0, 1, 1.0)
659            .expect("Operation failed");
660        sparse_matrix
661            .add_distance(0, 2, 2.5)
662            .expect("Operation failed");
663        sparse_matrix
664            .add_distance(0, 3, 0.5)
665            .expect("Operation failed");
666
667        let neighbors = sparse_matrix.neighbors_within_distance(0, 2.0);
668
669        // Should find neighbors at distances 1.0 and 0.5 (both <= 2.0)
670        assert_eq!(neighbors.len(), 2);
671
672        let mut neighbor_distances: Vec<f64> = neighbors.iter().map(|(_, d)| *d).collect();
673        neighbor_distances.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
674        assert_eq!(neighbor_distances, vec![0.5, 1.0]);
675    }
676
677    #[test]
678    fn test_k_nearest_neighbors() {
679        let mut sparse_matrix = SparseDistanceMatrix::new(5, f64::INFINITY);
680
681        sparse_matrix
682            .add_distance(0, 1, 3.0)
683            .expect("Operation failed");
684        sparse_matrix
685            .add_distance(0, 2, 1.0)
686            .expect("Operation failed");
687        sparse_matrix
688            .add_distance(0, 3, 2.0)
689            .expect("Operation failed");
690        sparse_matrix
691            .add_distance(0, 4, 4.0)
692            .expect("Operation failed");
693
694        let knn = sparse_matrix.k_nearest_neighbors(0, 2);
695
696        // Should get 2 nearest neighbors: points 2 (dist=1.0) and 3 (dist=2.0)
697        assert_eq!(knn.len(), 2);
698        assert_eq!(knn[0], (2, 1.0)); // Nearest
699        assert_eq!(knn[1], (3, 2.0)); // Second nearest
700    }
701
702    #[test]
703    fn test_sparse_knn_graph() {
704        let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0])
705            .expect("Operation failed");
706
707        let sparse_graph =
708            sparse_knn_graph(data.view(), 2, Metric::Euclidean).expect("Operation failed");
709
710        // Each point should have 2 neighbors
711        // Total edges = 4 points * 2 neighbors = 8, but some may be duplicates when symmetrized
712        assert!(sparse_graph.nnz() > 0);
713        assert!(sparse_graph.sparsity() > 0.0);
714    }
715
716    #[test]
717    fn test_sparse_epsilon_graph() {
718        let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 5.0])
719            .expect("Operation failed");
720
721        let sparse_graph =
722            sparse_epsilon_graph(data.view(), 1.0, Metric::Euclidean).expect("Operation failed");
723
724        // Points (0,0), (0.5,0), and (0,0.5) should be connected
725        // Point (5,5) should be isolated
726        assert!(sparse_graph.nnz() >= 3); // At least the close connections
727
728        // Check specific connections
729        assert!(sparse_graph.get_distance(0, 1) <= 1.0);
730        assert!(sparse_graph.get_distance(0, 2) <= 1.0);
731    }
732
733    #[test]
734    fn test_to_dense() {
735        let mut sparse_matrix = SparseDistanceMatrix::new(3, f64::INFINITY);
736        sparse_matrix
737            .add_distance(0, 1, 2.0)
738            .expect("Operation failed");
739        sparse_matrix
740            .add_distance(1, 2, 3.0)
741            .expect("Operation failed");
742
743        let dense = sparse_matrix.to_dense();
744
745        assert_eq!(dense.shape(), &[3, 3]);
746        assert_eq!(dense[[0, 1]], 2.0);
747        assert_eq!(dense[[1, 0]], 2.0); // Symmetric
748        assert_eq!(dense[[1, 2]], 3.0);
749        assert_eq!(dense[[2, 1]], 3.0); // Symmetric
750        assert_eq!(dense[[0, 0]], 0.0); // Diagonal
751        assert_eq!(dense[[0, 2]], f64::INFINITY); // Unconnected
752    }
753
754    #[test]
755    fn test_sparse_knn_graph_chebyshev() {
756        // data: 4 points in 2D
757        let data = Array2::from_shape_vec((4, 2), vec![0.0_f64, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0])
758            .expect("Shape error");
759
760        let graph = sparse_knn_graph(data.view(), 2, Metric::Chebyshev)
761            .expect("sparse_knn_graph Chebyshev failed");
762        assert!(graph.nnz() > 0, "Chebyshev KNN graph should have edges");
763        // Chebyshev distance between (0,0) and (1,0) is 1.0; should be connected.
764        assert!(
765            graph.get_distance(0, 1) > 0.0,
766            "points 0 and 1 should be neighbours"
767        );
768    }
769
770    #[test]
771    fn test_sparse_knn_graph_cosine() {
772        // Identical direction vectors → cosine distance 0.
773        let data = Array2::from_shape_vec((4, 2), vec![1.0_f64, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 3.0])
774            .expect("Shape error");
775
776        let graph = sparse_knn_graph(data.view(), 2, Metric::Cosine)
777            .expect("sparse_knn_graph Cosine failed");
778        assert!(graph.nnz() > 0, "Cosine KNN graph should have edges");
779        // (1,0) and (2,0) have cosine distance 0 → nearest neighbour.
780        assert_eq!(
781            graph.get_distance(0, 1),
782            0.0,
783            "parallel vectors have cosine distance 0"
784        );
785    }
786
787    #[test]
788    fn test_sparse_knn_graph_correlation() {
789        // Perfect positive correlation → correlation distance ≈ 0.
790        let data = Array2::from_shape_vec(
791            (3, 4),
792            vec![
793                1.0_f64, 2.0, 3.0, 4.0, // row 0
794                2.0, 4.0, 6.0, 8.0, // row 1 — perfect positive correlation with row 0
795                4.0, 3.0, 2.0, 1.0, // row 2 — perfect negative correlation with row 0
796            ],
797        )
798        .expect("Shape error");
799
800        let graph = sparse_knn_graph(data.view(), 2, Metric::Correlation)
801            .expect("sparse_knn_graph Correlation failed");
802        assert!(graph.nnz() > 0, "Correlation KNN graph should have edges");
803        // Rows 0 and 1 are perfectly correlated → distance ≈ 0.
804        let d01 = graph.get_distance(0, 1);
805        assert!(
806            d01 < 1e-9,
807            "perfectly correlated rows have correlation distance ≈ 0, got {d01}"
808        );
809    }
810
811    #[test]
812    fn test_sparse_epsilon_graph_chebyshev() {
813        let data = Array2::from_shape_vec((3, 2), vec![0.0_f64, 0.0, 0.8, 0.0, 5.0, 5.0])
814            .expect("Shape error");
815
816        // epsilon = 1.0 under Chebyshev: distance(0,1) = max(0.8, 0.0) = 0.8 < 1.0 → connected.
817        let graph = sparse_epsilon_graph(data.view(), 1.0, Metric::Chebyshev)
818            .expect("sparse_epsilon_graph Chebyshev failed");
819        assert!(
820            graph.get_distance(0, 1) > 0.0,
821            "points 0 and 1 should be connected under Chebyshev"
822        );
823        // Point 2 is far away → not connected to 0.
824        assert_eq!(
825            graph.get_distance(0, 2),
826            f64::INFINITY,
827            "distant point should be disconnected"
828        );
829    }
830
831    #[test]
832    fn test_sparse_epsilon_graph_cosine() {
833        // Parallel vectors have cosine distance 0 → within any epsilon > 0.
834        let data = Array2::from_shape_vec((3, 2), vec![1.0_f64, 0.0, 2.0, 0.0, 0.0, 1.0])
835            .expect("Shape error");
836
837        let graph = sparse_epsilon_graph(data.view(), 0.5, Metric::Cosine)
838            .expect("sparse_epsilon_graph Cosine failed");
839        // Rows 0 and 1 are parallel (distance 0) → connected.
840        assert!(
841            graph.get_distance(0, 1) < 0.5,
842            "parallel vectors connected under cosine epsilon graph"
843        );
844        // Row 2 is orthogonal to rows 0/1 (distance 1.0) → not connected for epsilon 0.5.
845        assert_eq!(
846            graph.get_distance(0, 2),
847            f64::INFINITY,
848            "orthogonal vector should be disconnected"
849        );
850    }
851
852    #[test]
853    fn test_sparse_epsilon_graph_correlation() {
854        // Perfectly correlated rows: distance ≈ 0 → within epsilon 0.01.
855        let data = Array2::from_shape_vec(
856            (3, 4),
857            vec![
858                1.0_f64, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 4.0, 3.0, 2.0, 1.0,
859            ],
860        )
861        .expect("Shape error");
862
863        let graph = sparse_epsilon_graph(data.view(), 0.01, Metric::Correlation)
864            .expect("sparse_epsilon_graph Correlation failed");
865        // Rows 0 and 1 are perfectly correlated → distance ≈ 0 < 0.01.
866        let d01 = graph.get_distance(0, 1);
867        assert!(
868            d01 < 0.01,
869            "perfectly correlated rows connected under correlation epsilon graph, got {d01}"
870        );
871    }
872}