1use super::{BirthDeathPair, PersistenceDiagram};
6
7#[derive(Debug, Clone)]
13pub struct BottleneckDistance;
14
15impl BottleneckDistance {
16 pub fn compute(d1: &PersistenceDiagram, d2: &PersistenceDiagram, dim: usize) -> f64 {
18 let pts1: Vec<(f64, f64)> = d1
19 .pairs_of_dim(dim)
20 .filter(|p| !p.is_essential())
21 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
22 .collect();
23
24 let pts2: Vec<(f64, f64)> = d2
25 .pairs_of_dim(dim)
26 .filter(|p| !p.is_essential())
27 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
28 .collect();
29
30 Self::bottleneck_finite(&pts1, &pts2)
31 }
32
33 fn bottleneck_finite(pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
35 if pts1.is_empty() && pts2.is_empty() {
36 return 0.0;
37 }
38
39 let mut all_distances = Vec::new();
41
42 for &(b1, d1) in pts1 {
44 for &(b2, d2) in pts2 {
45 let dist = Self::l_inf((b1, d1), (b2, d2));
46 all_distances.push(dist);
47 }
48 }
49
50 for &(b, d) in pts1 {
52 let diag_dist = (d - b) / 2.0;
53 all_distances.push(diag_dist);
54 }
55 for &(b, d) in pts2 {
56 let diag_dist = (d - b) / 2.0;
57 all_distances.push(diag_dist);
58 }
59
60 if all_distances.is_empty() {
61 return 0.0;
62 }
63
64 all_distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
66
67 for &threshold in &all_distances {
69 if Self::can_match(pts1, pts2, threshold) {
70 return threshold;
71 }
72 }
73
74 *all_distances.last().unwrap_or(&0.0)
76 }
77
78 fn can_match(pts1: &[(f64, f64)], pts2: &[(f64, f64)], threshold: f64) -> bool {
80 let mut used2 = vec![false; pts2.len()];
82 let mut matched1 = 0;
83
84 for &p1 in pts1 {
85 let mut found = false;
87 for (j, &p2) in pts2.iter().enumerate() {
88 if !used2[j] && Self::l_inf(p1, p2) <= threshold {
89 used2[j] = true;
90 found = true;
91 break;
92 }
93 }
94
95 if !found {
96 if Self::diag_dist(p1) <= threshold {
98 matched1 += 1;
99 continue;
100 }
101 return false;
102 }
103 matched1 += 1;
104 }
105
106 for (j, &p2) in pts2.iter().enumerate() {
108 if !used2[j] && Self::diag_dist(p2) > threshold {
109 return false;
110 }
111 }
112
113 true
114 }
115
116 fn l_inf(p1: (f64, f64), p2: (f64, f64)) -> f64 {
118 (p1.0 - p2.0).abs().max((p1.1 - p2.1).abs())
119 }
120
121 fn diag_dist(p: (f64, f64)) -> f64 {
123 (p.1 - p.0) / 2.0
124 }
125}
126
127#[derive(Debug, Clone)]
131pub struct WassersteinDistance {
132 pub p: f64,
134}
135
136impl WassersteinDistance {
137 pub fn new(p: f64) -> Self {
139 Self { p: p.max(1.0) }
140 }
141
142 pub fn compute(&self, d1: &PersistenceDiagram, d2: &PersistenceDiagram, dim: usize) -> f64 {
144 let pts1: Vec<(f64, f64)> = d1
145 .pairs_of_dim(dim)
146 .filter(|p| !p.is_essential())
147 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
148 .collect();
149
150 let pts2: Vec<(f64, f64)> = d2
151 .pairs_of_dim(dim)
152 .filter(|p| !p.is_essential())
153 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
154 .collect();
155
156 self.wasserstein_finite(&pts1, &pts2)
157 }
158
159 fn wasserstein_finite(&self, pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
161 if pts1.is_empty() && pts2.is_empty() {
162 return 0.0;
163 }
164
165 let mut used2 = vec![false; pts2.len()];
167 let mut total_cost = 0.0;
168
169 for &p1 in pts1 {
170 let diag_cost = Self::diag_dist(p1).powf(self.p);
171
172 let mut best_cost = diag_cost;
174 let mut best_j = None;
175
176 for (j, &p2) in pts2.iter().enumerate() {
177 if !used2[j] {
178 let cost = Self::l_inf(p1, p2).powf(self.p);
179 if cost < best_cost {
180 best_cost = cost;
181 best_j = Some(j);
182 }
183 }
184 }
185
186 total_cost += best_cost;
187 if let Some(j) = best_j {
188 used2[j] = true;
189 }
190 }
191
192 for (j, &p2) in pts2.iter().enumerate() {
194 if !used2[j] {
195 total_cost += Self::diag_dist(p2).powf(self.p);
196 }
197 }
198
199 total_cost.powf(1.0 / self.p)
200 }
201
202 fn l_inf(p1: (f64, f64), p2: (f64, f64)) -> f64 {
203 (p1.0 - p2.0).abs().max((p1.1 - p2.1).abs())
204 }
205
206 fn diag_dist(p: (f64, f64)) -> f64 {
207 (p.1 - p.0) / 2.0
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct PersistenceLandscape {
214 pub landscapes: Vec<Vec<f64>>,
216 pub grid: Vec<f64>,
218 pub num_landscapes: usize,
220}
221
222impl PersistenceLandscape {
223 pub fn from_diagram(
225 diagram: &PersistenceDiagram,
226 dim: usize,
227 num_landscapes: usize,
228 resolution: usize,
229 ) -> Self {
230 let pairs: Vec<(f64, f64)> = diagram
231 .pairs_of_dim(dim)
232 .filter(|p| !p.is_essential())
233 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
234 .filter(|p| p.1.is_finite())
235 .collect();
236
237 if pairs.is_empty() {
238 return Self {
239 landscapes: vec![vec![0.0; resolution]; num_landscapes],
240 grid: (0..resolution).map(|i| i as f64 / resolution as f64).collect(),
241 num_landscapes,
242 };
243 }
244
245 let min_t = pairs.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
247 let max_t = pairs.iter().map(|p| p.1).fold(f64::NEG_INFINITY, f64::max);
248 let range = (max_t - min_t).max(1e-10);
249
250 let grid: Vec<f64> = (0..resolution)
251 .map(|i| min_t + (i as f64 / (resolution - 1).max(1) as f64) * range)
252 .collect();
253
254 let mut landscapes = vec![vec![0.0; resolution]; num_landscapes];
256
257 for (gi, &t) in grid.iter().enumerate() {
258 let mut values: Vec<f64> = pairs
260 .iter()
261 .map(|&(b, d)| {
262 if t < b || t > d {
263 0.0
264 } else if t <= (b + d) / 2.0 {
265 t - b
266 } else {
267 d - t
268 }
269 })
270 .collect();
271
272 values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
274
275 for (k, &v) in values.iter().take(num_landscapes).enumerate() {
277 landscapes[k][gi] = v;
278 }
279 }
280
281 Self {
282 landscapes,
283 grid,
284 num_landscapes,
285 }
286 }
287
288 pub fn l2_distance(&self, other: &Self) -> f64 {
290 if self.grid.len() != other.grid.len() || self.num_landscapes != other.num_landscapes {
291 return f64::INFINITY;
292 }
293
294 let n = self.grid.len();
295 let dt = if n > 1 {
296 (self.grid[n - 1] - self.grid[0]) / (n - 1) as f64
297 } else {
298 1.0
299 };
300
301 let mut total = 0.0;
302 for k in 0..self.num_landscapes {
303 for i in 0..n {
304 let diff = self.landscapes[k][i] - other.landscapes[k][i];
305 total += diff * diff * dt;
306 }
307 }
308
309 total.sqrt()
310 }
311
312 pub fn to_vector(&self) -> Vec<f64> {
314 self.landscapes.iter().flat_map(|l| l.iter().copied()).collect()
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 fn sample_diagram() -> PersistenceDiagram {
323 let mut d = PersistenceDiagram::new();
324 d.add(BirthDeathPair::finite(0, 0.0, 1.0));
325 d.add(BirthDeathPair::finite(0, 0.5, 1.5));
326 d.add(BirthDeathPair::finite(1, 0.2, 0.8));
327 d
328 }
329
330 #[test]
331 fn test_bottleneck_same() {
332 let d = sample_diagram();
333 let dist = BottleneckDistance::compute(&d, &d, 0);
334 assert!(dist < 1e-10);
335 }
336
337 #[test]
338 fn test_bottleneck_different() {
339 let d1 = sample_diagram();
340 let mut d2 = PersistenceDiagram::new();
341 d2.add(BirthDeathPair::finite(0, 0.0, 2.0));
342
343 let dist = BottleneckDistance::compute(&d1, &d2, 0);
344 assert!(dist > 0.0);
345 }
346
347 #[test]
348 fn test_wasserstein() {
349 let d1 = sample_diagram();
350 let d2 = sample_diagram();
351
352 let w1 = WassersteinDistance::new(1.0);
353 let dist = w1.compute(&d1, &d2, 0);
354 assert!(dist < 1e-10);
355 }
356
357 #[test]
358 fn test_persistence_landscape() {
359 let d = sample_diagram();
360 let landscape = PersistenceLandscape::from_diagram(&d, 0, 3, 20);
361
362 assert_eq!(landscape.landscapes.len(), 3);
363 assert_eq!(landscape.grid.len(), 20);
364 }
365
366 #[test]
367 fn test_landscape_distance() {
368 let d1 = sample_diagram();
369 let l1 = PersistenceLandscape::from_diagram(&d1, 0, 3, 20);
370 let l2 = PersistenceLandscape::from_diagram(&d1, 0, 3, 20);
371
372 let dist = l1.l2_distance(&l2);
373 assert!(dist < 1e-10);
374 }
375
376 #[test]
377 fn test_landscape_vector() {
378 let d = sample_diagram();
379 let landscape = PersistenceLandscape::from_diagram(&d, 0, 2, 10);
380
381 let vec = landscape.to_vector();
382 assert_eq!(vec.len(), 20); }
384}