1use super::{BirthDeathPair, PersistenceDiagram};
6
7#[derive(Debug, Clone)]
13pub struct BottleneckDistance;
14
15impl BottleneckDistance {
16 pub fn compute(d1: &PersistenceDiagram, d2: &PersistenceDiagram, dim: usize) -> f64 {
18 let pts1: Vec<(f64, f64)> = d1
19 .pairs_of_dim(dim)
20 .filter(|p| !p.is_essential())
21 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
22 .collect();
23
24 let pts2: Vec<(f64, f64)> = d2
25 .pairs_of_dim(dim)
26 .filter(|p| !p.is_essential())
27 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
28 .collect();
29
30 Self::bottleneck_finite(&pts1, &pts2)
31 }
32
33 fn bottleneck_finite(pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
35 if pts1.is_empty() && pts2.is_empty() {
36 return 0.0;
37 }
38
39 let mut all_distances = Vec::new();
41
42 for &(b1, d1) in pts1 {
44 for &(b2, d2) in pts2 {
45 let dist = Self::l_inf((b1, d1), (b2, d2));
46 all_distances.push(dist);
47 }
48 }
49
50 for &(b, d) in pts1 {
52 let diag_dist = (d - b) / 2.0;
53 all_distances.push(diag_dist);
54 }
55 for &(b, d) in pts2 {
56 let diag_dist = (d - b) / 2.0;
57 all_distances.push(diag_dist);
58 }
59
60 if all_distances.is_empty() {
61 return 0.0;
62 }
63
64 all_distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
66
67 for &threshold in &all_distances {
69 if Self::can_match(pts1, pts2, threshold) {
70 return threshold;
71 }
72 }
73
74 *all_distances.last().unwrap_or(&0.0)
76 }
77
78 fn can_match(pts1: &[(f64, f64)], pts2: &[(f64, f64)], threshold: f64) -> bool {
80 let mut used2 = vec![false; pts2.len()];
82 let mut matched1 = 0;
83
84 for &p1 in pts1 {
85 let mut found = false;
87 for (j, &p2) in pts2.iter().enumerate() {
88 if !used2[j] && Self::l_inf(p1, p2) <= threshold {
89 used2[j] = true;
90 found = true;
91 break;
92 }
93 }
94
95 if !found {
96 if Self::diag_dist(p1) <= threshold {
98 matched1 += 1;
99 continue;
100 }
101 return false;
102 }
103 matched1 += 1;
104 }
105
106 for (j, &p2) in pts2.iter().enumerate() {
108 if !used2[j] && Self::diag_dist(p2) > threshold {
109 return false;
110 }
111 }
112
113 true
114 }
115
116 fn l_inf(p1: (f64, f64), p2: (f64, f64)) -> f64 {
118 (p1.0 - p2.0).abs().max((p1.1 - p2.1).abs())
119 }
120
121 fn diag_dist(p: (f64, f64)) -> f64 {
123 (p.1 - p.0) / 2.0
124 }
125}
126
127#[derive(Debug, Clone)]
131pub struct WassersteinDistance {
132 pub p: f64,
134}
135
136impl WassersteinDistance {
137 pub fn new(p: f64) -> Self {
139 Self { p: p.max(1.0) }
140 }
141
142 pub fn compute(&self, d1: &PersistenceDiagram, d2: &PersistenceDiagram, dim: usize) -> f64 {
144 let pts1: Vec<(f64, f64)> = d1
145 .pairs_of_dim(dim)
146 .filter(|p| !p.is_essential())
147 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
148 .collect();
149
150 let pts2: Vec<(f64, f64)> = d2
151 .pairs_of_dim(dim)
152 .filter(|p| !p.is_essential())
153 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
154 .collect();
155
156 self.wasserstein_finite(&pts1, &pts2)
157 }
158
159 fn wasserstein_finite(&self, pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
161 if pts1.is_empty() && pts2.is_empty() {
162 return 0.0;
163 }
164
165 let mut used2 = vec![false; pts2.len()];
167 let mut total_cost = 0.0;
168
169 for &p1 in pts1 {
170 let diag_cost = Self::diag_dist(p1).powf(self.p);
171
172 let mut best_cost = diag_cost;
174 let mut best_j = None;
175
176 for (j, &p2) in pts2.iter().enumerate() {
177 if !used2[j] {
178 let cost = Self::l_inf(p1, p2).powf(self.p);
179 if cost < best_cost {
180 best_cost = cost;
181 best_j = Some(j);
182 }
183 }
184 }
185
186 total_cost += best_cost;
187 if let Some(j) = best_j {
188 used2[j] = true;
189 }
190 }
191
192 for (j, &p2) in pts2.iter().enumerate() {
194 if !used2[j] {
195 total_cost += Self::diag_dist(p2).powf(self.p);
196 }
197 }
198
199 total_cost.powf(1.0 / self.p)
200 }
201
202 fn l_inf(p1: (f64, f64), p2: (f64, f64)) -> f64 {
203 (p1.0 - p2.0).abs().max((p1.1 - p2.1).abs())
204 }
205
206 fn diag_dist(p: (f64, f64)) -> f64 {
207 (p.1 - p.0) / 2.0
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct PersistenceLandscape {
214 pub landscapes: Vec<Vec<f64>>,
216 pub grid: Vec<f64>,
218 pub num_landscapes: usize,
220}
221
222impl PersistenceLandscape {
223 pub fn from_diagram(
225 diagram: &PersistenceDiagram,
226 dim: usize,
227 num_landscapes: usize,
228 resolution: usize,
229 ) -> Self {
230 let pairs: Vec<(f64, f64)> = diagram
231 .pairs_of_dim(dim)
232 .filter(|p| !p.is_essential())
233 .map(|p| (p.birth, p.death.unwrap_or(f64::INFINITY)))
234 .filter(|p| p.1.is_finite())
235 .collect();
236
237 if pairs.is_empty() {
238 return Self {
239 landscapes: vec![vec![0.0; resolution]; num_landscapes],
240 grid: (0..resolution)
241 .map(|i| i as f64 / resolution as f64)
242 .collect(),
243 num_landscapes,
244 };
245 }
246
247 let min_t = pairs.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
249 let max_t = pairs.iter().map(|p| p.1).fold(f64::NEG_INFINITY, f64::max);
250 let range = (max_t - min_t).max(1e-10);
251
252 let grid: Vec<f64> = (0..resolution)
253 .map(|i| min_t + (i as f64 / (resolution - 1).max(1) as f64) * range)
254 .collect();
255
256 let mut landscapes = vec![vec![0.0; resolution]; num_landscapes];
258
259 for (gi, &t) in grid.iter().enumerate() {
260 let mut values: Vec<f64> = pairs
262 .iter()
263 .map(|&(b, d)| {
264 if t < b || t > d {
265 0.0
266 } else if t <= (b + d) / 2.0 {
267 t - b
268 } else {
269 d - t
270 }
271 })
272 .collect();
273
274 values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
276
277 for (k, &v) in values.iter().take(num_landscapes).enumerate() {
279 landscapes[k][gi] = v;
280 }
281 }
282
283 Self {
284 landscapes,
285 grid,
286 num_landscapes,
287 }
288 }
289
290 pub fn l2_distance(&self, other: &Self) -> f64 {
292 if self.grid.len() != other.grid.len() || self.num_landscapes != other.num_landscapes {
293 return f64::INFINITY;
294 }
295
296 let n = self.grid.len();
297 let dt = if n > 1 {
298 (self.grid[n - 1] - self.grid[0]) / (n - 1) as f64
299 } else {
300 1.0
301 };
302
303 let mut total = 0.0;
304 for k in 0..self.num_landscapes {
305 for i in 0..n {
306 let diff = self.landscapes[k][i] - other.landscapes[k][i];
307 total += diff * diff * dt;
308 }
309 }
310
311 total.sqrt()
312 }
313
314 pub fn to_vector(&self) -> Vec<f64> {
316 self.landscapes
317 .iter()
318 .flat_map(|l| l.iter().copied())
319 .collect()
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 fn sample_diagram() -> PersistenceDiagram {
328 let mut d = PersistenceDiagram::new();
329 d.add(BirthDeathPair::finite(0, 0.0, 1.0));
330 d.add(BirthDeathPair::finite(0, 0.5, 1.5));
331 d.add(BirthDeathPair::finite(1, 0.2, 0.8));
332 d
333 }
334
335 #[test]
336 fn test_bottleneck_same() {
337 let d = sample_diagram();
338 let dist = BottleneckDistance::compute(&d, &d, 0);
339 assert!(dist < 1e-10);
340 }
341
342 #[test]
343 fn test_bottleneck_different() {
344 let d1 = sample_diagram();
345 let mut d2 = PersistenceDiagram::new();
346 d2.add(BirthDeathPair::finite(0, 0.0, 2.0));
347
348 let dist = BottleneckDistance::compute(&d1, &d2, 0);
349 assert!(dist > 0.0);
350 }
351
352 #[test]
353 fn test_wasserstein() {
354 let d1 = sample_diagram();
355 let d2 = sample_diagram();
356
357 let w1 = WassersteinDistance::new(1.0);
358 let dist = w1.compute(&d1, &d2, 0);
359 assert!(dist < 1e-10);
360 }
361
362 #[test]
363 fn test_persistence_landscape() {
364 let d = sample_diagram();
365 let landscape = PersistenceLandscape::from_diagram(&d, 0, 3, 20);
366
367 assert_eq!(landscape.landscapes.len(), 3);
368 assert_eq!(landscape.grid.len(), 20);
369 }
370
371 #[test]
372 fn test_landscape_distance() {
373 let d1 = sample_diagram();
374 let l1 = PersistenceLandscape::from_diagram(&d1, 0, 3, 20);
375 let l2 = PersistenceLandscape::from_diagram(&d1, 0, 3, 20);
376
377 let dist = l1.l2_distance(&l2);
378 assert!(dist < 1e-10);
379 }
380
381 #[test]
382 fn test_landscape_vector() {
383 let d = sample_diagram();
384 let landscape = PersistenceLandscape::from_diagram(&d, 0, 2, 10);
385
386 let vec = landscape.to_vector();
387 assert_eq!(vec.len(), 20); }
389}