1use crate::error::{Result, TransformError};
18use crate::tda::alpha_complex::sym_diff_sorted;
19use std::collections::HashMap;
20
21#[derive(Debug, Clone, Default)]
25pub struct CubicalConfig {
26 }
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct CubicalCell {
43 pub coordinates: Vec<usize>,
45 pub dimension: usize,
47}
48
49impl CubicalCell {
50 pub fn new(coordinates: Vec<usize>) -> Self {
52 let dimension = coordinates.iter().filter(|&&c| c % 2 == 1).count();
53 Self {
54 coordinates,
55 dimension,
56 }
57 }
58
59 pub fn sort_key(&self) -> (usize, &[usize]) {
61 (self.dimension, &self.coordinates)
62 }
63}
64
65#[derive(Debug, Clone)]
71pub struct CubicalComplex {
72 pub cells: Vec<(CubicalCell, f64)>,
74 pub spatial_dim: usize,
76}
77
78impl CubicalComplex {
79 pub fn from_image_2d(image: &[Vec<f64>]) -> Result<Self> {
87 if image.is_empty() || image[0].is_empty() {
88 return Err(TransformError::InvalidInput(
89 "Image must be non-empty".to_string(),
90 ));
91 }
92 let rows = image.len();
93 let cols = image[0].len();
94 for row in image.iter() {
95 if row.len() != cols {
96 return Err(TransformError::InvalidInput(
97 "All image rows must have the same length".to_string(),
98 ));
99 }
100 }
101
102 let vertex_value = |r: usize, c: usize| -> f64 {
107 let pr = r.min(rows - 1);
109 let pc = c.min(cols - 1);
110 image[pr][pc]
111 };
112
113 let mut cells: Vec<(CubicalCell, f64)> = Vec::new();
114
115 for r in 0..=rows {
117 for c in 0..=cols {
118 let fv = vertex_value(r, c);
119 cells.push((CubicalCell::new(vec![2 * r, 2 * c]), fv));
120 }
121 }
122
123 for r in 0..=rows {
125 for c in 0..cols {
126 let fv = vertex_value(r, c).max(vertex_value(r, c + 1));
127 cells.push((CubicalCell::new(vec![2 * r, 2 * c + 1]), fv));
128 }
129 }
130
131 for r in 0..rows {
133 for c in 0..=cols {
134 let fv = vertex_value(r, c).max(vertex_value(r + 1, c));
135 cells.push((CubicalCell::new(vec![2 * r + 1, 2 * c]), fv));
136 }
137 }
138
139 for r in 0..rows {
141 for c in 0..cols {
142 let fv = vertex_value(r, c)
143 .max(vertex_value(r, c + 1))
144 .max(vertex_value(r + 1, c))
145 .max(vertex_value(r + 1, c + 1));
146 cells.push((CubicalCell::new(vec![2 * r + 1, 2 * c + 1]), fv));
147 }
148 }
149
150 cells.sort_by(|(ca, fa), (cb, fb)| {
152 fa.partial_cmp(fb)
153 .unwrap_or(std::cmp::Ordering::Equal)
154 .then(ca.dimension.cmp(&cb.dimension))
155 .then(ca.coordinates.cmp(&cb.coordinates))
156 });
157
158 Ok(Self {
159 cells,
160 spatial_dim: 2,
161 })
162 }
163
164 pub fn from_image_3d(image: &[Vec<Vec<f64>>]) -> Result<Self> {
166 if image.is_empty() || image[0].is_empty() || image[0][0].is_empty() {
167 return Err(TransformError::InvalidInput(
168 "3D image must be non-empty".to_string(),
169 ));
170 }
171 let slices = image.len();
172 let rows = image[0].len();
173 let cols = image[0][0].len();
174
175 let vertex_value = |s: usize, r: usize, c: usize| -> f64 {
176 let ps = s.min(slices - 1);
177 let pr = r.min(rows - 1);
178 let pc = c.min(cols - 1);
179 image[ps][pr][pc]
180 };
181
182 let mut cells: Vec<(CubicalCell, f64)> = Vec::new();
183
184 for s in 0..=slices {
186 for r in 0..=rows {
187 for c in 0..=cols {
188 let fv = vertex_value(s, r, c);
189 cells.push((CubicalCell::new(vec![2 * s, 2 * r, 2 * c]), fv));
190 }
191 }
192 }
193
194 for s in 0..slices {
197 for r in 0..=rows {
198 for c in 0..=cols {
199 let fv = vertex_value(s, r, c).max(vertex_value(s + 1, r, c));
200 cells.push((CubicalCell::new(vec![2 * s + 1, 2 * r, 2 * c]), fv));
201 }
202 }
203 }
204 for s in 0..=slices {
206 for r in 0..rows {
207 for c in 0..=cols {
208 let fv = vertex_value(s, r, c).max(vertex_value(s, r + 1, c));
209 cells.push((CubicalCell::new(vec![2 * s, 2 * r + 1, 2 * c]), fv));
210 }
211 }
212 }
213 for s in 0..=slices {
215 for r in 0..=rows {
216 for c in 0..cols {
217 let fv = vertex_value(s, r, c).max(vertex_value(s, r, c + 1));
218 cells.push((CubicalCell::new(vec![2 * s, 2 * r, 2 * c + 1]), fv));
219 }
220 }
221 }
222
223 for s in 0..slices {
226 for r in 0..rows {
227 for c in 0..=cols {
228 let fv = vertex_value(s, r, c)
229 .max(vertex_value(s + 1, r, c))
230 .max(vertex_value(s, r + 1, c))
231 .max(vertex_value(s + 1, r + 1, c));
232 cells.push((CubicalCell::new(vec![2 * s + 1, 2 * r + 1, 2 * c]), fv));
233 }
234 }
235 }
236 for s in 0..slices {
238 for r in 0..=rows {
239 for c in 0..cols {
240 let fv = vertex_value(s, r, c)
241 .max(vertex_value(s + 1, r, c))
242 .max(vertex_value(s, r, c + 1))
243 .max(vertex_value(s + 1, r, c + 1));
244 cells.push((CubicalCell::new(vec![2 * s + 1, 2 * r, 2 * c + 1]), fv));
245 }
246 }
247 }
248 for s in 0..=slices {
250 for r in 0..rows {
251 for c in 0..cols {
252 let fv = vertex_value(s, r, c)
253 .max(vertex_value(s, r + 1, c))
254 .max(vertex_value(s, r, c + 1))
255 .max(vertex_value(s, r + 1, c + 1));
256 cells.push((CubicalCell::new(vec![2 * s, 2 * r + 1, 2 * c + 1]), fv));
257 }
258 }
259 }
260
261 for s in 0..slices {
263 for r in 0..rows {
264 for c in 0..cols {
265 let fv = vertex_value(s, r, c)
266 .max(vertex_value(s + 1, r, c))
267 .max(vertex_value(s, r + 1, c))
268 .max(vertex_value(s, r, c + 1))
269 .max(vertex_value(s + 1, r + 1, c))
270 .max(vertex_value(s + 1, r, c + 1))
271 .max(vertex_value(s, r + 1, c + 1))
272 .max(vertex_value(s + 1, r + 1, c + 1));
273 cells.push((CubicalCell::new(vec![2 * s + 1, 2 * r + 1, 2 * c + 1]), fv));
274 }
275 }
276 }
277
278 cells.sort_by(|(ca, fa), (cb, fb)| {
279 fa.partial_cmp(fb)
280 .unwrap_or(std::cmp::Ordering::Equal)
281 .then(ca.dimension.cmp(&cb.dimension))
282 .then(ca.coordinates.cmp(&cb.coordinates))
283 });
284
285 Ok(Self {
286 cells,
287 spatial_dim: 3,
288 })
289 }
290
291 pub fn boundary(&self, cell: &CubicalCell) -> Vec<CubicalCell> {
296 if cell.dimension == 0 {
297 return Vec::new();
298 }
299 let mut faces = Vec::new();
300 for (i, &coord) in cell.coordinates.iter().enumerate() {
301 if coord % 2 == 1 {
302 let mut coords_low = cell.coordinates.clone();
305 coords_low[i] = coord - 1;
306 faces.push(CubicalCell::new(coords_low));
307 let mut coords_high = cell.coordinates.clone();
309 coords_high[i] = coord + 1;
310 faces.push(CubicalCell::new(coords_high));
311 }
312 }
313 faces
314 }
315
316 pub fn persistence_diagram(&self) -> Vec<(f64, f64, usize)> {
320 let n = self.cells.len();
321
322 let cell_index: HashMap<&CubicalCell, usize> = self
324 .cells
325 .iter()
326 .enumerate()
327 .map(|(i, (c, _))| (c, i))
328 .collect();
329
330 let mut columns: Vec<Vec<usize>> = self
332 .cells
333 .iter()
334 .map(|(cell, _)| {
335 let mut col: Vec<usize> = self
336 .boundary(cell)
337 .iter()
338 .filter_map(|face| cell_index.get(face).copied())
339 .collect();
340 col.sort_unstable();
341 col
342 })
343 .collect();
344
345 let mut pivot_col: HashMap<usize, usize> = HashMap::new();
347 let mut pairs: Vec<(f64, f64, usize)> = Vec::new();
348
349 for j in 0..n {
350 while let Some(&pivot) = columns[j].last() {
351 if let Some(&k) = pivot_col.get(&pivot) {
352 let col_k = columns[k].clone();
353 sym_diff_sorted(&mut columns[j], &col_k);
354 } else {
355 break;
356 }
357 }
358
359 if let Some(&pivot) = columns[j].last() {
360 pivot_col.insert(pivot, j);
361 let birth_idx = pivot;
362 let death_idx = j;
363 let (birth_cell, birth_fv) = &self.cells[birth_idx];
364 let (_, death_fv) = &self.cells[death_idx];
365 let dim = birth_cell.dimension;
366 pairs.push((*birth_fv, *death_fv, dim));
367 }
368 }
369
370 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
371 pairs
372 }
373
374 pub fn cell_count(&self, dim: usize) -> usize {
376 self.cells
377 .iter()
378 .filter(|(c, _)| c.dimension == dim)
379 .count()
380 }
381}
382
383#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn simple_3x3() -> Vec<Vec<f64>> {
390 vec![
391 vec![0.0, 1.0, 0.0],
392 vec![1.0, 1.0, 1.0],
393 vec![0.0, 1.0, 0.0],
394 ]
395 }
396
397 #[test]
398 fn test_cell_counts_3x3() {
399 let img = simple_3x3();
400 let cc = CubicalComplex::from_image_2d(&img).expect("should build");
401 assert_eq!(cc.cell_count(0), 16, "Expected 16 vertices");
407 assert_eq!(cc.cell_count(1), 24, "Expected 24 edges");
408 assert_eq!(cc.cell_count(2), 9, "Expected 9 faces");
409 assert_eq!(cc.cells.len(), 49, "Expected 49 total cells");
410 }
411
412 #[test]
413 fn test_boundary_of_edge() {
414 let edge = CubicalCell::new(vec![0, 1]);
416 assert_eq!(edge.dimension, 1);
417 let img = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
418 let cc = CubicalComplex::from_image_2d(&img).expect("should build");
419 let faces = cc.boundary(&edge);
420 assert_eq!(faces.len(), 2);
421 assert!(faces.iter().all(|f| f.dimension == 0));
422 }
423
424 #[test]
425 fn test_boundary_of_vertex_empty() {
426 let vtx = CubicalCell::new(vec![0, 0]);
427 let img = vec![vec![1.0]];
428 let cc = CubicalComplex::from_image_2d(&img).expect("should build");
429 let faces = cc.boundary(&vtx);
430 assert!(faces.is_empty());
431 }
432
433 #[test]
434 fn test_persistence_diagram_non_trivial() {
435 let img = simple_3x3();
436 let cc = CubicalComplex::from_image_2d(&img).expect("should build");
437 let diag = cc.persistence_diagram();
438 assert!(!diag.is_empty(), "Expected non-empty persistence diagram");
440 for (birth, death, _) in &diag {
441 assert!(birth <= death, "birth={birth} > death={death}");
442 }
443 }
444
445 #[test]
446 fn test_from_image_3d_basic() {
447 let vol = vec![
448 vec![vec![0.0, 1.0], vec![1.0, 0.0]],
449 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
450 ];
451 let cc = CubicalComplex::from_image_3d(&vol).expect("3D build should succeed");
452 assert_eq!(cc.cell_count(0), 27, "Expected 27 vertices");
455 assert!(cc.cell_count(3) > 0, "Expected 3-cells");
456 }
457
458 #[test]
459 fn test_invalid_empty_image() {
460 let result = CubicalComplex::from_image_2d(&[]);
461 assert!(result.is_err());
462 }
463}