Skip to main content

scirs2_transform/
tda_ext.rs

1//! Extended Topological Data Analysis utilities
2//!
3//! This module provides additional TDA constructs and free functions that
4//! complement the core `tda` module:
5//!
6//! - [`VietorisRipsComplex`]: explicit simplex-list representation with Euler
7//!   characteristic and simplex counting.
8//! - [`compute_persistence`]: distance-matrix-based persistent homology.
9//! - [`persistence_landscape_fn`]: evaluate persistence landscapes over a grid.
10//! - [`persistence_image_fn`]: Gaussian-kernel persistence image from a diagram.
11//! - `wasserstein_distance_p`: p-Wasserstein distance between diagrams.
12//! - `bottleneck_distance_fn`: bottleneck distance as a free function.
13//!
14//! ## References
15//!
16//! - Edelsbrunner & Harer (2010). Computational Topology: An Introduction.
17//! - Adams et al. (2017). Persistence Images: A Stable Vector Representation.
18//! - Munch (2017). A User's Guide to Topological Data Analysis.
19
20use crate::error::{Result, TransformError};
21use crate::tda::{PersistenceDiagram, VietorisRips};
22use scirs2_core::ndarray::Array2;
23
24// ─── VietorisRipsComplex ─────────────────────────────────────────────────────
25
26/// A Vietoris-Rips simplicial complex at a fixed scale `epsilon`.
27///
28/// Unlike [`VietorisRips`] (which computes the full persistent homology
29/// across all scales), this struct stores the explicit list of simplices
30/// formed when all pairwise distances ≤ `epsilon`.
31///
32/// ## Example
33///
34/// ```rust
35/// use scirs2_transform::tda_ext::VietorisRipsComplex;
36///
37/// let pts = vec![
38///     vec![0.0, 0.0],
39///     vec![1.0, 0.0],
40///     vec![1.0, 1.0],
41///     vec![0.0, 1.0],
42/// ];
43/// let vrc = VietorisRipsComplex::new(&pts, 1.5).expect("should succeed");
44/// assert!(vrc.n_simplices(0) == 4); // four 0-simplices (vertices)
45/// assert!(vrc.euler_characteristic() != 0); // non-trivial topology
46/// ```
47#[derive(Debug, Clone)]
48pub struct VietorisRipsComplex {
49    /// Input point cloud
50    pub points: Vec<Vec<f64>>,
51    /// Scale parameter (all edges with length ≤ epsilon are included)
52    pub epsilon: f64,
53    /// All simplices (sorted by dimension then by vertex tuple)
54    pub simplices: Vec<Vec<usize>>,
55}
56
57impl VietorisRipsComplex {
58    /// Construct the Vietoris-Rips complex for the given point cloud at scale `epsilon`.
59    ///
60    /// Only simplices up to dimension 2 (triangles) are computed for tractability.
61    ///
62    /// # Arguments
63    /// * `points`  — slice of point vectors (all of equal length)
64    /// * `epsilon` — maximum edge length to include
65    pub fn new(points: &[Vec<f64>], epsilon: f64) -> Result<Self> {
66        if points.is_empty() {
67            return Ok(Self {
68                points: Vec::new(),
69                epsilon,
70                simplices: Vec::new(),
71            });
72        }
73        if epsilon < 0.0 {
74            return Err(TransformError::InvalidInput(
75                "epsilon must be non-negative".to_string(),
76            ));
77        }
78        let n = points.len();
79        let dim = points[0].len();
80
81        // Pairwise distances
82        let dist = pairwise_distances(points, dim);
83
84        let mut simplices: Vec<Vec<usize>> = Vec::new();
85
86        // 0-simplices: all vertices
87        for i in 0..n {
88            simplices.push(vec![i]);
89        }
90
91        // 1-simplices: all edges with dist <= epsilon
92        for i in 0..n {
93            for j in (i + 1)..n {
94                if dist[i][j] <= epsilon {
95                    simplices.push(vec![i, j]);
96                }
97            }
98        }
99
100        // 2-simplices: triangles (all edges present)
101        for i in 0..n {
102            for j in (i + 1)..n {
103                if dist[i][j] > epsilon {
104                    continue;
105                }
106                for k in (j + 1)..n {
107                    if dist[i][k] <= epsilon && dist[j][k] <= epsilon {
108                        simplices.push(vec![i, j, k]);
109                    }
110                }
111            }
112        }
113
114        // Sort: by dimension first, then lexicographically within each dimension
115        simplices.sort_by(|a, b| a.len().cmp(&b.len()).then_with(|| a.cmp(b)));
116
117        Ok(Self {
118            points: points.to_vec(),
119            epsilon,
120            simplices,
121        })
122    }
123
124    /// Count the number of simplices of a given dimension.
125    ///
126    /// Dimension 0 = vertices, 1 = edges, 2 = triangles.
127    pub fn n_simplices(&self, dim: usize) -> usize {
128        self.simplices.iter().filter(|s| s.len() == dim + 1).count()
129    }
130
131    /// Compute the Euler characteristic χ = Σ_k (-1)^k * C_k,
132    /// where C_k is the number of k-simplices.
133    pub fn euler_characteristic(&self) -> i64 {
134        let mut chi = 0_i64;
135        for simplex in &self.simplices {
136            let k = simplex.len() as i64 - 1;
137            if k % 2 == 0 {
138                chi += 1;
139            } else {
140                chi -= 1;
141            }
142        }
143        chi
144    }
145
146    /// Check whether two vertices are connected by an edge in the complex.
147    pub fn are_connected(&self, u: usize, v: usize) -> bool {
148        let edge = if u < v { vec![u, v] } else { vec![v, u] };
149        self.simplices.contains(&edge)
150    }
151
152    /// List all edges (1-simplices) as pairs (u, v).
153    pub fn edges(&self) -> Vec<(usize, usize)> {
154        self.simplices
155            .iter()
156            .filter(|s| s.len() == 2)
157            .map(|s| (s[0], s[1]))
158            .collect()
159    }
160}
161
162// ─── compute_persistence (distance matrix API) ────────────────────────────────
163
164/// Compute persistent homology from a precomputed distance matrix.
165///
166/// Returns one [`PersistenceDiagram`] per homological dimension (H0, H1, …,
167/// up to `max_dim`).  The filtration is the Vietoris-Rips filtration: a
168/// simplex enters at the maximum pairwise distance among its vertices.
169///
170/// This is equivalent to constructing a nested family of Vietoris-Rips
171/// complexes parameterised by `epsilon ∈ [0, max_epsilon]`.
172///
173/// # Arguments
174/// * `distance_matrix` — symmetric n×n matrix of pairwise distances
175/// * `max_dim`         — maximum homological dimension to compute (typically 1 or 2)
176/// * `max_epsilon`     — upper bound on the filtration parameter
177///
178/// # Example
179///
180/// ```rust
181/// use scirs2_transform::tda_ext::compute_persistence;
182///
183/// let dist = vec![
184///     vec![0.0, 1.0, 1.4, 1.0],
185///     vec![1.0, 0.0, 1.0, 1.4],
186///     vec![1.4, 1.0, 0.0, 1.0],
187///     vec![1.0, 1.4, 1.0, 0.0],
188/// ];
189/// let diagrams = compute_persistence(&dist, 1, 2.0).expect("should succeed");
190/// assert_eq!(diagrams.len(), 2); // H0 and H1
191/// assert!(!diagrams[0].is_empty()); // at least one H0 feature
192/// ```
193pub fn compute_persistence(
194    distance_matrix: &[Vec<f64>],
195    max_dim: usize,
196    max_epsilon: f64,
197) -> Result<Vec<PersistenceDiagram>> {
198    let n = distance_matrix.len();
199    if n == 0 {
200        // Return empty diagrams
201        return Ok((0..=max_dim).map(|d| PersistenceDiagram::new(d)).collect());
202    }
203
204    // Validate distance matrix
205    for row in distance_matrix {
206        if row.len() != n {
207            return Err(TransformError::InvalidInput(
208                "distance_matrix must be square".to_string(),
209            ));
210        }
211    }
212    if max_epsilon < 0.0 {
213        return Err(TransformError::InvalidInput(
214            "max_epsilon must be non-negative".to_string(),
215        ));
216    }
217
218    // Build points from distance matrix for the VietorisRips struct
219    // We use the ndarray interface that VietorisRips::compute expects.
220    // Convert the distance matrix to an Array2 of positions via MDS-like embedding
221    // (actually VietorisRips::compute accepts the data matrix, not distance).
222    // We use the approach of lifting points into n-dimensional space via
223    // the distance matrix's row vectors (approximate MDS, sufficient for filtration).
224    //
225    // However since VietorisRips::compute takes a data matrix and recomputes
226    // Euclidean distances, we need to provide point coordinates such that
227    // ||p_i - p_j|| = distance_matrix[i][j].
228    //
229    // For the general case we use the distance matrix directly to construct
230    // the filtration manually (boundary matrix reduction).
231
232    // Collect all unique pairwise distances (filtration values)
233    let mut filt_values: Vec<f64> = Vec::new();
234    for i in 0..n {
235        for j in (i + 1)..n {
236            let d = distance_matrix[i][j];
237            if d <= max_epsilon && d >= 0.0 {
238                filt_values.push(d);
239            }
240        }
241    }
242    filt_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
243    filt_values.dedup_by(|a, b| (*a - *b).abs() < 1e-15);
244
245    // Enumerate simplices with their filtration values
246    // 0-simplices: all vertices, filtration 0.0
247    // 1-simplices: edges, filtration = dist(i,j)
248    // 2-simplices: triangles, filtration = max edge length
249
250    #[derive(Clone)]
251    struct FiltSimplex {
252        vertices: Vec<usize>,
253        filtration: f64,
254    }
255
256    let mut simplices: Vec<FiltSimplex> = Vec::new();
257
258    // Vertices
259    for i in 0..n {
260        simplices.push(FiltSimplex {
261            vertices: vec![i],
262            filtration: 0.0,
263        });
264    }
265
266    // Edges
267    for i in 0..n {
268        for j in (i + 1)..n {
269            let d = distance_matrix[i][j];
270            if d <= max_epsilon {
271                simplices.push(FiltSimplex {
272                    vertices: vec![i, j],
273                    filtration: d,
274                });
275            }
276        }
277    }
278
279    // Triangles (for H1 / H2)
280    if max_dim >= 1 {
281        for i in 0..n {
282            for j in (i + 1)..n {
283                let d_ij = distance_matrix[i][j];
284                if d_ij > max_epsilon {
285                    continue;
286                }
287                for k in (j + 1)..n {
288                    let d_ik = distance_matrix[i][k];
289                    let d_jk = distance_matrix[j][k];
290                    if d_ik > max_epsilon || d_jk > max_epsilon {
291                        continue;
292                    }
293                    let max_d = d_ij.max(d_ik).max(d_jk);
294                    simplices.push(FiltSimplex {
295                        vertices: vec![i, j, k],
296                        filtration: max_d,
297                    });
298                }
299            }
300        }
301    }
302
303    // Tetrahedra (for H2 / H3)
304    if max_dim >= 2 {
305        for i in 0..n {
306            for j in (i + 1)..n {
307                let d_ij = distance_matrix[i][j];
308                if d_ij > max_epsilon {
309                    continue;
310                }
311                for k in (j + 1)..n {
312                    let d_ik = distance_matrix[i][k];
313                    let d_jk = distance_matrix[j][k];
314                    if d_ik > max_epsilon || d_jk > max_epsilon {
315                        continue;
316                    }
317                    for l in (k + 1)..n {
318                        let d_il = distance_matrix[i][l];
319                        let d_jl = distance_matrix[j][l];
320                        let d_kl = distance_matrix[k][l];
321                        if d_il > max_epsilon || d_jl > max_epsilon || d_kl > max_epsilon {
322                            continue;
323                        }
324                        let max_d = d_ij.max(d_ik).max(d_jk).max(d_il).max(d_jl).max(d_kl);
325                        simplices.push(FiltSimplex {
326                            vertices: vec![i, j, k, l],
327                            filtration: max_d,
328                        });
329                    }
330                }
331            }
332        }
333    }
334
335    // Sort by filtration value, then by dimension
336    simplices.sort_by(|a, b| {
337        a.filtration
338            .partial_cmp(&b.filtration)
339            .unwrap_or(std::cmp::Ordering::Equal)
340            .then_with(|| a.vertices.len().cmp(&b.vertices.len()))
341    });
342
343    // Index simplices for boundary matrix
344    let total = simplices.len();
345    let simplex_idx: std::collections::HashMap<Vec<usize>, usize> = simplices
346        .iter()
347        .enumerate()
348        .map(|(i, s)| (s.vertices.clone(), i))
349        .collect();
350
351    // Build boundary matrix columns as sorted lists of row indices (mod 2)
352    // col j = list of (index of) (dim-1)-faces of simplex j
353    let mut boundary: Vec<Vec<usize>> = vec![Vec::new(); total];
354    for (j, simp) in simplices.iter().enumerate() {
355        let d = simp.vertices.len();
356        if d <= 1 {
357            continue; // 0-simplex has empty boundary
358        }
359        // Faces: remove one vertex at a time
360        for omit in 0..d {
361            let face: Vec<usize> = simp
362                .vertices
363                .iter()
364                .enumerate()
365                .filter(|&(i, _)| i != omit)
366                .map(|(_, &v)| v)
367                .collect();
368            if let Some(&row_idx) = simplex_idx.get(&face) {
369                boundary[j].push(row_idx);
370            }
371        }
372        boundary[j].sort_unstable();
373    }
374
375    // Persistence pairing via standard reduction (column reduction over F_2)
376    // low[j] = lowest row index in column j (None if zero column)
377    let mut low: Vec<Option<usize>> = vec![None; total];
378    // pivot_col[r] = column that has lowest row index r
379    let mut pivot_col: Vec<Option<usize>> = vec![None; total];
380
381    for j in 0..total {
382        loop {
383            let lo = boundary[j].last().copied();
384            match lo {
385                None => break,
386                Some(r) => {
387                    if let Some(k) = pivot_col[r] {
388                        // Add column k to column j (mod 2)
389                        let bk = boundary[k].clone();
390                        sym_diff_inplace(&mut boundary[j], &bk);
391                    } else {
392                        low[j] = Some(r);
393                        pivot_col[r] = Some(j);
394                        break;
395                    }
396                }
397            }
398        }
399    }
400
401    // Extract persistence pairs
402    let mut diagrams: Vec<PersistenceDiagram> =
403        (0..=max_dim).map(|d| PersistenceDiagram::new(d)).collect();
404
405    let mut paired: Vec<bool> = vec![false; total];
406
407    for j in 0..total {
408        if let Some(r) = low[j] {
409            let birth = simplices[r].filtration;
410            let death = simplices[j].filtration;
411            // The dimension of the feature is dim(simplex r) = r's vertex count - 1
412            let feature_dim = simplices[r].vertices.len() - 1;
413            if feature_dim <= max_dim && (death - birth).abs() > 1e-15 {
414                diagrams[feature_dim].add_point(birth, death, feature_dim);
415            }
416            paired[r] = true;
417            paired[j] = true;
418        }
419    }
420
421    // Unpaired simplices → essential features
422    for i in 0..total {
423        if !paired[i] {
424            let dim = simplices[i].vertices.len() - 1;
425            if dim <= max_dim {
426                diagrams[dim].add_point(simplices[i].filtration, f64::INFINITY, dim);
427            }
428        }
429    }
430
431    Ok(diagrams)
432}
433
434// ─── persistence_landscape free function ─────────────────────────────────────
435
436/// Compute the persistence landscape of a diagram, evaluated at points `x`.
437///
438/// The persistence landscape λ_k(t) is defined as:
439///   λ_k(t) = kth largest value of  min(t - b, d - t)⁺  over all (b, d) pairs.
440///
441/// # Arguments
442/// * `dgm`      — persistence diagram (only finite points are used)
443/// * `n_layers` — number of landscape layers to compute
444/// * `x`        — evaluation points (must be sorted in ascending order)
445///
446/// # Returns
447/// Matrix of shape (n_layers × len(x)), where entry \[k, i\] = λ_{k+1}(x\[i\]).
448///
449/// # Example
450///
451/// ```rust
452/// use scirs2_transform::tda_ext::{compute_persistence, persistence_landscape_fn};
453///
454/// let dist = vec![
455///     vec![0.0, 1.0, 1.4, 1.0],
456///     vec![1.0, 0.0, 1.0, 1.4],
457///     vec![1.4, 1.0, 0.0, 1.0],
458///     vec![1.0, 1.4, 1.0, 0.0],
459/// ];
460/// let diagrams = compute_persistence(&dist, 1, 2.0).expect("should succeed");
461/// let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
462/// let landscape = persistence_landscape_fn(&diagrams[0], 2, &x);
463/// assert_eq!(landscape.len(), 2);       // n_layers
464/// assert_eq!(landscape[0].len(), 20);   // len(x)
465/// ```
466pub fn persistence_landscape_fn(
467    dgm: &PersistenceDiagram,
468    n_layers: usize,
469    x: &[f64],
470) -> Vec<Vec<f64>> {
471    if n_layers == 0 || x.is_empty() {
472        return vec![vec![0.0; x.len()]; n_layers];
473    }
474
475    // Collect finite (b, d) pairs
476    let finite_pts: Vec<(f64, f64)> = dgm
477        .points
478        .iter()
479        .filter(|p| p.death.is_finite())
480        .map(|p| (p.birth, p.death))
481        .collect();
482
483    let nx = x.len();
484    // For each evaluation point t, compute tent values and keep top n_layers
485    let mut landscape = vec![vec![0.0_f64; nx]; n_layers];
486
487    for (ix, &t) in x.iter().enumerate() {
488        // Tent function value for each pair at t
489        let mut tents: Vec<f64> = finite_pts
490            .iter()
491            .map(|&(b, d)| {
492                let v = (t - b).min(d - t);
493                if v < 0.0 {
494                    0.0
495                } else {
496                    v
497                }
498            })
499            .collect();
500        // Sort descending and take top n_layers
501        tents.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
502        for k in 0..n_layers {
503            landscape[k][ix] = tents.get(k).copied().unwrap_or(0.0);
504        }
505    }
506
507    landscape
508}
509
510// ─── persistence_image free function ─────────────────────────────────────────
511
512/// Compute a persistence image from a persistence diagram.
513///
514/// Maps each point (b, p) (where p = d − b is persistence) in the diagram to
515/// a 2D Gaussian kernel, then evaluates on a regular grid over
516/// [0, max_birth] × [0, max_persistence].
517///
518/// Points are weighted by their persistence (linear weighting function).
519///
520/// # Arguments
521/// * `dgm`             — source persistence diagram
522/// * `bandwidth`       — Gaussian kernel bandwidth (σ)
523/// * `grid`            — (n_rows, n_cols) resolution of the output image
524/// * `max_birth`       — upper bound of the birth axis
525/// * `max_persistence` — upper bound of the persistence axis
526///
527/// # Returns
528/// 2D grid of shape (n_rows × n_cols) as a Vec of rows.
529///
530/// # Example
531///
532/// ```rust
533/// use scirs2_transform::tda_ext::{compute_persistence, persistence_image_fn};
534///
535/// let dist = vec![
536///     vec![0.0, 1.0, 1.4, 1.0],
537///     vec![1.0, 0.0, 1.0, 1.4],
538///     vec![1.4, 1.0, 0.0, 1.0],
539///     vec![1.0, 1.4, 1.0, 0.0],
540/// ];
541/// let diagrams = compute_persistence(&dist, 0, 2.0).expect("should succeed");
542/// let img = persistence_image_fn(&diagrams[0], 0.1, (5, 5), 2.0, 2.0);
543/// assert_eq!(img.len(), 5);
544/// assert_eq!(img[0].len(), 5);
545/// ```
546pub fn persistence_image_fn(
547    dgm: &PersistenceDiagram,
548    bandwidth: f64,
549    grid: (usize, usize),
550    max_birth: f64,
551    max_persistence: f64,
552) -> Vec<Vec<f64>> {
553    let (n_rows, n_cols) = grid;
554    if n_rows == 0 || n_cols == 0 {
555        return vec![];
556    }
557
558    let bw = bandwidth.max(1e-10);
559    let two_bw_sq = 2.0 * bw * bw;
560    let norm_factor = 1.0 / (std::f64::consts::TAU * bw * bw);
561
562    // Grid cell centres
563    let row_centers: Vec<f64> = if n_rows == 1 {
564        vec![max_persistence * 0.5]
565    } else {
566        (0..n_rows)
567            .map(|i| max_persistence * i as f64 / (n_rows - 1) as f64)
568            .collect()
569    };
570    let col_centers: Vec<f64> = if n_cols == 1 {
571        vec![max_birth * 0.5]
572    } else {
573        (0..n_cols)
574            .map(|j| max_birth * j as f64 / (n_cols - 1) as f64)
575            .collect()
576    };
577
578    // Collect finite (birth, persistence) pairs with weight = persistence
579    let pts: Vec<(f64, f64, f64)> = dgm
580        .points
581        .iter()
582        .filter(|p| p.death.is_finite() && p.death > p.birth)
583        .map(|p| (p.birth, p.death - p.birth, p.death - p.birth)) // (b, pers, weight)
584        .collect();
585
586    let mut image = vec![vec![0.0_f64; n_cols]; n_rows];
587
588    for (r, &p_center) in row_centers.iter().enumerate() {
589        for (c, &b_center) in col_centers.iter().enumerate() {
590            let mut val = 0.0_f64;
591            for &(b, pers, weight) in &pts {
592                let db = b_center - b;
593                let dp = p_center - pers;
594                let exponent = -(db * db + dp * dp) / two_bw_sq;
595                val += weight * norm_factor * exponent.exp();
596            }
597            image[r][c] = val;
598        }
599    }
600
601    image
602}
603
604// ─── Internal helpers ─────────────────────────────────────────────────────────
605
606/// Compute n×n pairwise Euclidean distances.
607fn pairwise_distances(points: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
608    let n = points.len();
609    let mut dist = vec![vec![0.0_f64; n]; n];
610    for i in 0..n {
611        for j in (i + 1)..n {
612            let mut sq = 0.0_f64;
613            for d in 0..dim.min(points[i].len()).min(points[j].len()) {
614                let diff = points[i][d] - points[j][d];
615                sq += diff * diff;
616            }
617            let d = sq.sqrt();
618            dist[i][j] = d;
619            dist[j][i] = d;
620        }
621    }
622    dist
623}
624
625/// Symmetric difference of two sorted Vec<usize> in-place (mod-2 addition of boundary chains).
626fn sym_diff_inplace(a: &mut Vec<usize>, b: &[usize]) {
627    let mut result = Vec::with_capacity(a.len() + b.len());
628    let mut ai = 0_usize;
629    let mut bi = 0_usize;
630    while ai < a.len() && bi < b.len() {
631        match a[ai].cmp(&b[bi]) {
632            std::cmp::Ordering::Less => {
633                result.push(a[ai]);
634                ai += 1;
635            }
636            std::cmp::Ordering::Greater => {
637                result.push(b[bi]);
638                bi += 1;
639            }
640            std::cmp::Ordering::Equal => {
641                // Cancel (mod 2): skip both
642                ai += 1;
643                bi += 1;
644            }
645        }
646    }
647    while ai < a.len() {
648        result.push(a[ai]);
649        ai += 1;
650    }
651    while bi < b.len() {
652        result.push(b[bi]);
653        bi += 1;
654    }
655    *a = result;
656}
657
658// ─── Tests ────────────────────────────────────────────────────────────────────
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663    use crate::tda::PersistenceDiagram;
664
665    fn square_dist() -> Vec<Vec<f64>> {
666        // 4 points: (0,0), (1,0), (1,1), (0,1) — unit square
667        vec![
668            vec![0.0, 1.0, 1.414, 1.0],
669            vec![1.0, 0.0, 1.0, 1.414],
670            vec![1.414, 1.0, 0.0, 1.0],
671            vec![1.0, 1.414, 1.0, 0.0],
672        ]
673    }
674
675    fn square_points() -> Vec<Vec<f64>> {
676        vec![
677            vec![0.0, 0.0],
678            vec![1.0, 0.0],
679            vec![1.0, 1.0],
680            vec![0.0, 1.0],
681        ]
682    }
683
684    // ── VietorisRipsComplex ───────────────────────────────────────────────────
685
686    #[test]
687    fn test_vrc_vertices() {
688        let pts = square_points();
689        let vrc = VietorisRipsComplex::new(&pts, 1.5).expect("new");
690        assert_eq!(vrc.n_simplices(0), 4, "Should have 4 vertices");
691    }
692
693    #[test]
694    fn test_vrc_edges_unit_square() {
695        let pts = square_points();
696        // At epsilon = 1.0: only 4 edges (sides), no diagonals (sqrt(2) ≈ 1.414 > 1.0)
697        let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
698        assert_eq!(vrc.n_simplices(1), 4, "Unit square at eps=1 has 4 edges");
699        assert_eq!(vrc.n_simplices(2), 0, "No triangles at eps=1");
700    }
701
702    #[test]
703    fn test_vrc_complete_graph() {
704        let pts = square_points();
705        // At epsilon = 2.0: all 6 edges present → 4 triangles
706        let vrc = VietorisRipsComplex::new(&pts, 2.0).expect("new");
707        assert_eq!(
708            vrc.n_simplices(1),
709            6,
710            "Complete graph on 4 vertices has 6 edges"
711        );
712        assert_eq!(vrc.n_simplices(2), 4, "4 triangles in K4");
713    }
714
715    #[test]
716    fn test_vrc_euler_characteristic() {
717        let pts = square_points();
718        // At eps = 1.0: 4 vertices, 4 edges, 0 triangles → χ = 4 - 4 + 0 = 0
719        let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
720        assert_eq!(vrc.euler_characteristic(), 0);
721    }
722
723    #[test]
724    fn test_vrc_empty_input() {
725        let vrc = VietorisRipsComplex::new(&[], 1.0).expect("empty ok");
726        assert_eq!(vrc.n_simplices(0), 0);
727        assert_eq!(vrc.euler_characteristic(), 0);
728    }
729
730    #[test]
731    fn test_vrc_negative_epsilon_error() {
732        let pts = square_points();
733        assert!(VietorisRipsComplex::new(&pts, -0.1).is_err());
734    }
735
736    #[test]
737    fn test_vrc_are_connected() {
738        let pts = square_points();
739        let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
740        // Edges of the square: (0,1), (1,2), (2,3), (0,3)
741        assert!(vrc.are_connected(0, 1));
742        assert!(vrc.are_connected(1, 2));
743        // Diagonal (0,2) has length sqrt(2) > 1.0
744        assert!(!vrc.are_connected(0, 2));
745    }
746
747    // ── compute_persistence ───────────────────────────────────────────────────
748
749    #[test]
750    fn test_compute_persistence_h0_square() {
751        let dist = square_dist();
752        let diagrams = compute_persistence(&dist, 1, 2.0).expect("persistence");
753        assert_eq!(diagrams.len(), 2);
754        let h0 = &diagrams[0];
755        // Should have H0 features (at least one)
756        assert!(!h0.is_empty(), "H0 should not be empty");
757    }
758
759    #[test]
760    fn test_compute_persistence_empty() {
761        let diagrams = compute_persistence(&[], 1, 1.0).expect("empty");
762        assert_eq!(diagrams.len(), 2); // H0 and H1, both empty
763        assert!(diagrams[0].is_empty());
764        assert!(diagrams[1].is_empty());
765    }
766
767    #[test]
768    fn test_compute_persistence_non_square_error() {
769        let dist = vec![vec![0.0, 1.0], vec![1.0, 0.0, 2.0]];
770        assert!(compute_persistence(&dist, 1, 2.0).is_err());
771    }
772
773    #[test]
774    fn test_compute_persistence_returns_finite_pairs() {
775        let dist = square_dist();
776        let diagrams = compute_persistence(&dist, 1, 2.0).expect("persistence");
777        for dgm in &diagrams {
778            for pt in &dgm.points {
779                assert!(pt.birth.is_finite());
780                assert!(pt.birth >= 0.0);
781                if pt.death.is_finite() {
782                    assert!(pt.death >= pt.birth);
783                }
784            }
785        }
786    }
787
788    // ── persistence_landscape_fn ──────────────────────────────────────────────
789
790    #[test]
791    fn test_landscape_fn_shape() {
792        let mut dgm = PersistenceDiagram::new(0);
793        dgm.add_point(0.0, 2.0, 0);
794        dgm.add_point(0.5, 1.5, 0);
795        let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
796        let l = persistence_landscape_fn(&dgm, 3, &x);
797        assert_eq!(l.len(), 3);
798        assert_eq!(l[0].len(), 20);
799    }
800
801    #[test]
802    fn test_landscape_fn_non_negative() {
803        let mut dgm = PersistenceDiagram::new(0);
804        dgm.add_point(0.0, 1.0, 0);
805        let x: Vec<f64> = (0..10).map(|i| i as f64 * 0.15).collect();
806        let l = persistence_landscape_fn(&dgm, 2, &x);
807        for row in &l {
808            for &v in row {
809                assert!(v >= 0.0, "landscape must be non-negative, got {v}");
810            }
811        }
812    }
813
814    #[test]
815    fn test_landscape_fn_tent_shape() {
816        let mut dgm = PersistenceDiagram::new(0);
817        // Single (0, 1) pair → tent peaked at t=0.5 with height 0.5
818        dgm.add_point(0.0, 1.0, 0);
819        let x = vec![0.0, 0.25, 0.5, 0.75, 1.0];
820        let l = persistence_landscape_fn(&dgm, 1, &x);
821        // λ_1(0.5) = min(0.5 - 0, 1.0 - 0.5) = 0.5
822        assert!((l[0][2] - 0.5).abs() < 1e-10, "peak should be 0.5");
823        // λ_1(0) = 0, λ_1(1) = 0
824        assert!(l[0][0] < 1e-10);
825        assert!(l[0][4] < 1e-10);
826    }
827
828    #[test]
829    fn test_landscape_fn_empty_diagram() {
830        let dgm = PersistenceDiagram::new(0);
831        let x = vec![0.0, 1.0, 2.0];
832        let l = persistence_landscape_fn(&dgm, 2, &x);
833        assert_eq!(l.len(), 2);
834        for row in &l {
835            assert!(row.iter().all(|&v| v == 0.0));
836        }
837    }
838
839    // ── persistence_image_fn ──────────────────────────────────────────────────
840
841    #[test]
842    fn test_persistence_image_fn_shape() {
843        let mut dgm = PersistenceDiagram::new(0);
844        dgm.add_point(0.0, 1.0, 0);
845        dgm.add_point(0.2, 0.8, 0);
846        let img = persistence_image_fn(&dgm, 0.1, (5, 5), 1.0, 1.0);
847        assert_eq!(img.len(), 5);
848        assert_eq!(img[0].len(), 5);
849    }
850
851    #[test]
852    fn test_persistence_image_fn_non_negative() {
853        let mut dgm = PersistenceDiagram::new(0);
854        dgm.add_point(0.0, 1.0, 0);
855        let img = persistence_image_fn(&dgm, 0.1, (4, 4), 1.0, 1.0);
856        for row in &img {
857            for &v in row {
858                assert!(v >= 0.0, "image pixel must be non-negative, got {v}");
859            }
860        }
861    }
862
863    #[test]
864    fn test_persistence_image_fn_has_signal() {
865        let mut dgm = PersistenceDiagram::new(0);
866        dgm.add_point(0.0, 1.0, 0);
867        let img = persistence_image_fn(&dgm, 0.15, (6, 6), 1.5, 1.5);
868        let has_positive = img.iter().flat_map(|row| row.iter()).any(|&v| v > 0.0);
869        assert!(
870            has_positive,
871            "image should have nonzero pixels for a nonempty diagram"
872        );
873    }
874
875    #[test]
876    fn test_persistence_image_fn_empty_diagram() {
877        let dgm = PersistenceDiagram::new(0);
878        let img = persistence_image_fn(&dgm, 0.1, (4, 4), 1.0, 1.0);
879        assert_eq!(img.len(), 4);
880        for row in &img {
881            assert!(row.iter().all(|&v| v == 0.0));
882        }
883    }
884
885    #[test]
886    fn test_sym_diff_inplace() {
887        let mut a = vec![1_usize, 3, 5];
888        let b = vec![2, 3, 4];
889        sym_diff_inplace(&mut a, &b);
890        // Expected: 1, 2, 4, 5 (3 cancelled)
891        assert_eq!(a, vec![1, 2, 4, 5]);
892    }
893}