Skip to main content

scirs2_ndimage/array4d/
mod.rs

1//! 4D spatiotemporal array support for scirs2-ndimage.
2//!
3//! Provides `Array4D<T>` with shape [T×D×H×W] and operations including
4//! Gaussian filtering, temporal differencing, MIP, connected components 4D,
5//! and region tracking.
6
7use crate::error::NdimageError;
8use std::collections::VecDeque;
9
10// ---------------------------------------------------------------------------
11// Array4D
12// ---------------------------------------------------------------------------
13
14/// A 4-dimensional array with layout [time, depth, height, width].
15#[derive(Debug, Clone)]
16pub struct Array4D<T> {
17    data: Vec<T>,
18    shape: [usize; 4],
19}
20
21impl<T: Clone + Default> Array4D<T> {
22    /// Create a new array filled with `fill`.
23    pub fn new(shape: [usize; 4], fill: T) -> Self {
24        let n = shape[0] * shape[1] * shape[2] * shape[3];
25        Array4D {
26            data: vec![fill; n],
27            shape,
28        }
29    }
30
31    /// Create from flat data vector. Returns error if length mismatches.
32    pub fn from_data(data: Vec<T>, shape: [usize; 4]) -> Result<Self, NdimageError> {
33        let expected = shape[0] * shape[1] * shape[2] * shape[3];
34        if data.len() != expected {
35            return Err(NdimageError::DimensionError(format!(
36                "from_data: data length {} does not match shape {:?} (expected {})",
37                data.len(),
38                shape,
39                expected
40            )));
41        }
42        Ok(Array4D { data, shape })
43    }
44
45    /// Return the shape [T, D, H, W].
46    pub fn shape(&self) -> [usize; 4] {
47        self.shape
48    }
49
50    /// Total number of elements.
51    pub fn n_elements(&self) -> usize {
52        self.shape[0] * self.shape[1] * self.shape[2] * self.shape[3]
53    }
54
55    /// Flat index for (t, d, h, w).
56    pub fn index(&self, t: usize, d: usize, h: usize, w: usize) -> usize {
57        t * (self.shape[1] * self.shape[2] * self.shape[3])
58            + d * (self.shape[2] * self.shape[3])
59            + h * self.shape[3]
60            + w
61    }
62
63    /// Get immutable reference to element, returns None if out of bounds.
64    pub fn get(&self, t: usize, d: usize, h: usize, w: usize) -> Option<&T> {
65        if t < self.shape[0] && d < self.shape[1] && h < self.shape[2] && w < self.shape[3] {
66            let idx = self.index(t, d, h, w);
67            self.data.get(idx)
68        } else {
69            None
70        }
71    }
72
73    /// Get mutable reference to element, returns None if out of bounds.
74    pub fn get_mut(&mut self, t: usize, d: usize, h: usize, w: usize) -> Option<&mut T> {
75        if t < self.shape[0] && d < self.shape[1] && h < self.shape[2] && w < self.shape[3] {
76            let idx = self.index(t, d, h, w);
77            self.data.get_mut(idx)
78        } else {
79            None
80        }
81    }
82
83    /// Set element. Returns error if out of bounds.
84    pub fn set(
85        &mut self,
86        t: usize,
87        d: usize,
88        h: usize,
89        w: usize,
90        value: T,
91    ) -> Result<(), NdimageError> {
92        if t >= self.shape[0] || d >= self.shape[1] || h >= self.shape[2] || w >= self.shape[3] {
93            return Err(NdimageError::InvalidInput(format!(
94                "set: index ({},{},{},{}) out of bounds for shape {:?}",
95                t, d, h, w, self.shape
96            )));
97        }
98        let idx = self.index(t, d, h, w);
99        self.data[idx] = value;
100        Ok(())
101    }
102
103    /// Extract the 3D volume at time t as nested Vec `[D][H][W]`.
104    pub fn slice_time(&self, t: usize) -> Result<Vec<Vec<Vec<T>>>, NdimageError> {
105        if t >= self.shape[0] {
106            return Err(NdimageError::InvalidInput(format!(
107                "slice_time: t={} out of bounds (shape[0]={})",
108                t, self.shape[0]
109            )));
110        }
111        let (nt, nd, nh, nw) = (self.shape[0], self.shape[1], self.shape[2], self.shape[3]);
112        let _ = nt; // suppress unused
113        let mut volume = Vec::with_capacity(nd);
114        for d in 0..nd {
115            let mut plane = Vec::with_capacity(nh);
116            for h in 0..nh {
117                let mut row = Vec::with_capacity(nw);
118                for w in 0..nw {
119                    let idx = self.index(t, d, h, w);
120                    row.push(self.data[idx].clone());
121                }
122                plane.push(row);
123            }
124            volume.push(plane);
125        }
126        Ok(volume)
127    }
128
129    /// Extract the temporal slice at voxel (d, h, w) as Vec length T.
130    pub fn slice_spatial(&self, d: usize, h: usize, w: usize) -> Result<Vec<T>, NdimageError> {
131        if d >= self.shape[1] || h >= self.shape[2] || w >= self.shape[3] {
132            return Err(NdimageError::InvalidInput(format!(
133                "slice_spatial: ({},{},{}) out of bounds for shape {:?}",
134                d, h, w, self.shape
135            )));
136        }
137        let nt = self.shape[0];
138        let mut result = Vec::with_capacity(nt);
139        for t in 0..nt {
140            let idx = self.index(t, d, h, w);
141            result.push(self.data[idx].clone());
142        }
143        Ok(result)
144    }
145
146    /// Immutable access to flat data.
147    pub fn data(&self) -> &[T] {
148        &self.data
149    }
150
151    /// Mutable access to flat data.
152    pub fn data_mut(&mut self) -> &mut Vec<T> {
153        &mut self.data
154    }
155}
156
157// ---------------------------------------------------------------------------
158// 4D label alias
159// ---------------------------------------------------------------------------
160
161/// Label array for 4D connected components (0 = background, ≥1 = labels).
162pub type Label4D = Array4D<usize>;
163
164// ---------------------------------------------------------------------------
165// Gaussian helpers
166// ---------------------------------------------------------------------------
167
168/// Build a 1D Gaussian kernel of given sigma (truncated at 3*sigma, odd length).
169fn gaussian_kernel_1d(sigma: f64) -> Vec<f64> {
170    if sigma <= 0.0 {
171        return vec![1.0];
172    }
173    let half = (3.0 * sigma).ceil() as usize;
174    let len = 2 * half + 1;
175    let mut kernel = Vec::with_capacity(len);
176    let s2 = 2.0 * sigma * sigma;
177    let mut sum = 0.0;
178    for i in 0..len {
179        let x = i as f64 - half as f64;
180        let v = (-x * x / s2).exp();
181        kernel.push(v);
182        sum += v;
183    }
184    for v in &mut kernel {
185        *v /= sum;
186    }
187    kernel
188}
189
190/// Convolve a slice (1D) with kernel using reflect padding.
191fn convolve_1d(data: &[f64], kernel: &[f64]) -> Vec<f64> {
192    let n = data.len();
193    let k = kernel.len();
194    let half = k / 2;
195    let mut out = vec![0.0f64; n];
196    for i in 0..n {
197        let mut val = 0.0;
198        for j in 0..k {
199            let src = i as isize + j as isize - half as isize;
200            // reflect padding
201            let idx = reflect_index(src, n);
202            val += data[idx] * kernel[j];
203        }
204        out[i] = val;
205    }
206    out
207}
208
209fn reflect_index(i: isize, n: usize) -> usize {
210    if n == 0 {
211        return 0;
212    }
213    let n = n as isize;
214    let mut i = i;
215    if i < 0 {
216        i = -i - 1;
217    }
218    if i >= n {
219        i = 2 * n - i - 1;
220    }
221    i.max(0).min(n - 1) as usize
222}
223
224// ---------------------------------------------------------------------------
225// Gaussian filter 4D
226// ---------------------------------------------------------------------------
227
228/// Apply separable Gaussian filtering to a 4D array.
229///
230/// * `sigma_t` — temporal smoothing (along axis 0)
231/// * `sigma_s` — spatial smoothing (along axes 1, 2, 3)
232pub fn gaussian_filter_4d(arr: &Array4D<f64>, sigma_t: f64, sigma_s: f64) -> Array4D<f64> {
233    let [nt, nd, nh, nw] = arr.shape();
234    let kernel_t = gaussian_kernel_1d(sigma_t);
235    let kernel_s = gaussian_kernel_1d(sigma_s);
236
237    // Work on flat data copy
238    let mut buf: Vec<f64> = arr.data().to_vec();
239
240    // Helper: get element
241    let idx = |t: usize, d: usize, h: usize, w: usize| -> usize {
242        t * (nd * nh * nw) + d * (nh * nw) + h * nw + w
243    };
244
245    // Smooth along time axis (axis 0)
246    if kernel_t.len() > 1 {
247        let src = buf.clone();
248        for d in 0..nd {
249            for h in 0..nh {
250                for w in 0..nw {
251                    let slice: Vec<f64> = (0..nt).map(|t| src[idx(t, d, h, w)]).collect();
252                    let smoothed = convolve_1d(&slice, &kernel_t);
253                    for t in 0..nt {
254                        buf[idx(t, d, h, w)] = smoothed[t];
255                    }
256                }
257            }
258        }
259    }
260
261    // Smooth along depth axis (axis 1)
262    if kernel_s.len() > 1 {
263        let src = buf.clone();
264        for t in 0..nt {
265            for h in 0..nh {
266                for w in 0..nw {
267                    let slice: Vec<f64> = (0..nd).map(|d| src[idx(t, d, h, w)]).collect();
268                    let smoothed = convolve_1d(&slice, &kernel_s);
269                    for d in 0..nd {
270                        buf[idx(t, d, h, w)] = smoothed[d];
271                    }
272                }
273            }
274        }
275    }
276
277    // Smooth along height axis (axis 2)
278    if kernel_s.len() > 1 {
279        let src = buf.clone();
280        for t in 0..nt {
281            for d in 0..nd {
282                for w in 0..nw {
283                    let slice: Vec<f64> = (0..nh).map(|h| src[idx(t, d, h, w)]).collect();
284                    let smoothed = convolve_1d(&slice, &kernel_s);
285                    for h in 0..nh {
286                        buf[idx(t, d, h, w)] = smoothed[h];
287                    }
288                }
289            }
290        }
291    }
292
293    // Smooth along width axis (axis 3)
294    if kernel_s.len() > 1 {
295        let src = buf.clone();
296        for t in 0..nt {
297            for d in 0..nd {
298                for h in 0..nh {
299                    let slice: Vec<f64> = (0..nw).map(|w| src[idx(t, d, h, w)]).collect();
300                    let smoothed = convolve_1d(&slice, &kernel_s);
301                    for w in 0..nw {
302                        buf[idx(t, d, h, w)] = smoothed[w];
303                    }
304                }
305            }
306        }
307    }
308
309    Array4D {
310        data: buf,
311        shape: [nt, nd, nh, nw],
312    }
313}
314
315// ---------------------------------------------------------------------------
316// Temporal finite differences
317// ---------------------------------------------------------------------------
318
319/// Compute temporal finite differences: `output[t] = input[t+1] - input[t]`.
320/// Output shape is `[T-1, D, H, W]`.
321pub fn diff_4d_temporal(arr: &Array4D<f64>) -> Array4D<f64> {
322    let [nt, nd, nh, nw] = arr.shape();
323    if nt < 2 {
324        return Array4D::new([0, nd, nh, nw], 0.0);
325    }
326    let out_t = nt - 1;
327    let mut out = Array4D::new([out_t, nd, nh, nw], 0.0);
328    for t in 0..out_t {
329        for d in 0..nd {
330            for h in 0..nh {
331                for w in 0..nw {
332                    let v1 = arr.get(t + 1, d, h, w).copied().unwrap_or(0.0);
333                    let v0 = arr.get(t, d, h, w).copied().unwrap_or(0.0);
334                    let _ = out.set(t, d, h, w, v1 - v0);
335                }
336            }
337        }
338    }
339    out
340}
341
342// ---------------------------------------------------------------------------
343// Maximum Intensity Projection
344// ---------------------------------------------------------------------------
345
346/// Maximum intensity projection along the specified axis.
347///
348/// * axis 0 → T dimension collapsed → output shape [1, D, H, W]
349/// * axis 1 → D dimension collapsed → output shape [T, 1, H, W]
350/// * axis 2 → H dimension collapsed → output shape [T, D, 1, W]
351/// * axis 3 → W dimension collapsed → output shape [T, D, H, 1]
352pub fn max_intensity_projection_4d(
353    arr: &Array4D<f64>,
354    axis: usize,
355) -> Result<Array4D<f64>, NdimageError> {
356    let [nt, nd, nh, nw] = arr.shape();
357    if axis > 3 {
358        return Err(NdimageError::InvalidInput(format!(
359            "max_intensity_projection_4d: axis {} invalid for 4D array",
360            axis
361        )));
362    }
363
364    let out_shape = match axis {
365        0 => [1, nd, nh, nw],
366        1 => [nt, 1, nh, nw],
367        2 => [nt, nd, 1, nw],
368        3 => [nt, nd, nh, 1],
369        _ => unreachable!(),
370    };
371
372    let mut out = Array4D::new(out_shape, f64::NEG_INFINITY);
373
374    for t in 0..nt {
375        for d in 0..nd {
376            for h in 0..nh {
377                for w in 0..nw {
378                    let v = arr.get(t, d, h, w).copied().unwrap_or(f64::NEG_INFINITY);
379                    let (ot, od, oh, ow) = match axis {
380                        0 => (0, d, h, w),
381                        1 => (t, 0, h, w),
382                        2 => (t, d, 0, w),
383                        3 => (t, d, h, 0),
384                        _ => unreachable!(),
385                    };
386                    if let Some(cur) = out.get_mut(ot, od, oh, ow) {
387                        if v > *cur {
388                            *cur = v;
389                        }
390                    }
391                }
392            }
393        }
394    }
395
396    Ok(out)
397}
398
399// ---------------------------------------------------------------------------
400// Connected components 4D
401// ---------------------------------------------------------------------------
402
403/// 6-connected neighbors in 4D (one step along exactly one axis).
404fn neighbors_6_4d(
405    t: usize,
406    d: usize,
407    h: usize,
408    w: usize,
409    shape: [usize; 4],
410) -> Vec<(usize, usize, usize, usize)> {
411    let [nt, nd, nh, nw] = shape;
412    let mut result = Vec::with_capacity(8);
413    if t > 0 {
414        result.push((t - 1, d, h, w));
415    }
416    if t + 1 < nt {
417        result.push((t + 1, d, h, w));
418    }
419    if d > 0 {
420        result.push((t, d - 1, h, w));
421    }
422    if d + 1 < nd {
423        result.push((t, d + 1, h, w));
424    }
425    if h > 0 {
426        result.push((t, d, h - 1, w));
427    }
428    if h + 1 < nh {
429        result.push((t, d, h + 1, w));
430    }
431    if w > 0 {
432        result.push((t, d, h, w - 1));
433    }
434    if w + 1 < nw {
435        result.push((t, d, h, w + 1));
436    }
437    result
438}
439
440/// 26-connected neighbors in 4D spatial part (±1 in D,H,W) and ±1 in T — up to 80 neighbors.
441fn neighbors_26_4d(
442    t: usize,
443    d: usize,
444    h: usize,
445    w: usize,
446    shape: [usize; 4],
447) -> Vec<(usize, usize, usize, usize)> {
448    let [nt, nd, nh, nw] = shape;
449    let mut result = Vec::new();
450    let ti_min = if t == 0 { 0isize } else { -1isize };
451    let ti_max = if t + 1 < nt { 1isize } else { 0isize };
452    let di_min = if d == 0 { 0isize } else { -1isize };
453    let di_max = if d + 1 < nd { 1isize } else { 0isize };
454    let hi_min = if h == 0 { 0isize } else { -1isize };
455    let hi_max = if h + 1 < nh { 1isize } else { 0isize };
456    let wi_min = if w == 0 { 0isize } else { -1isize };
457    let wi_max = if w + 1 < nw { 1isize } else { 0isize };
458
459    for dt in ti_min..=ti_max {
460        for dd in di_min..=di_max {
461            for dh in hi_min..=hi_max {
462                for dw in wi_min..=wi_max {
463                    if dt == 0 && dd == 0 && dh == 0 && dw == 0 {
464                        continue;
465                    }
466                    result.push((
467                        (t as isize + dt) as usize,
468                        (d as isize + dd) as usize,
469                        (h as isize + dh) as usize,
470                        (w as isize + dw) as usize,
471                    ));
472                }
473            }
474        }
475    }
476    result
477}
478
479/// Compute connected components of a binary 4D array via BFS.
480///
481/// * `connectivity_26` — if true, use 26-connected spatial + temporal; else 6-connected.
482///
483/// Returns a `Label4D` where 0 = background, ≥1 = component labels.
484pub fn connected_components_4d(binary: &Array4D<bool>, connectivity_26: bool) -> Label4D {
485    let shape = binary.shape();
486    let [nt, nd, nh, nw] = shape;
487    let mut labels = Label4D::new(shape, 0usize);
488    let mut current_label = 0usize;
489
490    let flat_idx = |t: usize, d: usize, h: usize, w: usize| -> usize {
491        t * (nd * nh * nw) + d * (nh * nw) + h * nw + w
492    };
493
494    let n_total = nt * nd * nh * nw;
495    // Track visited via label array (0 = unvisited bg, or just check label != 0 for fg visited)
496    let mut visited = vec![false; n_total];
497
498    for t in 0..nt {
499        for d in 0..nd {
500            for h in 0..nh {
501                for w in 0..nw {
502                    let fi = flat_idx(t, d, h, w);
503                    let is_fg = binary.get(t, d, h, w).copied().unwrap_or(false);
504                    if !is_fg || visited[fi] {
505                        continue;
506                    }
507                    // BFS from (t, d, h, w)
508                    current_label += 1;
509                    let lbl = current_label;
510                    let mut queue = VecDeque::new();
511                    queue.push_back((t, d, h, w));
512                    visited[fi] = true;
513                    let _ = labels.set(t, d, h, w, lbl);
514
515                    while let Some((ct, cd, ch, cw)) = queue.pop_front() {
516                        let neighbors = if connectivity_26 {
517                            neighbors_26_4d(ct, cd, ch, cw, shape)
518                        } else {
519                            neighbors_6_4d(ct, cd, ch, cw, shape)
520                        };
521                        for (nt2, nd2, nh2, nw2) in neighbors {
522                            let nfi = flat_idx(nt2, nd2, nh2, nw2);
523                            if visited[nfi] {
524                                continue;
525                            }
526                            let nfg = binary.get(nt2, nd2, nh2, nw2).copied().unwrap_or(false);
527                            if nfg {
528                                visited[nfi] = true;
529                                let _ = labels.set(nt2, nd2, nh2, nw2, lbl);
530                                queue.push_back((nt2, nd2, nh2, nw2));
531                            }
532                        }
533                    }
534                }
535            }
536        }
537    }
538
539    labels
540}
541
542// ---------------------------------------------------------------------------
543// Region tracking
544// ---------------------------------------------------------------------------
545
546/// Result of tracking a labeled region across time frames.
547#[derive(Debug, Clone)]
548pub struct TrackletResult {
549    /// Unique tracklet identifier.
550    pub id: usize,
551    /// Frame at which this tracklet begins.
552    pub start_time: usize,
553    /// List of frame indices (time steps) this tracklet spans.
554    pub frames: Vec<usize>,
555    /// Centroid [d, h, w] per frame in `frames`.
556    pub centroid_per_frame: Vec<[f64; 3]>,
557}
558
559/// Compute the centroid [d, h, w] of a labeled region in a single time frame.
560fn region_centroid(labeled: &Label4D, t: usize, label: usize) -> Option<[f64; 3]> {
561    let [_nt, nd, nh, nw] = labeled.shape();
562    let mut sum_d = 0.0f64;
563    let mut sum_h = 0.0f64;
564    let mut sum_w = 0.0f64;
565    let mut count = 0usize;
566    for d in 0..nd {
567        for h in 0..nh {
568            for w in 0..nw {
569                if labeled.get(t, d, h, w).copied().unwrap_or(0) == label {
570                    sum_d += d as f64;
571                    sum_h += h as f64;
572                    sum_w += w as f64;
573                    count += 1;
574                }
575            }
576        }
577    }
578    if count == 0 {
579        None
580    } else {
581        Some([
582            sum_d / count as f64,
583            sum_h / count as f64,
584            sum_w / count as f64,
585        ])
586    }
587}
588
589/// Track labeled regions across time frames by majority overlap.
590///
591/// For each distinct label in frame t, find the label in frame t+1 with maximum
592/// voxel overlap, creating/extending tracklets accordingly.
593pub fn track_regions_4d(labeled: &Label4D) -> Vec<TrackletResult> {
594    let [nt, nd, nh, nw] = labeled.shape();
595
596    // Collect labels present in each frame
597    let mut frame_labels: Vec<std::collections::HashSet<usize>> = Vec::with_capacity(nt);
598    for t in 0..nt {
599        let mut set = std::collections::HashSet::new();
600        for d in 0..nd {
601            for h in 0..nh {
602                for w in 0..nw {
603                    let lbl = labeled.get(t, d, h, w).copied().unwrap_or(0);
604                    if lbl > 0 {
605                        set.insert(lbl);
606                    }
607                }
608            }
609        }
610        frame_labels.push(set);
611    }
612
613    // Map (frame, label) -> tracklet_id
614    let mut assignment: std::collections::HashMap<(usize, usize), usize> =
615        std::collections::HashMap::new();
616    let mut tracklets: Vec<TrackletResult> = Vec::new();
617    let mut next_id = 1usize;
618
619    // Initialize frame 0
620    for &lbl in &frame_labels[0] {
621        let id = next_id;
622        next_id += 1;
623        assignment.insert((0, lbl), id);
624        let centroid = region_centroid(labeled, 0, lbl).unwrap_or([0.0; 3]);
625        tracklets.push(TrackletResult {
626            id,
627            start_time: 0,
628            frames: vec![0],
629            centroid_per_frame: vec![centroid],
630        });
631    }
632
633    // Link across frames
634    for t in 1..nt {
635        // For each label in frame t, compute overlap with labels in frame t-1
636        for &lbl_t in &frame_labels[t] {
637            // Count overlapping voxels with each label in t-1
638            let mut overlap: std::collections::HashMap<usize, usize> =
639                std::collections::HashMap::new();
640            for d in 0..nd {
641                for h in 0..nh {
642                    for w in 0..nw {
643                        let cur = labeled.get(t, d, h, w).copied().unwrap_or(0);
644                        let prev = labeled.get(t - 1, d, h, w).copied().unwrap_or(0);
645                        if cur == lbl_t && prev > 0 {
646                            *overlap.entry(prev).or_insert(0) += 1;
647                        }
648                    }
649                }
650            }
651
652            // Find best matching previous label
653            let best_prev = overlap.iter().max_by_key(|(_, &cnt)| cnt).map(|(&k, _)| k);
654
655            let centroid = region_centroid(labeled, t, lbl_t).unwrap_or([0.0; 3]);
656
657            if let Some(prev_lbl) = best_prev {
658                if let Some(&tid) = assignment.get(&(t - 1, prev_lbl)) {
659                    // Extend existing tracklet
660                    assignment.insert((t, lbl_t), tid);
661                    if let Some(tk) = tracklets.iter_mut().find(|tk| tk.id == tid) {
662                        tk.frames.push(t);
663                        tk.centroid_per_frame.push(centroid);
664                    }
665                    continue;
666                }
667            }
668
669            // No match found — new tracklet
670            let id = next_id;
671            next_id += 1;
672            assignment.insert((t, lbl_t), id);
673            tracklets.push(TrackletResult {
674                id,
675                start_time: t,
676                frames: vec![t],
677                centroid_per_frame: vec![centroid],
678            });
679        }
680    }
681
682    tracklets
683}
684
685// ---------------------------------------------------------------------------
686// Tests
687// ---------------------------------------------------------------------------
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[test]
694    fn test_array4d_create_and_get_set() {
695        let shape = [2, 3, 4, 5];
696        let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
697        assert_eq!(arr.shape(), shape);
698        assert_eq!(arr.n_elements(), 2 * 3 * 4 * 5);
699
700        // Set and get roundtrip
701        arr.set(1, 2, 3, 4, 42.0).expect("set failed");
702        assert_eq!(arr.get(1, 2, 3, 4).copied(), Some(42.0));
703        // Out-of-bounds get returns None
704        assert!(arr.get(2, 0, 0, 0).is_none());
705    }
706
707    #[test]
708    fn test_from_data_shape_mismatch() {
709        let data = vec![1.0f64; 10];
710        let result = Array4D::from_data(data, [2, 3, 4, 5]);
711        assert!(result.is_err());
712    }
713
714    #[test]
715    fn test_slice_time() {
716        let shape = [3, 2, 4, 5];
717        let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
718        arr.set(1, 1, 2, 3, 7.0).expect("set failed");
719        let vol = arr.slice_time(1).expect("slice_time failed");
720        assert_eq!(vol.len(), 2);
721        assert_eq!(vol[1].len(), 4);
722        assert_eq!(vol[1][2].len(), 5);
723        assert!((vol[1][2][3] - 7.0).abs() < 1e-12);
724    }
725
726    #[test]
727    fn test_connected_components_4d_two_cubes() {
728        // Two separate 2×2×2 cubes in a 4×4×4 spatial volume over 2 time steps
729        // Cube 1: t=0, d=0..1, h=0..1, w=0..1
730        // Cube 2: t=0, d=2..3, h=2..3, w=2..3
731        // They are spatially separated by a gap so they should get separate labels.
732        let shape = [1, 4, 4, 4];
733        let mut binary: Array4D<bool> = Array4D::new(shape, false);
734        // Cube 1
735        for d in 0..2 {
736            for h in 0..2 {
737                for w in 0..2 {
738                    binary.set(0, d, h, w, true).expect("set failed");
739                }
740            }
741        }
742        // Cube 2
743        for d in 2..4 {
744            for h in 2..4 {
745                for w in 2..4 {
746                    binary.set(0, d, h, w, true).expect("set failed");
747                }
748            }
749        }
750        let labels = connected_components_4d(&binary, false);
751        // Collect unique non-zero labels
752        let mut unique: std::collections::HashSet<usize> = std::collections::HashSet::new();
753        for v in labels.data().iter() {
754            if *v > 0 {
755                unique.insert(*v);
756            }
757        }
758        assert_eq!(unique.len(), 2, "Expected exactly 2 connected components");
759    }
760
761    #[test]
762    fn test_diff_4d_temporal() {
763        let shape = [3, 2, 2, 2];
764        let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
765        // Set time 0 to 1.0, time 1 to 3.0, time 2 to 6.0
766        for d in 0..2 {
767            for h in 0..2 {
768                for w in 0..2 {
769                    arr.set(0, d, h, w, 1.0).unwrap();
770                    arr.set(1, d, h, w, 3.0).unwrap();
771                    arr.set(2, d, h, w, 6.0).unwrap();
772                }
773            }
774        }
775        let diff = diff_4d_temporal(&arr);
776        assert_eq!(diff.shape()[0], 2);
777        assert!((diff.get(0, 0, 0, 0).copied().unwrap() - 2.0).abs() < 1e-12);
778        assert!((diff.get(1, 0, 0, 0).copied().unwrap() - 3.0).abs() < 1e-12);
779    }
780
781    #[test]
782    fn test_mip_4d() {
783        let shape = [2, 2, 2, 2];
784        let mut arr: Array4D<f64> = Array4D::new(shape, 0.0);
785        arr.set(0, 0, 0, 0, 5.0).unwrap();
786        arr.set(1, 0, 0, 0, 10.0).unwrap();
787        // MIP along axis 0 (collapse time)
788        let mip = max_intensity_projection_4d(&arr, 0).expect("mip failed");
789        assert_eq!(mip.shape()[0], 1);
790        assert!((mip.get(0, 0, 0, 0).copied().unwrap() - 10.0).abs() < 1e-12);
791    }
792
793    #[test]
794    fn test_gaussian_filter_4d_identity_sigma_zero() {
795        let shape = [2, 2, 2, 2];
796        let mut arr: Array4D<f64> = Array4D::new(shape, 1.0);
797        arr.set(0, 0, 0, 0, 5.0).unwrap();
798        // sigma=0 → kernel is [1.0], no change
799        let out = gaussian_filter_4d(&arr, 0.0, 0.0);
800        assert!((out.get(0, 0, 0, 0).copied().unwrap() - 5.0).abs() < 1e-12);
801    }
802
803    #[test]
804    fn test_track_regions_4d() {
805        // Simple 1-region tracklet across 2 time frames
806        let shape = [2, 3, 3, 3];
807        let mut binary: Array4D<bool> = Array4D::new(shape, false);
808        // Same cube at t=0 and t=1
809        for d in 0..2 {
810            for h in 0..2 {
811                for w in 0..2 {
812                    binary.set(0, d, h, w, true).unwrap();
813                    binary.set(1, d, h, w, true).unwrap();
814                }
815            }
816        }
817        let labeled = connected_components_4d(&binary, false);
818        let tracklets = track_regions_4d(&labeled);
819        // Should have at least 1 tracklet spanning both frames
820        assert!(!tracklets.is_empty());
821        let multi_frame = tracklets.iter().any(|tk| tk.frames.len() >= 2);
822        assert!(
823            multi_frame,
824            "Expected at least one tracklet spanning 2 frames"
825        );
826    }
827}