1use scirs2_core::ndarray::{Array2, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{ClusteringError, Result};
13use crate::hierarchy::{LinkageMethod, Metric};
14
15#[derive(Debug, Clone)]
19pub struct SparseDistanceMatrix<F: Float> {
20    rows: Vec<usize>,
22    cols: Vec<usize>,
24    data: Vec<F>,
26    n_samples: usize,
28    default_value: F,
30}
31
32impl<F: Float + FromPrimitive> SparseDistanceMatrix<F> {
33    pub fn new(n_samples: usize, default_value: F) -> Self {
35        Self {
36            rows: Vec::new(),
37            cols: Vec::new(),
38            data: Vec::new(),
39            n_samples,
40            default_value,
41        }
42    }
43
44    pub fn from_dense(dense: ArrayView2<F>, threshold: F) -> Self {
46        let n_samples = dense.shape()[0];
47        let mut rows = Vec::new();
48        let mut cols = Vec::new();
49        let mut data = Vec::new();
50
51        for i in 0..n_samples {
52            for j in (i + 1)..n_samples {
53                let distance = dense[[i, j]];
54                if distance > threshold {
55                    rows.push(i);
56                    cols.push(j);
57                    data.push(distance);
58                }
59            }
60        }
61
62        Self {
63            rows,
64            cols,
65            data,
66            n_samples,
67            default_value: F::zero(),
68        }
69    }
70
71    pub fn add_distance(&mut self, i: usize, j: usize, distance: F) -> Result<()> {
73        if i >= self.n_samples || j >= self.n_samples {
74            return Err(ClusteringError::InvalidInput("Index out of bounds".into()));
75        }
76
77        let (row, col) = if i < j { (i, j) } else { (j, i) };
79
80        for idx in 0..self.rows.len() {
82            if self.rows[idx] == row && self.cols[idx] == col {
83                if distance < self.data[idx] {
85                    self.data[idx] = distance;
86                }
87                return Ok(());
88            }
89        }
90
91        self.rows.push(row);
93        self.cols.push(col);
94        self.data.push(distance);
95
96        Ok(())
97    }
98
99    pub fn get_distance(&self, i: usize, j: usize) -> F {
101        if i == j {
102            return F::zero();
103        }
104
105        let (row, col) = if i < j { (i, j) } else { (j, i) };
106
107        for idx in 0..self.rows.len() {
109            if self.rows[idx] == row && self.cols[idx] == col {
110                return self.data[idx];
111            }
112        }
113
114        self.default_value
115    }
116
117    pub fn neighbors_within_distance(&self, point: usize, maxdistance: F) -> Vec<(usize, F)> {
119        let mut neighbors = Vec::new();
120
121        for idx in 0..self.rows.len() {
123            let (neighbor, distance) = if self.rows[idx] == point {
124                (self.cols[idx], self.data[idx])
125            } else if self.cols[idx] == point {
126                (self.rows[idx], self.data[idx])
127            } else {
128                continue;
129            };
130
131            if distance <= maxdistance {
132                neighbors.push((neighbor, distance));
133            }
134        }
135
136        neighbors
137    }
138
139    pub fn k_nearest_neighbors(&self, point: usize, k: usize) -> Vec<(usize, F)> {
141        let mut all_neighbors = Vec::new();
142
143        for idx in 0..self.rows.len() {
145            let (neighbor, distance) = if self.rows[idx] == point {
146                (self.cols[idx], self.data[idx])
147            } else if self.cols[idx] == point {
148                (self.rows[idx], self.data[idx])
149            } else {
150                continue;
151            };
152
153            all_neighbors.push((neighbor, distance));
154        }
155
156        all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
158        all_neighbors.truncate(k);
159
160        all_neighbors
161    }
162
163    pub fn to_dense(&self) -> Array2<F> {
165        let mut dense = Array2::from_elem((self.n_samples, self.n_samples), self.default_value);
166
167        for i in 0..self.n_samples {
169            dense[[i, i]] = F::zero();
170        }
171
172        for idx in 0..self.rows.len() {
174            let i = self.rows[idx];
175            let j = self.cols[idx];
176            let distance = self.data[idx];
177
178            dense[[i, j]] = distance;
179            dense[[j, i]] = distance;
180        }
181
182        dense
183    }
184
185    pub fn nnz(&self) -> usize {
187        self.data.len()
188    }
189
190    pub fn sparsity(&self) -> f64 {
192        let total_entries = self.n_samples * (self.n_samples - 1) / 2;
193        1.0 - (self.nnz() as f64 / total_entries as f64)
194    }
195
196    pub fn n_samples(&self) -> usize {
198        self.n_samples
199    }
200}
201
202pub struct SparseHierarchicalClustering<F: Float> {
207    sparse_matrix: SparseDistanceMatrix<F>,
208    linkage_method: LinkageMethod,
209}
210
211impl<F: Float + FromPrimitive + Debug + PartialOrd> SparseHierarchicalClustering<F> {
212    pub fn new(sparse_matrix: SparseDistanceMatrix<F>, linkage_method: LinkageMethod) -> Self {
214        Self {
215            sparse_matrix,
216            linkage_method,
217        }
218    }
219
220    pub fn fit(&self) -> Result<Array2<F>> {
222        let n_samples = self.sparse_matrix.n_samples();
223
224        if n_samples < 2 {
225            return Err(ClusteringError::InvalidInput(
226                "Need at least 2 samples for clustering".into(),
227            ));
228        }
229
230        let mst_edges = self.minimum_spanning_tree()?;
232
233        self.mst_to_linkage(mst_edges)
235    }
236
237    fn minimum_spanning_tree(&self) -> Result<Vec<(usize, usize, F)>> {
239        let n_samples = self.sparse_matrix.n_samples();
240        let mut mst_edges = Vec::new();
241        let mut visited = vec![false; n_samples];
242        let mut min_edge: HashMap<usize, (usize, F)> = HashMap::new();
243
244        visited[0] = true;
246
247        for neighbor_idx in 0..self.sparse_matrix.rows.len() {
249            let (i, j) = (
250                self.sparse_matrix.rows[neighbor_idx],
251                self.sparse_matrix.cols[neighbor_idx],
252            );
253            let distance = self.sparse_matrix.data[neighbor_idx];
254
255            if i == 0 && !visited[j] {
256                min_edge.insert(j, (i, distance));
257            } else if j == 0 && !visited[i] {
258                min_edge.insert(i, (j, distance));
259            }
260        }
261
262        for _ in 1..n_samples {
264            let mut min_dist = F::infinity();
266            let mut min_vertex = 0;
267            let mut min_parent = 0;
268
269            for (&vertex, &(parent, distance)) in &min_edge {
270                if !visited[vertex] && distance < min_dist {
271                    min_dist = distance;
272                    min_vertex = vertex;
273                    min_parent = parent;
274                }
275            }
276
277            if min_dist == F::infinity() {
278                min_dist = self.sparse_matrix.default_value;
280            }
281
282            mst_edges.push((min_parent, min_vertex, min_dist));
284            visited[min_vertex] = true;
285
286            for neighbor_idx in 0..self.sparse_matrix.rows.len() {
288                let (i, j) = (
289                    self.sparse_matrix.rows[neighbor_idx],
290                    self.sparse_matrix.cols[neighbor_idx],
291                );
292                let distance = self.sparse_matrix.data[neighbor_idx];
293
294                let (from_vertex, to_vertex) = if i == min_vertex && !visited[j] {
295                    (i, j)
296                } else if j == min_vertex && !visited[i] {
297                    (j, i)
298                } else {
299                    continue;
300                };
301
302                match min_edge.get(&to_vertex) {
304                    Some(&(_, current_dist)) if distance < current_dist => {
305                        min_edge.insert(to_vertex, (from_vertex, distance));
306                    }
307                    None => {
308                        min_edge.insert(to_vertex, (from_vertex, distance));
309                    }
310                    _ => {}
311                }
312            }
313        }
314
315        Ok(mst_edges)
316    }
317
318    fn mst_to_linkage(&self, mut mst_edges: Vec<(usize, usize, F)>) -> Result<Array2<F>> {
320        let n_samples = self.sparse_matrix.n_samples();
321
322        match self.linkage_method {
324            LinkageMethod::Single => {
325                mst_edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
327            }
328            _ => {
329                }
332        }
333
334        let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
335        let mut cluster_map: HashMap<usize, usize> = HashMap::new();
336        let mut next_cluster_id = n_samples;
337
338        for i in 0..n_samples {
340            cluster_map.insert(i, i);
341        }
342
343        for (step, (i, j, distance)) in mst_edges.iter().enumerate() {
344            let cluster_i = cluster_map[i];
345            let cluster_j = cluster_map[j];
346
347            linkage_matrix[[step, 0]] = F::from(cluster_i).unwrap();
349            linkage_matrix[[step, 1]] = F::from(cluster_j).unwrap();
350            linkage_matrix[[step, 2]] = *distance;
351            linkage_matrix[[step, 3]] = F::from(2).unwrap(); cluster_map.insert(*i, next_cluster_id);
355            cluster_map.insert(*j, next_cluster_id);
356            next_cluster_id += 1;
357        }
358
359        Ok(linkage_matrix)
360    }
361}
362
363#[allow(dead_code)]
365pub fn sparse_knn_graph<F>(
366    data: ArrayView2<F>,
367    k: usize,
368    metric: Metric,
369) -> Result<SparseDistanceMatrix<F>>
370where
371    F: Float + FromPrimitive + Debug,
372{
373    let n_samples = data.shape()[0];
374    let n_features = data.shape()[1];
375
376    if k >= n_samples {
377        return Err(ClusteringError::InvalidInput(
378            "k must be less than number of samples".into(),
379        ));
380    }
381
382    let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
383
384    for i in 0..n_samples {
386        let mut distances: Vec<(usize, F)> = Vec::new();
387
388        for j in 0..n_samples {
390            if i == j {
391                continue;
392            }
393
394            let dist = match metric {
395                Metric::Euclidean => {
396                    let mut sum = F::zero();
397                    for k in 0..n_features {
398                        let diff = data[[i, k]] - data[[j, k]];
399                        sum = sum + diff * diff;
400                    }
401                    sum.sqrt()
402                }
403                Metric::Manhattan => {
404                    let mut sum = F::zero();
405                    for k in 0..n_features {
406                        let diff = (data[[i, k]] - data[[j, k]]).abs();
407                        sum = sum + diff;
408                    }
409                    sum
410                }
411                _ => {
412                    return Err(ClusteringError::InvalidInput(
413                        "Metric not yet supported for sparse KNN".into(),
414                    ));
415                }
416            };
417
418            distances.push((j, dist));
419        }
420
421        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
423        distances.truncate(k);
424
425        for (neighbor, distance) in distances {
427            sparse_matrix.add_distance(i, neighbor, distance)?;
428        }
429    }
430
431    Ok(sparse_matrix)
432}
433
434#[allow(dead_code)]
436pub fn sparse_epsilon_graph<F>(
437    data: ArrayView2<F>,
438    epsilon: F,
439    metric: Metric,
440) -> Result<SparseDistanceMatrix<F>>
441where
442    F: Float + FromPrimitive + Debug,
443{
444    let n_samples = data.shape()[0];
445    let n_features = data.shape()[1];
446
447    let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
448
449    for i in 0..n_samples {
451        for j in (i + 1)..n_samples {
452            let dist = match metric {
453                Metric::Euclidean => {
454                    let mut sum = F::zero();
455                    for k in 0..n_features {
456                        let diff = data[[i, k]] - data[[j, k]];
457                        sum = sum + diff * diff;
458                    }
459                    sum.sqrt()
460                }
461                Metric::Manhattan => {
462                    let mut sum = F::zero();
463                    for k in 0..n_features {
464                        let diff = (data[[i, k]] - data[[j, k]]).abs();
465                        sum = sum + diff;
466                    }
467                    sum
468                }
469                _ => {
470                    return Err(ClusteringError::InvalidInput(
471                        "Metric not yet supported for sparse epsilon graph".into(),
472                    ));
473                }
474            };
475
476            if dist <= epsilon {
477                sparse_matrix.add_distance(i, j, dist)?;
478            }
479        }
480    }
481
482    Ok(sparse_matrix)
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use scirs2_core::ndarray::Array2;
489
490    #[test]
491    fn test_sparse_distance_matrix_creation() {
492        let sparse_matrix = SparseDistanceMatrix::<f64>::new(5, 0.0);
493        assert_eq!(sparse_matrix.n_samples(), 5);
494        assert_eq!(sparse_matrix.nnz(), 0);
495        assert_eq!(sparse_matrix.sparsity(), 1.0);
496    }
497
498    #[test]
499    fn test_sparse_distance_matrix_add_distance() {
500        let mut sparse_matrix = SparseDistanceMatrix::new(3, 0.0);
501
502        sparse_matrix.add_distance(0, 1, 2.0).unwrap();
503        sparse_matrix.add_distance(1, 2, 3.0).unwrap();
504
505        assert_eq!(sparse_matrix.get_distance(0, 1), 2.0);
506        assert_eq!(sparse_matrix.get_distance(1, 0), 2.0); assert_eq!(sparse_matrix.get_distance(1, 2), 3.0);
508        assert_eq!(sparse_matrix.get_distance(0, 2), 0.0); assert_eq!(sparse_matrix.nnz(), 2);
510    }
511
512    #[test]
513    fn test_sparse_from_dense() {
514        let dense =
515            Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 5.0, 1.0, 0.0, 2.0, 5.0, 2.0, 0.0])
516                .unwrap();
517
518        let sparse = SparseDistanceMatrix::from_dense(dense.view(), 1.5);
519
520        assert_eq!(sparse.nnz(), 2);
522        assert_eq!(sparse.get_distance(0, 2), 5.0);
523        assert_eq!(sparse.get_distance(1, 2), 2.0);
524        assert_eq!(sparse.get_distance(0, 1), 0.0); }
526
527    #[test]
528    fn test_neighbors_within_distance() {
529        let mut sparse_matrix = SparseDistanceMatrix::new(4, f64::INFINITY);
530
531        sparse_matrix.add_distance(0, 1, 1.0).unwrap();
532        sparse_matrix.add_distance(0, 2, 2.5).unwrap();
533        sparse_matrix.add_distance(0, 3, 0.5).unwrap();
534
535        let neighbors = sparse_matrix.neighbors_within_distance(0, 2.0);
536
537        assert_eq!(neighbors.len(), 2);
539
540        let mut neighbor_distances: Vec<f64> = neighbors.iter().map(|(_, d)| *d).collect();
541        neighbor_distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
542        assert_eq!(neighbor_distances, vec![0.5, 1.0]);
543    }
544
545    #[test]
546    fn test_k_nearest_neighbors() {
547        let mut sparse_matrix = SparseDistanceMatrix::new(5, f64::INFINITY);
548
549        sparse_matrix.add_distance(0, 1, 3.0).unwrap();
550        sparse_matrix.add_distance(0, 2, 1.0).unwrap();
551        sparse_matrix.add_distance(0, 3, 2.0).unwrap();
552        sparse_matrix.add_distance(0, 4, 4.0).unwrap();
553
554        let knn = sparse_matrix.k_nearest_neighbors(0, 2);
555
556        assert_eq!(knn.len(), 2);
558        assert_eq!(knn[0], (2, 1.0)); assert_eq!(knn[1], (3, 2.0)); }
561
562    #[test]
563    fn test_sparse_knn_graph() {
564        let data =
565            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0]).unwrap();
566
567        let sparse_graph = sparse_knn_graph(data.view(), 2, Metric::Euclidean).unwrap();
568
569        assert!(sparse_graph.nnz() > 0);
572        assert!(sparse_graph.sparsity() > 0.0);
573    }
574
575    #[test]
576    fn test_sparse_epsilon_graph() {
577        let data =
578            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 5.0]).unwrap();
579
580        let sparse_graph = sparse_epsilon_graph(data.view(), 1.0, Metric::Euclidean).unwrap();
581
582        assert!(sparse_graph.nnz() >= 3); assert!(sparse_graph.get_distance(0, 1) <= 1.0);
588        assert!(sparse_graph.get_distance(0, 2) <= 1.0);
589    }
590
591    #[test]
592    fn test_to_dense() {
593        let mut sparse_matrix = SparseDistanceMatrix::new(3, f64::INFINITY);
594        sparse_matrix.add_distance(0, 1, 2.0).unwrap();
595        sparse_matrix.add_distance(1, 2, 3.0).unwrap();
596
597        let dense = sparse_matrix.to_dense();
598
599        assert_eq!(dense.shape(), &[3, 3]);
600        assert_eq!(dense[[0, 1]], 2.0);
601        assert_eq!(dense[[1, 0]], 2.0); assert_eq!(dense[[1, 2]], 3.0);
603        assert_eq!(dense[[2, 1]], 3.0); assert_eq!(dense[[0, 0]], 0.0); assert_eq!(dense[[0, 2]], f64::INFINITY); }
607}