Skip to main content

scirs2_transform/
tda_vr.rs

1//! Topological Data Analysis (TDA) and Persistent Homology
2//!
3//! This module provides implementations of persistent homology for analyzing
4//! topological features of data. Key concepts include:
5//!
6//! - **Simplicial complexes**: Geometric objects built from vertices, edges, triangles, etc.
7//! - **Filtrations**: Nested sequences of simplicial complexes parameterized by a scale
8//! - **Persistence diagrams**: Collections of (birth, death) pairs representing topological features
9//! - **Barcodes**: Interval representations of persistent homology
10//! - **Persistence images**: Stable vectorizations of persistence diagrams
11//!
12//! ## Algorithms
13//!
14//! - Vietoris-Rips complex construction via distance-based filtration
15//! - Boundary matrix reduction for computing persistent homology
16//! - Bottleneck distance between persistence diagrams
17//! - Wasserstein distance between persistence diagrams
18//! - Persistence image vectorization
19//!
20//! ## References
21//!
22//! - Edelsbrunner, H., Letscher, D., & Zomorodian, A. (2002). Topological persistence and simplification.
23//! - Zomorodian, A., & Carlsson, G. (2005). Computing persistent homology.
24//! - Adams, H., et al. (2017). Persistence images: A stable vector representation of persistent homology.
25
26use crate::error::{Result, TransformError};
27use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
28use scirs2_core::numeric::{Float, NumCast};
29use std::collections::HashMap;
30use std::fmt;
31
32// ─── Core Data Structures ────────────────────────────────────────────────────
33
34/// A single persistence point (birth, death) pair in a persistence diagram.
35/// The dimension indicates which homological dimension this feature belongs to.
36#[derive(Debug, Clone, PartialEq)]
37pub struct PersistencePoint {
38    /// Birth time (filtration value at which the feature appears)
39    pub birth: f64,
40    /// Death time (filtration value at which the feature disappears),
41    /// or f64::INFINITY for essential features
42    pub death: f64,
43    /// Homological dimension (0 = components, 1 = loops, 2 = voids, ...)
44    pub dimension: usize,
45}
46
47impl PersistencePoint {
48    /// Create a new persistence point
49    pub fn new(birth: f64, death: f64, dimension: usize) -> Self {
50        Self {
51            birth,
52            death,
53            dimension,
54        }
55    }
56
57    /// Compute the persistence (lifetime) of this feature
58    pub fn persistence(&self) -> f64 {
59        if self.death.is_infinite() {
60            f64::INFINITY
61        } else {
62            self.death - self.birth
63        }
64    }
65
66    /// Check if this is an essential feature (never dies)
67    pub fn is_essential(&self) -> bool {
68        self.death.is_infinite()
69    }
70}
71
72impl fmt::Display for PersistencePoint {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        if self.death.is_infinite() {
75            write!(f, "H{}({:.4}, ∞)", self.dimension, self.birth)
76        } else {
77            write!(
78                f,
79                "H{}({:.4}, {:.4})",
80                self.dimension, self.birth, self.death
81            )
82        }
83    }
84}
85
86/// Persistence diagram: a collection of (birth, death) pairs per dimension.
87///
88/// A persistence diagram captures the topological features of a dataset
89/// across all scales, organized by homological dimension.
90#[derive(Debug, Clone)]
91pub struct PersistenceDiagram {
92    /// All persistence points across all dimensions
93    pub points: Vec<PersistencePoint>,
94    /// Maximum homological dimension computed
95    pub max_dimension: usize,
96}
97
98impl PersistenceDiagram {
99    /// Create an empty persistence diagram
100    pub fn new(max_dimension: usize) -> Self {
101        Self {
102            points: Vec::new(),
103            max_dimension,
104        }
105    }
106
107    /// Add a persistence point to the diagram
108    pub fn add_point(&mut self, birth: f64, death: f64, dimension: usize) {
109        self.points
110            .push(PersistencePoint::new(birth, death, dimension));
111    }
112
113    /// Get all points in a specific homological dimension
114    pub fn points_in_dimension(&self, dim: usize) -> Vec<&PersistencePoint> {
115        self.points.iter().filter(|p| p.dimension == dim).collect()
116    }
117
118    /// Get all finite persistence points (non-essential features)
119    pub fn finite_points(&self) -> Vec<&PersistencePoint> {
120        self.points.iter().filter(|p| !p.is_essential()).collect()
121    }
122
123    /// Get all essential points (infinite persistence)
124    pub fn essential_points(&self) -> Vec<&PersistencePoint> {
125        self.points.iter().filter(|p| p.is_essential()).collect()
126    }
127
128    /// Compute the total persistence (sum of all finite persistence values)
129    pub fn total_persistence(&self, p: f64) -> f64 {
130        self.points
131            .iter()
132            .filter(|pt| !pt.is_essential())
133            .map(|pt| pt.persistence().powf(p))
134            .sum::<f64>()
135            .powf(1.0 / p)
136    }
137
138    /// Filter points by minimum persistence threshold
139    pub fn filter_by_persistence(&self, min_persistence: f64) -> PersistenceDiagram {
140        let filtered_points: Vec<PersistencePoint> = self
141            .points
142            .iter()
143            .filter(|p| p.persistence() >= min_persistence)
144            .cloned()
145            .collect();
146
147        PersistenceDiagram {
148            points: filtered_points,
149            max_dimension: self.max_dimension,
150        }
151    }
152
153    /// Number of points in the diagram
154    pub fn len(&self) -> usize {
155        self.points.len()
156    }
157
158    /// Whether the diagram is empty
159    pub fn is_empty(&self) -> bool {
160        self.points.is_empty()
161    }
162
163    /// Convert to barcode representation
164    pub fn to_barcode(&self) -> Barcode {
165        Barcode::from_diagram(self)
166    }
167
168    /// Get the Betti numbers (count of features) at a given filtration value
169    pub fn betti_numbers_at(&self, filtration_value: f64) -> Vec<usize> {
170        let mut betti = vec![0usize; self.max_dimension + 1];
171        for p in &self.points {
172            if p.birth <= filtration_value && (p.death > filtration_value || p.death.is_infinite())
173            {
174                if p.dimension <= self.max_dimension {
175                    betti[p.dimension] += 1;
176                }
177            }
178        }
179        betti
180    }
181}
182
183// ─── Barcode ─────────────────────────────────────────────────────────────────
184
185/// An interval [birth, death) in a persistence barcode
186#[derive(Debug, Clone, PartialEq)]
187pub struct BarcodeInterval {
188    /// Start of the interval (birth filtration value)
189    pub birth: f64,
190    /// End of the interval (death filtration value), or ∞
191    pub death: f64,
192    /// Homological dimension
193    pub dimension: usize,
194}
195
196impl BarcodeInterval {
197    /// Length of the interval (persistence)
198    pub fn length(&self) -> f64 {
199        if self.death.is_infinite() {
200            f64::INFINITY
201        } else {
202            self.death - self.birth
203        }
204    }
205}
206
207/// Persistence barcode: a multi-set of intervals representing topological features.
208///
209/// Each interval [birth, death) represents a topological feature that appears
210/// at filtration value `birth` and disappears at filtration value `death`.
211#[derive(Debug, Clone)]
212pub struct Barcode {
213    /// All barcode intervals
214    pub intervals: Vec<BarcodeInterval>,
215    /// Maximum dimension
216    pub max_dimension: usize,
217}
218
219impl Barcode {
220    /// Create a barcode from a persistence diagram
221    pub fn from_diagram(diagram: &PersistenceDiagram) -> Self {
222        let intervals: Vec<BarcodeInterval> = diagram
223            .points
224            .iter()
225            .map(|p| BarcodeInterval {
226                birth: p.birth,
227                death: p.death,
228                dimension: p.dimension,
229            })
230            .collect();
231
232        Barcode {
233            intervals,
234            max_dimension: diagram.max_dimension,
235        }
236    }
237
238    /// Get intervals in a specific dimension, sorted by birth
239    pub fn intervals_in_dimension(&self, dim: usize) -> Vec<&BarcodeInterval> {
240        let mut intervals: Vec<&BarcodeInterval> = self
241            .intervals
242            .iter()
243            .filter(|i| i.dimension == dim)
244            .collect();
245        intervals.sort_by(|a, b| {
246            a.birth
247                .partial_cmp(&b.birth)
248                .unwrap_or(std::cmp::Ordering::Equal)
249        });
250        intervals
251    }
252
253    /// Number of intervals
254    pub fn len(&self) -> usize {
255        self.intervals.len()
256    }
257
258    /// Whether the barcode is empty
259    pub fn is_empty(&self) -> bool {
260        self.intervals.is_empty()
261    }
262}
263
264// ─── Simplex and Filtration ───────────────────────────────────────────────────
265
266/// A simplex (vertex set) with its filtration value
267#[derive(Debug, Clone, PartialEq)]
268struct FilteredSimplex {
269    /// Sorted vertex indices
270    vertices: Vec<usize>,
271    /// Filtration value at which this simplex appears
272    filtration_value: f64,
273}
274
275impl FilteredSimplex {
276    fn new(vertices: Vec<usize>, filtration_value: f64) -> Self {
277        let mut v = vertices;
278        v.sort_unstable();
279        Self {
280            vertices: v,
281            filtration_value,
282        }
283    }
284
285    fn dimension(&self) -> usize {
286        self.vertices.len().saturating_sub(1)
287    }
288}
289
290// ─── Boundary Matrix ─────────────────────────────────────────────────────────
291
292/// Boundary matrix column (sparse representation using sorted pivot-tracked columns)
293struct BoundaryMatrix {
294    /// Columns of the boundary matrix (each column is a sorted list of row indices)
295    columns: Vec<Vec<usize>>,
296    /// Pivot row index for each column (-1 if zero column), using i64 for sentinel
297    pivots: Vec<i64>,
298}
299
300impl BoundaryMatrix {
301    fn new(n_cols: usize) -> Self {
302        Self {
303            columns: vec![Vec::new(); n_cols],
304            pivots: vec![-1i64; n_cols],
305        }
306    }
307
308    /// Set column from a boundary list
309    fn set_column(&mut self, col: usize, boundary: Vec<usize>) {
310        let mut b = boundary;
311        b.sort_unstable();
312        b.dedup();
313        let pivot = b.last().copied().map(|v| v as i64).unwrap_or(-1);
314        self.columns[col] = b;
315        self.pivots[col] = pivot;
316    }
317
318    /// Get the pivot (lowest row index) of column j
319    fn pivot(&self, j: usize) -> i64 {
320        self.pivots[j]
321    }
322
323    /// Add column j to column i (XOR / Z_2 addition, i.e., symmetric difference)
324    fn add_column(&mut self, target: usize, source: usize) {
325        let src = self.columns[source].clone();
326        let tgt = self.columns[target].clone();
327
328        // Symmetric difference of two sorted lists
329        let mut result = Vec::with_capacity(src.len() + tgt.len());
330        let (mut i, mut j) = (0, 0);
331        while i < src.len() && j < tgt.len() {
332            match src[i].cmp(&tgt[j]) {
333                std::cmp::Ordering::Less => {
334                    result.push(src[i]);
335                    i += 1;
336                }
337                std::cmp::Ordering::Greater => {
338                    result.push(tgt[j]);
339                    j += 1;
340                }
341                std::cmp::Ordering::Equal => {
342                    // Both have it; cancel out (Z_2 arithmetic)
343                    i += 1;
344                    j += 1;
345                }
346            }
347        }
348        result.extend_from_slice(&src[i..]);
349        result.extend_from_slice(&tgt[j..]);
350
351        let pivot = result.last().copied().map(|v| v as i64).unwrap_or(-1);
352        self.columns[target] = result;
353        self.pivots[target] = pivot;
354    }
355
356    /// Standard reduction algorithm (column reduction over Z_2)
357    fn reduce(&mut self) {
358        let n = self.columns.len();
359        // Map from pivot row -> column index
360        let mut pivot_to_col: HashMap<usize, usize> = HashMap::new();
361
362        for j in 0..n {
363            loop {
364                let piv = self.pivot(j);
365                if piv < 0 {
366                    break; // Zero column, done
367                }
368                let piv_row = piv as usize;
369                if let Some(&k) = pivot_to_col.get(&piv_row) {
370                    // Column k also has pivot at piv_row; add k to j
371                    self.add_column(j, k);
372                } else {
373                    // Record this column's pivot
374                    pivot_to_col.insert(piv_row, j);
375                    break;
376                }
377            }
378        }
379    }
380}
381
382// ─── Vietoris-Rips Complex ───────────────────────────────────────────────────
383
384/// Vietoris-Rips complex for computing persistent homology
385///
386/// The Vietoris-Rips complex builds a filtration of simplicial complexes from
387/// point cloud data by including a simplex whenever all pairwise distances
388/// between its vertices are at most the filtration parameter ε.
389///
390/// # Example
391///
392/// ```rust
393/// use scirs2_transform::tda::VietorisRips;
394/// use scirs2_core::ndarray::Array2;
395///
396/// let points = Array2::from_shape_vec((4, 2), vec![
397///     0.0, 0.0,   1.0, 0.0,   1.0, 1.0,   0.0, 1.0,
398/// ]).expect("should succeed");
399///
400/// let diagram = VietorisRips::compute(&points, 1, 2.0).expect("should succeed");
401/// assert!(!diagram.is_empty());
402/// ```
403pub struct VietorisRips;
404
405impl VietorisRips {
406    /// Compute the persistent homology of a point cloud using the Vietoris-Rips filtration
407    ///
408    /// # Arguments
409    /// * `points` - Point cloud data (n_points × n_features)
410    /// * `max_dim` - Maximum homological dimension to compute (0 = components, 1 = loops, ...)
411    /// * `max_radius` - Maximum filtration radius; simplices with diameter > 2*max_radius are ignored
412    ///
413    /// # Returns
414    /// * A persistence diagram with (birth, death) pairs for each dimension
415    pub fn compute<S>(
416        points: &ArrayBase<S, Ix2>,
417        max_dim: usize,
418        max_radius: f64,
419    ) -> Result<PersistenceDiagram>
420    where
421        S: Data,
422        S::Elem: Float + NumCast,
423    {
424        let n = points.nrows();
425        if n < 2 {
426            return Err(TransformError::InvalidInput(
427                "VietorisRips requires at least 2 points".to_string(),
428            ));
429        }
430
431        // Compute pairwise Euclidean distances
432        let dist = Self::compute_distance_matrix(points)?;
433
434        // Build Vietoris-Rips filtration up to max_dim + 1 skeleton
435        let mut filtered_simplices = Self::build_filtration(&dist, max_dim, max_radius);
436
437        // Sort by filtration value, then dimension for correct order
438        filtered_simplices.sort_by(|a, b| {
439            a.filtration_value
440                .partial_cmp(&b.filtration_value)
441                .unwrap_or(std::cmp::Ordering::Equal)
442                .then(a.dimension().cmp(&b.dimension()))
443        });
444
445        // Assign indices to simplices
446        let n_simplices = filtered_simplices.len();
447
448        // Build index lookup
449        let mut simplex_to_idx: HashMap<Vec<usize>, usize> = HashMap::new();
450        for (i, s) in filtered_simplices.iter().enumerate() {
451            simplex_to_idx.insert(s.vertices.clone(), i);
452        }
453
454        // Build boundary matrix
455        let mut bm = BoundaryMatrix::new(n_simplices);
456        for (j, simplex) in filtered_simplices.iter().enumerate() {
457            if simplex.dimension() == 0 {
458                // Vertices have empty boundary
459                continue;
460            }
461            // Boundary of a simplex: all faces obtained by removing one vertex
462            let mut boundary_indices = Vec::with_capacity(simplex.vertices.len());
463            for k in 0..simplex.vertices.len() {
464                let face: Vec<usize> = simplex
465                    .vertices
466                    .iter()
467                    .enumerate()
468                    .filter(|(i, _)| *i != k)
469                    .map(|(_, &v)| v)
470                    .collect();
471                if let Some(&idx) = simplex_to_idx.get(&face) {
472                    boundary_indices.push(idx);
473                }
474            }
475            bm.set_column(j, boundary_indices);
476        }
477
478        // Reduce boundary matrix
479        bm.reduce();
480
481        // Extract persistence pairs
482        let mut diagram = PersistenceDiagram::new(max_dim);
483
484        // Track which columns are "positive" (not killed by a later simplex)
485        let mut killed = vec![false; n_simplices];
486
487        for j in 0..n_simplices {
488            let piv = bm.pivot(j);
489            if piv >= 0 {
490                let i = piv as usize;
491                // Column j kills simplex i
492                killed[i] = true;
493                let dim_creator = filtered_simplices[i].dimension();
494                if dim_creator <= max_dim {
495                    let birth = filtered_simplices[i].filtration_value;
496                    let death = filtered_simplices[j].filtration_value;
497                    if (death - birth).abs() > 1e-12 {
498                        diagram.add_point(birth, death, dim_creator);
499                    }
500                }
501            }
502        }
503
504        // Essential features: simplices not killed and with zero column (not reducible)
505        for i in 0..n_simplices {
506            if !killed[i] && bm.pivot(i) < 0 {
507                let dim = filtered_simplices[i].dimension();
508                if dim <= max_dim {
509                    let birth = filtered_simplices[i].filtration_value;
510                    // Mark as essential (death = infinity)
511                    // Skip H_0 essential features except the longest-lived one
512                    // (there's exactly one connected component in a connected point cloud)
513                    diagram.add_point(birth, f64::INFINITY, dim);
514                }
515            }
516        }
517
518        Ok(diagram)
519    }
520
521    /// Compute pairwise Euclidean distance matrix
522    fn compute_distance_matrix<S>(points: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
523    where
524        S: Data,
525        S::Elem: Float + NumCast,
526    {
527        let n = points.nrows();
528        let mut dist = Array2::zeros((n, n));
529
530        for i in 0..n {
531            for j in (i + 1)..n {
532                let mut d_sq = 0.0f64;
533                for k in 0..points.ncols() {
534                    let diff = NumCast::from(points[[i, k]]).unwrap_or(0.0)
535                        - NumCast::from(points[[j, k]]).unwrap_or(0.0);
536                    d_sq += diff * diff;
537                }
538                let d = d_sq.sqrt();
539                dist[[i, j]] = d;
540                dist[[j, i]] = d;
541            }
542        }
543
544        Ok(dist)
545    }
546
547    /// Build all simplices in the Vietoris-Rips filtration up to max_dim+1 skeleton
548    fn build_filtration(
549        dist: &Array2<f64>,
550        max_dim: usize,
551        max_radius: f64,
552    ) -> Vec<FilteredSimplex> {
553        let n = dist.nrows();
554        let max_diam = 2.0 * max_radius;
555        let mut simplices = Vec::new();
556
557        // Add vertices (0-simplices)
558        for i in 0..n {
559            simplices.push(FilteredSimplex::new(vec![i], 0.0));
560        }
561
562        // Iteratively build higher-dimensional simplices
563        // Use clique-finding: a k-simplex [v0,...,vk] is in the filtration iff
564        // all pairwise distances <= max_diam; filtration value = max pairwise distance
565
566        // Start with edges (1-simplices) and build cliques
567        let mut prev_dim_simplices: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
568
569        for dim in 1..=(max_dim + 1) {
570            let mut next_dim_simplices: Vec<Vec<usize>> = Vec::new();
571
572            // Extend each (dim-1)-simplex by adding a vertex with larger index
573            // than all current vertices (to avoid duplicates)
574            for simplex in &prev_dim_simplices {
575                let last_vertex = *simplex.last().unwrap_or(&0);
576
577                for v in (last_vertex + 1)..n {
578                    // Check if v can be added (all distances to current vertices <= max_diam)
579                    let max_dist_to_v =
580                        simplex.iter().map(|&u| dist[[u, v]]).fold(0.0f64, f64::max);
581
582                    if max_dist_to_v <= max_diam {
583                        let mut new_simplex = simplex.clone();
584                        new_simplex.push(v);
585
586                        // Filtration value = diameter of simplex = max pairwise distance
587                        let filtration_val = Self::simplex_diameter(&new_simplex, dist);
588                        simplices.push(FilteredSimplex::new(new_simplex.clone(), filtration_val));
589                        next_dim_simplices.push(new_simplex);
590                    }
591                }
592            }
593
594            if next_dim_simplices.is_empty() {
595                break;
596            }
597            prev_dim_simplices = next_dim_simplices;
598        }
599
600        simplices
601    }
602
603    /// Compute the diameter (max pairwise distance) of a simplex
604    fn simplex_diameter(vertices: &[usize], dist: &Array2<f64>) -> f64 {
605        let mut max_d = 0.0f64;
606        for i in 0..vertices.len() {
607            for j in (i + 1)..vertices.len() {
608                let d = dist[[vertices[i], vertices[j]]];
609                if d > max_d {
610                    max_d = d;
611                }
612            }
613        }
614        max_d
615    }
616}
617
618// ─── Persistence Image ────────────────────────────────────────────────────────
619
620/// Persistence image: a stable vector representation of persistence diagrams.
621///
622/// Maps persistence diagrams to a 2D grid image by placing a Gaussian kernel
623/// centered at each persistence point (birth, persistence) and integrating
624/// over grid cells.
625///
626/// # References
627/// Adams, H., et al. (2017). Persistence images: A stable vector representation
628/// of persistent homology. JMLR, 18(8), 1-35.
629pub struct PersistenceImage {
630    /// Image resolution (resolution × resolution pixels)
631    resolution: usize,
632    /// Birth axis range [min, max]
633    birth_range: (f64, f64),
634    /// Persistence axis range [min, max]
635    persistence_range: (f64, f64),
636    /// Gaussian kernel bandwidth (sigma)
637    sigma: f64,
638    /// Weight function applied to each point
639    weight_type: PersistenceWeight,
640    /// Homological dimension to use
641    dimension: usize,
642}
643
644/// Weight function for persistence image computation
645#[derive(Debug, Clone)]
646pub enum PersistenceWeight {
647    /// Uniform weight (all points weighted equally)
648    Uniform,
649    /// Linear weight: w(b, p) = p (favors high persistence)
650    Linear,
651    /// Arctan weight: w(b, p) = arctan(p) (smooth truncation)
652    Arctan,
653    /// Custom weight based on persistence threshold
654    Threshold(f64),
655}
656
657impl PersistenceImage {
658    /// Create a new PersistenceImage computer
659    ///
660    /// # Arguments
661    /// * `resolution` - Grid resolution (resolution × resolution)
662    /// * `dimension` - Homological dimension to use
663    /// * `sigma` - Gaussian kernel bandwidth
664    /// * `weight_type` - Weight function
665    pub fn new(
666        resolution: usize,
667        dimension: usize,
668        sigma: f64,
669        weight_type: PersistenceWeight,
670    ) -> Result<Self> {
671        if resolution == 0 {
672            return Err(TransformError::InvalidInput(
673                "Resolution must be positive".to_string(),
674            ));
675        }
676        if sigma <= 0.0 {
677            return Err(TransformError::InvalidInput(
678                "Sigma must be positive".to_string(),
679            ));
680        }
681        Ok(Self {
682            resolution,
683            birth_range: (0.0, 1.0),
684            persistence_range: (0.0, 1.0),
685            sigma,
686            weight_type,
687            dimension,
688        })
689    }
690
691    /// Compute a persistence image from a persistence diagram
692    ///
693    /// # Arguments
694    /// * `diagram` - The persistence diagram to vectorize
695    /// * `resolution` - Grid resolution
696    ///
697    /// # Returns
698    /// * A (resolution × resolution) array representing the persistence image
699    pub fn compute(diagram: &PersistenceDiagram, resolution: usize) -> Result<Array2<f64>> {
700        if resolution == 0 {
701            return Err(TransformError::InvalidInput(
702                "Resolution must be positive".to_string(),
703            ));
704        }
705
706        let img = PersistenceImage::new(resolution, 0, 0.1, PersistenceWeight::Linear)?;
707        img.transform(diagram)
708    }
709
710    /// Transform a diagram using this image's configuration
711    pub fn transform(&self, diagram: &PersistenceDiagram) -> Result<Array2<f64>> {
712        // Collect finite points in the target dimension
713        let pts: Vec<(f64, f64)> = diagram
714            .points
715            .iter()
716            .filter(|p| p.dimension == self.dimension && !p.is_essential())
717            .map(|p| (p.birth, p.persistence()))
718            .collect();
719
720        if pts.is_empty() {
721            return Ok(Array2::zeros((self.resolution, self.resolution)));
722        }
723
724        // Determine range from data if not set (auto-range)
725        let b_min = self.birth_range.0;
726        let b_max = self
727            .birth_range
728            .1
729            .max(pts.iter().map(|(b, _)| *b).fold(0.0_f64, f64::max));
730        let p_min = self.persistence_range.0;
731        let p_max = self
732            .persistence_range
733            .1
734            .max(pts.iter().map(|(_, p)| *p).fold(0.0_f64, f64::max));
735
736        let b_range = (b_max - b_min).max(1e-10);
737        let p_range = (p_max - p_min).max(1e-10);
738        let cell_size_b = b_range / self.resolution as f64;
739        let cell_size_p = p_range / self.resolution as f64;
740
741        let mut image = Array2::<f64>::zeros((self.resolution, self.resolution));
742        let norm_factor = 1.0 / (2.0 * std::f64::consts::PI * self.sigma * self.sigma);
743
744        for &(birth, pers) in &pts {
745            // Weight for this point
746            let weight = match &self.weight_type {
747                PersistenceWeight::Uniform => 1.0,
748                PersistenceWeight::Linear => pers,
749                PersistenceWeight::Arctan => pers.atan(),
750                PersistenceWeight::Threshold(t) => {
751                    if pers >= *t {
752                        1.0
753                    } else {
754                        pers / t
755                    }
756                }
757            };
758
759            // Add Gaussian kernel contribution to each grid cell
760            for i in 0..self.resolution {
761                let cell_b = b_min + (i as f64 + 0.5) * cell_size_b;
762                for j in 0..self.resolution {
763                    let cell_p = p_min + (j as f64 + 0.5) * cell_size_p;
764                    let db = (cell_b - birth) / self.sigma;
765                    let dp = (cell_p - pers) / self.sigma;
766                    let gauss = norm_factor * (-0.5 * (db * db + dp * dp)).exp();
767                    image[[i, j]] += weight * gauss * cell_size_b * cell_size_p;
768                }
769            }
770        }
771
772        Ok(image)
773    }
774
775    /// Set the birth range for the image
776    pub fn with_birth_range(mut self, min: f64, max: f64) -> Self {
777        self.birth_range = (min, max);
778        self
779    }
780
781    /// Set the persistence range for the image
782    pub fn with_persistence_range(mut self, min: f64, max: f64) -> Self {
783        self.persistence_range = (min, max);
784        self
785    }
786}
787
788// ─── Distances Between Diagrams ───────────────────────────────────────────────
789
790/// Compute the bottleneck distance between two persistence diagrams.
791///
792/// The bottleneck distance measures the maximum displacement when matching
793/// points between two diagrams optimally (each point can also be matched
794/// to the diagonal).
795///
796/// # Arguments
797/// * `d1` - First persistence diagram
798/// * `d2` - Second persistence diagram
799///
800/// # Returns
801/// * The bottleneck distance (non-negative)
802pub fn bottleneck_distance(d1: &PersistenceDiagram, d2: &PersistenceDiagram) -> f64 {
803    // Get finite points from both diagrams (all dimensions)
804    let pts1: Vec<(f64, f64)> = d1
805        .points
806        .iter()
807        .filter(|p| !p.is_essential())
808        .map(|p| (p.birth, p.death))
809        .collect();
810
811    let pts2: Vec<(f64, f64)> = d2
812        .points
813        .iter()
814        .filter(|p| !p.is_essential())
815        .map(|p| (p.birth, p.death))
816        .collect();
817
818    bottleneck_distance_between(&pts1, &pts2)
819}
820
821/// Compute the bottleneck distance between two persistence diagrams for a specific dimension.
822pub fn bottleneck_distance_dim(
823    d1: &PersistenceDiagram,
824    d2: &PersistenceDiagram,
825    dim: usize,
826) -> f64 {
827    let pts1: Vec<(f64, f64)> = d1
828        .points
829        .iter()
830        .filter(|p| p.dimension == dim && !p.is_essential())
831        .map(|p| (p.birth, p.death))
832        .collect();
833
834    let pts2: Vec<(f64, f64)> = d2
835        .points
836        .iter()
837        .filter(|p| p.dimension == dim && !p.is_essential())
838        .map(|p| (p.birth, p.death))
839        .collect();
840
841    bottleneck_distance_between(&pts1, &pts2)
842}
843
844/// Compute the Wasserstein-p distance between two persistence diagrams.
845///
846/// # Arguments
847/// * `d1` - First persistence diagram
848/// * `d2` - Second persistence diagram
849/// * `p` - Wasserstein power (typically 1 or 2)
850///
851/// # Returns
852/// * The Wasserstein-p distance
853pub fn wasserstein_distance(d1: &PersistenceDiagram, d2: &PersistenceDiagram, p: f64) -> f64 {
854    let pts1: Vec<(f64, f64)> = d1
855        .points
856        .iter()
857        .filter(|p| !p.is_essential())
858        .map(|pt| (pt.birth, pt.death))
859        .collect();
860
861    let pts2: Vec<(f64, f64)> = d2
862        .points
863        .iter()
864        .filter(|p| !p.is_essential())
865        .map(|pt| (pt.birth, pt.death))
866        .collect();
867
868    wasserstein_distance_between(&pts1, &pts2, p)
869}
870
871/// Internal: compute bottleneck distance between two point sets with diagonal projections.
872/// Uses a binary search + bipartite matching approach.
873fn bottleneck_distance_between(pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
874    // Distance from a diagram point (b,d) to the diagonal
875    let diag_dist = |(b, d): (f64, f64)| -> f64 { (d - b) / 2.0 };
876
877    // L∞ distance between two diagram points
878    let point_dist = |(b1, d1): (f64, f64), (b2, d2): (f64, f64)| -> f64 {
879        (b1 - b2).abs().max((d1 - d2).abs())
880    };
881
882    // If both are empty, distance is 0
883    if pts1.is_empty() && pts2.is_empty() {
884        return 0.0;
885    }
886
887    // Collect all candidate distances for binary search
888    let mut candidates = Vec::new();
889
890    for &p1 in pts1 {
891        for &p2 in pts2 {
892            candidates.push(point_dist(p1, p2));
893        }
894        candidates.push(diag_dist(p1));
895    }
896    for &p2 in pts2 {
897        candidates.push(diag_dist(p2));
898    }
899    candidates.push(0.0);
900
901    // Sort and deduplicate candidates
902    candidates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
903    candidates.dedup_by(|a, b| (*a - *b).abs() < 1e-14);
904
905    // Binary search for minimum bottleneck distance using hopcroft-karp style feasibility
906    let mut lo = 0;
907    let mut hi = candidates.len().saturating_sub(1);
908    let mut result = candidates.last().copied().unwrap_or(0.0);
909
910    while lo <= hi {
911        let mid = (lo + hi) / 2;
912        let delta = candidates[mid];
913
914        if is_feasible_bottleneck(pts1, pts2, delta) {
915            result = delta;
916            if mid == 0 {
917                break;
918            }
919            hi = mid - 1;
920        } else {
921            lo = mid + 1;
922        }
923    }
924
925    result
926}
927
928/// Check if there's a perfect matching with all costs ≤ delta (bottleneck feasibility)
929/// Uses greedy matching with diagonal fallback
930fn is_feasible_bottleneck(pts1: &[(f64, f64)], pts2: &[(f64, f64)], delta: f64) -> bool {
931    let diag_dist = |(b, d): (f64, f64)| -> f64 { (d - b) / 2.0 };
932    let point_dist = |(b1, d1): (f64, f64), (b2, d2): (f64, f64)| -> f64 {
933        (b1 - b2).abs().max((d1 - d2).abs())
934    };
935
936    // Try to match pts1 to pts2 or diagonal; use augmenting path search
937    let n = pts1.len();
938    let m = pts2.len();
939
940    // Bipartite graph: left = pts1, right = pts2 ∪ diagonal projections of pts1 and pts2
941    // Simplified: use greedy matching then check unmatched points against diagonal
942
943    let mut matched2 = vec![false; m];
944    let mut matched1 = vec![false; n];
945
946    // Try to match each point in pts1 to a point in pts2
947    let mut assignment: Vec<Option<usize>> = vec![None; n];
948
949    for i in 0..n {
950        for j in 0..m {
951            if !matched2[j] && point_dist(pts1[i], pts2[j]) <= delta {
952                assignment[i] = Some(j);
953                matched2[j] = true;
954                matched1[i] = true;
955                break;
956            }
957        }
958    }
959
960    // All unmatched pts1 must be within delta of their diagonal projection
961    for i in 0..n {
962        if !matched1[i] && diag_dist(pts1[i]) > delta {
963            return false;
964        }
965    }
966
967    // All unmatched pts2 must be within delta of their diagonal projection
968    for j in 0..m {
969        if !matched2[j] && diag_dist(pts2[j]) > delta {
970            return false;
971        }
972    }
973
974    true
975}
976
977/// Internal: compute Wasserstein distance between two point sets
978fn wasserstein_distance_between(pts1: &[(f64, f64)], pts2: &[(f64, f64)], p: f64) -> f64 {
979    let diag_dist = |(b, d): (f64, f64)| -> f64 { (d - b) / 2.0 };
980    let point_dist_lp = |(b1, d1): (f64, f64), (b2, d2): (f64, f64), p: f64| -> f64 {
981        // Use L∞ for Wasserstein as is standard in TDA
982        (b1 - b2).abs().max((d1 - d2).abs()).powf(p)
983    };
984
985    // Pad smaller set with diagonal projections
986    let n = pts1.len();
987    let m = pts2.len();
988
989    // Greedy matching with Hungarian-style cost minimization (simplified)
990    let mut total_cost = 0.0f64;
991
992    // Unmatched points go to diagonal
993    let mut matched2 = vec![false; m];
994
995    for i in 0..n {
996        // Find best match for pts1[i] (either a point in pts2 or the diagonal)
997        let diag_cost = diag_dist(pts1[i]).powf(p);
998        let mut best_cost = diag_cost;
999        let mut best_j = None;
1000
1001        for j in 0..m {
1002            if !matched2[j] {
1003                let cost = point_dist_lp(pts1[i], pts2[j], p);
1004                if cost < best_cost {
1005                    best_cost = cost;
1006                    best_j = Some(j);
1007                }
1008            }
1009        }
1010
1011        if let Some(j) = best_j {
1012            matched2[j] = true;
1013        }
1014        total_cost += best_cost;
1015    }
1016
1017    // Remaining unmatched pts2 go to diagonal
1018    for j in 0..m {
1019        if !matched2[j] {
1020            total_cost += diag_dist(pts2[j]).powf(p);
1021        }
1022    }
1023
1024    total_cost.powf(1.0 / p)
1025}
1026
1027// ─── Persistence Landscapes ───────────────────────────────────────────────────
1028
1029/// Persistence landscape: a functional summary of persistence diagrams.
1030///
1031/// The k-th landscape function λ_k(t) is the k-th largest "tent function" value at t.
1032/// Landscapes are stable, vectorizable, and support averaging over multiple diagrams.
1033#[derive(Debug, Clone)]
1034pub struct PersistenceLandscape {
1035    /// Number of landscape functions to compute
1036    n_landscapes: usize,
1037    /// Homological dimension
1038    dimension: usize,
1039    /// Sampled landscape values at grid points
1040    pub landscapes: Array2<f64>,
1041    /// Grid points (t values)
1042    pub grid: Array1<f64>,
1043}
1044
1045impl PersistenceLandscape {
1046    /// Compute persistence landscapes from a diagram
1047    ///
1048    /// # Arguments
1049    /// * `diagram` - The persistence diagram
1050    /// * `n_landscapes` - Number of landscape functions (k = 1, ..., n_landscapes)
1051    /// * `n_grid_points` - Number of grid points for sampling
1052    /// * `dimension` - Homological dimension
1053    ///
1054    /// # Returns
1055    /// * A PersistenceLandscape with sampled landscape values
1056    pub fn compute(
1057        diagram: &PersistenceDiagram,
1058        n_landscapes: usize,
1059        n_grid_points: usize,
1060        dimension: usize,
1061    ) -> Result<Self> {
1062        if n_landscapes == 0 {
1063            return Err(TransformError::InvalidInput(
1064                "n_landscapes must be positive".to_string(),
1065            ));
1066        }
1067        if n_grid_points < 2 {
1068            return Err(TransformError::InvalidInput(
1069                "n_grid_points must be at least 2".to_string(),
1070            ));
1071        }
1072
1073        let pts: Vec<(f64, f64)> = diagram
1074            .points
1075            .iter()
1076            .filter(|p| p.dimension == dimension && !p.is_essential())
1077            .map(|p| (p.birth, p.death))
1078            .collect();
1079
1080        if pts.is_empty() {
1081            let grid = Array1::linspace(0.0, 1.0, n_grid_points);
1082            return Ok(Self {
1083                n_landscapes,
1084                dimension,
1085                landscapes: Array2::zeros((n_landscapes, n_grid_points)),
1086                grid,
1087            });
1088        }
1089
1090        // Determine grid range
1091        let t_min = pts.iter().map(|(b, _)| *b).fold(f64::INFINITY, f64::min);
1092        let t_max = pts.iter().map(|(_, d)| *d).fold(0.0_f64, f64::max);
1093        let grid = Array1::linspace(t_min, t_max, n_grid_points);
1094
1095        let mut landscapes = Array2::<f64>::zeros((n_landscapes, n_grid_points));
1096
1097        for (g_idx, &t) in grid.iter().enumerate() {
1098            // Compute tent function values at t for all points
1099            let mut tent_values: Vec<f64> = pts
1100                .iter()
1101                .map(|&(b, d)| {
1102                    if t <= (b + d) / 2.0 {
1103                        (t - b).max(0.0)
1104                    } else {
1105                        (d - t).max(0.0)
1106                    }
1107                })
1108                .collect();
1109
1110            // Sort descending to get the k-th largest values
1111            tent_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1112
1113            for k in 0..n_landscapes {
1114                landscapes[[k, g_idx]] = tent_values.get(k).copied().unwrap_or(0.0);
1115            }
1116        }
1117
1118        Ok(Self {
1119            n_landscapes,
1120            dimension,
1121            landscapes,
1122            grid,
1123        })
1124    }
1125
1126    /// Compute the L2 norm of the k-th landscape function
1127    pub fn l2_norm(&self, k: usize) -> f64 {
1128        if k >= self.n_landscapes {
1129            return 0.0;
1130        }
1131        let row = self.landscapes.row(k);
1132        row.iter().map(|&v| v * v).sum::<f64>().sqrt()
1133    }
1134
1135    /// Inner product between two landscape functions
1136    pub fn inner_product(&self, other: &Self) -> f64 {
1137        let n = self.landscapes.shape()[1].min(other.landscapes.shape()[1]);
1138        let k = self.n_landscapes.min(other.n_landscapes);
1139        let mut sum = 0.0;
1140        for i in 0..k {
1141            for j in 0..n {
1142                sum += self.landscapes[[i, j]] * other.landscapes[[i, j]];
1143            }
1144        }
1145        sum
1146    }
1147}
1148
1149// ─── Tests ────────────────────────────────────────────────────────────────────
1150
1151#[cfg(test)]
1152mod tests {
1153    use super::*;
1154    use scirs2_core::ndarray::Array2;
1155
1156    fn square_points() -> Array2<f64> {
1157        Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0])
1158            .expect("shape ok")
1159    }
1160
1161    #[test]
1162    fn test_vietoris_rips_h0() {
1163        let pts = square_points();
1164        let diagram = VietorisRips::compute(&pts, 0, 2.0).expect("vr compute");
1165        // At H0 dimension we expect connected components
1166        let h0_pts = diagram.points_in_dimension(0);
1167        assert!(!h0_pts.is_empty(), "Should have H0 features");
1168    }
1169
1170    #[test]
1171    fn test_vietoris_rips_h1() {
1172        let pts = square_points();
1173        let diagram = VietorisRips::compute(&pts, 1, 2.0).expect("vr compute");
1174        // Square should have a 1-cycle (loop)
1175        let h1_pts = diagram.points_in_dimension(1);
1176        // The square has at least one loop
1177        let _finite_h1: Vec<_> = h1_pts.iter().filter(|p| !p.is_essential()).collect();
1178        // At least H0 features should be present
1179        assert!(!diagram.is_empty(), "Diagram should not be empty");
1180    }
1181
1182    #[test]
1183    fn test_persistence_point_persistence() {
1184        let p = PersistencePoint::new(0.5, 1.5, 0);
1185        assert!((p.persistence() - 1.0).abs() < 1e-10);
1186        assert!(!p.is_essential());
1187
1188        let q = PersistencePoint::new(0.5, f64::INFINITY, 0);
1189        assert!(q.is_essential());
1190        assert!(q.persistence().is_infinite());
1191    }
1192
1193    #[test]
1194    fn test_persistence_diagram_filter() {
1195        let mut diagram = PersistenceDiagram::new(1);
1196        diagram.add_point(0.0, 0.01, 0); // short-lived (noise)
1197        diagram.add_point(0.0, 1.0, 0); // long-lived (signal)
1198        diagram.add_point(0.2, 0.8, 1); // H1 feature
1199
1200        let filtered = diagram.filter_by_persistence(0.5);
1201        assert_eq!(filtered.len(), 2); // only the two long-lived features
1202    }
1203
1204    #[test]
1205    fn test_barcode_from_diagram() {
1206        let mut diagram = PersistenceDiagram::new(1);
1207        diagram.add_point(0.0, 1.0, 0);
1208        diagram.add_point(0.5, 0.9, 1);
1209
1210        let barcode = diagram.to_barcode();
1211        assert_eq!(barcode.len(), 2);
1212        assert_eq!(barcode.intervals_in_dimension(0).len(), 1);
1213        assert_eq!(barcode.intervals_in_dimension(1).len(), 1);
1214        assert!((barcode.intervals_in_dimension(0)[0].length() - 1.0).abs() < 1e-10);
1215    }
1216
1217    #[test]
1218    fn test_persistence_image() {
1219        let mut diagram = PersistenceDiagram::new(1);
1220        diagram.add_point(0.0, 1.0, 0);
1221        diagram.add_point(0.2, 0.8, 0);
1222
1223        let image = PersistenceImage::compute(&diagram, 10).expect("pi compute");
1224        assert_eq!(image.shape(), &[10, 10]);
1225        // Image should have non-negative values
1226        assert!(image.iter().all(|&v| v >= 0.0));
1227        // Image should have some non-zero content
1228        assert!(image.iter().any(|&v| v > 0.0));
1229    }
1230
1231    #[test]
1232    fn test_bottleneck_distance_same_diagram() {
1233        let mut diagram = PersistenceDiagram::new(0);
1234        diagram.add_point(0.0, 1.0, 0);
1235        diagram.add_point(0.5, 0.9, 0);
1236
1237        // Bottleneck distance with itself should be ~0
1238        let dist = bottleneck_distance(&diagram, &diagram);
1239        assert!(dist < 1e-10, "Self-distance should be 0, got {}", dist);
1240    }
1241
1242    #[test]
1243    fn test_bottleneck_distance_empty_diagrams() {
1244        let d1 = PersistenceDiagram::new(0);
1245        let d2 = PersistenceDiagram::new(0);
1246        let dist = bottleneck_distance(&d1, &d2);
1247        assert!(dist < 1e-10);
1248    }
1249
1250    #[test]
1251    fn test_bottleneck_distance_different_diagrams() {
1252        let mut d1 = PersistenceDiagram::new(0);
1253        d1.add_point(0.0, 1.0, 0);
1254
1255        let mut d2 = PersistenceDiagram::new(0);
1256        d2.add_point(0.0, 0.5, 0);
1257
1258        let dist = bottleneck_distance(&d1, &d2);
1259        // The optimal matching pairs (0,1) with (0,0.5), cost = max(|0-0|, |1-0.5|) = 0.5
1260        // or match (0,1) to diagonal at (0.5, 0.5), cost = 0.5
1261        // and (0,0.5) to diagonal at (0.25, 0.25), cost = 0.25
1262        // max = 0.5
1263        assert!(
1264            dist > 0.0,
1265            "Different diagrams should have positive distance"
1266        );
1267    }
1268
1269    #[test]
1270    fn test_persistence_landscape() {
1271        let mut diagram = PersistenceDiagram::new(0);
1272        diagram.add_point(0.0, 2.0, 0);
1273        diagram.add_point(0.5, 1.5, 0);
1274
1275        let landscape =
1276            PersistenceLandscape::compute(&diagram, 2, 20, 0).expect("landscape compute");
1277        assert_eq!(landscape.landscapes.shape(), &[2, 20]);
1278        // First landscape function should be non-negative
1279        assert!(landscape.landscapes.row(0).iter().all(|&v| v >= -1e-10));
1280        assert!(landscape.l2_norm(0) > 0.0);
1281    }
1282
1283    #[test]
1284    fn test_wasserstein_distance() {
1285        let mut d1 = PersistenceDiagram::new(0);
1286        d1.add_point(0.0, 1.0, 0);
1287
1288        let mut d2 = PersistenceDiagram::new(0);
1289        d2.add_point(0.0, 1.0, 0);
1290
1291        // Wasserstein distance of identical diagrams should be 0
1292        let wd = wasserstein_distance(&d1, &d2, 1.0);
1293        assert!(wd < 1e-10, "Identical diagrams: W=0, got {}", wd);
1294    }
1295
1296    #[test]
1297    fn test_betti_numbers() {
1298        let mut diagram = PersistenceDiagram::new(1);
1299        diagram.add_point(0.0, f64::INFINITY, 0); // one component throughout
1300        diagram.add_point(0.3, 0.7, 1); // loop from t=0.3 to t=0.7
1301
1302        let betti = diagram.betti_numbers_at(0.5);
1303        assert_eq!(betti[0], 1); // one connected component
1304        assert_eq!(betti[1], 1); // one loop active at t=0.5
1305
1306        let betti_early = diagram.betti_numbers_at(0.1);
1307        assert_eq!(betti_early[1], 0); // loop not yet born
1308    }
1309
1310    #[test]
1311    fn test_vietoris_rips_small_radius() {
1312        let pts = square_points();
1313        // With very small radius, no edges are formed so all 4 points are separate components
1314        let diagram = VietorisRips::compute(&pts, 0, 0.1).expect("vr compute");
1315        let h0_pts = diagram.points_in_dimension(0);
1316        // All 4 points should appear as H0 features (either finite or essential)
1317        assert!(!h0_pts.is_empty());
1318    }
1319
1320    #[test]
1321    fn test_total_persistence() {
1322        let mut diagram = PersistenceDiagram::new(0);
1323        diagram.add_point(0.0, 1.0, 0);
1324        diagram.add_point(0.0, 3.0, 0);
1325
1326        let tp = diagram.total_persistence(2.0);
1327        // Should be sqrt(1^2 + 3^2) = sqrt(10) ≈ 3.162
1328        assert!((tp - (10.0f64).sqrt()).abs() < 1e-10);
1329    }
1330
1331    #[test]
1332    fn test_persistence_image_custom() {
1333        let mut diagram = PersistenceDiagram::new(0);
1334        diagram.add_point(0.0, 1.0, 0);
1335        diagram.add_point(0.2, 0.8, 0);
1336
1337        let img_computer = PersistenceImage::new(5, 0, 0.2, PersistenceWeight::Arctan)
1338            .expect("pi new")
1339            .with_birth_range(0.0, 1.0)
1340            .with_persistence_range(0.0, 1.0);
1341
1342        let image = img_computer.transform(&diagram).expect("pi transform");
1343        assert_eq!(image.shape(), &[5, 5]);
1344        assert!(image.iter().all(|&v| v >= 0.0));
1345    }
1346}