Skip to main content

scirs2_ndimage/segmentation/
atlas.rs

1//! Atlas-based segmentation algorithms.
2//!
3//! Provides label fusion methods for multi-atlas segmentation:
4//!
5//! - [`AtlasSegmentation`]: orchestrates multi-atlas label fusion pipelines
6//! - [`MajorityVoting`]: simple voxel-wise majority voting over atlas labels
7//! - [`STAPLE`]: Simultaneous Truth and Performance Level Estimation
8//! - [`JointLabelFusion`]: locally weighted voting using patch similarity (Wang et al. 2013)
9//!
10//! # References
11//!
12//! - Artaechevarria et al. (2009), "Combination Strategies in Multi-Atlas Image Segmentation",
13//!   IEEE TMI 28(8):1266-1277.
14//! - Warfield et al. (2004), "Simultaneous Truth and Performance Level Estimation (STAPLE)",
15//!   IEEE TMI 23(7):903-921.
16//! - Wang et al. (2013), "Multi-Atlas Segmentation with Joint Label Fusion", IEEE TPAMI 35(3).
17
18use std::collections::HashMap;
19
20use scirs2_core::ndarray::Array3;
21
22use crate::error::{NdimageError, NdimageResult};
23
24// ─── Shared helpers ───────────────────────────────────────────────────────────
25
26/// Validate that all atlas label volumes have the same shape.
27fn check_shapes(labels: &[Array3<u32>]) -> NdimageResult<(usize, usize, usize)> {
28    if labels.is_empty() {
29        return Err(NdimageError::InvalidInput(
30            "Atlas segmentation: must provide at least one atlas".to_string(),
31        ));
32    }
33    let s = labels[0].shape();
34    let shape = (s[0], s[1], s[2]);
35    for (i, lab) in labels.iter().enumerate().skip(1) {
36        if lab.shape() != labels[0].shape() {
37            return Err(NdimageError::DimensionError(format!(
38                "Atlas segmentation: atlas {} has shape {:?}, expected {:?}",
39                i,
40                lab.shape(),
41                labels[0].shape()
42            )));
43        }
44    }
45    Ok(shape)
46}
47
48// ─── MajorityVoting ───────────────────────────────────────────────────────────
49
50/// Voxel-wise majority voting over multiple atlas label maps.
51///
52/// At each voxel the label that appears most frequently among the atlases is
53/// selected.  Ties are broken by choosing the smallest label index.
54pub struct MajorityVoting;
55
56impl MajorityVoting {
57    /// Fuse `labels` by majority voting.
58    ///
59    /// Each element of `labels` is a 3-D array of label values with shape
60    /// `(nz, ny, nx)`.  Returns a single label volume of the same shape.
61    pub fn fuse(labels: &[Array3<u32>]) -> NdimageResult<Array3<u32>> {
62        let (nz, ny, nx) = check_shapes(labels)?;
63        let n_atlases = labels.len();
64        let mut result = Array3::<u32>::zeros((nz, ny, nx));
65
66        for iz in 0..nz {
67            for iy in 0..ny {
68                for ix in 0..nx {
69                    let mut counts: HashMap<u32, usize> = HashMap::new();
70                    for a in 0..n_atlases {
71                        let lv = labels[a][[iz, iy, ix]];
72                        *counts.entry(lv).or_insert(0) += 1;
73                    }
74                    // Select label with highest count (ties: smallest label wins)
75                    let winner = counts
76                        .iter()
77                        .max_by(|a, b| a.1.cmp(b.1).then_with(|| b.0.cmp(a.0)))
78                        .map(|(&lv, _)| lv)
79                        .unwrap_or(0);
80                    result[[iz, iy, ix]] = winner;
81                }
82            }
83        }
84        Ok(result)
85    }
86
87    /// Return per-voxel confidence (fraction of atlases agreeing with the
88    /// majority label).
89    pub fn confidence(labels: &[Array3<u32>]) -> NdimageResult<Array3<f64>> {
90        let fused = Self::fuse(labels)?;
91        let (nz, ny, nx) = check_shapes(labels)?;
92        let n_atlases = labels.len() as f64;
93        let mut conf = Array3::<f64>::zeros((nz, ny, nx));
94        for iz in 0..nz {
95            for iy in 0..ny {
96                for ix in 0..nx {
97                    let winner = fused[[iz, iy, ix]];
98                    let agree = labels.iter().filter(|l| l[[iz, iy, ix]] == winner).count();
99                    conf[[iz, iy, ix]] = agree as f64 / n_atlases;
100                }
101            }
102        }
103        Ok(conf)
104    }
105}
106
107// ─── STAPLE ───────────────────────────────────────────────────────────────────
108
109/// STAPLE (Simultaneous Truth and Performance Level Estimation) algorithm.
110///
111/// Estimates the "ground truth" probabilistic segmentation and each rater's
112/// sensitivity/specificity via expectation-maximisation.
113///
114/// This implementation supports binary (foreground/background) segmentation.
115/// Multi-label inputs are binarised by treating label > 0 as foreground.
116#[derive(Debug, Clone)]
117pub struct StapleConfig {
118    /// Maximum number of EM iterations (default 20).
119    pub max_iterations: usize,
120    /// Convergence threshold on the max absolute parameter change (default 1e-5).
121    pub convergence_threshold: f64,
122    /// Initial sensitivity estimate for each rater (default 0.99).
123    pub init_sensitivity: f64,
124    /// Initial specificity estimate for each rater (default 0.99).
125    pub init_specificity: f64,
126}
127
128impl Default for StapleConfig {
129    fn default() -> Self {
130        Self {
131            max_iterations: 20,
132            convergence_threshold: 1e-5,
133            init_sensitivity: 0.99,
134            init_specificity: 0.99,
135        }
136    }
137}
138
139/// Per-rater performance parameters estimated by STAPLE.
140#[derive(Debug, Clone)]
141pub struct RaterPerformance {
142    /// Sensitivity (true positive rate) for this rater.
143    pub sensitivity: f64,
144    /// Specificity (true negative rate) for this rater.
145    pub specificity: f64,
146}
147
148/// Result of a STAPLE estimation run.
149#[derive(Debug, Clone)]
150pub struct StapleResult {
151    /// Probabilistic foreground segmentation `W[z,y,x] ∈ [0, 1]`.
152    pub probability: Array3<f64>,
153    /// Hard binary label map (W > 0.5 → 1, else 0).
154    pub label: Array3<u32>,
155    /// Estimated performance parameters per rater.
156    pub performance: Vec<RaterPerformance>,
157    /// Number of EM iterations performed.
158    pub iterations: usize,
159    /// Whether the EM converged.
160    pub converged: bool,
161}
162
163/// STAPLE algorithm for binary segmentation quality estimation.
164pub struct STAPLE {
165    config: StapleConfig,
166}
167
168impl STAPLE {
169    /// Create a new STAPLE estimator with default configuration.
170    pub fn new() -> Self {
171        Self {
172            config: StapleConfig::default(),
173        }
174    }
175
176    /// Create with custom configuration.
177    pub fn with_config(config: StapleConfig) -> Self {
178        Self { config }
179    }
180
181    /// Run the STAPLE EM algorithm on a set of binary atlas segmentations.
182    ///
183    /// `labels` should contain binary or multi-label volumes; values > 0 are
184    /// treated as foreground.  All volumes must have the same shape.
185    pub fn estimate(&self, labels: &[Array3<u32>]) -> NdimageResult<StapleResult> {
186        let (nz, ny, nx) = check_shapes(labels)?;
187        let n = nz * ny * nx;
188        let r = labels.len();
189
190        // Flatten observations: d[rater][voxel] ∈ {0, 1}
191        let d: Vec<Vec<u8>> = labels
192            .iter()
193            .map(|l| l.iter().map(|&v| if v > 0 { 1u8 } else { 0u8 }).collect())
194            .collect();
195
196        // Initial performance parameters
197        let mut p: Vec<f64> = vec![self.config.init_sensitivity; r]; // sensitivity
198        let mut q: Vec<f64> = vec![self.config.init_specificity; r]; // specificity
199
200        // Prior probability of foreground (uniform 0.5)
201        let prior_fg = 0.5_f64;
202
203        // Initial W: fraction of raters labelling each voxel as foreground
204        let mut w: Vec<f64> = (0..n)
205            .map(|i| d.iter().map(|rater| rater[i] as f64).sum::<f64>() / r as f64)
206            .collect();
207
208        let mut converged = false;
209        let mut n_iter = 0;
210
211        for _iter in 0..self.config.max_iterations {
212            n_iter += 1;
213
214            // M-step: update performance parameters
215            let sum_w: f64 = w.iter().sum();
216            let sum_w0: f64 = w.iter().map(|&wi| 1.0 - wi).sum();
217
218            let mut new_p = vec![0.0_f64; r];
219            let mut new_q = vec![0.0_f64; r];
220
221            for j in 0..r {
222                // Sensitivity: TP / (TP + FN) — foreground agreement
223                let tp: f64 = (0..n).map(|i| d[j][i] as f64 * w[i]).sum();
224                new_p[j] = (tp + 1e-10) / (sum_w + 1e-10);
225
226                // Specificity: TN / (TN + FP)
227                let tn: f64 = (0..n).map(|i| (1.0 - d[j][i] as f64) * (1.0 - w[i])).sum();
228                new_q[j] = (tn + 1e-10) / (sum_w0 + 1e-10);
229
230                // Clamp to valid probability range
231                new_p[j] = new_p[j].clamp(1e-6, 1.0 - 1e-6);
232                new_q[j] = new_q[j].clamp(1e-6, 1.0 - 1e-6);
233            }
234
235            // E-step: update W
236            let mut max_change = 0.0_f64;
237            let mut new_w = vec![0.0_f64; n];
238            for i in 0..n {
239                // Log likelihood of observing d[*][i] given W=1 and W=0
240                let mut ll1 = prior_fg.ln();
241                let mut ll0 = (1.0 - prior_fg).ln();
242                for j in 0..r {
243                    if d[j][i] == 1 {
244                        ll1 += new_p[j].ln();
245                        ll0 += (1.0 - new_q[j]).ln();
246                    } else {
247                        ll1 += (1.0 - new_p[j]).ln();
248                        ll0 += new_q[j].ln();
249                    }
250                }
251                let max_ll = ll1.max(ll0);
252                let p1 = (ll1 - max_ll).exp();
253                let p0 = (ll0 - max_ll).exp();
254                new_w[i] = p1 / (p1 + p0 + 1e-10);
255                max_change = max_change.max((new_w[i] - w[i]).abs());
256            }
257
258            // Check convergence
259            let param_change = (0..r)
260                .map(|j| (new_p[j] - p[j]).abs().max((new_q[j] - q[j]).abs()))
261                .fold(0.0_f64, f64::max);
262
263            p = new_p;
264            q = new_q;
265            w = new_w;
266
267            if param_change < self.config.convergence_threshold
268                && max_change < self.config.convergence_threshold
269            {
270                converged = true;
271                break;
272            }
273        }
274
275        // Build output arrays
276        let mut probability = Array3::<f64>::zeros((nz, ny, nx));
277        let mut label = Array3::<u32>::zeros((nz, ny, nx));
278        for iz in 0..nz {
279            for iy in 0..ny {
280                for ix in 0..nx {
281                    let idx = iz * ny * nx + iy * nx + ix;
282                    let wi = w[idx];
283                    probability[[iz, iy, ix]] = wi;
284                    label[[iz, iy, ix]] = if wi > 0.5 { 1 } else { 0 };
285                }
286            }
287        }
288
289        let performance: Vec<RaterPerformance> = (0..r)
290            .map(|j| RaterPerformance {
291                sensitivity: p[j],
292                specificity: q[j],
293            })
294            .collect();
295
296        Ok(StapleResult {
297            probability,
298            label,
299            performance,
300            iterations: n_iter,
301            converged,
302        })
303    }
304}
305
306// ─── JointLabelFusion ─────────────────────────────────────────────────────────
307
308/// Configuration for Joint Label Fusion (JLF).
309#[derive(Debug, Clone)]
310pub struct JlfConfig {
311    /// Half-width of the patch (neighbourhood) used for similarity weighting.
312    /// Patch is `(2*patch_radius+1)^dim` voxels.
313    pub patch_radius: usize,
314    /// Alpha parameter controlling the steepness of the similarity weight decay
315    /// (default 0.1).  Larger → more uniform weighting.
316    pub alpha: f64,
317    /// Beta parameter for patch normalisation (default 2.0).
318    pub beta: f64,
319}
320
321impl Default for JlfConfig {
322    fn default() -> Self {
323        Self {
324            patch_radius: 2,
325            alpha: 0.1,
326            beta: 2.0,
327        }
328    }
329}
330
331/// Result of joint label fusion.
332#[derive(Debug, Clone)]
333pub struct JlfResult {
334    /// Final fused label volume.
335    pub label: Array3<u32>,
336    /// Per-voxel weight normalisation factor.
337    pub weight_sum: Array3<f64>,
338}
339
340/// Joint Label Fusion (Wang et al., 2013) for 3D label volumes.
341///
342/// Each voxel in the target image is labelled by a weighted majority vote over
343/// the atlas labels.  The weight for each atlas at each location is derived from
344/// the normalised cross-correlation between patches in the target image and
345/// the corresponding atlas image.
346pub struct JointLabelFusion {
347    config: JlfConfig,
348}
349
350impl JointLabelFusion {
351    /// Create with default configuration.
352    pub fn new() -> Self {
353        Self {
354            config: JlfConfig::default(),
355        }
356    }
357
358    /// Create with custom configuration.
359    pub fn with_config(config: JlfConfig) -> Self {
360        Self { config }
361    }
362
363    /// Perform joint label fusion.
364    ///
365    /// # Arguments
366    ///
367    /// * `target` – intensity volume of the target subject (`f64`, shape `[nz,ny,nx]`).
368    /// * `atlas_images` – intensity volumes of the registered atlas subjects.
369    /// * `atlas_labels` – corresponding label volumes.
370    ///
371    /// All volumes must have the same shape.
372    pub fn fuse(
373        &self,
374        target: &Array3<f64>,
375        atlas_images: &[Array3<f64>],
376        atlas_labels: &[Array3<u32>],
377    ) -> NdimageResult<JlfResult> {
378        if atlas_images.len() != atlas_labels.len() {
379            return Err(NdimageError::InvalidInput(
380                "JointLabelFusion: atlas_images and atlas_labels must have equal length"
381                    .to_string(),
382            ));
383        }
384        let n_atlases = atlas_images.len();
385        if n_atlases == 0 {
386            return Err(NdimageError::InvalidInput(
387                "JointLabelFusion: must provide at least one atlas".to_string(),
388            ));
389        }
390
391        let ts = target.shape();
392        for (i, ai) in atlas_images.iter().enumerate() {
393            if ai.shape() != ts {
394                return Err(NdimageError::DimensionError(format!(
395                    "JointLabelFusion: atlas_images[{}] shape {:?} ≠ target shape {:?}",
396                    i,
397                    ai.shape(),
398                    ts
399                )));
400            }
401        }
402        for (i, al) in atlas_labels.iter().enumerate() {
403            if al.shape() != ts {
404                return Err(NdimageError::DimensionError(format!(
405                    "JointLabelFusion: atlas_labels[{}] shape {:?} ≠ target shape {:?}",
406                    i,
407                    al.shape(),
408                    ts
409                )));
410            }
411        }
412
413        let (nz, ny, nx) = (ts[0], ts[1], ts[2]);
414        let pr = self.config.patch_radius as isize;
415
416        // Accumulate weighted label votes
417        // For efficiency, accumulate per-label probability maps
418        let mut label_votes: HashMap<u32, Array3<f64>> = HashMap::new();
419        let mut weight_sum = Array3::<f64>::zeros((nz, ny, nx));
420
421        for iz in 0..nz {
422            for iy in 0..ny {
423                for ix in 0..nx {
424                    // Extract target patch
425                    let t_patch =
426                        extract_patch_3d(target, iz as isize, iy as isize, ix as isize, pr);
427
428                    // Compute weights for each atlas
429                    let mut weights = Vec::with_capacity(n_atlases);
430                    for a in 0..n_atlases {
431                        let a_patch = extract_patch_3d(
432                            &atlas_images[a],
433                            iz as isize,
434                            iy as isize,
435                            ix as isize,
436                            pr,
437                        );
438                        let w = self.patch_weight(&t_patch, &a_patch);
439                        weights.push(w);
440                    }
441
442                    // Normalise weights
443                    let w_sum: f64 = weights.iter().sum();
444                    let w_norm: Vec<f64> = if w_sum > 1e-12 {
445                        weights.iter().map(|&w| w / w_sum).collect()
446                    } else {
447                        vec![1.0 / n_atlases as f64; n_atlases]
448                    };
449
450                    // Accumulate votes
451                    let total_w: f64 = w_norm.iter().sum();
452                    weight_sum[[iz, iy, ix]] = total_w;
453                    for (a, &wn) in w_norm.iter().enumerate() {
454                        let lv = atlas_labels[a][[iz, iy, ix]];
455                        label_votes
456                            .entry(lv)
457                            .or_insert_with(|| Array3::<f64>::zeros((nz, ny, nx)))[[iz, iy, ix]] +=
458                            wn;
459                    }
460                }
461            }
462        }
463
464        // Winner-take-all label selection
465        let mut label_result = Array3::<u32>::zeros((nz, ny, nx));
466        for iz in 0..nz {
467            for iy in 0..ny {
468                for ix in 0..nx {
469                    let winner_label = label_votes
470                        .iter()
471                        .max_by(|a, b| {
472                            a.1[[iz, iy, ix]]
473                                .partial_cmp(&b.1[[iz, iy, ix]])
474                                .unwrap_or(std::cmp::Ordering::Equal)
475                                .then_with(|| b.0.cmp(a.0))
476                        })
477                        .map(|(&lv, _)| lv)
478                        .unwrap_or(0);
479                    label_result[[iz, iy, ix]] = winner_label;
480                }
481            }
482        }
483
484        Ok(JlfResult {
485            label: label_result,
486            weight_sum,
487        })
488    }
489
490    /// Compute the similarity weight between a target patch and an atlas patch.
491    ///
492    /// Uses a normalised sum of squared differences (NSSD) decay function.
493    fn patch_weight(&self, target_patch: &[f64], atlas_patch: &[f64]) -> f64 {
494        if target_patch.is_empty() {
495            return 1.0;
496        }
497        let n = target_patch.len().min(atlas_patch.len()) as f64;
498        let ssd: f64 = target_patch
499            .iter()
500            .zip(atlas_patch.iter())
501            .map(|(t, a)| (t - a).powi(2))
502            .sum();
503        let normalised_ssd = ssd / (n * self.config.beta + 1e-10);
504        (-normalised_ssd / (self.config.alpha + 1e-10)).exp()
505    }
506}
507
508/// Extract a cubic patch of half-width `pr` centred at `(iz, iy, ix)`.
509///
510/// Out-of-bounds positions are clamped (edge replication).
511fn extract_patch_3d(vol: &Array3<f64>, iz: isize, iy: isize, ix: isize, pr: isize) -> Vec<f64> {
512    let shape = vol.shape();
513    let (nz, ny, nx) = (shape[0] as isize, shape[1] as isize, shape[2] as isize);
514    let mut patch = Vec::with_capacity(((2 * pr + 1) as usize).pow(3));
515    for dz in -pr..=pr {
516        for dy in -pr..=pr {
517            for dx in -pr..=pr {
518                let z = (iz + dz).clamp(0, nz - 1) as usize;
519                let y = (iy + dy).clamp(0, ny - 1) as usize;
520                let x = (ix + dx).clamp(0, nx - 1) as usize;
521                patch.push(vol[[z, y, x]]);
522            }
523        }
524    }
525    patch
526}
527
528// ─── AtlasSegmentation ───────────────────────────────────────────────────────
529
530/// Fusion method to use in the multi-atlas pipeline.
531#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532pub enum FusionMethod {
533    /// Simple majority voting.
534    MajorityVoting,
535    /// STAPLE probabilistic estimation.
536    Staple,
537    /// Joint label fusion with patch similarity weighting.
538    JointLabelFusion,
539}
540
541/// Configuration for the multi-atlas segmentation pipeline.
542#[derive(Debug, Clone)]
543pub struct AtlasConfig {
544    /// Label fusion method.
545    pub fusion_method: FusionMethod,
546    /// STAPLE configuration (used when `fusion_method == Staple`).
547    pub staple_config: StapleConfig,
548    /// JLF configuration (used when `fusion_method == JointLabelFusion`).
549    pub jlf_config: JlfConfig,
550}
551
552impl Default for AtlasConfig {
553    fn default() -> Self {
554        Self {
555            fusion_method: FusionMethod::MajorityVoting,
556            staple_config: StapleConfig::default(),
557            jlf_config: JlfConfig::default(),
558        }
559    }
560}
561
562/// Result of multi-atlas segmentation.
563#[derive(Debug, Clone)]
564pub struct AtlasSegmentationResult {
565    /// Final label volume.
566    pub label: Array3<u32>,
567    /// Number of atlases used.
568    pub n_atlases: usize,
569    /// Fusion method employed.
570    pub fusion_method: FusionMethod,
571    /// Optional STAPLE result (populated when `fusion_method == Staple`).
572    pub staple_result: Option<StapleResult>,
573}
574
575/// Multi-atlas label fusion segmentation pipeline.
576///
577/// Accepts pre-registered atlas label volumes and fuses them using the
578/// configured method.  If JLF is used, corresponding atlas intensity images
579/// and the target image must also be provided.
580pub struct AtlasSegmentation {
581    config: AtlasConfig,
582}
583
584impl AtlasSegmentation {
585    /// Create with default configuration (majority voting).
586    pub fn new() -> Self {
587        Self {
588            config: AtlasConfig::default(),
589        }
590    }
591
592    /// Create with custom configuration.
593    pub fn with_config(config: AtlasConfig) -> Self {
594        Self { config }
595    }
596
597    /// Fuse atlas labels.
598    ///
599    /// `atlas_labels` are the registered atlas segmentations.
600    /// `target_image` / `atlas_images` are required only when
601    /// `fusion_method == JointLabelFusion`.
602    pub fn segment(
603        &self,
604        atlas_labels: &[Array3<u32>],
605        target_image: Option<&Array3<f64>>,
606        atlas_images: Option<&[Array3<f64>]>,
607    ) -> NdimageResult<AtlasSegmentationResult> {
608        let n_atlases = atlas_labels.len();
609        match self.config.fusion_method {
610            FusionMethod::MajorityVoting => {
611                let label = MajorityVoting::fuse(atlas_labels)?;
612                Ok(AtlasSegmentationResult {
613                    label,
614                    n_atlases,
615                    fusion_method: FusionMethod::MajorityVoting,
616                    staple_result: None,
617                })
618            }
619            FusionMethod::Staple => {
620                let staple = STAPLE::with_config(self.config.staple_config.clone());
621                let sr = staple.estimate(atlas_labels)?;
622                let label = sr.label.clone();
623                Ok(AtlasSegmentationResult {
624                    label,
625                    n_atlases,
626                    fusion_method: FusionMethod::Staple,
627                    staple_result: Some(sr),
628                })
629            }
630            FusionMethod::JointLabelFusion => {
631                let target = target_image.ok_or_else(|| {
632                    NdimageError::InvalidInput(
633                        "AtlasSegmentation: JLF requires target_image".to_string(),
634                    )
635                })?;
636                let imgs = atlas_images.ok_or_else(|| {
637                    NdimageError::InvalidInput(
638                        "AtlasSegmentation: JLF requires atlas_images".to_string(),
639                    )
640                })?;
641                let jlf = JointLabelFusion::with_config(self.config.jlf_config.clone());
642                let jr = jlf.fuse(target, imgs, atlas_labels)?;
643                Ok(AtlasSegmentationResult {
644                    label: jr.label,
645                    n_atlases,
646                    fusion_method: FusionMethod::JointLabelFusion,
647                    staple_result: None,
648                })
649            }
650        }
651    }
652}
653
654// ─── Unit tests ───────────────────────────────────────────────────────────────
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659    use scirs2_core::ndarray::Array3;
660
661    /// Create a simple test label volume: foreground sphere in the centre.
662    fn sphere_labels(nz: usize, ny: usize, nx: usize, label: u32) -> Array3<u32> {
663        let mut a = Array3::<u32>::zeros((nz, ny, nx));
664        let cz = nz as f64 / 2.0;
665        let cy = ny as f64 / 2.0;
666        let cx = nx as f64 / 2.0;
667        let r2 = ((nz.min(ny).min(nx)) as f64 / 3.0).powi(2);
668        for iz in 0..nz {
669            for iy in 0..ny {
670                for ix in 0..nx {
671                    let d2 = (iz as f64 - cz).powi(2)
672                        + (iy as f64 - cy).powi(2)
673                        + (ix as f64 - cx).powi(2);
674                    if d2 < r2 {
675                        a[[iz, iy, ix]] = label;
676                    }
677                }
678            }
679        }
680        a
681    }
682
683    #[test]
684    fn test_majority_voting_identical_atlases() {
685        let a = sphere_labels(8, 8, 8, 1);
686        let labels = vec![a.clone(), a.clone(), a.clone()];
687        let fused = MajorityVoting::fuse(&labels)
688            .expect("MajorityVoting::fuse should succeed with identical atlases");
689        // All atlases agree → output must equal input
690        for iz in 0..8 {
691            for iy in 0..8 {
692                for ix in 0..8 {
693                    assert_eq!(fused[[iz, iy, ix]], a[[iz, iy, ix]]);
694                }
695            }
696        }
697    }
698
699    #[test]
700    fn test_majority_voting_confidence_perfect() {
701        let a = sphere_labels(6, 6, 6, 1);
702        let labels = vec![a.clone(), a.clone()];
703        let conf = MajorityVoting::confidence(&labels)
704            .expect("MajorityVoting::confidence should succeed with identical atlases");
705        for v in conf.iter() {
706            assert!((*v - 1.0).abs() < 1e-10);
707        }
708    }
709
710    #[test]
711    fn test_majority_voting_tie_breaks() {
712        // Two atlases disagree at every voxel: label 1 vs label 2
713        let a = sphere_labels(4, 4, 4, 1);
714        let b = sphere_labels(4, 4, 4, 2);
715        let labels = vec![a, b];
716        let fused = MajorityVoting::fuse(&labels)
717            .expect("MajorityVoting::fuse should succeed with two atlases");
718        // Tie: smallest label (1) should win when counts are equal
719        for v in fused.iter() {
720            assert!(*v == 1 || *v == 0, "unexpected label {}", v);
721        }
722    }
723
724    #[test]
725    fn test_staple_smoke() {
726        let a = sphere_labels(4, 4, 4, 1);
727        let b = sphere_labels(4, 4, 4, 1);
728        let labels = vec![a, b];
729        let staple = STAPLE::new();
730        let result = staple
731            .estimate(&labels)
732            .expect("STAPLE::estimate should succeed with valid atlases");
733        assert_eq!(result.performance.len(), 2);
734        // Sensitivities should be high for identical atlases
735        for perf in &result.performance {
736            assert!(
737                perf.sensitivity > 0.5,
738                "Expected high sensitivity, got {}",
739                perf.sensitivity
740            );
741        }
742    }
743
744    #[test]
745    fn test_staple_single_atlas() {
746        let a = sphere_labels(4, 4, 4, 1);
747        let labels = vec![a];
748        let result = STAPLE::new()
749            .estimate(&labels)
750            .expect("STAPLE::estimate should succeed with single atlas");
751        assert_eq!(result.performance.len(), 1);
752    }
753
754    #[test]
755    fn test_jlf_smoke() {
756        let n = 6;
757        let target = Array3::<f64>::from_elem((n, n, n), 100.0);
758        let atlas_img = Array3::<f64>::from_elem((n, n, n), 100.0);
759        let atlas_label = sphere_labels(n, n, n, 1);
760        let jlf = JointLabelFusion::new();
761        let result = jlf
762            .fuse(&target, &[atlas_img], std::slice::from_ref(&atlas_label))
763            .expect("JLF::fuse should succeed with single identical atlas");
764        // Single identical atlas → output == input labels
765        for iz in 0..n {
766            for iy in 0..n {
767                for ix in 0..n {
768                    assert_eq!(result.label[[iz, iy, ix]], atlas_label[[iz, iy, ix]]);
769                }
770            }
771        }
772    }
773
774    #[test]
775    fn test_atlas_segmentation_majority_voting() {
776        let a = sphere_labels(6, 6, 6, 1);
777        let labels = vec![a.clone(), a.clone()];
778        let seg = AtlasSegmentation::new();
779        let result = seg
780            .segment(&labels, None, None)
781            .expect("AtlasSegmentation::segment should succeed with valid atlases");
782        assert_eq!(result.fusion_method, FusionMethod::MajorityVoting);
783        assert_eq!(result.n_atlases, 2);
784    }
785
786    #[test]
787    fn test_atlas_segmentation_staple() {
788        let a = sphere_labels(4, 4, 4, 1);
789        let labels = vec![a.clone(), a.clone()];
790        let config = AtlasConfig {
791            fusion_method: FusionMethod::Staple,
792            ..Default::default()
793        };
794        let seg = AtlasSegmentation::with_config(config);
795        let result = seg
796            .segment(&labels, None, None)
797            .expect("AtlasSegmentation STAPLE should succeed with valid atlases");
798        assert_eq!(result.fusion_method, FusionMethod::Staple);
799        assert!(result.staple_result.is_some());
800    }
801}