ruvector_math/homology/
distance.rs

1//! Distances between Persistence Diagrams
2//!
3//! Bottleneck and Wasserstein distances for comparing topological signatures.
4
5use super::{BirthDeathPair, PersistenceDiagram};
6
7/// Bottleneck distance between persistence diagrams
8///
9/// d_∞(D1, D2) = inf_γ sup_p ||p - γ(p)||_∞
10///
11/// where γ ranges over bijections between diagrams (with diagonal).
12#[derive(Debug, Clone)]
13pub struct BottleneckDistance;
14
15impl BottleneckDistance {
16    /// Compute bottleneck distance for dimension d
17    pub fn compute(d1: &PersistenceDiagram, d2: &PersistenceDiagram, dim: usize) -> f64 {
18        let pts1: Vec<(f64, f64)> = d1
19            .pairs_of_dim(dim)
20            .filter(|p| !p.is_essential())
21            .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
22            .collect();
23
24        let pts2: Vec<(f64, f64)> = d2
25            .pairs_of_dim(dim)
26            .filter(|p| !p.is_essential())
27            .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
28            .collect();
29
30        Self::bottleneck_finite(&pts1, &pts2)
31    }
32
33    /// Bottleneck distance for finite points
34    fn bottleneck_finite(pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
35        if pts1.is_empty() && pts2.is_empty() {
36            return 0.0;
37        }
38
39        // Include diagonal projections
40        let mut all_distances = Vec::new();
41
42        // Distances between points
43        for &(b1, d1) in pts1 {
44            for &(b2, d2) in pts2 {
45                let dist = Self::l_inf((b1, d1), (b2, d2));
46                all_distances.push(dist);
47            }
48        }
49
50        // Distances to diagonal
51        for &(b, d) in pts1 {
52            let diag_dist = (d - b) / 2.0;
53            all_distances.push(diag_dist);
54        }
55        for &(b, d) in pts2 {
56            let diag_dist = (d - b) / 2.0;
57            all_distances.push(diag_dist);
58        }
59
60        if all_distances.is_empty() {
61            return 0.0;
62        }
63
64        // Sort and binary search for bottleneck
65        all_distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
66
67        // For small instances, use greedy matching at each threshold
68        for &threshold in &all_distances {
69            if Self::can_match(pts1, pts2, threshold) {
70                return threshold;
71            }
72        }
73
74        // Fallback
75        *all_distances.last().unwrap_or(&0.0)
76    }
77
78    /// Check if perfect matching exists at threshold
79    fn can_match(pts1: &[(f64, f64)], pts2: &[(f64, f64)], threshold: f64) -> bool {
80        // Simple greedy matching (not optimal but fast)
81        let mut used2 = vec![false; pts2.len()];
82        let mut matched1 = 0;
83
84        for &p1 in pts1 {
85            // Try to match to a point in pts2
86            let mut found = false;
87            for (j, &p2) in pts2.iter().enumerate() {
88                if !used2[j] && Self::l_inf(p1, p2) <= threshold {
89                    used2[j] = true;
90                    found = true;
91                    break;
92                }
93            }
94
95            if !found {
96                // Try to match to diagonal
97                if Self::diag_dist(p1) <= threshold {
98                    matched1 += 1;
99                    continue;
100                }
101                return false;
102            }
103            matched1 += 1;
104        }
105
106        // Check unmatched pts2 can go to diagonal
107        for (j, &p2) in pts2.iter().enumerate() {
108            if !used2[j] && Self::diag_dist(p2) > threshold {
109                return false;
110            }
111        }
112
113        true
114    }
115
116    /// L-infinity distance between points
117    fn l_inf(p1: (f64, f64), p2: (f64, f64)) -> f64 {
118        (p1.0 - p2.0).abs().max((p1.1 - p2.1).abs())
119    }
120
121    /// Distance to diagonal
122    fn diag_dist(p: (f64, f64)) -> f64 {
123        (p.1 - p.0) / 2.0
124    }
125}
126
127/// Wasserstein distance between persistence diagrams
128///
129/// W_p(D1, D2) = (inf_γ Σ ||p - γ(p)||_∞^p)^{1/p}
130#[derive(Debug, Clone)]
131pub struct WassersteinDistance {
132    /// Power p (usually 1 or 2)
133    pub p: f64,
134}
135
136impl WassersteinDistance {
137    /// Create with power p
138    pub fn new(p: f64) -> Self {
139        Self { p: p.max(1.0) }
140    }
141
142    /// Compute W_p distance for dimension d
143    pub fn compute(&self, d1: &PersistenceDiagram, d2: &PersistenceDiagram, dim: usize) -> f64 {
144        let pts1: Vec<(f64, f64)> = d1
145            .pairs_of_dim(dim)
146            .filter(|p| !p.is_essential())
147            .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
148            .collect();
149
150        let pts2: Vec<(f64, f64)> = d2
151            .pairs_of_dim(dim)
152            .filter(|p| !p.is_essential())
153            .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
154            .collect();
155
156        self.wasserstein_finite(&pts1, &pts2)
157    }
158
159    /// Wasserstein distance for finite points (greedy approximation)
160    fn wasserstein_finite(&self, pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
161        if pts1.is_empty() && pts2.is_empty() {
162            return 0.0;
163        }
164
165        // Greedy matching (approximation)
166        let mut used2 = vec![false; pts2.len()];
167        let mut total_cost = 0.0;
168
169        for &p1 in pts1 {
170            let diag_cost = Self::diag_dist(p1).powf(self.p);
171
172            // Find best match
173            let mut best_cost = diag_cost;
174            let mut best_j = None;
175
176            for (j, &p2) in pts2.iter().enumerate() {
177                if !used2[j] {
178                    let cost = Self::l_inf(p1, p2).powf(self.p);
179                    if cost < best_cost {
180                        best_cost = cost;
181                        best_j = Some(j);
182                    }
183                }
184            }
185
186            total_cost += best_cost;
187            if let Some(j) = best_j {
188                used2[j] = true;
189            }
190        }
191
192        // Unmatched pts2 go to diagonal
193        for (j, &p2) in pts2.iter().enumerate() {
194            if !used2[j] {
195                total_cost += Self::diag_dist(p2).powf(self.p);
196            }
197        }
198
199        total_cost.powf(1.0 / self.p)
200    }
201
202    fn l_inf(p1: (f64, f64), p2: (f64, f64)) -> f64 {
203        (p1.0 - p2.0).abs().max((p1.1 - p2.1).abs())
204    }
205
206    fn diag_dist(p: (f64, f64)) -> f64 {
207        (p.1 - p.0) / 2.0
208    }
209}
210
211/// Persistence landscape for machine learning
212#[derive(Debug, Clone)]
213pub struct PersistenceLandscape {
214    /// Landscape functions λ_k(t)
215    pub landscapes: Vec<Vec<f64>>,
216    /// Grid points
217    pub grid: Vec<f64>,
218    /// Number of landscape functions
219    pub num_landscapes: usize,
220}
221
222impl PersistenceLandscape {
223    /// Compute landscape from persistence diagram
224    pub fn from_diagram(
225        diagram: &PersistenceDiagram,
226        dim: usize,
227        num_landscapes: usize,
228        resolution: usize,
229    ) -> Self {
230        let pairs: Vec<(f64, f64)> = diagram
231            .pairs_of_dim(dim)
232            .filter(|p| !p.is_essential())
233            .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
234            .filter(|p| p.1.is_finite())
235            .collect();
236
237        if pairs.is_empty() {
238            return Self {
239                landscapes: vec![vec![0.0; resolution]; num_landscapes],
240                grid: (0..resolution).map(|i| i as f64 / resolution as f64).collect(),
241                num_landscapes,
242            };
243        }
244
245        // Determine grid
246        let min_t = pairs.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
247        let max_t = pairs.iter().map(|p| p.1).fold(f64::NEG_INFINITY, f64::max);
248        let range = (max_t - min_t).max(1e-10);
249
250        let grid: Vec<f64> = (0..resolution)
251            .map(|i| min_t + (i as f64 / (resolution - 1).max(1) as f64) * range)
252            .collect();
253
254        // Compute tent functions at each grid point
255        let mut landscapes = vec![vec![0.0; resolution]; num_landscapes];
256
257        for (gi, &t) in grid.iter().enumerate() {
258            // Evaluate all tent functions at t
259            let mut values: Vec<f64> = pairs
260                .iter()
261                .map(|&(b, d)| {
262                    if t < b || t > d {
263                        0.0
264                    } else if t <= (b + d) / 2.0 {
265                        t - b
266                    } else {
267                        d - t
268                    }
269                })
270                .collect();
271
272            // Sort descending
273            values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
274
275            // Take top k
276            for (k, &v) in values.iter().take(num_landscapes).enumerate() {
277                landscapes[k][gi] = v;
278            }
279        }
280
281        Self {
282            landscapes,
283            grid,
284            num_landscapes,
285        }
286    }
287
288    /// L2 distance between landscapes
289    pub fn l2_distance(&self, other: &Self) -> f64 {
290        if self.grid.len() != other.grid.len() || self.num_landscapes != other.num_landscapes {
291            return f64::INFINITY;
292        }
293
294        let n = self.grid.len();
295        let dt = if n > 1 {
296            (self.grid[n - 1] - self.grid[0]) / (n - 1) as f64
297        } else {
298            1.0
299        };
300
301        let mut total = 0.0;
302        for k in 0..self.num_landscapes {
303            for i in 0..n {
304                let diff = self.landscapes[k][i] - other.landscapes[k][i];
305                total += diff * diff * dt;
306            }
307        }
308
309        total.sqrt()
310    }
311
312    /// Get feature vector (flattened landscape)
313    pub fn to_vector(&self) -> Vec<f64> {
314        self.landscapes.iter().flat_map(|l| l.iter().copied()).collect()
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    fn sample_diagram() -> PersistenceDiagram {
323        let mut d = PersistenceDiagram::new();
324        d.add(BirthDeathPair::finite(0, 0.0, 1.0));
325        d.add(BirthDeathPair::finite(0, 0.5, 1.5));
326        d.add(BirthDeathPair::finite(1, 0.2, 0.8));
327        d
328    }
329
330    #[test]
331    fn test_bottleneck_same() {
332        let d = sample_diagram();
333        let dist = BottleneckDistance::compute(&d, &d, 0);
334        assert!(dist < 1e-10);
335    }
336
337    #[test]
338    fn test_bottleneck_different() {
339        let d1 = sample_diagram();
340        let mut d2 = PersistenceDiagram::new();
341        d2.add(BirthDeathPair::finite(0, 0.0, 2.0));
342
343        let dist = BottleneckDistance::compute(&d1, &d2, 0);
344        assert!(dist > 0.0);
345    }
346
347    #[test]
348    fn test_wasserstein() {
349        let d1 = sample_diagram();
350        let d2 = sample_diagram();
351
352        let w1 = WassersteinDistance::new(1.0);
353        let dist = w1.compute(&d1, &d2, 0);
354        assert!(dist < 1e-10);
355    }
356
357    #[test]
358    fn test_persistence_landscape() {
359        let d = sample_diagram();
360        let landscape = PersistenceLandscape::from_diagram(&d, 0, 3, 20);
361
362        assert_eq!(landscape.landscapes.len(), 3);
363        assert_eq!(landscape.grid.len(), 20);
364    }
365
366    #[test]
367    fn test_landscape_distance() {
368        let d1 = sample_diagram();
369        let l1 = PersistenceLandscape::from_diagram(&d1, 0, 3, 20);
370        let l2 = PersistenceLandscape::from_diagram(&d1, 0, 3, 20);
371
372        let dist = l1.l2_distance(&l2);
373        assert!(dist < 1e-10);
374    }
375
376    #[test]
377    fn test_landscape_vector() {
378        let d = sample_diagram();
379        let landscape = PersistenceLandscape::from_diagram(&d, 0, 2, 10);
380
381        let vec = landscape.to_vector();
382        assert_eq!(vec.len(), 20); // 2 landscapes × 10 points
383    }
384}