1use scirs2_core::ndarray::{Array2, ArrayView2};
38
39use crate::error::{ClusteringError, Result};
40
41#[derive(Debug, Clone)]
47pub struct CoreDistance {
48 pub point_idx: usize,
50 pub core_dist: Option<f64>,
53}
54
55#[derive(Debug, Clone)]
57pub struct ReachabilityPoint {
58 pub point_idx: usize,
60 pub reachability_dist: Option<f64>,
63}
64
65#[derive(Debug, Clone)]
71struct SeedEntry {
72 point_idx: usize,
74 reachability: f64,
76}
77
78impl PartialEq for SeedEntry {
79 fn eq(&self, other: &Self) -> bool {
80 self.reachability == other.reachability && self.point_idx == other.point_idx
81 }
82}
83
84impl Eq for SeedEntry {}
85
86impl PartialOrd for SeedEntry {
87 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92impl Ord for SeedEntry {
93 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
94 other
96 .reachability
97 .partial_cmp(&self.reachability)
98 .unwrap_or(std::cmp::Ordering::Equal)
99 .then(self.point_idx.cmp(&other.point_idx))
100 }
101}
102
103#[inline]
109fn sq_euclid(a: &[f64], b: &[f64]) -> f64 {
110 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
111}
112
113#[inline]
115fn euclid(a: &[f64], b: &[f64]) -> f64 {
116 sq_euclid(a, b).sqrt()
117}
118
119fn build_distance_matrix(data: ArrayView2<f64>) -> Array2<f64> {
125 let n = data.shape()[0];
126 let mut dm = Array2::<f64>::zeros((n, n));
127 for i in 0..n {
128 let ri = data.row(i).to_vec();
129 for j in (i + 1)..n {
130 let rj = data.row(j).to_vec();
131 let d = euclid(&ri, &rj);
132 dm[[i, j]] = d;
133 dm[[j, i]] = d;
134 }
135 }
136 dm
137}
138
139fn neighbours_within(point_idx: usize, dm: &Array2<f64>, max_eps: f64) -> Vec<usize> {
145 let n = dm.shape()[0];
146 (0..n)
147 .filter(|&j| j != point_idx && dm[[point_idx, j]] <= max_eps)
148 .collect()
149}
150
151fn core_distance(
156 point_idx: usize,
157 neighbours: &[usize],
158 dm: &Array2<f64>,
159 min_pts: usize,
160) -> Option<f64> {
161 if neighbours.len() + 1 < min_pts {
163 return None;
164 }
165 let mut dists: Vec<f64> = neighbours.iter().map(|&j| dm[[point_idx, j]]).collect();
166 dists.sort_by(|a: &f64, b: &f64| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
167 dists.get(min_pts.saturating_sub(2)).cloned()
169}
170
171fn update_seeds(
177 core_pt: usize,
178 core_dist: f64,
179 neighbours: &[usize],
180 dm: &Array2<f64>,
181 processed: &[bool],
182 current_reach: &mut Vec<Option<f64>>,
183 seeds: &mut std::collections::BinaryHeap<SeedEntry>,
184) {
185 for &nb in neighbours {
186 if processed[nb] {
187 continue;
188 }
189 let new_reach = core_dist.max(dm[[core_pt, nb]]);
190 let update = match current_reach[nb] {
191 None => true,
192 Some(old) => new_reach < old,
193 };
194 if update {
195 current_reach[nb] = Some(new_reach);
196 seeds.push(SeedEntry {
197 point_idx: nb,
198 reachability: new_reach,
199 });
200 }
201 }
202}
203
204pub fn optics(
225 data: ArrayView2<f64>,
226 min_pts: usize,
227 max_eps: f64,
228) -> Result<Vec<ReachabilityPoint>> {
229 let n = data.shape()[0];
230
231 if n == 0 {
232 return Err(ClusteringError::InvalidInput("Empty input data".into()));
233 }
234 if min_pts < 2 {
235 return Err(ClusteringError::InvalidInput("min_pts must be >= 2".into()));
236 }
237 if max_eps <= 0.0 {
238 return Err(ClusteringError::InvalidInput("max_eps must be > 0".into()));
239 }
240
241 let dm = build_distance_matrix(data);
242
243 let mut processed = vec![false; n];
245 let mut current_reach: Vec<Option<f64>> = vec![None; n];
247 let mut core_dists: Vec<Option<f64>> = vec![None; n];
249
250 let mut ordering: Vec<ReachabilityPoint> = Vec::with_capacity(n);
251
252 for start in 0..n {
253 if processed[start] {
254 continue;
255 }
256
257 processed[start] = true;
259 let nbrs = neighbours_within(start, &dm, max_eps);
260 let cd = core_distance(start, &nbrs, &dm, min_pts);
261 core_dists[start] = cd;
262 ordering.push(ReachabilityPoint {
263 point_idx: start,
264 reachability_dist: None,
265 });
266
267 if let Some(cd_val) = cd {
268 let mut seeds = std::collections::BinaryHeap::new();
270 update_seeds(
271 start,
272 cd_val,
273 &nbrs,
274 &dm,
275 &processed,
276 &mut current_reach,
277 &mut seeds,
278 );
279
280 while let Some(entry) = seeds.pop() {
281 let pt = entry.point_idx;
282 if processed[pt] {
283 continue;
284 }
285
286 processed[pt] = true;
287 let pt_nbrs = neighbours_within(pt, &dm, max_eps);
288 let pt_cd = core_distance(pt, &pt_nbrs, &dm, min_pts);
289 core_dists[pt] = pt_cd;
290
291 ordering.push(ReachabilityPoint {
292 point_idx: pt,
293 reachability_dist: current_reach[pt],
294 });
295
296 if let Some(pt_cd_val) = pt_cd {
297 update_seeds(
298 pt,
299 pt_cd_val,
300 &pt_nbrs,
301 &dm,
302 &processed,
303 &mut current_reach,
304 &mut seeds,
305 );
306 }
307 }
308 }
309 }
310
311 Ok(ordering)
312}
313
314pub fn extract_dbscan(reachability: &[ReachabilityPoint], eps: f64) -> Vec<i32> {
337 let n = reachability.len();
338 let mut labels = vec![-1i32; n];
339
340 let mut pos_of: Vec<usize> = vec![0; n];
343 for (pos, rp) in reachability.iter().enumerate() {
344 if rp.point_idx < n {
345 pos_of[rp.point_idx] = pos;
346 }
347 }
348
349 let mut cluster_id: i32 = -1;
350
351 for pos in 0..n {
352 let rp = &reachability[pos];
353 let reach_exceeds = match rp.reachability_dist {
354 Some(r) => r > eps,
355 None => true, };
357
358 if reach_exceeds {
359 cluster_id += 1;
370 labels[rp.point_idx] = cluster_id;
371 } else {
372 if pos > 0 {
374 let prev_idx = reachability[pos - 1].point_idx;
375 let prev_label = if prev_idx < n { labels[prev_idx] } else { -1 };
376 if prev_label >= 0 {
377 labels[rp.point_idx] = prev_label;
378 } else {
379 cluster_id += 1;
381 labels[rp.point_idx] = cluster_id;
382 }
383 }
384 }
385 }
386
387 labels
388}
389
390pub fn extract_xi_clusters(reachability: &[ReachabilityPoint], xi: f64) -> Result<Vec<i32>> {
413 if xi <= 0.0 || xi >= 1.0 {
414 return Err(ClusteringError::InvalidInput("xi must be in (0, 1)".into()));
415 }
416
417 let n = reachability.len();
418 if n == 0 {
419 return Ok(Vec::new());
420 }
421
422 let reach: Vec<f64> = reachability
424 .iter()
425 .map(|rp| rp.reachability_dist.unwrap_or(f64::INFINITY))
426 .collect();
427
428 let max_finite = reach
430 .iter()
431 .filter(|r| r.is_finite())
432 .cloned()
433 .fold(f64::NEG_INFINITY, f64::max);
434
435 let fill = if max_finite.is_finite() {
436 max_finite * 1.1 + 1.0
437 } else {
438 1.0
439 };
440
441 let rf: Vec<f64> = reach
442 .iter()
443 .map(|&r| if r.is_finite() { r } else { fill })
444 .collect();
445
446 let is_steep_down = |i: usize| -> bool {
454 if i + 1 >= n {
455 return false;
456 }
457 rf[i] > 0.0 && rf[i].is_finite() && rf[i + 1].is_finite() && rf[i] * (1.0 - xi) >= rf[i + 1]
458 };
459
460 let is_steep_up = |i: usize| -> bool {
461 if i + 1 >= n {
462 return false;
463 }
464 rf[i + 1] > 0.0
465 && rf[i].is_finite()
466 && rf[i + 1].is_finite()
467 && rf[i] * (1.0 - xi) <= rf[i + 1]
468 };
469
470 let mut sd_areas: Vec<(usize, usize, f64)> = Vec::new(); let mut i = 0;
473 while i < n.saturating_sub(1) {
474 if is_steep_down(i) {
475 let s = i;
476 let mut e = i;
477 while e + 1 < n && is_steep_down(e) {
478 e += 1;
479 }
480 sd_areas.push((s, e, rf[s]));
481 i = e + 1;
482 } else {
483 i += 1;
484 }
485 }
486
487 let mut su_areas: Vec<(usize, usize, f64)> = Vec::new(); let mut i = 0;
490 while i < n.saturating_sub(1) {
491 if is_steep_up(i) {
492 let s = i;
493 let mut e = i;
494 while e + 1 < n && is_steep_up(e) {
495 e += 1;
496 }
497 let end_reach_idx = if e + 1 < n { e + 1 } else { e };
498 su_areas.push((s, e, rf[end_reach_idx]));
499 i = e + 1;
500 } else {
501 i += 1;
502 }
503 }
504
505 let mut cluster_ranges: Vec<(usize, usize)> = Vec::new();
509
510 for &(sd_s, sd_e, sd_r) in &sd_areas {
511 for &(su_s, su_e, su_r) in &su_areas {
512 if su_s <= sd_e {
514 continue;
515 }
516 let interior_lo = sd_e + 1;
518 let interior_hi = su_s;
519 if interior_lo >= interior_hi {
520 continue;
521 }
522
523 let r_high = sd_r.max(su_r);
525 let r_low = sd_r.min(su_r);
526 if r_high <= 0.0 || r_low / r_high < (1.0 - xi).powi(2) {
527 continue;
528 }
529
530 let int_min = rf[interior_lo..interior_hi]
532 .iter()
533 .cloned()
534 .filter(|v| v.is_finite())
535 .fold(f64::INFINITY, f64::min);
536
537 if int_min < r_high {
538 let cluster_end = (su_e + 1).min(n - 1);
539 cluster_ranges.push((sd_s, cluster_end));
540 break; }
542 }
543 }
544
545 cluster_ranges.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
547
548 let mut keep = vec![true; cluster_ranges.len()];
549 for outer in 0..cluster_ranges.len() {
550 if !keep[outer] {
551 continue;
552 }
553 for inner in (outer + 1)..cluster_ranges.len() {
554 if !keep[inner] {
555 continue;
556 }
557 let (os, oe) = cluster_ranges[outer];
558 let (is, ie) = cluster_ranges[inner];
559 if is >= os && ie <= oe {
560 keep[inner] = false;
561 }
562 }
563 }
564
565 let valid_clusters: Vec<(usize, usize)> = cluster_ranges
566 .iter()
567 .zip(keep.iter())
568 .filter_map(|(&r, &k)| if k { Some(r) } else { None })
569 .collect();
570
571 let mut labels = vec![-1i32; n];
573 for (cid, &(range_s, range_e)) in valid_clusters.iter().enumerate() {
574 for pos in range_s..=range_e.min(n - 1) {
575 let orig = reachability[pos].point_idx;
576 if orig < n && labels[orig] < 0 {
577 labels[orig] = cid as i32;
578 }
579 }
580 }
581
582 Ok(labels)
583}
584
585pub fn reachability_plot(optics_result: &[ReachabilityPoint]) -> (Vec<f64>, Vec<f64>) {
599 let x: Vec<f64> = (0..optics_result.len()).map(|i| i as f64).collect();
600 let y: Vec<f64> = optics_result
601 .iter()
602 .map(|rp| rp.reachability_dist.unwrap_or(f64::INFINITY))
603 .collect();
604 (x, y)
605}
606
607pub fn compute_core_distances(
621 data: ArrayView2<f64>,
622 min_pts: usize,
623 max_eps: f64,
624) -> Result<Vec<CoreDistance>> {
625 let n = data.shape()[0];
626 if n == 0 {
627 return Ok(Vec::new());
628 }
629 let dm = build_distance_matrix(data);
630 let result = (0..n)
631 .map(|i| {
632 let nbrs = neighbours_within(i, &dm, max_eps);
633 let cd = core_distance(i, &nbrs, &dm, min_pts);
634 CoreDistance {
635 point_idx: i,
636 core_dist: cd,
637 }
638 })
639 .collect();
640 Ok(result)
641}
642
643#[cfg(test)]
648mod tests {
649 use super::*;
650 use scirs2_core::ndarray::Array2;
651
652 fn two_cluster_data() -> Array2<f64> {
654 Array2::from_shape_vec(
655 (14, 2),
656 vec![
657 1.0, 2.0, 1.1, 1.9, 0.9, 2.1, 1.2, 1.8, 0.8, 2.0, 1.0, 2.2, 1.15, 1.85,
659 8.0, 8.0, 8.1, 7.9, 7.9, 8.1, 8.2, 7.8, 7.8, 8.0, 8.0, 8.2, 8.15, 7.85,
661 ],
662 )
663 .expect("shape ok")
664 }
665
666 #[test]
669 fn test_optics_produces_full_ordering() {
670 let data = two_cluster_data();
671 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
672 assert_eq!(ord.len(), 14, "every point must appear in ordering");
673 let mut seen = vec![false; 14];
674 for rp in &ord {
675 assert!(!seen[rp.point_idx], "duplicate index {}", rp.point_idx);
676 seen[rp.point_idx] = true;
677 }
678 assert!(seen.iter().all(|&s| s), "missing indices in ordering");
679 }
680
681 #[test]
682 fn test_optics_first_point_has_no_reachability() {
683 let data = two_cluster_data();
684 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
685 assert!(
688 ord[0].reachability_dist.is_none(),
689 "first ordering entry should have reachability = None"
690 );
691 }
692
693 #[test]
694 fn test_optics_within_cluster_reachabilities_small() {
695 let data = two_cluster_data();
696 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
697 let mut prev_cluster: Option<usize> = None;
700 let mut within_reaches: Vec<f64> = Vec::new();
701
702 for rp in &ord {
703 let cluster = if rp.point_idx < 7 { 0 } else { 1 };
704 if prev_cluster == Some(cluster) {
705 if let Some(r) = rp.reachability_dist {
706 within_reaches.push(r);
707 }
708 }
709 prev_cluster = Some(cluster);
710 }
711
712 if !within_reaches.is_empty() {
713 let avg: f64 = within_reaches.iter().sum::<f64>() / within_reaches.len() as f64;
714 assert!(
715 avg < 2.0,
716 "expected small within-cluster reach, got {}",
717 avg
718 );
719 }
720 }
721
722 #[test]
723 fn test_optics_max_eps_restricts_reachability() {
724 let data = two_cluster_data();
725 let ord = optics(data.view(), 2, 0.01).expect("optics");
727 assert_eq!(ord.len(), 14);
728 let all_none = ord.iter().all(|rp| rp.reachability_dist.is_none());
730 assert!(all_none, "with tiny max_eps every point is isolated");
731 }
732
733 #[test]
734 fn test_optics_single_point() {
735 let data = Array2::from_shape_vec((1, 2), vec![3.0, 4.0]).expect("shape");
736 let ord = optics(data.view(), 2, f64::INFINITY).expect("optics");
737 assert_eq!(ord.len(), 1);
738 assert_eq!(ord[0].point_idx, 0);
739 assert!(ord[0].reachability_dist.is_none());
740 }
741
742 #[test]
743 fn test_optics_error_empty() {
744 let data = Array2::<f64>::zeros((0, 2));
745 assert!(optics(data.view(), 2, f64::INFINITY).is_err());
746 }
747
748 #[test]
749 fn test_optics_error_min_pts_too_small() {
750 let data = two_cluster_data();
751 assert!(optics(data.view(), 1, f64::INFINITY).is_err());
752 }
753
754 #[test]
755 fn test_optics_error_non_positive_max_eps() {
756 let data = two_cluster_data();
757 assert!(optics(data.view(), 3, 0.0).is_err());
758 assert!(optics(data.view(), 3, -1.0).is_err());
759 }
760
761 #[test]
764 fn test_extract_dbscan_two_clusters() {
765 let data = two_cluster_data();
766 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
767 let labels = extract_dbscan(&ord, 0.5);
769 assert_eq!(labels.len(), 14);
770 let a_labels: Vec<i32> = (0..7).map(|i| labels[i]).collect();
772 let b_labels: Vec<i32> = (7..14).map(|i| labels[i]).collect();
773 assert!(a_labels.iter().all(|&l| l >= 0));
775 assert!(b_labels.iter().all(|&l| l >= 0));
776 let a_mode = *a_labels
778 .iter()
779 .max_by_key(|&&l| a_labels.iter().filter(|&&x| x == l).count())
780 .expect("a has labels");
781 let b_mode = *b_labels
782 .iter()
783 .max_by_key(|&&l| b_labels.iter().filter(|&&x| x == l).count())
784 .expect("b has labels");
785 assert_ne!(a_mode, b_mode, "clusters should receive distinct labels");
786 }
787
788 #[test]
789 fn test_extract_dbscan_all_noise_small_eps() {
790 let data = two_cluster_data();
791 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
792 let labels = extract_dbscan(&ord, 1e-10);
794 assert_eq!(labels.len(), 14);
797 }
798
799 #[test]
800 fn test_extract_dbscan_empty_ordering() {
801 let labels = extract_dbscan(&[], 0.5);
802 assert!(labels.is_empty());
803 }
804
805 #[test]
808 fn test_extract_xi_returns_correct_length() {
809 let data = two_cluster_data();
810 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
811 let labels = extract_xi_clusters(&ord, 0.05).expect("xi");
812 assert_eq!(labels.len(), 14);
813 }
814
815 #[test]
816 fn test_extract_xi_labels_valid_range() {
817 let data = two_cluster_data();
818 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
819 let labels = extract_xi_clusters(&ord, 0.1).expect("xi");
820 assert!(labels.iter().all(|&l| l >= -1), "labels must be >= -1");
821 }
822
823 #[test]
824 fn test_extract_xi_error_invalid_xi() {
825 let data = two_cluster_data();
826 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
827 assert!(extract_xi_clusters(&ord, 0.0).is_err());
828 assert!(extract_xi_clusters(&ord, 1.0).is_err());
829 assert!(extract_xi_clusters(&ord, -0.1).is_err());
830 assert!(extract_xi_clusters(&ord, 1.5).is_err());
831 }
832
833 #[test]
834 fn test_extract_xi_empty_ordering() {
835 let labels = extract_xi_clusters(&[], 0.1).expect("xi empty");
836 assert!(labels.is_empty());
837 }
838
839 #[test]
842 fn test_reachability_plot_lengths() {
843 let data = two_cluster_data();
844 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
845 let (xs, ys) = reachability_plot(&ord);
846 assert_eq!(xs.len(), 14);
847 assert_eq!(ys.len(), 14);
848 }
849
850 #[test]
851 fn test_reachability_plot_x_sequential() {
852 let data = two_cluster_data();
853 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
854 let (xs, _ys) = reachability_plot(&ord);
855 for (i, &x) in xs.iter().enumerate() {
856 assert!(
857 (x - i as f64).abs() < 1e-12,
858 "x[{}] should be {}, got {}",
859 i,
860 i,
861 x
862 );
863 }
864 }
865
866 #[test]
867 fn test_reachability_plot_none_becomes_infinity() {
868 let data = two_cluster_data();
869 let ord = optics(data.view(), 3, f64::INFINITY).expect("optics");
870 let (_, ys) = reachability_plot(&ord);
871 assert!(ys[0].is_infinite(), "component root should be INFINITY");
873 }
874
875 #[test]
876 fn test_reachability_plot_empty() {
877 let (xs, ys) = reachability_plot(&[]);
878 assert!(xs.is_empty());
879 assert!(ys.is_empty());
880 }
881
882 #[test]
885 fn test_core_distances_length() {
886 let data = two_cluster_data();
887 let cds = compute_core_distances(data.view(), 3, f64::INFINITY).expect("cds");
888 assert_eq!(cds.len(), 14);
889 }
890
891 #[test]
892 fn test_core_distances_dense_cluster_are_core() {
893 let data = two_cluster_data();
894 let cds = compute_core_distances(data.view(), 3, f64::INFINITY).expect("cds");
896 let n_core = cds.iter().filter(|cd| cd.core_dist.is_some()).count();
897 assert!(
898 n_core >= 10,
899 "most points should be core points, got {}",
900 n_core
901 );
902 }
903
904 #[test]
905 fn test_core_distances_tiny_eps_no_cores() {
906 let data = two_cluster_data();
907 let cds = compute_core_distances(data.view(), 3, 1e-15).expect("cds");
909 let n_core = cds.iter().filter(|cd| cd.core_dist.is_some()).count();
910 assert_eq!(n_core, 0, "no cores expected with tiny eps");
911 }
912
913 #[test]
914 fn test_core_distances_empty_data() {
915 let data = Array2::<f64>::zeros((0, 2));
916 let cds = compute_core_distances(data.view(), 3, f64::INFINITY).expect("cds empty");
917 assert!(cds.is_empty());
918 }
919}