1use crate::error::{Result, TransformError};
21use crate::tda_vr::{PersistenceDiagram, VietorisRips};
22use scirs2_core::ndarray::Array2;
23
24#[derive(Debug, Clone)]
48pub struct VietorisRipsComplex {
49 pub points: Vec<Vec<f64>>,
51 pub epsilon: f64,
53 pub simplices: Vec<Vec<usize>>,
55}
56
57impl VietorisRipsComplex {
58 pub fn new(points: &[Vec<f64>], epsilon: f64) -> Result<Self> {
66 if points.is_empty() {
67 return Ok(Self {
68 points: Vec::new(),
69 epsilon,
70 simplices: Vec::new(),
71 });
72 }
73 if epsilon < 0.0 {
74 return Err(TransformError::InvalidInput(
75 "epsilon must be non-negative".to_string(),
76 ));
77 }
78 let n = points.len();
79 let dim = points[0].len();
80
81 let dist = pairwise_distances(points, dim);
83
84 let mut simplices: Vec<Vec<usize>> = Vec::new();
85
86 for i in 0..n {
88 simplices.push(vec![i]);
89 }
90
91 for i in 0..n {
93 for j in (i + 1)..n {
94 if dist[i][j] <= epsilon {
95 simplices.push(vec![i, j]);
96 }
97 }
98 }
99
100 for i in 0..n {
102 for j in (i + 1)..n {
103 if dist[i][j] > epsilon {
104 continue;
105 }
106 for k in (j + 1)..n {
107 if dist[i][k] <= epsilon && dist[j][k] <= epsilon {
108 simplices.push(vec![i, j, k]);
109 }
110 }
111 }
112 }
113
114 simplices.sort_by(|a, b| a.len().cmp(&b.len()).then_with(|| a.cmp(b)));
116
117 Ok(Self {
118 points: points.to_vec(),
119 epsilon,
120 simplices,
121 })
122 }
123
124 pub fn n_simplices(&self, dim: usize) -> usize {
128 self.simplices.iter().filter(|s| s.len() == dim + 1).count()
129 }
130
131 pub fn euler_characteristic(&self) -> i64 {
134 let mut chi = 0_i64;
135 for simplex in &self.simplices {
136 let k = simplex.len() as i64 - 1;
137 if k % 2 == 0 {
138 chi += 1;
139 } else {
140 chi -= 1;
141 }
142 }
143 chi
144 }
145
146 pub fn are_connected(&self, u: usize, v: usize) -> bool {
148 let edge = if u < v { vec![u, v] } else { vec![v, u] };
149 self.simplices.contains(&edge)
150 }
151
152 pub fn edges(&self) -> Vec<(usize, usize)> {
154 self.simplices
155 .iter()
156 .filter(|s| s.len() == 2)
157 .map(|s| (s[0], s[1]))
158 .collect()
159 }
160}
161
162pub fn compute_persistence(
194 distance_matrix: &[Vec<f64>],
195 max_dim: usize,
196 max_epsilon: f64,
197) -> Result<Vec<PersistenceDiagram>> {
198 let n = distance_matrix.len();
199 if n == 0 {
200 return Ok((0..=max_dim).map(|d| PersistenceDiagram::new(d)).collect());
202 }
203
204 for row in distance_matrix {
206 if row.len() != n {
207 return Err(TransformError::InvalidInput(
208 "distance_matrix must be square".to_string(),
209 ));
210 }
211 }
212 if max_epsilon < 0.0 {
213 return Err(TransformError::InvalidInput(
214 "max_epsilon must be non-negative".to_string(),
215 ));
216 }
217
218 let mut filt_values: Vec<f64> = Vec::new();
234 for i in 0..n {
235 for j in (i + 1)..n {
236 let d = distance_matrix[i][j];
237 if d <= max_epsilon && d >= 0.0 {
238 filt_values.push(d);
239 }
240 }
241 }
242 filt_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
243 filt_values.dedup_by(|a, b| (*a - *b).abs() < 1e-15);
244
245 #[derive(Clone)]
251 struct FiltSimplex {
252 vertices: Vec<usize>,
253 filtration: f64,
254 }
255
256 let mut simplices: Vec<FiltSimplex> = Vec::new();
257
258 for i in 0..n {
260 simplices.push(FiltSimplex {
261 vertices: vec![i],
262 filtration: 0.0,
263 });
264 }
265
266 for i in 0..n {
268 for j in (i + 1)..n {
269 let d = distance_matrix[i][j];
270 if d <= max_epsilon {
271 simplices.push(FiltSimplex {
272 vertices: vec![i, j],
273 filtration: d,
274 });
275 }
276 }
277 }
278
279 if max_dim >= 1 {
281 for i in 0..n {
282 for j in (i + 1)..n {
283 let d_ij = distance_matrix[i][j];
284 if d_ij > max_epsilon {
285 continue;
286 }
287 for k in (j + 1)..n {
288 let d_ik = distance_matrix[i][k];
289 let d_jk = distance_matrix[j][k];
290 if d_ik > max_epsilon || d_jk > max_epsilon {
291 continue;
292 }
293 let max_d = d_ij.max(d_ik).max(d_jk);
294 simplices.push(FiltSimplex {
295 vertices: vec![i, j, k],
296 filtration: max_d,
297 });
298 }
299 }
300 }
301 }
302
303 if max_dim >= 2 {
305 for i in 0..n {
306 for j in (i + 1)..n {
307 let d_ij = distance_matrix[i][j];
308 if d_ij > max_epsilon {
309 continue;
310 }
311 for k in (j + 1)..n {
312 let d_ik = distance_matrix[i][k];
313 let d_jk = distance_matrix[j][k];
314 if d_ik > max_epsilon || d_jk > max_epsilon {
315 continue;
316 }
317 for l in (k + 1)..n {
318 let d_il = distance_matrix[i][l];
319 let d_jl = distance_matrix[j][l];
320 let d_kl = distance_matrix[k][l];
321 if d_il > max_epsilon || d_jl > max_epsilon || d_kl > max_epsilon {
322 continue;
323 }
324 let max_d = d_ij.max(d_ik).max(d_jk).max(d_il).max(d_jl).max(d_kl);
325 simplices.push(FiltSimplex {
326 vertices: vec![i, j, k, l],
327 filtration: max_d,
328 });
329 }
330 }
331 }
332 }
333 }
334
335 simplices.sort_by(|a, b| {
337 a.filtration
338 .partial_cmp(&b.filtration)
339 .unwrap_or(std::cmp::Ordering::Equal)
340 .then_with(|| a.vertices.len().cmp(&b.vertices.len()))
341 });
342
343 let total = simplices.len();
345 let simplex_idx: std::collections::HashMap<Vec<usize>, usize> = simplices
346 .iter()
347 .enumerate()
348 .map(|(i, s)| (s.vertices.clone(), i))
349 .collect();
350
351 let mut boundary: Vec<Vec<usize>> = vec![Vec::new(); total];
354 for (j, simp) in simplices.iter().enumerate() {
355 let d = simp.vertices.len();
356 if d <= 1 {
357 continue; }
359 for omit in 0..d {
361 let face: Vec<usize> = simp
362 .vertices
363 .iter()
364 .enumerate()
365 .filter(|&(i, _)| i != omit)
366 .map(|(_, &v)| v)
367 .collect();
368 if let Some(&row_idx) = simplex_idx.get(&face) {
369 boundary[j].push(row_idx);
370 }
371 }
372 boundary[j].sort_unstable();
373 }
374
375 let mut low: Vec<Option<usize>> = vec![None; total];
378 let mut pivot_col: Vec<Option<usize>> = vec![None; total];
380
381 for j in 0..total {
382 loop {
383 let lo = boundary[j].last().copied();
384 match lo {
385 None => break,
386 Some(r) => {
387 if let Some(k) = pivot_col[r] {
388 let bk = boundary[k].clone();
390 sym_diff_inplace(&mut boundary[j], &bk);
391 } else {
392 low[j] = Some(r);
393 pivot_col[r] = Some(j);
394 break;
395 }
396 }
397 }
398 }
399 }
400
401 let mut diagrams: Vec<PersistenceDiagram> =
403 (0..=max_dim).map(|d| PersistenceDiagram::new(d)).collect();
404
405 let mut paired: Vec<bool> = vec![false; total];
406
407 for j in 0..total {
408 if let Some(r) = low[j] {
409 let birth = simplices[r].filtration;
410 let death = simplices[j].filtration;
411 let feature_dim = simplices[r].vertices.len() - 1;
413 if feature_dim <= max_dim && (death - birth).abs() > 1e-15 {
414 diagrams[feature_dim].add_point(birth, death, feature_dim);
415 }
416 paired[r] = true;
417 paired[j] = true;
418 }
419 }
420
421 for i in 0..total {
423 if !paired[i] {
424 let dim = simplices[i].vertices.len() - 1;
425 if dim <= max_dim {
426 diagrams[dim].add_point(simplices[i].filtration, f64::INFINITY, dim);
427 }
428 }
429 }
430
431 Ok(diagrams)
432}
433
434pub fn persistence_landscape_fn(
467 dgm: &PersistenceDiagram,
468 n_layers: usize,
469 x: &[f64],
470) -> Vec<Vec<f64>> {
471 if n_layers == 0 || x.is_empty() {
472 return vec![vec![0.0; x.len()]; n_layers];
473 }
474
475 let finite_pts: Vec<(f64, f64)> = dgm
477 .points
478 .iter()
479 .filter(|p| p.death.is_finite())
480 .map(|p| (p.birth, p.death))
481 .collect();
482
483 let nx = x.len();
484 let mut landscape = vec![vec![0.0_f64; nx]; n_layers];
486
487 for (ix, &t) in x.iter().enumerate() {
488 let mut tents: Vec<f64> = finite_pts
490 .iter()
491 .map(|&(b, d)| {
492 let v = (t - b).min(d - t);
493 if v < 0.0 {
494 0.0
495 } else {
496 v
497 }
498 })
499 .collect();
500 tents.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
502 for k in 0..n_layers {
503 landscape[k][ix] = tents.get(k).copied().unwrap_or(0.0);
504 }
505 }
506
507 landscape
508}
509
510pub fn persistence_image_fn(
547 dgm: &PersistenceDiagram,
548 bandwidth: f64,
549 grid: (usize, usize),
550 max_birth: f64,
551 max_persistence: f64,
552) -> Vec<Vec<f64>> {
553 let (n_rows, n_cols) = grid;
554 if n_rows == 0 || n_cols == 0 {
555 return vec![];
556 }
557
558 let bw = bandwidth.max(1e-10);
559 let two_bw_sq = 2.0 * bw * bw;
560 let norm_factor = 1.0 / (std::f64::consts::TAU * bw * bw);
561
562 let row_centers: Vec<f64> = if n_rows == 1 {
564 vec![max_persistence * 0.5]
565 } else {
566 (0..n_rows)
567 .map(|i| max_persistence * i as f64 / (n_rows - 1) as f64)
568 .collect()
569 };
570 let col_centers: Vec<f64> = if n_cols == 1 {
571 vec![max_birth * 0.5]
572 } else {
573 (0..n_cols)
574 .map(|j| max_birth * j as f64 / (n_cols - 1) as f64)
575 .collect()
576 };
577
578 let pts: Vec<(f64, f64, f64)> = dgm
580 .points
581 .iter()
582 .filter(|p| p.death.is_finite() && p.death > p.birth)
583 .map(|p| (p.birth, p.death - p.birth, p.death - p.birth)) .collect();
585
586 let mut image = vec![vec![0.0_f64; n_cols]; n_rows];
587
588 for (r, &p_center) in row_centers.iter().enumerate() {
589 for (c, &b_center) in col_centers.iter().enumerate() {
590 let mut val = 0.0_f64;
591 for &(b, pers, weight) in &pts {
592 let db = b_center - b;
593 let dp = p_center - pers;
594 let exponent = -(db * db + dp * dp) / two_bw_sq;
595 val += weight * norm_factor * exponent.exp();
596 }
597 image[r][c] = val;
598 }
599 }
600
601 image
602}
603
604fn pairwise_distances(points: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
608 let n = points.len();
609 let mut dist = vec![vec![0.0_f64; n]; n];
610 for i in 0..n {
611 for j in (i + 1)..n {
612 let mut sq = 0.0_f64;
613 for d in 0..dim.min(points[i].len()).min(points[j].len()) {
614 let diff = points[i][d] - points[j][d];
615 sq += diff * diff;
616 }
617 let d = sq.sqrt();
618 dist[i][j] = d;
619 dist[j][i] = d;
620 }
621 }
622 dist
623}
624
625fn sym_diff_inplace(a: &mut Vec<usize>, b: &[usize]) {
627 let mut result = Vec::with_capacity(a.len() + b.len());
628 let mut ai = 0_usize;
629 let mut bi = 0_usize;
630 while ai < a.len() && bi < b.len() {
631 match a[ai].cmp(&b[bi]) {
632 std::cmp::Ordering::Less => {
633 result.push(a[ai]);
634 ai += 1;
635 }
636 std::cmp::Ordering::Greater => {
637 result.push(b[bi]);
638 bi += 1;
639 }
640 std::cmp::Ordering::Equal => {
641 ai += 1;
643 bi += 1;
644 }
645 }
646 }
647 while ai < a.len() {
648 result.push(a[ai]);
649 ai += 1;
650 }
651 while bi < b.len() {
652 result.push(b[bi]);
653 bi += 1;
654 }
655 *a = result;
656}
657
658#[cfg(test)]
661mod tests {
662 use super::*;
663 use crate::tda_vr::PersistenceDiagram;
664
665 fn square_dist() -> Vec<Vec<f64>> {
666 vec![
668 vec![0.0, 1.0, 1.414, 1.0],
669 vec![1.0, 0.0, 1.0, 1.414],
670 vec![1.414, 1.0, 0.0, 1.0],
671 vec![1.0, 1.414, 1.0, 0.0],
672 ]
673 }
674
675 fn square_points() -> Vec<Vec<f64>> {
676 vec![
677 vec![0.0, 0.0],
678 vec![1.0, 0.0],
679 vec![1.0, 1.0],
680 vec![0.0, 1.0],
681 ]
682 }
683
684 #[test]
687 fn test_vrc_vertices() {
688 let pts = square_points();
689 let vrc = VietorisRipsComplex::new(&pts, 1.5).expect("new");
690 assert_eq!(vrc.n_simplices(0), 4, "Should have 4 vertices");
691 }
692
693 #[test]
694 fn test_vrc_edges_unit_square() {
695 let pts = square_points();
696 let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
698 assert_eq!(vrc.n_simplices(1), 4, "Unit square at eps=1 has 4 edges");
699 assert_eq!(vrc.n_simplices(2), 0, "No triangles at eps=1");
700 }
701
702 #[test]
703 fn test_vrc_complete_graph() {
704 let pts = square_points();
705 let vrc = VietorisRipsComplex::new(&pts, 2.0).expect("new");
707 assert_eq!(
708 vrc.n_simplices(1),
709 6,
710 "Complete graph on 4 vertices has 6 edges"
711 );
712 assert_eq!(vrc.n_simplices(2), 4, "4 triangles in K4");
713 }
714
715 #[test]
716 fn test_vrc_euler_characteristic() {
717 let pts = square_points();
718 let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
720 assert_eq!(vrc.euler_characteristic(), 0);
721 }
722
723 #[test]
724 fn test_vrc_empty_input() {
725 let vrc = VietorisRipsComplex::new(&[], 1.0).expect("empty ok");
726 assert_eq!(vrc.n_simplices(0), 0);
727 assert_eq!(vrc.euler_characteristic(), 0);
728 }
729
730 #[test]
731 fn test_vrc_negative_epsilon_error() {
732 let pts = square_points();
733 assert!(VietorisRipsComplex::new(&pts, -0.1).is_err());
734 }
735
736 #[test]
737 fn test_vrc_are_connected() {
738 let pts = square_points();
739 let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
740 assert!(vrc.are_connected(0, 1));
742 assert!(vrc.are_connected(1, 2));
743 assert!(!vrc.are_connected(0, 2));
745 }
746
747 #[test]
750 fn test_compute_persistence_h0_square() {
751 let dist = square_dist();
752 let diagrams = compute_persistence(&dist, 1, 2.0).expect("persistence");
753 assert_eq!(diagrams.len(), 2);
754 let h0 = &diagrams[0];
755 assert!(!h0.is_empty(), "H0 should not be empty");
757 }
758
759 #[test]
760 fn test_compute_persistence_empty() {
761 let diagrams = compute_persistence(&[], 1, 1.0).expect("empty");
762 assert_eq!(diagrams.len(), 2); assert!(diagrams[0].is_empty());
764 assert!(diagrams[1].is_empty());
765 }
766
767 #[test]
768 fn test_compute_persistence_non_square_error() {
769 let dist = vec![vec![0.0, 1.0], vec![1.0, 0.0, 2.0]];
770 assert!(compute_persistence(&dist, 1, 2.0).is_err());
771 }
772
773 #[test]
774 fn test_compute_persistence_returns_finite_pairs() {
775 let dist = square_dist();
776 let diagrams = compute_persistence(&dist, 1, 2.0).expect("persistence");
777 for dgm in &diagrams {
778 for pt in &dgm.points {
779 assert!(pt.birth.is_finite());
780 assert!(pt.birth >= 0.0);
781 if pt.death.is_finite() {
782 assert!(pt.death >= pt.birth);
783 }
784 }
785 }
786 }
787
788 #[test]
791 fn test_landscape_fn_shape() {
792 let mut dgm = PersistenceDiagram::new(0);
793 dgm.add_point(0.0, 2.0, 0);
794 dgm.add_point(0.5, 1.5, 0);
795 let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
796 let l = persistence_landscape_fn(&dgm, 3, &x);
797 assert_eq!(l.len(), 3);
798 assert_eq!(l[0].len(), 20);
799 }
800
801 #[test]
802 fn test_landscape_fn_non_negative() {
803 let mut dgm = PersistenceDiagram::new(0);
804 dgm.add_point(0.0, 1.0, 0);
805 let x: Vec<f64> = (0..10).map(|i| i as f64 * 0.15).collect();
806 let l = persistence_landscape_fn(&dgm, 2, &x);
807 for row in &l {
808 for &v in row {
809 assert!(v >= 0.0, "landscape must be non-negative, got {v}");
810 }
811 }
812 }
813
814 #[test]
815 fn test_landscape_fn_tent_shape() {
816 let mut dgm = PersistenceDiagram::new(0);
817 dgm.add_point(0.0, 1.0, 0);
819 let x = vec![0.0, 0.25, 0.5, 0.75, 1.0];
820 let l = persistence_landscape_fn(&dgm, 1, &x);
821 assert!((l[0][2] - 0.5).abs() < 1e-10, "peak should be 0.5");
823 assert!(l[0][0] < 1e-10);
825 assert!(l[0][4] < 1e-10);
826 }
827
828 #[test]
829 fn test_landscape_fn_empty_diagram() {
830 let dgm = PersistenceDiagram::new(0);
831 let x = vec![0.0, 1.0, 2.0];
832 let l = persistence_landscape_fn(&dgm, 2, &x);
833 assert_eq!(l.len(), 2);
834 for row in &l {
835 assert!(row.iter().all(|&v| v == 0.0));
836 }
837 }
838
839 #[test]
842 fn test_persistence_image_fn_shape() {
843 let mut dgm = PersistenceDiagram::new(0);
844 dgm.add_point(0.0, 1.0, 0);
845 dgm.add_point(0.2, 0.8, 0);
846 let img = persistence_image_fn(&dgm, 0.1, (5, 5), 1.0, 1.0);
847 assert_eq!(img.len(), 5);
848 assert_eq!(img[0].len(), 5);
849 }
850
851 #[test]
852 fn test_persistence_image_fn_non_negative() {
853 let mut dgm = PersistenceDiagram::new(0);
854 dgm.add_point(0.0, 1.0, 0);
855 let img = persistence_image_fn(&dgm, 0.1, (4, 4), 1.0, 1.0);
856 for row in &img {
857 for &v in row {
858 assert!(v >= 0.0, "image pixel must be non-negative, got {v}");
859 }
860 }
861 }
862
863 #[test]
864 fn test_persistence_image_fn_has_signal() {
865 let mut dgm = PersistenceDiagram::new(0);
866 dgm.add_point(0.0, 1.0, 0);
867 let img = persistence_image_fn(&dgm, 0.15, (6, 6), 1.5, 1.5);
868 let has_positive = img.iter().flat_map(|row| row.iter()).any(|&v| v > 0.0);
869 assert!(
870 has_positive,
871 "image should have nonzero pixels for a nonempty diagram"
872 );
873 }
874
875 #[test]
876 fn test_persistence_image_fn_empty_diagram() {
877 let dgm = PersistenceDiagram::new(0);
878 let img = persistence_image_fn(&dgm, 0.1, (4, 4), 1.0, 1.0);
879 assert_eq!(img.len(), 4);
880 for row in &img {
881 assert!(row.iter().all(|&v| v == 0.0));
882 }
883 }
884
885 #[test]
886 fn test_sym_diff_inplace() {
887 let mut a = vec![1_usize, 3, 5];
888 let b = vec![2, 3, 4];
889 sym_diff_inplace(&mut a, &b);
890 assert_eq!(a, vec![1, 2, 4, 5]);
892 }
893}