1use crate::error::{Result, TransformError};
27use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
28use scirs2_core::numeric::{Float, NumCast};
29use std::collections::HashMap;
30use std::fmt;
31
32#[derive(Debug, Clone, PartialEq)]
37pub struct PersistencePoint {
38 pub birth: f64,
40 pub death: f64,
43 pub dimension: usize,
45}
46
47impl PersistencePoint {
48 pub fn new(birth: f64, death: f64, dimension: usize) -> Self {
50 Self {
51 birth,
52 death,
53 dimension,
54 }
55 }
56
57 pub fn persistence(&self) -> f64 {
59 if self.death.is_infinite() {
60 f64::INFINITY
61 } else {
62 self.death - self.birth
63 }
64 }
65
66 pub fn is_essential(&self) -> bool {
68 self.death.is_infinite()
69 }
70}
71
72impl fmt::Display for PersistencePoint {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 if self.death.is_infinite() {
75 write!(f, "H{}({:.4}, ∞)", self.dimension, self.birth)
76 } else {
77 write!(
78 f,
79 "H{}({:.4}, {:.4})",
80 self.dimension, self.birth, self.death
81 )
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
91pub struct PersistenceDiagram {
92 pub points: Vec<PersistencePoint>,
94 pub max_dimension: usize,
96}
97
98impl PersistenceDiagram {
99 pub fn new(max_dimension: usize) -> Self {
101 Self {
102 points: Vec::new(),
103 max_dimension,
104 }
105 }
106
107 pub fn add_point(&mut self, birth: f64, death: f64, dimension: usize) {
109 self.points
110 .push(PersistencePoint::new(birth, death, dimension));
111 }
112
113 pub fn points_in_dimension(&self, dim: usize) -> Vec<&PersistencePoint> {
115 self.points.iter().filter(|p| p.dimension == dim).collect()
116 }
117
118 pub fn finite_points(&self) -> Vec<&PersistencePoint> {
120 self.points.iter().filter(|p| !p.is_essential()).collect()
121 }
122
123 pub fn essential_points(&self) -> Vec<&PersistencePoint> {
125 self.points.iter().filter(|p| p.is_essential()).collect()
126 }
127
128 pub fn total_persistence(&self, p: f64) -> f64 {
130 self.points
131 .iter()
132 .filter(|pt| !pt.is_essential())
133 .map(|pt| pt.persistence().powf(p))
134 .sum::<f64>()
135 .powf(1.0 / p)
136 }
137
138 pub fn filter_by_persistence(&self, min_persistence: f64) -> PersistenceDiagram {
140 let filtered_points: Vec<PersistencePoint> = self
141 .points
142 .iter()
143 .filter(|p| p.persistence() >= min_persistence)
144 .cloned()
145 .collect();
146
147 PersistenceDiagram {
148 points: filtered_points,
149 max_dimension: self.max_dimension,
150 }
151 }
152
153 pub fn len(&self) -> usize {
155 self.points.len()
156 }
157
158 pub fn is_empty(&self) -> bool {
160 self.points.is_empty()
161 }
162
163 pub fn to_barcode(&self) -> Barcode {
165 Barcode::from_diagram(self)
166 }
167
168 pub fn betti_numbers_at(&self, filtration_value: f64) -> Vec<usize> {
170 let mut betti = vec![0usize; self.max_dimension + 1];
171 for p in &self.points {
172 if p.birth <= filtration_value && (p.death > filtration_value || p.death.is_infinite())
173 {
174 if p.dimension <= self.max_dimension {
175 betti[p.dimension] += 1;
176 }
177 }
178 }
179 betti
180 }
181}
182
183#[derive(Debug, Clone, PartialEq)]
187pub struct BarcodeInterval {
188 pub birth: f64,
190 pub death: f64,
192 pub dimension: usize,
194}
195
196impl BarcodeInterval {
197 pub fn length(&self) -> f64 {
199 if self.death.is_infinite() {
200 f64::INFINITY
201 } else {
202 self.death - self.birth
203 }
204 }
205}
206
207#[derive(Debug, Clone)]
212pub struct Barcode {
213 pub intervals: Vec<BarcodeInterval>,
215 pub max_dimension: usize,
217}
218
219impl Barcode {
220 pub fn from_diagram(diagram: &PersistenceDiagram) -> Self {
222 let intervals: Vec<BarcodeInterval> = diagram
223 .points
224 .iter()
225 .map(|p| BarcodeInterval {
226 birth: p.birth,
227 death: p.death,
228 dimension: p.dimension,
229 })
230 .collect();
231
232 Barcode {
233 intervals,
234 max_dimension: diagram.max_dimension,
235 }
236 }
237
238 pub fn intervals_in_dimension(&self, dim: usize) -> Vec<&BarcodeInterval> {
240 let mut intervals: Vec<&BarcodeInterval> = self
241 .intervals
242 .iter()
243 .filter(|i| i.dimension == dim)
244 .collect();
245 intervals.sort_by(|a, b| {
246 a.birth
247 .partial_cmp(&b.birth)
248 .unwrap_or(std::cmp::Ordering::Equal)
249 });
250 intervals
251 }
252
253 pub fn len(&self) -> usize {
255 self.intervals.len()
256 }
257
258 pub fn is_empty(&self) -> bool {
260 self.intervals.is_empty()
261 }
262}
263
264#[derive(Debug, Clone, PartialEq)]
268struct FilteredSimplex {
269 vertices: Vec<usize>,
271 filtration_value: f64,
273}
274
275impl FilteredSimplex {
276 fn new(vertices: Vec<usize>, filtration_value: f64) -> Self {
277 let mut v = vertices;
278 v.sort_unstable();
279 Self {
280 vertices: v,
281 filtration_value,
282 }
283 }
284
285 fn dimension(&self) -> usize {
286 self.vertices.len().saturating_sub(1)
287 }
288}
289
290struct BoundaryMatrix {
294 columns: Vec<Vec<usize>>,
296 pivots: Vec<i64>,
298}
299
300impl BoundaryMatrix {
301 fn new(n_cols: usize) -> Self {
302 Self {
303 columns: vec![Vec::new(); n_cols],
304 pivots: vec![-1i64; n_cols],
305 }
306 }
307
308 fn set_column(&mut self, col: usize, boundary: Vec<usize>) {
310 let mut b = boundary;
311 b.sort_unstable();
312 b.dedup();
313 let pivot = b.last().copied().map(|v| v as i64).unwrap_or(-1);
314 self.columns[col] = b;
315 self.pivots[col] = pivot;
316 }
317
318 fn pivot(&self, j: usize) -> i64 {
320 self.pivots[j]
321 }
322
323 fn add_column(&mut self, target: usize, source: usize) {
325 let src = self.columns[source].clone();
326 let tgt = self.columns[target].clone();
327
328 let mut result = Vec::with_capacity(src.len() + tgt.len());
330 let (mut i, mut j) = (0, 0);
331 while i < src.len() && j < tgt.len() {
332 match src[i].cmp(&tgt[j]) {
333 std::cmp::Ordering::Less => {
334 result.push(src[i]);
335 i += 1;
336 }
337 std::cmp::Ordering::Greater => {
338 result.push(tgt[j]);
339 j += 1;
340 }
341 std::cmp::Ordering::Equal => {
342 i += 1;
344 j += 1;
345 }
346 }
347 }
348 result.extend_from_slice(&src[i..]);
349 result.extend_from_slice(&tgt[j..]);
350
351 let pivot = result.last().copied().map(|v| v as i64).unwrap_or(-1);
352 self.columns[target] = result;
353 self.pivots[target] = pivot;
354 }
355
356 fn reduce(&mut self) {
358 let n = self.columns.len();
359 let mut pivot_to_col: HashMap<usize, usize> = HashMap::new();
361
362 for j in 0..n {
363 loop {
364 let piv = self.pivot(j);
365 if piv < 0 {
366 break; }
368 let piv_row = piv as usize;
369 if let Some(&k) = pivot_to_col.get(&piv_row) {
370 self.add_column(j, k);
372 } else {
373 pivot_to_col.insert(piv_row, j);
375 break;
376 }
377 }
378 }
379 }
380}
381
382pub struct VietorisRips;
404
405impl VietorisRips {
406 pub fn compute<S>(
416 points: &ArrayBase<S, Ix2>,
417 max_dim: usize,
418 max_radius: f64,
419 ) -> Result<PersistenceDiagram>
420 where
421 S: Data,
422 S::Elem: Float + NumCast,
423 {
424 let n = points.nrows();
425 if n < 2 {
426 return Err(TransformError::InvalidInput(
427 "VietorisRips requires at least 2 points".to_string(),
428 ));
429 }
430
431 let dist = Self::compute_distance_matrix(points)?;
433
434 let mut filtered_simplices = Self::build_filtration(&dist, max_dim, max_radius);
436
437 filtered_simplices.sort_by(|a, b| {
439 a.filtration_value
440 .partial_cmp(&b.filtration_value)
441 .unwrap_or(std::cmp::Ordering::Equal)
442 .then(a.dimension().cmp(&b.dimension()))
443 });
444
445 let n_simplices = filtered_simplices.len();
447
448 let mut simplex_to_idx: HashMap<Vec<usize>, usize> = HashMap::new();
450 for (i, s) in filtered_simplices.iter().enumerate() {
451 simplex_to_idx.insert(s.vertices.clone(), i);
452 }
453
454 let mut bm = BoundaryMatrix::new(n_simplices);
456 for (j, simplex) in filtered_simplices.iter().enumerate() {
457 if simplex.dimension() == 0 {
458 continue;
460 }
461 let mut boundary_indices = Vec::with_capacity(simplex.vertices.len());
463 for k in 0..simplex.vertices.len() {
464 let face: Vec<usize> = simplex
465 .vertices
466 .iter()
467 .enumerate()
468 .filter(|(i, _)| *i != k)
469 .map(|(_, &v)| v)
470 .collect();
471 if let Some(&idx) = simplex_to_idx.get(&face) {
472 boundary_indices.push(idx);
473 }
474 }
475 bm.set_column(j, boundary_indices);
476 }
477
478 bm.reduce();
480
481 let mut diagram = PersistenceDiagram::new(max_dim);
483
484 let mut killed = vec![false; n_simplices];
486
487 for j in 0..n_simplices {
488 let piv = bm.pivot(j);
489 if piv >= 0 {
490 let i = piv as usize;
491 killed[i] = true;
493 let dim_creator = filtered_simplices[i].dimension();
494 if dim_creator <= max_dim {
495 let birth = filtered_simplices[i].filtration_value;
496 let death = filtered_simplices[j].filtration_value;
497 if (death - birth).abs() > 1e-12 {
498 diagram.add_point(birth, death, dim_creator);
499 }
500 }
501 }
502 }
503
504 for i in 0..n_simplices {
506 if !killed[i] && bm.pivot(i) < 0 {
507 let dim = filtered_simplices[i].dimension();
508 if dim <= max_dim {
509 let birth = filtered_simplices[i].filtration_value;
510 diagram.add_point(birth, f64::INFINITY, dim);
514 }
515 }
516 }
517
518 Ok(diagram)
519 }
520
521 fn compute_distance_matrix<S>(points: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
523 where
524 S: Data,
525 S::Elem: Float + NumCast,
526 {
527 let n = points.nrows();
528 let mut dist = Array2::zeros((n, n));
529
530 for i in 0..n {
531 for j in (i + 1)..n {
532 let mut d_sq = 0.0f64;
533 for k in 0..points.ncols() {
534 let diff = NumCast::from(points[[i, k]]).unwrap_or(0.0)
535 - NumCast::from(points[[j, k]]).unwrap_or(0.0);
536 d_sq += diff * diff;
537 }
538 let d = d_sq.sqrt();
539 dist[[i, j]] = d;
540 dist[[j, i]] = d;
541 }
542 }
543
544 Ok(dist)
545 }
546
547 fn build_filtration(
549 dist: &Array2<f64>,
550 max_dim: usize,
551 max_radius: f64,
552 ) -> Vec<FilteredSimplex> {
553 let n = dist.nrows();
554 let max_diam = 2.0 * max_radius;
555 let mut simplices = Vec::new();
556
557 for i in 0..n {
559 simplices.push(FilteredSimplex::new(vec![i], 0.0));
560 }
561
562 let mut prev_dim_simplices: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
568
569 for dim in 1..=(max_dim + 1) {
570 let mut next_dim_simplices: Vec<Vec<usize>> = Vec::new();
571
572 for simplex in &prev_dim_simplices {
575 let last_vertex = *simplex.last().unwrap_or(&0);
576
577 for v in (last_vertex + 1)..n {
578 let max_dist_to_v =
580 simplex.iter().map(|&u| dist[[u, v]]).fold(0.0f64, f64::max);
581
582 if max_dist_to_v <= max_diam {
583 let mut new_simplex = simplex.clone();
584 new_simplex.push(v);
585
586 let filtration_val = Self::simplex_diameter(&new_simplex, dist);
588 simplices.push(FilteredSimplex::new(new_simplex.clone(), filtration_val));
589 next_dim_simplices.push(new_simplex);
590 }
591 }
592 }
593
594 if next_dim_simplices.is_empty() {
595 break;
596 }
597 prev_dim_simplices = next_dim_simplices;
598 }
599
600 simplices
601 }
602
603 fn simplex_diameter(vertices: &[usize], dist: &Array2<f64>) -> f64 {
605 let mut max_d = 0.0f64;
606 for i in 0..vertices.len() {
607 for j in (i + 1)..vertices.len() {
608 let d = dist[[vertices[i], vertices[j]]];
609 if d > max_d {
610 max_d = d;
611 }
612 }
613 }
614 max_d
615 }
616}
617
618pub struct PersistenceImage {
630 resolution: usize,
632 birth_range: (f64, f64),
634 persistence_range: (f64, f64),
636 sigma: f64,
638 weight_type: PersistenceWeight,
640 dimension: usize,
642}
643
644#[derive(Debug, Clone)]
646pub enum PersistenceWeight {
647 Uniform,
649 Linear,
651 Arctan,
653 Threshold(f64),
655}
656
657impl PersistenceImage {
658 pub fn new(
666 resolution: usize,
667 dimension: usize,
668 sigma: f64,
669 weight_type: PersistenceWeight,
670 ) -> Result<Self> {
671 if resolution == 0 {
672 return Err(TransformError::InvalidInput(
673 "Resolution must be positive".to_string(),
674 ));
675 }
676 if sigma <= 0.0 {
677 return Err(TransformError::InvalidInput(
678 "Sigma must be positive".to_string(),
679 ));
680 }
681 Ok(Self {
682 resolution,
683 birth_range: (0.0, 1.0),
684 persistence_range: (0.0, 1.0),
685 sigma,
686 weight_type,
687 dimension,
688 })
689 }
690
691 pub fn compute(diagram: &PersistenceDiagram, resolution: usize) -> Result<Array2<f64>> {
700 if resolution == 0 {
701 return Err(TransformError::InvalidInput(
702 "Resolution must be positive".to_string(),
703 ));
704 }
705
706 let img = PersistenceImage::new(resolution, 0, 0.1, PersistenceWeight::Linear)?;
707 img.transform(diagram)
708 }
709
710 pub fn transform(&self, diagram: &PersistenceDiagram) -> Result<Array2<f64>> {
712 let pts: Vec<(f64, f64)> = diagram
714 .points
715 .iter()
716 .filter(|p| p.dimension == self.dimension && !p.is_essential())
717 .map(|p| (p.birth, p.persistence()))
718 .collect();
719
720 if pts.is_empty() {
721 return Ok(Array2::zeros((self.resolution, self.resolution)));
722 }
723
724 let b_min = self.birth_range.0;
726 let b_max = self
727 .birth_range
728 .1
729 .max(pts.iter().map(|(b, _)| *b).fold(0.0_f64, f64::max));
730 let p_min = self.persistence_range.0;
731 let p_max = self
732 .persistence_range
733 .1
734 .max(pts.iter().map(|(_, p)| *p).fold(0.0_f64, f64::max));
735
736 let b_range = (b_max - b_min).max(1e-10);
737 let p_range = (p_max - p_min).max(1e-10);
738 let cell_size_b = b_range / self.resolution as f64;
739 let cell_size_p = p_range / self.resolution as f64;
740
741 let mut image = Array2::<f64>::zeros((self.resolution, self.resolution));
742 let norm_factor = 1.0 / (2.0 * std::f64::consts::PI * self.sigma * self.sigma);
743
744 for &(birth, pers) in &pts {
745 let weight = match &self.weight_type {
747 PersistenceWeight::Uniform => 1.0,
748 PersistenceWeight::Linear => pers,
749 PersistenceWeight::Arctan => pers.atan(),
750 PersistenceWeight::Threshold(t) => {
751 if pers >= *t {
752 1.0
753 } else {
754 pers / t
755 }
756 }
757 };
758
759 for i in 0..self.resolution {
761 let cell_b = b_min + (i as f64 + 0.5) * cell_size_b;
762 for j in 0..self.resolution {
763 let cell_p = p_min + (j as f64 + 0.5) * cell_size_p;
764 let db = (cell_b - birth) / self.sigma;
765 let dp = (cell_p - pers) / self.sigma;
766 let gauss = norm_factor * (-0.5 * (db * db + dp * dp)).exp();
767 image[[i, j]] += weight * gauss * cell_size_b * cell_size_p;
768 }
769 }
770 }
771
772 Ok(image)
773 }
774
775 pub fn with_birth_range(mut self, min: f64, max: f64) -> Self {
777 self.birth_range = (min, max);
778 self
779 }
780
781 pub fn with_persistence_range(mut self, min: f64, max: f64) -> Self {
783 self.persistence_range = (min, max);
784 self
785 }
786}
787
788pub fn bottleneck_distance(d1: &PersistenceDiagram, d2: &PersistenceDiagram) -> f64 {
803 let pts1: Vec<(f64, f64)> = d1
805 .points
806 .iter()
807 .filter(|p| !p.is_essential())
808 .map(|p| (p.birth, p.death))
809 .collect();
810
811 let pts2: Vec<(f64, f64)> = d2
812 .points
813 .iter()
814 .filter(|p| !p.is_essential())
815 .map(|p| (p.birth, p.death))
816 .collect();
817
818 bottleneck_distance_between(&pts1, &pts2)
819}
820
821pub fn bottleneck_distance_dim(
823 d1: &PersistenceDiagram,
824 d2: &PersistenceDiagram,
825 dim: usize,
826) -> f64 {
827 let pts1: Vec<(f64, f64)> = d1
828 .points
829 .iter()
830 .filter(|p| p.dimension == dim && !p.is_essential())
831 .map(|p| (p.birth, p.death))
832 .collect();
833
834 let pts2: Vec<(f64, f64)> = d2
835 .points
836 .iter()
837 .filter(|p| p.dimension == dim && !p.is_essential())
838 .map(|p| (p.birth, p.death))
839 .collect();
840
841 bottleneck_distance_between(&pts1, &pts2)
842}
843
844pub fn wasserstein_distance(d1: &PersistenceDiagram, d2: &PersistenceDiagram, p: f64) -> f64 {
854 let pts1: Vec<(f64, f64)> = d1
855 .points
856 .iter()
857 .filter(|p| !p.is_essential())
858 .map(|pt| (pt.birth, pt.death))
859 .collect();
860
861 let pts2: Vec<(f64, f64)> = d2
862 .points
863 .iter()
864 .filter(|p| !p.is_essential())
865 .map(|pt| (pt.birth, pt.death))
866 .collect();
867
868 wasserstein_distance_between(&pts1, &pts2, p)
869}
870
871fn bottleneck_distance_between(pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
874 let diag_dist = |(b, d): (f64, f64)| -> f64 { (d - b) / 2.0 };
876
877 let point_dist = |(b1, d1): (f64, f64), (b2, d2): (f64, f64)| -> f64 {
879 (b1 - b2).abs().max((d1 - d2).abs())
880 };
881
882 if pts1.is_empty() && pts2.is_empty() {
884 return 0.0;
885 }
886
887 let mut candidates = Vec::new();
889
890 for &p1 in pts1 {
891 for &p2 in pts2 {
892 candidates.push(point_dist(p1, p2));
893 }
894 candidates.push(diag_dist(p1));
895 }
896 for &p2 in pts2 {
897 candidates.push(diag_dist(p2));
898 }
899 candidates.push(0.0);
900
901 candidates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
903 candidates.dedup_by(|a, b| (*a - *b).abs() < 1e-14);
904
905 let mut lo = 0;
907 let mut hi = candidates.len().saturating_sub(1);
908 let mut result = candidates.last().copied().unwrap_or(0.0);
909
910 while lo <= hi {
911 let mid = (lo + hi) / 2;
912 let delta = candidates[mid];
913
914 if is_feasible_bottleneck(pts1, pts2, delta) {
915 result = delta;
916 if mid == 0 {
917 break;
918 }
919 hi = mid - 1;
920 } else {
921 lo = mid + 1;
922 }
923 }
924
925 result
926}
927
928fn is_feasible_bottleneck(pts1: &[(f64, f64)], pts2: &[(f64, f64)], delta: f64) -> bool {
931 let diag_dist = |(b, d): (f64, f64)| -> f64 { (d - b) / 2.0 };
932 let point_dist = |(b1, d1): (f64, f64), (b2, d2): (f64, f64)| -> f64 {
933 (b1 - b2).abs().max((d1 - d2).abs())
934 };
935
936 let n = pts1.len();
938 let m = pts2.len();
939
940 let mut matched2 = vec![false; m];
944 let mut matched1 = vec![false; n];
945
946 let mut assignment: Vec<Option<usize>> = vec![None; n];
948
949 for i in 0..n {
950 for j in 0..m {
951 if !matched2[j] && point_dist(pts1[i], pts2[j]) <= delta {
952 assignment[i] = Some(j);
953 matched2[j] = true;
954 matched1[i] = true;
955 break;
956 }
957 }
958 }
959
960 for i in 0..n {
962 if !matched1[i] && diag_dist(pts1[i]) > delta {
963 return false;
964 }
965 }
966
967 for j in 0..m {
969 if !matched2[j] && diag_dist(pts2[j]) > delta {
970 return false;
971 }
972 }
973
974 true
975}
976
977fn wasserstein_distance_between(pts1: &[(f64, f64)], pts2: &[(f64, f64)], p: f64) -> f64 {
979 let diag_dist = |(b, d): (f64, f64)| -> f64 { (d - b) / 2.0 };
980 let point_dist_lp = |(b1, d1): (f64, f64), (b2, d2): (f64, f64), p: f64| -> f64 {
981 (b1 - b2).abs().max((d1 - d2).abs()).powf(p)
983 };
984
985 let n = pts1.len();
987 let m = pts2.len();
988
989 let mut total_cost = 0.0f64;
991
992 let mut matched2 = vec![false; m];
994
995 for i in 0..n {
996 let diag_cost = diag_dist(pts1[i]).powf(p);
998 let mut best_cost = diag_cost;
999 let mut best_j = None;
1000
1001 for j in 0..m {
1002 if !matched2[j] {
1003 let cost = point_dist_lp(pts1[i], pts2[j], p);
1004 if cost < best_cost {
1005 best_cost = cost;
1006 best_j = Some(j);
1007 }
1008 }
1009 }
1010
1011 if let Some(j) = best_j {
1012 matched2[j] = true;
1013 }
1014 total_cost += best_cost;
1015 }
1016
1017 for j in 0..m {
1019 if !matched2[j] {
1020 total_cost += diag_dist(pts2[j]).powf(p);
1021 }
1022 }
1023
1024 total_cost.powf(1.0 / p)
1025}
1026
1027#[derive(Debug, Clone)]
1034pub struct PersistenceLandscape {
1035 n_landscapes: usize,
1037 dimension: usize,
1039 pub landscapes: Array2<f64>,
1041 pub grid: Array1<f64>,
1043}
1044
1045impl PersistenceLandscape {
1046 pub fn compute(
1057 diagram: &PersistenceDiagram,
1058 n_landscapes: usize,
1059 n_grid_points: usize,
1060 dimension: usize,
1061 ) -> Result<Self> {
1062 if n_landscapes == 0 {
1063 return Err(TransformError::InvalidInput(
1064 "n_landscapes must be positive".to_string(),
1065 ));
1066 }
1067 if n_grid_points < 2 {
1068 return Err(TransformError::InvalidInput(
1069 "n_grid_points must be at least 2".to_string(),
1070 ));
1071 }
1072
1073 let pts: Vec<(f64, f64)> = diagram
1074 .points
1075 .iter()
1076 .filter(|p| p.dimension == dimension && !p.is_essential())
1077 .map(|p| (p.birth, p.death))
1078 .collect();
1079
1080 if pts.is_empty() {
1081 let grid = Array1::linspace(0.0, 1.0, n_grid_points);
1082 return Ok(Self {
1083 n_landscapes,
1084 dimension,
1085 landscapes: Array2::zeros((n_landscapes, n_grid_points)),
1086 grid,
1087 });
1088 }
1089
1090 let t_min = pts.iter().map(|(b, _)| *b).fold(f64::INFINITY, f64::min);
1092 let t_max = pts.iter().map(|(_, d)| *d).fold(0.0_f64, f64::max);
1093 let grid = Array1::linspace(t_min, t_max, n_grid_points);
1094
1095 let mut landscapes = Array2::<f64>::zeros((n_landscapes, n_grid_points));
1096
1097 for (g_idx, &t) in grid.iter().enumerate() {
1098 let mut tent_values: Vec<f64> = pts
1100 .iter()
1101 .map(|&(b, d)| {
1102 if t <= (b + d) / 2.0 {
1103 (t - b).max(0.0)
1104 } else {
1105 (d - t).max(0.0)
1106 }
1107 })
1108 .collect();
1109
1110 tent_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1112
1113 for k in 0..n_landscapes {
1114 landscapes[[k, g_idx]] = tent_values.get(k).copied().unwrap_or(0.0);
1115 }
1116 }
1117
1118 Ok(Self {
1119 n_landscapes,
1120 dimension,
1121 landscapes,
1122 grid,
1123 })
1124 }
1125
1126 pub fn l2_norm(&self, k: usize) -> f64 {
1128 if k >= self.n_landscapes {
1129 return 0.0;
1130 }
1131 let row = self.landscapes.row(k);
1132 row.iter().map(|&v| v * v).sum::<f64>().sqrt()
1133 }
1134
1135 pub fn inner_product(&self, other: &Self) -> f64 {
1137 let n = self.landscapes.shape()[1].min(other.landscapes.shape()[1]);
1138 let k = self.n_landscapes.min(other.n_landscapes);
1139 let mut sum = 0.0;
1140 for i in 0..k {
1141 for j in 0..n {
1142 sum += self.landscapes[[i, j]] * other.landscapes[[i, j]];
1143 }
1144 }
1145 sum
1146 }
1147}
1148
1149#[cfg(test)]
1152mod tests {
1153 use super::*;
1154 use scirs2_core::ndarray::Array2;
1155
1156 fn square_points() -> Array2<f64> {
1157 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0])
1158 .expect("shape ok")
1159 }
1160
1161 #[test]
1162 fn test_vietoris_rips_h0() {
1163 let pts = square_points();
1164 let diagram = VietorisRips::compute(&pts, 0, 2.0).expect("vr compute");
1165 let h0_pts = diagram.points_in_dimension(0);
1167 assert!(!h0_pts.is_empty(), "Should have H0 features");
1168 }
1169
1170 #[test]
1171 fn test_vietoris_rips_h1() {
1172 let pts = square_points();
1173 let diagram = VietorisRips::compute(&pts, 1, 2.0).expect("vr compute");
1174 let h1_pts = diagram.points_in_dimension(1);
1176 let _finite_h1: Vec<_> = h1_pts.iter().filter(|p| !p.is_essential()).collect();
1178 assert!(!diagram.is_empty(), "Diagram should not be empty");
1180 }
1181
1182 #[test]
1183 fn test_persistence_point_persistence() {
1184 let p = PersistencePoint::new(0.5, 1.5, 0);
1185 assert!((p.persistence() - 1.0).abs() < 1e-10);
1186 assert!(!p.is_essential());
1187
1188 let q = PersistencePoint::new(0.5, f64::INFINITY, 0);
1189 assert!(q.is_essential());
1190 assert!(q.persistence().is_infinite());
1191 }
1192
1193 #[test]
1194 fn test_persistence_diagram_filter() {
1195 let mut diagram = PersistenceDiagram::new(1);
1196 diagram.add_point(0.0, 0.01, 0); diagram.add_point(0.0, 1.0, 0); diagram.add_point(0.2, 0.8, 1); let filtered = diagram.filter_by_persistence(0.5);
1201 assert_eq!(filtered.len(), 2); }
1203
1204 #[test]
1205 fn test_barcode_from_diagram() {
1206 let mut diagram = PersistenceDiagram::new(1);
1207 diagram.add_point(0.0, 1.0, 0);
1208 diagram.add_point(0.5, 0.9, 1);
1209
1210 let barcode = diagram.to_barcode();
1211 assert_eq!(barcode.len(), 2);
1212 assert_eq!(barcode.intervals_in_dimension(0).len(), 1);
1213 assert_eq!(barcode.intervals_in_dimension(1).len(), 1);
1214 assert!((barcode.intervals_in_dimension(0)[0].length() - 1.0).abs() < 1e-10);
1215 }
1216
1217 #[test]
1218 fn test_persistence_image() {
1219 let mut diagram = PersistenceDiagram::new(1);
1220 diagram.add_point(0.0, 1.0, 0);
1221 diagram.add_point(0.2, 0.8, 0);
1222
1223 let image = PersistenceImage::compute(&diagram, 10).expect("pi compute");
1224 assert_eq!(image.shape(), &[10, 10]);
1225 assert!(image.iter().all(|&v| v >= 0.0));
1227 assert!(image.iter().any(|&v| v > 0.0));
1229 }
1230
1231 #[test]
1232 fn test_bottleneck_distance_same_diagram() {
1233 let mut diagram = PersistenceDiagram::new(0);
1234 diagram.add_point(0.0, 1.0, 0);
1235 diagram.add_point(0.5, 0.9, 0);
1236
1237 let dist = bottleneck_distance(&diagram, &diagram);
1239 assert!(dist < 1e-10, "Self-distance should be 0, got {}", dist);
1240 }
1241
1242 #[test]
1243 fn test_bottleneck_distance_empty_diagrams() {
1244 let d1 = PersistenceDiagram::new(0);
1245 let d2 = PersistenceDiagram::new(0);
1246 let dist = bottleneck_distance(&d1, &d2);
1247 assert!(dist < 1e-10);
1248 }
1249
1250 #[test]
1251 fn test_bottleneck_distance_different_diagrams() {
1252 let mut d1 = PersistenceDiagram::new(0);
1253 d1.add_point(0.0, 1.0, 0);
1254
1255 let mut d2 = PersistenceDiagram::new(0);
1256 d2.add_point(0.0, 0.5, 0);
1257
1258 let dist = bottleneck_distance(&d1, &d2);
1259 assert!(
1264 dist > 0.0,
1265 "Different diagrams should have positive distance"
1266 );
1267 }
1268
1269 #[test]
1270 fn test_persistence_landscape() {
1271 let mut diagram = PersistenceDiagram::new(0);
1272 diagram.add_point(0.0, 2.0, 0);
1273 diagram.add_point(0.5, 1.5, 0);
1274
1275 let landscape =
1276 PersistenceLandscape::compute(&diagram, 2, 20, 0).expect("landscape compute");
1277 assert_eq!(landscape.landscapes.shape(), &[2, 20]);
1278 assert!(landscape.landscapes.row(0).iter().all(|&v| v >= -1e-10));
1280 assert!(landscape.l2_norm(0) > 0.0);
1281 }
1282
1283 #[test]
1284 fn test_wasserstein_distance() {
1285 let mut d1 = PersistenceDiagram::new(0);
1286 d1.add_point(0.0, 1.0, 0);
1287
1288 let mut d2 = PersistenceDiagram::new(0);
1289 d2.add_point(0.0, 1.0, 0);
1290
1291 let wd = wasserstein_distance(&d1, &d2, 1.0);
1293 assert!(wd < 1e-10, "Identical diagrams: W=0, got {}", wd);
1294 }
1295
1296 #[test]
1297 fn test_betti_numbers() {
1298 let mut diagram = PersistenceDiagram::new(1);
1299 diagram.add_point(0.0, f64::INFINITY, 0); diagram.add_point(0.3, 0.7, 1); let betti = diagram.betti_numbers_at(0.5);
1303 assert_eq!(betti[0], 1); assert_eq!(betti[1], 1); let betti_early = diagram.betti_numbers_at(0.1);
1307 assert_eq!(betti_early[1], 0); }
1309
1310 #[test]
1311 fn test_vietoris_rips_small_radius() {
1312 let pts = square_points();
1313 let diagram = VietorisRips::compute(&pts, 0, 0.1).expect("vr compute");
1315 let h0_pts = diagram.points_in_dimension(0);
1316 assert!(!h0_pts.is_empty());
1318 }
1319
1320 #[test]
1321 fn test_total_persistence() {
1322 let mut diagram = PersistenceDiagram::new(0);
1323 diagram.add_point(0.0, 1.0, 0);
1324 diagram.add_point(0.0, 3.0, 0);
1325
1326 let tp = diagram.total_persistence(2.0);
1327 assert!((tp - (10.0f64).sqrt()).abs() < 1e-10);
1329 }
1330
1331 #[test]
1332 fn test_persistence_image_custom() {
1333 let mut diagram = PersistenceDiagram::new(0);
1334 diagram.add_point(0.0, 1.0, 0);
1335 diagram.add_point(0.2, 0.8, 0);
1336
1337 let img_computer = PersistenceImage::new(5, 0, 0.2, PersistenceWeight::Arctan)
1338 .expect("pi new")
1339 .with_birth_range(0.0, 1.0)
1340 .with_persistence_range(0.0, 1.0);
1341
1342 let image = img_computer.transform(&diagram).expect("pi transform");
1343 assert_eq!(image.shape(), &[5, 5]);
1344 assert!(image.iter().all(|&v| v >= 0.0));
1345 }
1346}