1use std::collections::HashMap;
19
20use scirs2_core::ndarray::Array3;
21
22use crate::error::{NdimageError, NdimageResult};
23
24fn check_shapes(labels: &[Array3<u32>]) -> NdimageResult<(usize, usize, usize)> {
28 if labels.is_empty() {
29 return Err(NdimageError::InvalidInput(
30 "Atlas segmentation: must provide at least one atlas".to_string(),
31 ));
32 }
33 let s = labels[0].shape();
34 let shape = (s[0], s[1], s[2]);
35 for (i, lab) in labels.iter().enumerate().skip(1) {
36 if lab.shape() != labels[0].shape() {
37 return Err(NdimageError::DimensionError(format!(
38 "Atlas segmentation: atlas {} has shape {:?}, expected {:?}",
39 i,
40 lab.shape(),
41 labels[0].shape()
42 )));
43 }
44 }
45 Ok(shape)
46}
47
48pub struct MajorityVoting;
55
56impl MajorityVoting {
57 pub fn fuse(labels: &[Array3<u32>]) -> NdimageResult<Array3<u32>> {
62 let (nz, ny, nx) = check_shapes(labels)?;
63 let n_atlases = labels.len();
64 let mut result = Array3::<u32>::zeros((nz, ny, nx));
65
66 for iz in 0..nz {
67 for iy in 0..ny {
68 for ix in 0..nx {
69 let mut counts: HashMap<u32, usize> = HashMap::new();
70 for a in 0..n_atlases {
71 let lv = labels[a][[iz, iy, ix]];
72 *counts.entry(lv).or_insert(0) += 1;
73 }
74 let winner = counts
76 .iter()
77 .max_by(|a, b| a.1.cmp(b.1).then_with(|| b.0.cmp(a.0)))
78 .map(|(&lv, _)| lv)
79 .unwrap_or(0);
80 result[[iz, iy, ix]] = winner;
81 }
82 }
83 }
84 Ok(result)
85 }
86
87 pub fn confidence(labels: &[Array3<u32>]) -> NdimageResult<Array3<f64>> {
90 let fused = Self::fuse(labels)?;
91 let (nz, ny, nx) = check_shapes(labels)?;
92 let n_atlases = labels.len() as f64;
93 let mut conf = Array3::<f64>::zeros((nz, ny, nx));
94 for iz in 0..nz {
95 for iy in 0..ny {
96 for ix in 0..nx {
97 let winner = fused[[iz, iy, ix]];
98 let agree = labels.iter().filter(|l| l[[iz, iy, ix]] == winner).count();
99 conf[[iz, iy, ix]] = agree as f64 / n_atlases;
100 }
101 }
102 }
103 Ok(conf)
104 }
105}
106
107#[derive(Debug, Clone)]
117pub struct StapleConfig {
118 pub max_iterations: usize,
120 pub convergence_threshold: f64,
122 pub init_sensitivity: f64,
124 pub init_specificity: f64,
126}
127
128impl Default for StapleConfig {
129 fn default() -> Self {
130 Self {
131 max_iterations: 20,
132 convergence_threshold: 1e-5,
133 init_sensitivity: 0.99,
134 init_specificity: 0.99,
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct RaterPerformance {
142 pub sensitivity: f64,
144 pub specificity: f64,
146}
147
148#[derive(Debug, Clone)]
150pub struct StapleResult {
151 pub probability: Array3<f64>,
153 pub label: Array3<u32>,
155 pub performance: Vec<RaterPerformance>,
157 pub iterations: usize,
159 pub converged: bool,
161}
162
163pub struct STAPLE {
165 config: StapleConfig,
166}
167
168impl STAPLE {
169 pub fn new() -> Self {
171 Self {
172 config: StapleConfig::default(),
173 }
174 }
175
176 pub fn with_config(config: StapleConfig) -> Self {
178 Self { config }
179 }
180
181 pub fn estimate(&self, labels: &[Array3<u32>]) -> NdimageResult<StapleResult> {
186 let (nz, ny, nx) = check_shapes(labels)?;
187 let n = nz * ny * nx;
188 let r = labels.len();
189
190 let d: Vec<Vec<u8>> = labels
192 .iter()
193 .map(|l| l.iter().map(|&v| if v > 0 { 1u8 } else { 0u8 }).collect())
194 .collect();
195
196 let mut p: Vec<f64> = vec![self.config.init_sensitivity; r]; let mut q: Vec<f64> = vec![self.config.init_specificity; r]; let prior_fg = 0.5_f64;
202
203 let mut w: Vec<f64> = (0..n)
205 .map(|i| d.iter().map(|rater| rater[i] as f64).sum::<f64>() / r as f64)
206 .collect();
207
208 let mut converged = false;
209 let mut n_iter = 0;
210
211 for _iter in 0..self.config.max_iterations {
212 n_iter += 1;
213
214 let sum_w: f64 = w.iter().sum();
216 let sum_w0: f64 = w.iter().map(|&wi| 1.0 - wi).sum();
217
218 let mut new_p = vec![0.0_f64; r];
219 let mut new_q = vec![0.0_f64; r];
220
221 for j in 0..r {
222 let tp: f64 = (0..n).map(|i| d[j][i] as f64 * w[i]).sum();
224 new_p[j] = (tp + 1e-10) / (sum_w + 1e-10);
225
226 let tn: f64 = (0..n).map(|i| (1.0 - d[j][i] as f64) * (1.0 - w[i])).sum();
228 new_q[j] = (tn + 1e-10) / (sum_w0 + 1e-10);
229
230 new_p[j] = new_p[j].clamp(1e-6, 1.0 - 1e-6);
232 new_q[j] = new_q[j].clamp(1e-6, 1.0 - 1e-6);
233 }
234
235 let mut max_change = 0.0_f64;
237 let mut new_w = vec![0.0_f64; n];
238 for i in 0..n {
239 let mut ll1 = prior_fg.ln();
241 let mut ll0 = (1.0 - prior_fg).ln();
242 for j in 0..r {
243 if d[j][i] == 1 {
244 ll1 += new_p[j].ln();
245 ll0 += (1.0 - new_q[j]).ln();
246 } else {
247 ll1 += (1.0 - new_p[j]).ln();
248 ll0 += new_q[j].ln();
249 }
250 }
251 let max_ll = ll1.max(ll0);
252 let p1 = (ll1 - max_ll).exp();
253 let p0 = (ll0 - max_ll).exp();
254 new_w[i] = p1 / (p1 + p0 + 1e-10);
255 max_change = max_change.max((new_w[i] - w[i]).abs());
256 }
257
258 let param_change = (0..r)
260 .map(|j| (new_p[j] - p[j]).abs().max((new_q[j] - q[j]).abs()))
261 .fold(0.0_f64, f64::max);
262
263 p = new_p;
264 q = new_q;
265 w = new_w;
266
267 if param_change < self.config.convergence_threshold
268 && max_change < self.config.convergence_threshold
269 {
270 converged = true;
271 break;
272 }
273 }
274
275 let mut probability = Array3::<f64>::zeros((nz, ny, nx));
277 let mut label = Array3::<u32>::zeros((nz, ny, nx));
278 for iz in 0..nz {
279 for iy in 0..ny {
280 for ix in 0..nx {
281 let idx = iz * ny * nx + iy * nx + ix;
282 let wi = w[idx];
283 probability[[iz, iy, ix]] = wi;
284 label[[iz, iy, ix]] = if wi > 0.5 { 1 } else { 0 };
285 }
286 }
287 }
288
289 let performance: Vec<RaterPerformance> = (0..r)
290 .map(|j| RaterPerformance {
291 sensitivity: p[j],
292 specificity: q[j],
293 })
294 .collect();
295
296 Ok(StapleResult {
297 probability,
298 label,
299 performance,
300 iterations: n_iter,
301 converged,
302 })
303 }
304}
305
306#[derive(Debug, Clone)]
310pub struct JlfConfig {
311 pub patch_radius: usize,
314 pub alpha: f64,
317 pub beta: f64,
319}
320
321impl Default for JlfConfig {
322 fn default() -> Self {
323 Self {
324 patch_radius: 2,
325 alpha: 0.1,
326 beta: 2.0,
327 }
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct JlfResult {
334 pub label: Array3<u32>,
336 pub weight_sum: Array3<f64>,
338}
339
340pub struct JointLabelFusion {
347 config: JlfConfig,
348}
349
350impl JointLabelFusion {
351 pub fn new() -> Self {
353 Self {
354 config: JlfConfig::default(),
355 }
356 }
357
358 pub fn with_config(config: JlfConfig) -> Self {
360 Self { config }
361 }
362
363 pub fn fuse(
373 &self,
374 target: &Array3<f64>,
375 atlas_images: &[Array3<f64>],
376 atlas_labels: &[Array3<u32>],
377 ) -> NdimageResult<JlfResult> {
378 if atlas_images.len() != atlas_labels.len() {
379 return Err(NdimageError::InvalidInput(
380 "JointLabelFusion: atlas_images and atlas_labels must have equal length"
381 .to_string(),
382 ));
383 }
384 let n_atlases = atlas_images.len();
385 if n_atlases == 0 {
386 return Err(NdimageError::InvalidInput(
387 "JointLabelFusion: must provide at least one atlas".to_string(),
388 ));
389 }
390
391 let ts = target.shape();
392 for (i, ai) in atlas_images.iter().enumerate() {
393 if ai.shape() != ts {
394 return Err(NdimageError::DimensionError(format!(
395 "JointLabelFusion: atlas_images[{}] shape {:?} ≠ target shape {:?}",
396 i,
397 ai.shape(),
398 ts
399 )));
400 }
401 }
402 for (i, al) in atlas_labels.iter().enumerate() {
403 if al.shape() != ts {
404 return Err(NdimageError::DimensionError(format!(
405 "JointLabelFusion: atlas_labels[{}] shape {:?} ≠ target shape {:?}",
406 i,
407 al.shape(),
408 ts
409 )));
410 }
411 }
412
413 let (nz, ny, nx) = (ts[0], ts[1], ts[2]);
414 let pr = self.config.patch_radius as isize;
415
416 let mut label_votes: HashMap<u32, Array3<f64>> = HashMap::new();
419 let mut weight_sum = Array3::<f64>::zeros((nz, ny, nx));
420
421 for iz in 0..nz {
422 for iy in 0..ny {
423 for ix in 0..nx {
424 let t_patch =
426 extract_patch_3d(target, iz as isize, iy as isize, ix as isize, pr);
427
428 let mut weights = Vec::with_capacity(n_atlases);
430 for a in 0..n_atlases {
431 let a_patch = extract_patch_3d(
432 &atlas_images[a],
433 iz as isize,
434 iy as isize,
435 ix as isize,
436 pr,
437 );
438 let w = self.patch_weight(&t_patch, &a_patch);
439 weights.push(w);
440 }
441
442 let w_sum: f64 = weights.iter().sum();
444 let w_norm: Vec<f64> = if w_sum > 1e-12 {
445 weights.iter().map(|&w| w / w_sum).collect()
446 } else {
447 vec![1.0 / n_atlases as f64; n_atlases]
448 };
449
450 let total_w: f64 = w_norm.iter().sum();
452 weight_sum[[iz, iy, ix]] = total_w;
453 for (a, &wn) in w_norm.iter().enumerate() {
454 let lv = atlas_labels[a][[iz, iy, ix]];
455 label_votes
456 .entry(lv)
457 .or_insert_with(|| Array3::<f64>::zeros((nz, ny, nx)))[[iz, iy, ix]] +=
458 wn;
459 }
460 }
461 }
462 }
463
464 let mut label_result = Array3::<u32>::zeros((nz, ny, nx));
466 for iz in 0..nz {
467 for iy in 0..ny {
468 for ix in 0..nx {
469 let winner_label = label_votes
470 .iter()
471 .max_by(|a, b| {
472 a.1[[iz, iy, ix]]
473 .partial_cmp(&b.1[[iz, iy, ix]])
474 .unwrap_or(std::cmp::Ordering::Equal)
475 .then_with(|| b.0.cmp(a.0))
476 })
477 .map(|(&lv, _)| lv)
478 .unwrap_or(0);
479 label_result[[iz, iy, ix]] = winner_label;
480 }
481 }
482 }
483
484 Ok(JlfResult {
485 label: label_result,
486 weight_sum,
487 })
488 }
489
490 fn patch_weight(&self, target_patch: &[f64], atlas_patch: &[f64]) -> f64 {
494 if target_patch.is_empty() {
495 return 1.0;
496 }
497 let n = target_patch.len().min(atlas_patch.len()) as f64;
498 let ssd: f64 = target_patch
499 .iter()
500 .zip(atlas_patch.iter())
501 .map(|(t, a)| (t - a).powi(2))
502 .sum();
503 let normalised_ssd = ssd / (n * self.config.beta + 1e-10);
504 (-normalised_ssd / (self.config.alpha + 1e-10)).exp()
505 }
506}
507
508fn extract_patch_3d(vol: &Array3<f64>, iz: isize, iy: isize, ix: isize, pr: isize) -> Vec<f64> {
512 let shape = vol.shape();
513 let (nz, ny, nx) = (shape[0] as isize, shape[1] as isize, shape[2] as isize);
514 let mut patch = Vec::with_capacity(((2 * pr + 1) as usize).pow(3));
515 for dz in -pr..=pr {
516 for dy in -pr..=pr {
517 for dx in -pr..=pr {
518 let z = (iz + dz).clamp(0, nz - 1) as usize;
519 let y = (iy + dy).clamp(0, ny - 1) as usize;
520 let x = (ix + dx).clamp(0, nx - 1) as usize;
521 patch.push(vol[[z, y, x]]);
522 }
523 }
524 }
525 patch
526}
527
528#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532pub enum FusionMethod {
533 MajorityVoting,
535 Staple,
537 JointLabelFusion,
539}
540
541#[derive(Debug, Clone)]
543pub struct AtlasConfig {
544 pub fusion_method: FusionMethod,
546 pub staple_config: StapleConfig,
548 pub jlf_config: JlfConfig,
550}
551
552impl Default for AtlasConfig {
553 fn default() -> Self {
554 Self {
555 fusion_method: FusionMethod::MajorityVoting,
556 staple_config: StapleConfig::default(),
557 jlf_config: JlfConfig::default(),
558 }
559 }
560}
561
562#[derive(Debug, Clone)]
564pub struct AtlasSegmentationResult {
565 pub label: Array3<u32>,
567 pub n_atlases: usize,
569 pub fusion_method: FusionMethod,
571 pub staple_result: Option<StapleResult>,
573}
574
575pub struct AtlasSegmentation {
581 config: AtlasConfig,
582}
583
584impl AtlasSegmentation {
585 pub fn new() -> Self {
587 Self {
588 config: AtlasConfig::default(),
589 }
590 }
591
592 pub fn with_config(config: AtlasConfig) -> Self {
594 Self { config }
595 }
596
597 pub fn segment(
603 &self,
604 atlas_labels: &[Array3<u32>],
605 target_image: Option<&Array3<f64>>,
606 atlas_images: Option<&[Array3<f64>]>,
607 ) -> NdimageResult<AtlasSegmentationResult> {
608 let n_atlases = atlas_labels.len();
609 match self.config.fusion_method {
610 FusionMethod::MajorityVoting => {
611 let label = MajorityVoting::fuse(atlas_labels)?;
612 Ok(AtlasSegmentationResult {
613 label,
614 n_atlases,
615 fusion_method: FusionMethod::MajorityVoting,
616 staple_result: None,
617 })
618 }
619 FusionMethod::Staple => {
620 let staple = STAPLE::with_config(self.config.staple_config.clone());
621 let sr = staple.estimate(atlas_labels)?;
622 let label = sr.label.clone();
623 Ok(AtlasSegmentationResult {
624 label,
625 n_atlases,
626 fusion_method: FusionMethod::Staple,
627 staple_result: Some(sr),
628 })
629 }
630 FusionMethod::JointLabelFusion => {
631 let target = target_image.ok_or_else(|| {
632 NdimageError::InvalidInput(
633 "AtlasSegmentation: JLF requires target_image".to_string(),
634 )
635 })?;
636 let imgs = atlas_images.ok_or_else(|| {
637 NdimageError::InvalidInput(
638 "AtlasSegmentation: JLF requires atlas_images".to_string(),
639 )
640 })?;
641 let jlf = JointLabelFusion::with_config(self.config.jlf_config.clone());
642 let jr = jlf.fuse(target, imgs, atlas_labels)?;
643 Ok(AtlasSegmentationResult {
644 label: jr.label,
645 n_atlases,
646 fusion_method: FusionMethod::JointLabelFusion,
647 staple_result: None,
648 })
649 }
650 }
651 }
652}
653
654#[cfg(test)]
657mod tests {
658 use super::*;
659 use scirs2_core::ndarray::Array3;
660
661 fn sphere_labels(nz: usize, ny: usize, nx: usize, label: u32) -> Array3<u32> {
663 let mut a = Array3::<u32>::zeros((nz, ny, nx));
664 let cz = nz as f64 / 2.0;
665 let cy = ny as f64 / 2.0;
666 let cx = nx as f64 / 2.0;
667 let r2 = ((nz.min(ny).min(nx)) as f64 / 3.0).powi(2);
668 for iz in 0..nz {
669 for iy in 0..ny {
670 for ix in 0..nx {
671 let d2 = (iz as f64 - cz).powi(2)
672 + (iy as f64 - cy).powi(2)
673 + (ix as f64 - cx).powi(2);
674 if d2 < r2 {
675 a[[iz, iy, ix]] = label;
676 }
677 }
678 }
679 }
680 a
681 }
682
683 #[test]
684 fn test_majority_voting_identical_atlases() {
685 let a = sphere_labels(8, 8, 8, 1);
686 let labels = vec![a.clone(), a.clone(), a.clone()];
687 let fused = MajorityVoting::fuse(&labels)
688 .expect("MajorityVoting::fuse should succeed with identical atlases");
689 for iz in 0..8 {
691 for iy in 0..8 {
692 for ix in 0..8 {
693 assert_eq!(fused[[iz, iy, ix]], a[[iz, iy, ix]]);
694 }
695 }
696 }
697 }
698
699 #[test]
700 fn test_majority_voting_confidence_perfect() {
701 let a = sphere_labels(6, 6, 6, 1);
702 let labels = vec![a.clone(), a.clone()];
703 let conf = MajorityVoting::confidence(&labels)
704 .expect("MajorityVoting::confidence should succeed with identical atlases");
705 for v in conf.iter() {
706 assert!((*v - 1.0).abs() < 1e-10);
707 }
708 }
709
710 #[test]
711 fn test_majority_voting_tie_breaks() {
712 let a = sphere_labels(4, 4, 4, 1);
714 let b = sphere_labels(4, 4, 4, 2);
715 let labels = vec![a, b];
716 let fused = MajorityVoting::fuse(&labels)
717 .expect("MajorityVoting::fuse should succeed with two atlases");
718 for v in fused.iter() {
720 assert!(*v == 1 || *v == 0, "unexpected label {}", v);
721 }
722 }
723
724 #[test]
725 fn test_staple_smoke() {
726 let a = sphere_labels(4, 4, 4, 1);
727 let b = sphere_labels(4, 4, 4, 1);
728 let labels = vec![a, b];
729 let staple = STAPLE::new();
730 let result = staple
731 .estimate(&labels)
732 .expect("STAPLE::estimate should succeed with valid atlases");
733 assert_eq!(result.performance.len(), 2);
734 for perf in &result.performance {
736 assert!(
737 perf.sensitivity > 0.5,
738 "Expected high sensitivity, got {}",
739 perf.sensitivity
740 );
741 }
742 }
743
744 #[test]
745 fn test_staple_single_atlas() {
746 let a = sphere_labels(4, 4, 4, 1);
747 let labels = vec![a];
748 let result = STAPLE::new()
749 .estimate(&labels)
750 .expect("STAPLE::estimate should succeed with single atlas");
751 assert_eq!(result.performance.len(), 1);
752 }
753
754 #[test]
755 fn test_jlf_smoke() {
756 let n = 6;
757 let target = Array3::<f64>::from_elem((n, n, n), 100.0);
758 let atlas_img = Array3::<f64>::from_elem((n, n, n), 100.0);
759 let atlas_label = sphere_labels(n, n, n, 1);
760 let jlf = JointLabelFusion::new();
761 let result = jlf
762 .fuse(&target, &[atlas_img], std::slice::from_ref(&atlas_label))
763 .expect("JLF::fuse should succeed with single identical atlas");
764 for iz in 0..n {
766 for iy in 0..n {
767 for ix in 0..n {
768 assert_eq!(result.label[[iz, iy, ix]], atlas_label[[iz, iy, ix]]);
769 }
770 }
771 }
772 }
773
774 #[test]
775 fn test_atlas_segmentation_majority_voting() {
776 let a = sphere_labels(6, 6, 6, 1);
777 let labels = vec![a.clone(), a.clone()];
778 let seg = AtlasSegmentation::new();
779 let result = seg
780 .segment(&labels, None, None)
781 .expect("AtlasSegmentation::segment should succeed with valid atlases");
782 assert_eq!(result.fusion_method, FusionMethod::MajorityVoting);
783 assert_eq!(result.n_atlases, 2);
784 }
785
786 #[test]
787 fn test_atlas_segmentation_staple() {
788 let a = sphere_labels(4, 4, 4, 1);
789 let labels = vec![a.clone(), a.clone()];
790 let config = AtlasConfig {
791 fusion_method: FusionMethod::Staple,
792 ..Default::default()
793 };
794 let seg = AtlasSegmentation::with_config(config);
795 let result = seg
796 .segment(&labels, None, None)
797 .expect("AtlasSegmentation STAPLE should succeed with valid atlases");
798 assert_eq!(result.fusion_method, FusionMethod::Staple);
799 assert!(result.staple_result.is_some());
800 }
801}