1use crate::error::NdimageError;
8use std::collections::VecDeque;
9
10#[derive(Debug, Clone)]
16pub struct Array4D<T> {
17 data: Vec<T>,
18 shape: [usize; 4],
19}
20
21impl<T: Clone + Default> Array4D<T> {
22 pub fn new(shape: [usize; 4], fill: T) -> Self {
24 let n = shape[0] * shape[1] * shape[2] * shape[3];
25 Array4D {
26 data: vec![fill; n],
27 shape,
28 }
29 }
30
31 pub fn from_data(data: Vec<T>, shape: [usize; 4]) -> Result<Self, NdimageError> {
33 let expected = shape[0] * shape[1] * shape[2] * shape[3];
34 if data.len() != expected {
35 return Err(NdimageError::DimensionError(format!(
36 "from_data: data length {} does not match shape {:?} (expected {})",
37 data.len(),
38 shape,
39 expected
40 )));
41 }
42 Ok(Array4D { data, shape })
43 }
44
45 pub fn shape(&self) -> [usize; 4] {
47 self.shape
48 }
49
50 pub fn n_elements(&self) -> usize {
52 self.shape[0] * self.shape[1] * self.shape[2] * self.shape[3]
53 }
54
55 pub fn index(&self, t: usize, d: usize, h: usize, w: usize) -> usize {
57 t * (self.shape[1] * self.shape[2] * self.shape[3])
58 + d * (self.shape[2] * self.shape[3])
59 + h * self.shape[3]
60 + w
61 }
62
63 pub fn get(&self, t: usize, d: usize, h: usize, w: usize) -> Option<&T> {
65 if t < self.shape[0] && d < self.shape[1] && h < self.shape[2] && w < self.shape[3] {
66 let idx = self.index(t, d, h, w);
67 self.data.get(idx)
68 } else {
69 None
70 }
71 }
72
73 pub fn get_mut(&mut self, t: usize, d: usize, h: usize, w: usize) -> Option<&mut T> {
75 if t < self.shape[0] && d < self.shape[1] && h < self.shape[2] && w < self.shape[3] {
76 let idx = self.index(t, d, h, w);
77 self.data.get_mut(idx)
78 } else {
79 None
80 }
81 }
82
83 pub fn set(
85 &mut self,
86 t: usize,
87 d: usize,
88 h: usize,
89 w: usize,
90 value: T,
91 ) -> Result<(), NdimageError> {
92 if t >= self.shape[0] || d >= self.shape[1] || h >= self.shape[2] || w >= self.shape[3] {
93 return Err(NdimageError::InvalidInput(format!(
94 "set: index ({},{},{},{}) out of bounds for shape {:?}",
95 t, d, h, w, self.shape
96 )));
97 }
98 let idx = self.index(t, d, h, w);
99 self.data[idx] = value;
100 Ok(())
101 }
102
103 pub fn slice_time(&self, t: usize) -> Result<Vec<Vec<Vec<T>>>, NdimageError> {
105 if t >= self.shape[0] {
106 return Err(NdimageError::InvalidInput(format!(
107 "slice_time: t={} out of bounds (shape[0]={})",
108 t, self.shape[0]
109 )));
110 }
111 let (nt, nd, nh, nw) = (self.shape[0], self.shape[1], self.shape[2], self.shape[3]);
112 let _ = nt; let mut volume = Vec::with_capacity(nd);
114 for d in 0..nd {
115 let mut plane = Vec::with_capacity(nh);
116 for h in 0..nh {
117 let mut row = Vec::with_capacity(nw);
118 for w in 0..nw {
119 let idx = self.index(t, d, h, w);
120 row.push(self.data[idx].clone());
121 }
122 plane.push(row);
123 }
124 volume.push(plane);
125 }
126 Ok(volume)
127 }
128
129 pub fn slice_spatial(&self, d: usize, h: usize, w: usize) -> Result<Vec<T>, NdimageError> {
131 if d >= self.shape[1] || h >= self.shape[2] || w >= self.shape[3] {
132 return Err(NdimageError::InvalidInput(format!(
133 "slice_spatial: ({},{},{}) out of bounds for shape {:?}",
134 d, h, w, self.shape
135 )));
136 }
137 let nt = self.shape[0];
138 let mut result = Vec::with_capacity(nt);
139 for t in 0..nt {
140 let idx = self.index(t, d, h, w);
141 result.push(self.data[idx].clone());
142 }
143 Ok(result)
144 }
145
146 pub fn data(&self) -> &[T] {
148 &self.data
149 }
150
151 pub fn data_mut(&mut self) -> &mut Vec<T> {
153 &mut self.data
154 }
155}
156
157pub type Label4D = Array4D<usize>;
163
164fn gaussian_kernel_1d(sigma: f64) -> Vec<f64> {
170 if sigma <= 0.0 {
171 return vec![1.0];
172 }
173 let half = (3.0 * sigma).ceil() as usize;
174 let len = 2 * half + 1;
175 let mut kernel = Vec::with_capacity(len);
176 let s2 = 2.0 * sigma * sigma;
177 let mut sum = 0.0;
178 for i in 0..len {
179 let x = i as f64 - half as f64;
180 let v = (-x * x / s2).exp();
181 kernel.push(v);
182 sum += v;
183 }
184 for v in &mut kernel {
185 *v /= sum;
186 }
187 kernel
188}
189
190fn convolve_1d(data: &[f64], kernel: &[f64]) -> Vec<f64> {
192 let n = data.len();
193 let k = kernel.len();
194 let half = k / 2;
195 let mut out = vec![0.0f64; n];
196 for i in 0..n {
197 let mut val = 0.0;
198 for j in 0..k {
199 let src = i as isize + j as isize - half as isize;
200 let idx = reflect_index(src, n);
202 val += data[idx] * kernel[j];
203 }
204 out[i] = val;
205 }
206 out
207}
208
209fn reflect_index(i: isize, n: usize) -> usize {
210 if n == 0 {
211 return 0;
212 }
213 let n = n as isize;
214 let mut i = i;
215 if i < 0 {
216 i = -i - 1;
217 }
218 if i >= n {
219 i = 2 * n - i - 1;
220 }
221 i.max(0).min(n - 1) as usize
222}
223
224pub fn gaussian_filter_4d(arr: &Array4D<f64>, sigma_t: f64, sigma_s: f64) -> Array4D<f64> {
233 let [nt, nd, nh, nw] = arr.shape();
234 let kernel_t = gaussian_kernel_1d(sigma_t);
235 let kernel_s = gaussian_kernel_1d(sigma_s);
236
237 let mut buf: Vec<f64> = arr.data().to_vec();
239
240 let idx = |t: usize, d: usize, h: usize, w: usize| -> usize {
242 t * (nd * nh * nw) + d * (nh * nw) + h * nw + w
243 };
244
245 if kernel_t.len() > 1 {
247 let src = buf.clone();
248 for d in 0..nd {
249 for h in 0..nh {
250 for w in 0..nw {
251 let slice: Vec<f64> = (0..nt).map(|t| src[idx(t, d, h, w)]).collect();
252 let smoothed = convolve_1d(&slice, &kernel_t);
253 for t in 0..nt {
254 buf[idx(t, d, h, w)] = smoothed[t];
255 }
256 }
257 }
258 }
259 }
260
261 if kernel_s.len() > 1 {
263 let src = buf.clone();
264 for t in 0..nt {
265 for h in 0..nh {
266 for w in 0..nw {
267 let slice: Vec<f64> = (0..nd).map(|d| src[idx(t, d, h, w)]).collect();
268 let smoothed = convolve_1d(&slice, &kernel_s);
269 for d in 0..nd {
270 buf[idx(t, d, h, w)] = smoothed[d];
271 }
272 }
273 }
274 }
275 }
276
277 if kernel_s.len() > 1 {
279 let src = buf.clone();
280 for t in 0..nt {
281 for d in 0..nd {
282 for w in 0..nw {
283 let slice: Vec<f64> = (0..nh).map(|h| src[idx(t, d, h, w)]).collect();
284 let smoothed = convolve_1d(&slice, &kernel_s);
285 for h in 0..nh {
286 buf[idx(t, d, h, w)] = smoothed[h];
287 }
288 }
289 }
290 }
291 }
292
293 if kernel_s.len() > 1 {
295 let src = buf.clone();
296 for t in 0..nt {
297 for d in 0..nd {
298 for h in 0..nh {
299 let slice: Vec<f64> = (0..nw).map(|w| src[idx(t, d, h, w)]).collect();
300 let smoothed = convolve_1d(&slice, &kernel_s);
301 for w in 0..nw {
302 buf[idx(t, d, h, w)] = smoothed[w];
303 }
304 }
305 }
306 }
307 }
308
309 Array4D {
310 data: buf,
311 shape: [nt, nd, nh, nw],
312 }
313}
314
315pub fn diff_4d_temporal(arr: &Array4D<f64>) -> Array4D<f64> {
322 let [nt, nd, nh, nw] = arr.shape();
323 if nt < 2 {
324 return Array4D::new([0, nd, nh, nw], 0.0);
325 }
326 let out_t = nt - 1;
327 let mut out = Array4D::new([out_t, nd, nh, nw], 0.0);
328 for t in 0..out_t {
329 for d in 0..nd {
330 for h in 0..nh {
331 for w in 0..nw {
332 let v1 = arr.get(t + 1, d, h, w).copied().unwrap_or(0.0);
333 let v0 = arr.get(t, d, h, w).copied().unwrap_or(0.0);
334 let _ = out.set(t, d, h, w, v1 - v0);
335 }
336 }
337 }
338 }
339 out
340}
341
342pub fn max_intensity_projection_4d(
353 arr: &Array4D<f64>,
354 axis: usize,
355) -> Result<Array4D<f64>, NdimageError> {
356 let [nt, nd, nh, nw] = arr.shape();
357 if axis > 3 {
358 return Err(NdimageError::InvalidInput(format!(
359 "max_intensity_projection_4d: axis {} invalid for 4D array",
360 axis
361 )));
362 }
363
364 let out_shape = match axis {
365 0 => [1, nd, nh, nw],
366 1 => [nt, 1, nh, nw],
367 2 => [nt, nd, 1, nw],
368 3 => [nt, nd, nh, 1],
369 _ => unreachable!(),
370 };
371
372 let mut out = Array4D::new(out_shape, f64::NEG_INFINITY);
373
374 for t in 0..nt {
375 for d in 0..nd {
376 for h in 0..nh {
377 for w in 0..nw {
378 let v = arr.get(t, d, h, w).copied().unwrap_or(f64::NEG_INFINITY);
379 let (ot, od, oh, ow) = match axis {
380 0 => (0, d, h, w),
381 1 => (t, 0, h, w),
382 2 => (t, d, 0, w),
383 3 => (t, d, h, 0),
384 _ => unreachable!(),
385 };
386 if let Some(cur) = out.get_mut(ot, od, oh, ow) {
387 if v > *cur {
388 *cur = v;
389 }
390 }
391 }
392 }
393 }
394 }
395
396 Ok(out)
397}
398
399fn neighbors_6_4d(
405 t: usize,
406 d: usize,
407 h: usize,
408 w: usize,
409 shape: [usize; 4],
410) -> Vec<(usize, usize, usize, usize)> {
411 let [nt, nd, nh, nw] = shape;
412 let mut result = Vec::with_capacity(8);
413 if t > 0 {
414 result.push((t - 1, d, h, w));
415 }
416 if t + 1 < nt {
417 result.push((t + 1, d, h, w));
418 }
419 if d > 0 {
420 result.push((t, d - 1, h, w));
421 }
422 if d + 1 < nd {
423 result.push((t, d + 1, h, w));
424 }
425 if h > 0 {
426 result.push((t, d, h - 1, w));
427 }
428 if h + 1 < nh {
429 result.push((t, d, h + 1, w));
430 }
431 if w > 0 {
432 result.push((t, d, h, w - 1));
433 }
434 if w + 1 < nw {
435 result.push((t, d, h, w + 1));
436 }
437 result
438}
439
440fn neighbors_26_4d(
442 t: usize,
443 d: usize,
444 h: usize,
445 w: usize,
446 shape: [usize; 4],
447) -> Vec<(usize, usize, usize, usize)> {
448 let [nt, nd, nh, nw] = shape;
449 let mut result = Vec::new();
450 let ti_min = if t == 0 { 0isize } else { -1isize };
451 let ti_max = if t + 1 < nt { 1isize } else { 0isize };
452 let di_min = if d == 0 { 0isize } else { -1isize };
453 let di_max = if d + 1 < nd { 1isize } else { 0isize };
454 let hi_min = if h == 0 { 0isize } else { -1isize };
455 let hi_max = if h + 1 < nh { 1isize } else { 0isize };
456 let wi_min = if w == 0 { 0isize } else { -1isize };
457 let wi_max = if w + 1 < nw { 1isize } else { 0isize };
458
459 for dt in ti_min..=ti_max {
460 for dd in di_min..=di_max {
461 for dh in hi_min..=hi_max {
462 for dw in wi_min..=wi_max {
463 if dt == 0 && dd == 0 && dh == 0 && dw == 0 {
464 continue;
465 }
466 result.push((
467 (t as isize + dt) as usize,
468 (d as isize + dd) as usize,
469 (h as isize + dh) as usize,
470 (w as isize + dw) as usize,
471 ));
472 }
473 }
474 }
475 }
476 result
477}
478
479pub fn connected_components_4d(binary: &Array4D<bool>, connectivity_26: bool) -> Label4D {
485 let shape = binary.shape();
486 let [nt, nd, nh, nw] = shape;
487 let mut labels = Label4D::new(shape, 0usize);
488 let mut current_label = 0usize;
489
490 let flat_idx = |t: usize, d: usize, h: usize, w: usize| -> usize {
491 t * (nd * nh * nw) + d * (nh * nw) + h * nw + w
492 };
493
494 let n_total = nt * nd * nh * nw;
495 let mut visited = vec![false; n_total];
497
498 for t in 0..nt {
499 for d in 0..nd {
500 for h in 0..nh {
501 for w in 0..nw {
502 let fi = flat_idx(t, d, h, w);
503 let is_fg = binary.get(t, d, h, w).copied().unwrap_or(false);
504 if !is_fg || visited[fi] {
505 continue;
506 }
507 current_label += 1;
509 let lbl = current_label;
510 let mut queue = VecDeque::new();
511 queue.push_back((t, d, h, w));
512 visited[fi] = true;
513 let _ = labels.set(t, d, h, w, lbl);
514
515 while let Some((ct, cd, ch, cw)) = queue.pop_front() {
516 let neighbors = if connectivity_26 {
517 neighbors_26_4d(ct, cd, ch, cw, shape)
518 } else {
519 neighbors_6_4d(ct, cd, ch, cw, shape)
520 };
521 for (nt2, nd2, nh2, nw2) in neighbors {
522 let nfi = flat_idx(nt2, nd2, nh2, nw2);
523 if visited[nfi] {
524 continue;
525 }
526 let nfg = binary.get(nt2, nd2, nh2, nw2).copied().unwrap_or(false);
527 if nfg {
528 visited[nfi] = true;
529 let _ = labels.set(nt2, nd2, nh2, nw2, lbl);
530 queue.push_back((nt2, nd2, nh2, nw2));
531 }
532 }
533 }
534 }
535 }
536 }
537 }
538
539 labels
540}
541
542#[derive(Debug, Clone)]
548pub struct TrackletResult {
549 pub id: usize,
551 pub start_time: usize,
553 pub frames: Vec<usize>,
555 pub centroid_per_frame: Vec<[f64; 3]>,
557}
558
559fn region_centroid(labeled: &Label4D, t: usize, label: usize) -> Option<[f64; 3]> {
561 let [_nt, nd, nh, nw] = labeled.shape();
562 let mut sum_d = 0.0f64;
563 let mut sum_h = 0.0f64;
564 let mut sum_w = 0.0f64;
565 let mut count = 0usize;
566 for d in 0..nd {
567 for h in 0..nh {
568 for w in 0..nw {
569 if labeled.get(t, d, h, w).copied().unwrap_or(0) == label {
570 sum_d += d as f64;
571 sum_h += h as f64;
572 sum_w += w as f64;
573 count += 1;
574 }
575 }
576 }
577 }
578 if count == 0 {
579 None
580 } else {
581 Some([
582 sum_d / count as f64,
583 sum_h / count as f64,
584 sum_w / count as f64,
585 ])
586 }
587}
588
589pub fn track_regions_4d(labeled: &Label4D) -> Vec<TrackletResult> {
594 let [nt, nd, nh, nw] = labeled.shape();
595
596 let mut frame_labels: Vec<std::collections::HashSet<usize>> = Vec::with_capacity(nt);
598 for t in 0..nt {
599 let mut set = std::collections::HashSet::new();
600 for d in 0..nd {
601 for h in 0..nh {
602 for w in 0..nw {
603 let lbl = labeled.get(t, d, h, w).copied().unwrap_or(0);
604 if lbl > 0 {
605 set.insert(lbl);
606 }
607 }
608 }
609 }
610 frame_labels.push(set);
611 }
612
613 let mut assignment: std::collections::HashMap<(usize, usize), usize> =
615 std::collections::HashMap::new();
616 let mut tracklets: Vec<TrackletResult> = Vec::new();
617 let mut next_id = 1usize;
618
619 for &lbl in &frame_labels[0] {
621 let id = next_id;
622 next_id += 1;
623 assignment.insert((0, lbl), id);
624 let centroid = region_centroid(labeled, 0, lbl).unwrap_or([0.0; 3]);
625 tracklets.push(TrackletResult {
626 id,
627 start_time: 0,
628 frames: vec![0],
629 centroid_per_frame: vec![centroid],
630 });
631 }
632
633 for t in 1..nt {
635 for &lbl_t in &frame_labels[t] {
637 let mut overlap: std::collections::HashMap<usize, usize> =
639 std::collections::HashMap::new();
640 for d in 0..nd {
641 for h in 0..nh {
642 for w in 0..nw {
643 let cur = labeled.get(t, d, h, w).copied().unwrap_or(0);
644 let prev = labeled.get(t - 1, d, h, w).copied().unwrap_or(0);
645 if cur == lbl_t && prev > 0 {
646 *overlap.entry(prev).or_insert(0) += 1;
647 }
648 }
649 }
650 }
651
652 let best_prev = overlap.iter().max_by_key(|(_, &cnt)| cnt).map(|(&k, _)| k);
654
655 let centroid = region_centroid(labeled, t, lbl_t).unwrap_or([0.0; 3]);
656
657 if let Some(prev_lbl) = best_prev {
658 if let Some(&tid) = assignment.get(&(t - 1, prev_lbl)) {
659 assignment.insert((t, lbl_t), tid);
661 if let Some(tk) = tracklets.iter_mut().find(|tk| tk.id == tid) {
662 tk.frames.push(t);
663 tk.centroid_per_frame.push(centroid);
664 }
665 continue;
666 }
667 }
668
669 let id = next_id;
671 next_id += 1;
672 assignment.insert((t, lbl_t), id);
673 tracklets.push(TrackletResult {
674 id,
675 start_time: t,
676 frames: vec![t],
677 centroid_per_frame: vec![centroid],
678 });
679 }
680 }
681
682 tracklets
683}
684
685#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_array4d_create_and_get_set() {
695 let shape = [2, 3, 4, 5];
696 let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
697 assert_eq!(arr.shape(), shape);
698 assert_eq!(arr.n_elements(), 2 * 3 * 4 * 5);
699
700 arr.set(1, 2, 3, 4, 42.0).expect("set failed");
702 assert_eq!(arr.get(1, 2, 3, 4).copied(), Some(42.0));
703 assert!(arr.get(2, 0, 0, 0).is_none());
705 }
706
707 #[test]
708 fn test_from_data_shape_mismatch() {
709 let data = vec![1.0f64; 10];
710 let result = Array4D::from_data(data, [2, 3, 4, 5]);
711 assert!(result.is_err());
712 }
713
714 #[test]
715 fn test_slice_time() {
716 let shape = [3, 2, 4, 5];
717 let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
718 arr.set(1, 1, 2, 3, 7.0).expect("set failed");
719 let vol = arr.slice_time(1).expect("slice_time failed");
720 assert_eq!(vol.len(), 2);
721 assert_eq!(vol[1].len(), 4);
722 assert_eq!(vol[1][2].len(), 5);
723 assert!((vol[1][2][3] - 7.0).abs() < 1e-12);
724 }
725
726 #[test]
727 fn test_connected_components_4d_two_cubes() {
728 let shape = [1, 4, 4, 4];
733 let mut binary: Array4D<bool> = Array4D::new(shape, false);
734 for d in 0..2 {
736 for h in 0..2 {
737 for w in 0..2 {
738 binary.set(0, d, h, w, true).expect("set failed");
739 }
740 }
741 }
742 for d in 2..4 {
744 for h in 2..4 {
745 for w in 2..4 {
746 binary.set(0, d, h, w, true).expect("set failed");
747 }
748 }
749 }
750 let labels = connected_components_4d(&binary, false);
751 let mut unique: std::collections::HashSet<usize> = std::collections::HashSet::new();
753 for v in labels.data().iter() {
754 if *v > 0 {
755 unique.insert(*v);
756 }
757 }
758 assert_eq!(unique.len(), 2, "Expected exactly 2 connected components");
759 }
760
761 #[test]
762 fn test_diff_4d_temporal() {
763 let shape = [3, 2, 2, 2];
764 let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
765 for d in 0..2 {
767 for h in 0..2 {
768 for w in 0..2 {
769 arr.set(0, d, h, w, 1.0).unwrap();
770 arr.set(1, d, h, w, 3.0).unwrap();
771 arr.set(2, d, h, w, 6.0).unwrap();
772 }
773 }
774 }
775 let diff = diff_4d_temporal(&arr);
776 assert_eq!(diff.shape()[0], 2);
777 assert!((diff.get(0, 0, 0, 0).copied().unwrap() - 2.0).abs() < 1e-12);
778 assert!((diff.get(1, 0, 0, 0).copied().unwrap() - 3.0).abs() < 1e-12);
779 }
780
781 #[test]
782 fn test_mip_4d() {
783 let shape = [2, 2, 2, 2];
784 let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
785 arr.set(0, 0, 0, 0, 5.0).unwrap();
786 arr.set(1, 0, 0, 0, 10.0).unwrap();
787 let mip = max_intensity_projection_4d(&arr, 0).expect("mip failed");
789 assert_eq!(mip.shape()[0], 1);
790 assert!((mip.get(0, 0, 0, 0).copied().unwrap() - 10.0).abs() < 1e-12);
791 }
792
793 #[test]
794 fn test_gaussian_filter_4d_identity_sigma_zero() {
795 let shape = [2, 2, 2, 2];
796 let mut arr: Array4D<f64> = Array4D::new(shape, 1.0);
797 arr.set(0, 0, 0, 0, 5.0).unwrap();
798 let out = gaussian_filter_4d(&arr, 0.0, 0.0);
800 assert!((out.get(0, 0, 0, 0).copied().unwrap() - 5.0).abs() < 1e-12);
801 }
802
803 #[test]
804 fn test_track_regions_4d() {
805 let shape = [2, 3, 3, 3];
807 let mut binary: Array4D<bool> = Array4D::new(shape, false);
808 for d in 0..2 {
810 for h in 0..2 {
811 for w in 0..2 {
812 binary.set(0, d, h, w, true).unwrap();
813 binary.set(1, d, h, w, true).unwrap();
814 }
815 }
816 }
817 let labeled = connected_components_4d(&binary, false);
818 let tracklets = track_regions_4d(&labeled);
819 assert!(!tracklets.is_empty());
821 let multi_frame = tracklets.iter().any(|tk| tk.frames.len() >= 2);
822 assert!(
823 multi_frame,
824 "Expected at least one tracklet spanning 2 frames"
825 );
826 }
827}