1use scirs2_core::ndarray::{Array2, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12use crate::error::{ClusteringError, Result};
13use crate::hierarchy::{LinkageMethod, Metric};
14
15#[derive(Debug, Clone)]
19pub struct SparseDistanceMatrix<F: Float> {
20 rows: Vec<usize>,
22 cols: Vec<usize>,
24 data: Vec<F>,
26 n_samples: usize,
28 default_value: F,
30}
31
32impl<F: Float + FromPrimitive> SparseDistanceMatrix<F> {
33 pub fn new(n_samples: usize, default_value: F) -> Self {
35 Self {
36 rows: Vec::new(),
37 cols: Vec::new(),
38 data: Vec::new(),
39 n_samples,
40 default_value,
41 }
42 }
43
44 pub fn from_dense(dense: ArrayView2<F>, threshold: F) -> Self {
46 let n_samples = dense.shape()[0];
47 let mut rows = Vec::new();
48 let mut cols = Vec::new();
49 let mut data = Vec::new();
50
51 for i in 0..n_samples {
52 for j in (i + 1)..n_samples {
53 let distance = dense[[i, j]];
54 if distance > threshold {
55 rows.push(i);
56 cols.push(j);
57 data.push(distance);
58 }
59 }
60 }
61
62 Self {
63 rows,
64 cols,
65 data,
66 n_samples,
67 default_value: F::zero(),
68 }
69 }
70
71 pub fn add_distance(&mut self, i: usize, j: usize, distance: F) -> Result<()> {
73 if i >= self.n_samples || j >= self.n_samples {
74 return Err(ClusteringError::InvalidInput("Index out of bounds".into()));
75 }
76
77 let (row, col) = if i < j { (i, j) } else { (j, i) };
79
80 for idx in 0..self.rows.len() {
82 if self.rows[idx] == row && self.cols[idx] == col {
83 if distance < self.data[idx] {
85 self.data[idx] = distance;
86 }
87 return Ok(());
88 }
89 }
90
91 self.rows.push(row);
93 self.cols.push(col);
94 self.data.push(distance);
95
96 Ok(())
97 }
98
99 pub fn get_distance(&self, i: usize, j: usize) -> F {
101 if i == j {
102 return F::zero();
103 }
104
105 let (row, col) = if i < j { (i, j) } else { (j, i) };
106
107 for idx in 0..self.rows.len() {
109 if self.rows[idx] == row && self.cols[idx] == col {
110 return self.data[idx];
111 }
112 }
113
114 self.default_value
115 }
116
117 pub fn neighbors_within_distance(&self, point: usize, maxdistance: F) -> Vec<(usize, F)> {
119 let mut neighbors = Vec::new();
120
121 for idx in 0..self.rows.len() {
123 let (neighbor, distance) = if self.rows[idx] == point {
124 (self.cols[idx], self.data[idx])
125 } else if self.cols[idx] == point {
126 (self.rows[idx], self.data[idx])
127 } else {
128 continue;
129 };
130
131 if distance <= maxdistance {
132 neighbors.push((neighbor, distance));
133 }
134 }
135
136 neighbors
137 }
138
139 pub fn k_nearest_neighbors(&self, point: usize, k: usize) -> Vec<(usize, F)> {
141 let mut all_neighbors = Vec::new();
142
143 for idx in 0..self.rows.len() {
145 let (neighbor, distance) = if self.rows[idx] == point {
146 (self.cols[idx], self.data[idx])
147 } else if self.cols[idx] == point {
148 (self.rows[idx], self.data[idx])
149 } else {
150 continue;
151 };
152
153 all_neighbors.push((neighbor, distance));
154 }
155
156 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
158 all_neighbors.truncate(k);
159
160 all_neighbors
161 }
162
163 pub fn to_dense(&self) -> Array2<F> {
165 let mut dense = Array2::from_elem((self.n_samples, self.n_samples), self.default_value);
166
167 for i in 0..self.n_samples {
169 dense[[i, i]] = F::zero();
170 }
171
172 for idx in 0..self.rows.len() {
174 let i = self.rows[idx];
175 let j = self.cols[idx];
176 let distance = self.data[idx];
177
178 dense[[i, j]] = distance;
179 dense[[j, i]] = distance;
180 }
181
182 dense
183 }
184
185 pub fn nnz(&self) -> usize {
187 self.data.len()
188 }
189
190 pub fn sparsity(&self) -> f64 {
192 let total_entries = self.n_samples * (self.n_samples - 1) / 2;
193 1.0 - (self.nnz() as f64 / total_entries as f64)
194 }
195
196 pub fn n_samples(&self) -> usize {
198 self.n_samples
199 }
200}
201
202pub struct SparseHierarchicalClustering<F: Float> {
207 sparse_matrix: SparseDistanceMatrix<F>,
208 linkage_method: LinkageMethod,
209}
210
211impl<F: Float + FromPrimitive + Debug + PartialOrd> SparseHierarchicalClustering<F> {
212 pub fn new(sparse_matrix: SparseDistanceMatrix<F>, linkage_method: LinkageMethod) -> Self {
214 Self {
215 sparse_matrix,
216 linkage_method,
217 }
218 }
219
220 pub fn fit(&self) -> Result<Array2<F>> {
222 let n_samples = self.sparse_matrix.n_samples();
223
224 if n_samples < 2 {
225 return Err(ClusteringError::InvalidInput(
226 "Need at least 2 samples for clustering".into(),
227 ));
228 }
229
230 let mst_edges = self.minimum_spanning_tree()?;
232
233 self.mst_to_linkage(mst_edges)
235 }
236
237 fn minimum_spanning_tree(&self) -> Result<Vec<(usize, usize, F)>> {
239 let n_samples = self.sparse_matrix.n_samples();
240 let mut mst_edges = Vec::new();
241 let mut visited = vec![false; n_samples];
242 let mut min_edge: HashMap<usize, (usize, F)> = HashMap::new();
243
244 visited[0] = true;
246
247 for neighbor_idx in 0..self.sparse_matrix.rows.len() {
249 let (i, j) = (
250 self.sparse_matrix.rows[neighbor_idx],
251 self.sparse_matrix.cols[neighbor_idx],
252 );
253 let distance = self.sparse_matrix.data[neighbor_idx];
254
255 if i == 0 && !visited[j] {
256 min_edge.insert(j, (i, distance));
257 } else if j == 0 && !visited[i] {
258 min_edge.insert(i, (j, distance));
259 }
260 }
261
262 for _ in 1..n_samples {
264 let mut min_dist = F::infinity();
266 let mut min_vertex = 0;
267 let mut min_parent = 0;
268
269 for (&vertex, &(parent, distance)) in &min_edge {
270 if !visited[vertex] && distance < min_dist {
271 min_dist = distance;
272 min_vertex = vertex;
273 min_parent = parent;
274 }
275 }
276
277 if min_dist == F::infinity() {
278 min_dist = self.sparse_matrix.default_value;
280 }
281
282 mst_edges.push((min_parent, min_vertex, min_dist));
284 visited[min_vertex] = true;
285
286 for neighbor_idx in 0..self.sparse_matrix.rows.len() {
288 let (i, j) = (
289 self.sparse_matrix.rows[neighbor_idx],
290 self.sparse_matrix.cols[neighbor_idx],
291 );
292 let distance = self.sparse_matrix.data[neighbor_idx];
293
294 let (from_vertex, to_vertex) = if i == min_vertex && !visited[j] {
295 (i, j)
296 } else if j == min_vertex && !visited[i] {
297 (j, i)
298 } else {
299 continue;
300 };
301
302 match min_edge.get(&to_vertex) {
304 Some(&(_, current_dist)) if distance < current_dist => {
305 min_edge.insert(to_vertex, (from_vertex, distance));
306 }
307 None => {
308 min_edge.insert(to_vertex, (from_vertex, distance));
309 }
310 _ => {}
311 }
312 }
313 }
314
315 Ok(mst_edges)
316 }
317
318 fn mst_to_linkage(&self, mut mst_edges: Vec<(usize, usize, F)>) -> Result<Array2<F>> {
320 let n_samples = self.sparse_matrix.n_samples();
321
322 match self.linkage_method {
324 LinkageMethod::Single => {
325 mst_edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
327 }
328 _ => {
329 }
332 }
333
334 let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
335 let mut cluster_map: HashMap<usize, usize> = HashMap::new();
336 let mut next_cluster_id = n_samples;
337
338 for i in 0..n_samples {
340 cluster_map.insert(i, i);
341 }
342
343 for (step, (i, j, distance)) in mst_edges.iter().enumerate() {
344 let cluster_i = cluster_map[i];
345 let cluster_j = cluster_map[j];
346
347 linkage_matrix[[step, 0]] = F::from(cluster_i).unwrap();
349 linkage_matrix[[step, 1]] = F::from(cluster_j).unwrap();
350 linkage_matrix[[step, 2]] = *distance;
351 linkage_matrix[[step, 3]] = F::from(2).unwrap(); cluster_map.insert(*i, next_cluster_id);
355 cluster_map.insert(*j, next_cluster_id);
356 next_cluster_id += 1;
357 }
358
359 Ok(linkage_matrix)
360 }
361}
362
363#[allow(dead_code)]
365pub fn sparse_knn_graph<F>(
366 data: ArrayView2<F>,
367 k: usize,
368 metric: Metric,
369) -> Result<SparseDistanceMatrix<F>>
370where
371 F: Float + FromPrimitive + Debug,
372{
373 let n_samples = data.shape()[0];
374 let n_features = data.shape()[1];
375
376 if k >= n_samples {
377 return Err(ClusteringError::InvalidInput(
378 "k must be less than number of samples".into(),
379 ));
380 }
381
382 let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
383
384 for i in 0..n_samples {
386 let mut distances: Vec<(usize, F)> = Vec::new();
387
388 for j in 0..n_samples {
390 if i == j {
391 continue;
392 }
393
394 let dist = match metric {
395 Metric::Euclidean => {
396 let mut sum = F::zero();
397 for k in 0..n_features {
398 let diff = data[[i, k]] - data[[j, k]];
399 sum = sum + diff * diff;
400 }
401 sum.sqrt()
402 }
403 Metric::Manhattan => {
404 let mut sum = F::zero();
405 for k in 0..n_features {
406 let diff = (data[[i, k]] - data[[j, k]]).abs();
407 sum = sum + diff;
408 }
409 sum
410 }
411 _ => {
412 return Err(ClusteringError::InvalidInput(
413 "Metric not yet supported for sparse KNN".into(),
414 ));
415 }
416 };
417
418 distances.push((j, dist));
419 }
420
421 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
423 distances.truncate(k);
424
425 for (neighbor, distance) in distances {
427 sparse_matrix.add_distance(i, neighbor, distance)?;
428 }
429 }
430
431 Ok(sparse_matrix)
432}
433
434#[allow(dead_code)]
436pub fn sparse_epsilon_graph<F>(
437 data: ArrayView2<F>,
438 epsilon: F,
439 metric: Metric,
440) -> Result<SparseDistanceMatrix<F>>
441where
442 F: Float + FromPrimitive + Debug,
443{
444 let n_samples = data.shape()[0];
445 let n_features = data.shape()[1];
446
447 let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
448
449 for i in 0..n_samples {
451 for j in (i + 1)..n_samples {
452 let dist = match metric {
453 Metric::Euclidean => {
454 let mut sum = F::zero();
455 for k in 0..n_features {
456 let diff = data[[i, k]] - data[[j, k]];
457 sum = sum + diff * diff;
458 }
459 sum.sqrt()
460 }
461 Metric::Manhattan => {
462 let mut sum = F::zero();
463 for k in 0..n_features {
464 let diff = (data[[i, k]] - data[[j, k]]).abs();
465 sum = sum + diff;
466 }
467 sum
468 }
469 _ => {
470 return Err(ClusteringError::InvalidInput(
471 "Metric not yet supported for sparse epsilon graph".into(),
472 ));
473 }
474 };
475
476 if dist <= epsilon {
477 sparse_matrix.add_distance(i, j, dist)?;
478 }
479 }
480 }
481
482 Ok(sparse_matrix)
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use scirs2_core::ndarray::Array2;
489
490 #[test]
491 fn test_sparse_distance_matrix_creation() {
492 let sparse_matrix = SparseDistanceMatrix::<f64>::new(5, 0.0);
493 assert_eq!(sparse_matrix.n_samples(), 5);
494 assert_eq!(sparse_matrix.nnz(), 0);
495 assert_eq!(sparse_matrix.sparsity(), 1.0);
496 }
497
498 #[test]
499 fn test_sparse_distance_matrix_add_distance() {
500 let mut sparse_matrix = SparseDistanceMatrix::new(3, 0.0);
501
502 sparse_matrix.add_distance(0, 1, 2.0).unwrap();
503 sparse_matrix.add_distance(1, 2, 3.0).unwrap();
504
505 assert_eq!(sparse_matrix.get_distance(0, 1), 2.0);
506 assert_eq!(sparse_matrix.get_distance(1, 0), 2.0); assert_eq!(sparse_matrix.get_distance(1, 2), 3.0);
508 assert_eq!(sparse_matrix.get_distance(0, 2), 0.0); assert_eq!(sparse_matrix.nnz(), 2);
510 }
511
512 #[test]
513 fn test_sparse_from_dense() {
514 let dense =
515 Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 5.0, 1.0, 0.0, 2.0, 5.0, 2.0, 0.0])
516 .unwrap();
517
518 let sparse = SparseDistanceMatrix::from_dense(dense.view(), 1.5);
519
520 assert_eq!(sparse.nnz(), 2);
522 assert_eq!(sparse.get_distance(0, 2), 5.0);
523 assert_eq!(sparse.get_distance(1, 2), 2.0);
524 assert_eq!(sparse.get_distance(0, 1), 0.0); }
526
527 #[test]
528 fn test_neighbors_within_distance() {
529 let mut sparse_matrix = SparseDistanceMatrix::new(4, f64::INFINITY);
530
531 sparse_matrix.add_distance(0, 1, 1.0).unwrap();
532 sparse_matrix.add_distance(0, 2, 2.5).unwrap();
533 sparse_matrix.add_distance(0, 3, 0.5).unwrap();
534
535 let neighbors = sparse_matrix.neighbors_within_distance(0, 2.0);
536
537 assert_eq!(neighbors.len(), 2);
539
540 let mut neighbor_distances: Vec<f64> = neighbors.iter().map(|(_, d)| *d).collect();
541 neighbor_distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
542 assert_eq!(neighbor_distances, vec![0.5, 1.0]);
543 }
544
545 #[test]
546 fn test_k_nearest_neighbors() {
547 let mut sparse_matrix = SparseDistanceMatrix::new(5, f64::INFINITY);
548
549 sparse_matrix.add_distance(0, 1, 3.0).unwrap();
550 sparse_matrix.add_distance(0, 2, 1.0).unwrap();
551 sparse_matrix.add_distance(0, 3, 2.0).unwrap();
552 sparse_matrix.add_distance(0, 4, 4.0).unwrap();
553
554 let knn = sparse_matrix.k_nearest_neighbors(0, 2);
555
556 assert_eq!(knn.len(), 2);
558 assert_eq!(knn[0], (2, 1.0)); assert_eq!(knn[1], (3, 2.0)); }
561
562 #[test]
563 fn test_sparse_knn_graph() {
564 let data =
565 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0]).unwrap();
566
567 let sparse_graph = sparse_knn_graph(data.view(), 2, Metric::Euclidean).unwrap();
568
569 assert!(sparse_graph.nnz() > 0);
572 assert!(sparse_graph.sparsity() > 0.0);
573 }
574
575 #[test]
576 fn test_sparse_epsilon_graph() {
577 let data =
578 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 5.0]).unwrap();
579
580 let sparse_graph = sparse_epsilon_graph(data.view(), 1.0, Metric::Euclidean).unwrap();
581
582 assert!(sparse_graph.nnz() >= 3); assert!(sparse_graph.get_distance(0, 1) <= 1.0);
588 assert!(sparse_graph.get_distance(0, 2) <= 1.0);
589 }
590
591 #[test]
592 fn test_to_dense() {
593 let mut sparse_matrix = SparseDistanceMatrix::new(3, f64::INFINITY);
594 sparse_matrix.add_distance(0, 1, 2.0).unwrap();
595 sparse_matrix.add_distance(1, 2, 3.0).unwrap();
596
597 let dense = sparse_matrix.to_dense();
598
599 assert_eq!(dense.shape(), &[3, 3]);
600 assert_eq!(dense[[0, 1]], 2.0);
601 assert_eq!(dense[[1, 0]], 2.0); assert_eq!(dense[[1, 2]], 3.0);
603 assert_eq!(dense[[2, 1]], 3.0); assert_eq!(dense[[0, 0]], 0.0); assert_eq!(dense[[0, 2]], f64::INFINITY); }
607}