Skip to main content

scirs2_ndimage/registration/
deformable.rs

1//! Deformable Image Registration
2//!
3//! Provides diffeomorphic deformable registration algorithms:
4//!
5//! - `DemonsDiffeo`: diffeomorphic demons registration using update + composition
6//! - `FluidRegistration`: viscous fluid model via iterative force smoothing
7//! - `FreeFormDeformation`: B-spline free-form deformation (FFD)
8//! - `DisplacementField`: dense vector displacement field representation
9//! - `JacobianDeterminant`: Jacobian determinant of the deformation field
10//! - `CompositeTransform`: compose rigid + deformable transforms
11//!
12//! # References
13//!
14//! - Thirion (1998), "Image matching as a diffusion process", Medical Image Analysis.
15//! - Vercauteren et al. (2009), "Diffeomorphic Demons: Efficient Non-parametric
16//!   Image Registration", NeuroImage.
17//! - Rueckert et al. (1999), "Nonrigid Registration Using Free-Form Deformations:
18//!   Application to Breast MR Images", IEEE TMI.
19
20use scirs2_core::ndarray::{Array2, Array3};
21use std::f64::consts::PI;
22
23use crate::error::{NdimageError, NdimageResult};
24use crate::registration::{AffineTransform2D, RigidTransform2D};
25
26// ─── DisplacementField ───────────────────────────────────────────────────────
27
28/// A dense displacement (vector) field over a 2D or 3D domain.
29///
30/// For a 2D domain of shape `(rows, cols)` the field has shape `(rows, cols, 2)`
31/// where the last dimension carries `[dy, dx]` displacements in pixels.
32///
33/// For a 3D domain `(nz, ny, nx)` the shape is `(nz, ny, nx, 3)` with `[dz, dy, dx]`.
34#[derive(Debug, Clone)]
35pub struct DisplacementField {
36    /// Displacement vectors; shape `[..dims.., n_components]`.
37    pub field: Vec<f64>,
38    /// Spatial dimensions of the domain (e.g., `[rows, cols]` or `[nz, ny, nx]`).
39    pub dims: Vec<usize>,
40    /// Number of vector components (2 for 2D, 3 for 3D).
41    pub n_components: usize,
42}
43
44impl DisplacementField {
45    /// Create a zero displacement field for a 2D domain.
46    pub fn zeros_2d(rows: usize, cols: usize) -> Self {
47        Self {
48            field: vec![0.0; rows * cols * 2],
49            dims: vec![rows, cols],
50            n_components: 2,
51        }
52    }
53
54    /// Create a zero displacement field for a 3D domain.
55    pub fn zeros_3d(nz: usize, ny: usize, nx: usize) -> Self {
56        Self {
57            field: vec![0.0; nz * ny * nx * 3],
58            dims: vec![nz, ny, nx],
59            n_components: 3,
60        }
61    }
62
63    /// Total number of spatial voxels / pixels.
64    pub fn num_voxels(&self) -> usize {
65        self.dims.iter().product()
66    }
67
68    /// Access the displacement vector at pixel `(r, c)` for a 2D field.
69    ///
70    /// Returns `[dy, dx]`.
71    pub fn get_2d(&self, r: usize, c: usize) -> NdimageResult<[f64; 2]> {
72        if self.dims.len() != 2 || self.n_components != 2 {
73            return Err(NdimageError::InvalidInput(
74                "DisplacementField::get_2d: field is not 2D".to_string(),
75            ));
76        }
77        let cols = self.dims[1];
78        let base = (r * cols + c) * 2;
79        if base + 1 >= self.field.len() {
80            return Err(NdimageError::InvalidInput(
81                "DisplacementField::get_2d: index out of bounds".to_string(),
82            ));
83        }
84        Ok([self.field[base], self.field[base + 1]])
85    }
86
87    /// Set the displacement vector at pixel `(r, c)` for a 2D field.
88    pub fn set_2d(&mut self, r: usize, c: usize, dy: f64, dx: f64) -> NdimageResult<()> {
89        if self.dims.len() != 2 || self.n_components != 2 {
90            return Err(NdimageError::InvalidInput(
91                "DisplacementField::set_2d: field is not 2D".to_string(),
92            ));
93        }
94        let cols = self.dims[1];
95        let base = (r * cols + c) * 2;
96        if base + 1 >= self.field.len() {
97            return Err(NdimageError::InvalidInput(
98                "DisplacementField::set_2d: index out of bounds".to_string(),
99            ));
100        }
101        self.field[base] = dy;
102        self.field[base + 1] = dx;
103        Ok(())
104    }
105
106    /// Access the displacement vector at voxel `(iz, iy, ix)` for a 3D field.
107    ///
108    /// Returns `[dz, dy, dx]`.
109    pub fn get_3d(&self, iz: usize, iy: usize, ix: usize) -> NdimageResult<[f64; 3]> {
110        if self.dims.len() != 3 || self.n_components != 3 {
111            return Err(NdimageError::InvalidInput(
112                "DisplacementField::get_3d: field is not 3D".to_string(),
113            ));
114        }
115        let ny = self.dims[1];
116        let nx = self.dims[2];
117        let base = (iz * ny * nx + iy * nx + ix) * 3;
118        if base + 2 >= self.field.len() {
119            return Err(NdimageError::InvalidInput(
120                "DisplacementField::get_3d: index out of bounds".to_string(),
121            ));
122        }
123        Ok([self.field[base], self.field[base + 1], self.field[base + 2]])
124    }
125
126    /// Compose two displacement fields: `result(x) = self(x) + other(x + self(x))`.
127    ///
128    /// Currently supports only 2D fields.
129    pub fn compose_2d(&self, other: &DisplacementField) -> NdimageResult<DisplacementField> {
130        if self.dims.len() != 2
131            || other.dims.len() != 2
132            || self.dims != other.dims
133            || self.n_components != 2
134            || other.n_components != 2
135        {
136            return Err(NdimageError::DimensionError(
137                "DisplacementField::compose_2d: incompatible fields".to_string(),
138            ));
139        }
140        let rows = self.dims[0];
141        let cols = self.dims[1];
142        let mut result = DisplacementField::zeros_2d(rows, cols);
143        for r in 0..rows {
144            for c in 0..cols {
145                let [dy0, dx0] = self.get_2d(r, c)?;
146                // Interpolate other at (r + dy0, c + dx0)
147                let wr = r as f64 + dy0;
148                let wc = c as f64 + dx0;
149                let [dy1, dx1] = bilinear_sample_2d_field(other, wr, wc);
150                result.set_2d(r, c, dy0 + dy1, dx0 + dx1)?;
151            }
152        }
153        Ok(result)
154    }
155
156    /// Apply a Gaussian smoothing kernel to the displacement field in-place.
157    ///
158    /// Only 2D fields are supported.  `sigma` is in pixels.
159    pub fn gaussian_smooth_2d(&mut self, sigma: f64) -> NdimageResult<()> {
160        if self.dims.len() != 2 || self.n_components != 2 {
161            return Err(NdimageError::InvalidInput(
162                "DisplacementField::gaussian_smooth_2d: only 2D fields supported".to_string(),
163            ));
164        }
165        let rows = self.dims[0];
166        let cols = self.dims[1];
167
168        // Smooth each component independently
169        for comp in 0..2 {
170            let mut component: Vec<f64> =
171                (0..rows * cols).map(|i| self.field[i * 2 + comp]).collect();
172            gaussian_smooth_1d_separable(&mut component, rows, cols, sigma);
173            for i in 0..rows * cols {
174                self.field[i * 2 + comp] = component[i];
175            }
176        }
177        Ok(())
178    }
179
180    /// Root-mean-square magnitude of the displacement vectors.
181    pub fn rms_magnitude(&self) -> f64 {
182        let n = self.num_voxels();
183        if n == 0 {
184            return 0.0;
185        }
186        let sum_sq: f64 = (0..n)
187            .map(|i| {
188                let base = i * self.n_components;
189                (0..self.n_components)
190                    .map(|c| self.field[base + c].powi(2))
191                    .sum::<f64>()
192            })
193            .sum();
194        (sum_sq / n as f64).sqrt()
195    }
196}
197
198// ─── JacobianDeterminant ─────────────────────────────────────────────────────
199
200/// Compute the Jacobian determinant of a 2D displacement field.
201///
202/// For a diffeomorphic transform the determinant should be positive everywhere
203/// (no folding).  Negative values indicate folding artefacts.
204///
205/// Returns an `Array2<f64>` of determinants at each pixel.
206pub struct JacobianDeterminant;
207
208impl JacobianDeterminant {
209    /// Compute det(J) for a 2D displacement field using central finite differences.
210    ///
211    /// `field.dims` must be `[rows, cols]` with `n_components == 2`.
212    pub fn compute_2d(field: &DisplacementField) -> NdimageResult<Array2<f64>> {
213        if field.dims.len() != 2 || field.n_components != 2 {
214            return Err(NdimageError::InvalidInput(
215                "JacobianDeterminant::compute_2d: field must be 2D".to_string(),
216            ));
217        }
218        let rows = field.dims[0];
219        let cols = field.dims[1];
220        if rows < 3 || cols < 3 {
221            return Err(NdimageError::InvalidInput(
222                "JacobianDeterminant::compute_2d: domain must be at least 3×3".to_string(),
223            ));
224        }
225
226        let mut det = Array2::<f64>::zeros((rows, cols));
227        for r in 0..rows {
228            for c in 0..cols {
229                // Clamp neighbours for boundary pixels
230                let rn = r.saturating_sub(1);
231                let rp = (r + 1).min(rows - 1);
232                let cn = c.saturating_sub(1);
233                let cp = (c + 1).min(cols - 1);
234                let hr = (rp - rn) as f64;
235                let hc = (cp - cn) as f64;
236
237                // Deformation map φ(r,c) = (r + dy, c + dx)
238                let [dy_rn, dx_rn] = field.get_2d(rn, c)?;
239                let [dy_rp, dx_rp] = field.get_2d(rp, c)?;
240                let [dy_rcn, dx_rcn] = field.get_2d(r, cn)?;
241                let [dy_rcp, dx_rcp] = field.get_2d(r, cp)?;
242
243                // d(φ_r)/dr, d(φ_r)/dc
244                let dphi_r_dr = 1.0 + (dy_rp - dy_rn) / hr;
245                let dphi_r_dc = (dy_rcp - dy_rcn) / hc;
246                // d(φ_c)/dr, d(φ_c)/dc
247                let dphi_c_dr = (dx_rp - dx_rn) / hr;
248                let dphi_c_dc = 1.0 + (dx_rcp - dx_rcn) / hc;
249
250                det[[r, c]] = dphi_r_dr * dphi_c_dc - dphi_r_dc * dphi_c_dr;
251            }
252        }
253        Ok(det)
254    }
255
256    /// Count folded voxels (det(J) <= 0) and return the fraction.
257    pub fn folding_fraction_2d(field: &DisplacementField) -> NdimageResult<f64> {
258        let det = Self::compute_2d(field)?;
259        let total = det.len();
260        if total == 0 {
261            return Ok(0.0);
262        }
263        let folded = det.iter().filter(|&&v| v <= 0.0).count();
264        Ok(folded as f64 / total as f64)
265    }
266}
267
268// ─── CompositeTransform ──────────────────────────────────────────────────────
269
270/// Composition of a rigid transform followed by a deformable displacement.
271///
272/// The full mapping is:  `x' = deformable(rigid(x))`
273#[derive(Debug, Clone)]
274pub struct CompositeTransform {
275    /// Optional rigid component (rotation + translation).
276    pub rigid: Option<RigidTransform2D>,
277    /// Optional affine component (overrides rigid if both are set; applied first).
278    pub affine: Option<AffineTransform2D>,
279    /// Deformable displacement field applied after the linear transform.
280    pub deformable: Option<DisplacementField>,
281}
282
283impl CompositeTransform {
284    /// Create an identity composite transform.
285    pub fn identity() -> Self {
286        Self {
287            rigid: None,
288            affine: None,
289            deformable: None,
290        }
291    }
292
293    /// Apply the composite transform to a point `(r, c)` (in pixel coordinates)
294    /// and return the transformed point.
295    pub fn apply_to_point(&self, r: f64, c: f64) -> NdimageResult<(f64, f64)> {
296        // Step 1: apply rigid or affine linear transform
297        let (lr, lc) = if let Some(ref aff) = self.affine {
298            apply_affine_point(aff, r, c)
299        } else if let Some(ref rig) = self.rigid {
300            apply_rigid_point(rig, r, c)
301        } else {
302            (r, c)
303        };
304
305        // Step 2: add displacement from the deformable field
306        if let Some(ref field) = self.deformable {
307            if field.dims.len() == 2 && field.n_components == 2 {
308                let [dy, dx] = bilinear_sample_2d_field(field, lr, lc);
309                Ok((lr + dy, lc + dx))
310            } else {
311                Ok((lr, lc))
312            }
313        } else {
314            Ok((lr, lc))
315        }
316    }
317
318    /// Compose two `CompositeTransform`s by merging their displacement fields.
319    ///
320    /// Rigid/affine components from `self` take precedence.
321    /// Deformable fields are composed using `DisplacementField::compose_2d`.
322    pub fn compose(&self, other: &CompositeTransform) -> NdimageResult<CompositeTransform> {
323        let deformable = match (&self.deformable, &other.deformable) {
324            (Some(a), Some(b)) => Some(a.compose_2d(b)?),
325            (Some(a), None) => Some(a.clone()),
326            (None, Some(b)) => Some(b.clone()),
327            (None, None) => None,
328        };
329        Ok(CompositeTransform {
330            rigid: self.rigid.clone().or_else(|| other.rigid.clone()),
331            affine: self.affine.clone().or_else(|| other.affine.clone()),
332            deformable,
333        })
334    }
335}
336
337// ─── DemonsDiffeo ─────────────────────────────────────────────────────────────
338
339/// Configuration for diffeomorphic demons registration.
340#[derive(Debug, Clone)]
341pub struct DemonsConfig {
342    /// Maximum number of registration iterations.
343    pub max_iterations: usize,
344    /// Convergence criterion: stop when RMS displacement change is below this.
345    pub convergence_threshold: f64,
346    /// Gaussian smoothing sigma (pixels) applied to the update field.
347    pub fluid_sigma: f64,
348    /// Gaussian smoothing sigma applied to the accumulated displacement field.
349    pub diffeo_sigma: f64,
350    /// Step size multiplier (default 1.0; reduce for stability).
351    pub step_size: f64,
352}
353
354impl Default for DemonsConfig {
355    fn default() -> Self {
356        Self {
357            max_iterations: 100,
358            convergence_threshold: 1e-3,
359            fluid_sigma: 1.5,
360            diffeo_sigma: 1.0,
361            step_size: 1.0,
362        }
363    }
364}
365
366/// Result of a demons registration run.
367#[derive(Debug, Clone)]
368pub struct DemonsResult {
369    /// Final deformation field.
370    pub field: DisplacementField,
371    /// Number of iterations performed.
372    pub iterations: usize,
373    /// History of RMS displacement updates per iteration.
374    pub rms_history: Vec<f64>,
375    /// Whether the algorithm converged.
376    pub converged: bool,
377}
378
379/// Diffeomorphic demons registration for 2D images.
380///
381/// Registers a moving image to a fixed image by estimating a diffeomorphic
382/// displacement field that maximises image similarity (SSD).
383///
384/// # Algorithm
385///
386/// At each iteration:
387/// 1. Compute the update field `u` from the demons force (image gradient + intensity diff).
388/// 2. Smooth `u` with a Gaussian kernel (`fluid_sigma`).
389/// 3. Compose the current displacement with the exponential of `u`.
390/// 4. Smooth the composition with `diffeo_sigma`.
391pub struct DemonsDiffeo {
392    config: DemonsConfig,
393}
394
395impl DemonsDiffeo {
396    /// Create a new diffeomorphic demons registrar with default configuration.
397    pub fn new() -> Self {
398        Self {
399            config: DemonsConfig::default(),
400        }
401    }
402
403    /// Create with custom configuration.
404    pub fn with_config(config: DemonsConfig) -> Self {
405        Self { config }
406    }
407
408    /// Register `moving` to `fixed`.
409    ///
410    /// Both images must have the same shape `(rows, cols)`.
411    /// Returns the estimated displacement field and iteration statistics.
412    pub fn register(
413        &self,
414        fixed: &Array2<f64>,
415        moving: &Array2<f64>,
416    ) -> NdimageResult<DemonsResult> {
417        let fshape = fixed.shape();
418        let mshape = moving.shape();
419        if fshape != mshape {
420            return Err(NdimageError::DimensionError(format!(
421                "DemonsDiffeo: fixed shape {:?} != moving shape {:?}",
422                fshape, mshape
423            )));
424        }
425        let rows = fshape[0];
426        let cols = fshape[1];
427        if rows < 3 || cols < 3 {
428            return Err(NdimageError::InvalidInput(
429                "DemonsDiffeo: images must be at least 3×3".to_string(),
430            ));
431        }
432
433        let mut disp = DisplacementField::zeros_2d(rows, cols);
434        let mut rms_history = Vec::with_capacity(self.config.max_iterations);
435        let mut converged = false;
436
437        for iter in 0..self.config.max_iterations {
438            // Warp moving image with current displacement
439            let warped = warp_image_2d(moving, &disp);
440
441            // Compute demons force at each pixel
442            let mut update = DisplacementField::zeros_2d(rows, cols);
443            for r in 0..rows {
444                for c in 0..cols {
445                    let f_val = fixed[[r, c]];
446                    let m_val = warped[r * cols + c];
447                    let diff = f_val - m_val;
448
449                    // Gradient of fixed image (central difference)
450                    let f_gn = if r > 0 {
451                        fixed[[r - 1, c]]
452                    } else {
453                        fixed[[r, c]]
454                    };
455                    let f_gp = if r + 1 < rows {
456                        fixed[[r + 1, c]]
457                    } else {
458                        fixed[[r, c]]
459                    };
460                    let f_gcn = if c > 0 {
461                        fixed[[r, c - 1]]
462                    } else {
463                        fixed[[r, c]]
464                    };
465                    let f_gcp = if c + 1 < cols {
466                        fixed[[r, c + 1]]
467                    } else {
468                        fixed[[r, c]]
469                    };
470                    let gx = (f_gp - f_gn) * 0.5;
471                    let gy = (f_gcp - f_gcn) * 0.5;
472                    let denom = gx * gx + gy * gy + diff * diff + 1e-10;
473
474                    let uy = self.config.step_size * diff * gx / denom;
475                    let ux = self.config.step_size * diff * gy / denom;
476                    update.set_2d(r, c, uy, ux)?;
477                }
478            }
479
480            // Smooth the update field
481            update.gaussian_smooth_2d(self.config.fluid_sigma)?;
482
483            // Compute RMS of the update
484            let rms = update.rms_magnitude();
485            rms_history.push(rms);
486
487            // Compose: disp = disp ∘ update  (update exponential ~ identity + update for small steps)
488            disp = disp.compose_2d(&update)?;
489
490            // Smooth the total displacement field
491            disp.gaussian_smooth_2d(self.config.diffeo_sigma)?;
492
493            if rms < self.config.convergence_threshold {
494                converged = true;
495                let final_iter = iter + 1;
496                return Ok(DemonsResult {
497                    field: disp,
498                    iterations: final_iter,
499                    rms_history,
500                    converged,
501                });
502            }
503        }
504
505        Ok(DemonsResult {
506            field: disp,
507            iterations: self.config.max_iterations,
508            rms_history,
509            converged,
510        })
511    }
512}
513
514// ─── FluidRegistration ───────────────────────────────────────────────────────
515
516/// Configuration for viscous fluid-model registration.
517#[derive(Debug, Clone)]
518pub struct FluidConfig {
519    /// Number of outer iterations.
520    pub max_iterations: usize,
521    /// Viscosity regularisation parameter (smoothing sigma in pixels).
522    pub viscosity: f64,
523    /// Step size for gradient descent.
524    pub step_size: f64,
525    /// Convergence threshold on normalised energy change.
526    pub convergence_threshold: f64,
527}
528
529impl Default for FluidConfig {
530    fn default() -> Self {
531        Self {
532            max_iterations: 100,
533            viscosity: 2.0,
534            step_size: 0.5,
535            convergence_threshold: 1e-4,
536        }
537    }
538}
539
540/// Result of fluid registration.
541#[derive(Debug, Clone)]
542pub struct FluidResult {
543    /// Final deformation field.
544    pub field: DisplacementField,
545    /// Energy history per iteration (SSD).
546    pub energy_history: Vec<f64>,
547    /// Number of iterations.
548    pub iterations: usize,
549    /// Converged flag.
550    pub converged: bool,
551}
552
553/// Viscous fluid model registration for 2D images.
554///
555/// Models the moving image as a viscous fluid flowing toward the fixed image
556/// under image-derived body forces.  The velocity field is regularised by
557/// applying a Gaussian smoothing at each step (approximating Stokes' equation).
558pub struct FluidRegistration {
559    config: FluidConfig,
560}
561
562impl FluidRegistration {
563    /// Create with default configuration.
564    pub fn new() -> Self {
565        Self {
566            config: FluidConfig::default(),
567        }
568    }
569
570    /// Create with custom configuration.
571    pub fn with_config(config: FluidConfig) -> Self {
572        Self { config }
573    }
574
575    /// Register `moving` to `fixed`.
576    pub fn register(
577        &self,
578        fixed: &Array2<f64>,
579        moving: &Array2<f64>,
580    ) -> NdimageResult<FluidResult> {
581        let fshape = fixed.shape();
582        if fshape != moving.shape() {
583            return Err(NdimageError::DimensionError(format!(
584                "FluidRegistration: shape mismatch {:?} vs {:?}",
585                fshape,
586                moving.shape()
587            )));
588        }
589        let rows = fshape[0];
590        let cols = fshape[1];
591        if rows < 3 || cols < 3 {
592            return Err(NdimageError::InvalidInput(
593                "FluidRegistration: images must be at least 3×3".to_string(),
594            ));
595        }
596
597        let mut vel = DisplacementField::zeros_2d(rows, cols);
598        let mut disp = DisplacementField::zeros_2d(rows, cols);
599        let mut energy_history = Vec::with_capacity(self.config.max_iterations);
600        let mut prev_energy = f64::INFINITY;
601        let mut converged = false;
602
603        for iter in 0..self.config.max_iterations {
604            let warped = warp_image_2d(moving, &disp);
605            let mut ssd = 0.0_f64;
606
607            // Compute body forces from SSD gradient
608            for r in 0..rows {
609                for c in 0..cols {
610                    let diff = fixed[[r, c]] - warped[r * cols + c];
611                    ssd += diff * diff;
612
613                    // Gradient of warped image
614                    let w_rn = if r > 0 {
615                        warped[(r - 1) * cols + c]
616                    } else {
617                        warped[r * cols + c]
618                    };
619                    let w_rp = if r + 1 < rows {
620                        warped[(r + 1) * cols + c]
621                    } else {
622                        warped[r * cols + c]
623                    };
624                    let w_cn = if c > 0 {
625                        warped[r * cols + c - 1]
626                    } else {
627                        warped[r * cols + c]
628                    };
629                    let w_cp = if c + 1 < cols {
630                        warped[r * cols + c + 1]
631                    } else {
632                        warped[r * cols + c]
633                    };
634                    let gy = (w_rp - w_rn) * 0.5;
635                    let gx = (w_cp - w_cn) * 0.5;
636
637                    let fy = 2.0 * diff * gy;
638                    let fx = 2.0 * diff * gx;
639                    vel.set_2d(r, c, fy, fx)?;
640                }
641            }
642
643            energy_history.push(ssd);
644
645            // Smooth velocity field (viscosity regularisation)
646            vel.gaussian_smooth_2d(self.config.viscosity)?;
647
648            // Update displacement: disp = disp + step * vel
649            for i in 0..rows * cols {
650                disp.field[i * 2] += self.config.step_size * vel.field[i * 2];
651                disp.field[i * 2 + 1] += self.config.step_size * vel.field[i * 2 + 1];
652            }
653
654            let rel_change = (prev_energy - ssd).abs() / (prev_energy.abs() + 1e-12);
655            if rel_change < self.config.convergence_threshold {
656                converged = true;
657                return Ok(FluidResult {
658                    field: disp,
659                    energy_history,
660                    iterations: iter + 1,
661                    converged,
662                });
663            }
664            prev_energy = ssd;
665        }
666
667        Ok(FluidResult {
668            field: disp,
669            energy_history,
670            iterations: self.config.max_iterations,
671            converged,
672        })
673    }
674}
675
676// ─── FreeFormDeformation ─────────────────────────────────────────────────────
677
678/// Configuration for B-spline free-form deformation.
679#[derive(Debug, Clone)]
680pub struct FfdConfig {
681    /// Number of control point grid nodes along each axis `[n_r, n_c]`.
682    /// Values less than 4 will be clamped to 4 (minimum for cubic B-splines).
683    pub grid_size: [usize; 2],
684    /// Number of optimisation iterations.
685    pub max_iterations: usize,
686    /// Gradient descent step size.
687    pub step_size: f64,
688    /// Bending energy regularisation weight.
689    pub regularisation: f64,
690    /// Convergence threshold on energy change.
691    pub convergence_threshold: f64,
692}
693
694impl Default for FfdConfig {
695    fn default() -> Self {
696        Self {
697            grid_size: [8, 8],
698            max_iterations: 100,
699            step_size: 0.1,
700            regularisation: 0.01,
701            convergence_threshold: 1e-4,
702        }
703    }
704}
705
706/// Result of B-spline FFD registration.
707#[derive(Debug, Clone)]
708pub struct FfdResult {
709    /// Dense displacement field derived from the B-spline control points.
710    pub field: DisplacementField,
711    /// Control-point displacements in `y` direction, shape `[grid_r, grid_c]`.
712    pub ctrl_dy: Array2<f64>,
713    /// Control-point displacements in `x` direction, shape `[grid_r, grid_c]`.
714    pub ctrl_dx: Array2<f64>,
715    /// Energy history.
716    pub energy_history: Vec<f64>,
717    /// Number of iterations.
718    pub iterations: usize,
719    /// Converged flag.
720    pub converged: bool,
721}
722
723/// B-spline free-form deformation (FFD) registration for 2D images.
724///
725/// Uses a regular grid of cubic B-spline control points to parametrise the
726/// deformation field.  The control-point displacements are optimised by
727/// gradient descent to minimise SSD + bending energy regularisation.
728pub struct FreeFormDeformation {
729    config: FfdConfig,
730}
731
732impl FreeFormDeformation {
733    /// Create with default configuration.
734    pub fn new() -> Self {
735        Self {
736            config: FfdConfig::default(),
737        }
738    }
739
740    /// Create with custom configuration.
741    pub fn with_config(config: FfdConfig) -> Self {
742        let mut cfg = config;
743        cfg.grid_size[0] = cfg.grid_size[0].max(4);
744        cfg.grid_size[1] = cfg.grid_size[1].max(4);
745        Self { config: cfg }
746    }
747
748    /// Register `moving` to `fixed`.
749    pub fn register(&self, fixed: &Array2<f64>, moving: &Array2<f64>) -> NdimageResult<FfdResult> {
750        let fshape = fixed.shape();
751        if fshape != moving.shape() {
752            return Err(NdimageError::DimensionError(format!(
753                "FreeFormDeformation: shape mismatch {:?} vs {:?}",
754                fshape,
755                moving.shape()
756            )));
757        }
758        let rows = fshape[0];
759        let cols = fshape[1];
760        if rows < 4 || cols < 4 {
761            return Err(NdimageError::InvalidInput(
762                "FreeFormDeformation: images must be at least 4×4".to_string(),
763            ));
764        }
765
766        let [gr, gc] = self.config.grid_size;
767        // Initialise control-point grids at zero
768        let mut ctrl_dy = Array2::<f64>::zeros((gr, gc));
769        let mut ctrl_dx = Array2::<f64>::zeros((gr, gc));
770
771        let mut energy_history = Vec::with_capacity(self.config.max_iterations);
772        let mut prev_energy = f64::INFINITY;
773        let mut converged = false;
774
775        for iter in 0..self.config.max_iterations {
776            // Evaluate dense displacement from current control points
777            let disp = self.ctrl_to_dense(&ctrl_dy, &ctrl_dx, rows, cols);
778            let warped = warp_image_2d(moving, &disp);
779
780            // Compute gradient of SSD w.r.t. control points
781            let mut grad_dy = Array2::<f64>::zeros((gr, gc));
782            let mut grad_dx = Array2::<f64>::zeros((gr, gc));
783            let mut ssd = 0.0_f64;
784
785            for r in 0..rows {
786                for c in 0..cols {
787                    let diff = fixed[[r, c]] - warped[r * cols + c];
788                    ssd += diff * diff;
789
790                    // Image gradient at warped position
791                    let w_rn = if r > 0 {
792                        warped[(r - 1) * cols + c]
793                    } else {
794                        warped[r * cols + c]
795                    };
796                    let w_rp = if r + 1 < rows {
797                        warped[(r + 1) * cols + c]
798                    } else {
799                        warped[r * cols + c]
800                    };
801                    let w_cn = if c > 0 {
802                        warped[r * cols + c - 1]
803                    } else {
804                        warped[r * cols + c]
805                    };
806                    let w_cp = if c + 1 < cols {
807                        warped[r * cols + c + 1]
808                    } else {
809                        warped[r * cols + c]
810                    };
811                    let gy = (w_rp - w_rn) * 0.5;
812                    let gx = (w_cp - w_cn) * 0.5;
813
814                    // Propagate gradient to control points via B-spline basis
815                    let t_r = r as f64 / rows as f64 * (gr - 1) as f64;
816                    let t_c = c as f64 / cols as f64 * (gc - 1) as f64;
817                    let pr = (t_r.floor() as isize).clamp(0, gr as isize - 1) as usize;
818                    let pc = (t_c.floor() as isize).clamp(0, gc as isize - 1) as usize;
819
820                    // Simple trilinear weight for neighbouring control points
821                    let fr = t_r - pr as f64;
822                    let fc = t_c - pc as f64;
823                    for dr in 0..2_usize {
824                        for dc in 0..2_usize {
825                            let nrr = (pr + dr).min(gr - 1);
826                            let ncc = (pc + dc).min(gc - 1);
827                            let wr = if dr == 0 { 1.0 - fr } else { fr };
828                            let wc = if dc == 0 { 1.0 - fc } else { fc };
829                            let w = wr * wc;
830                            grad_dy[[nrr, ncc]] -= 2.0 * diff * gy * w;
831                            grad_dx[[nrr, ncc]] -= 2.0 * diff * gx * w;
832                        }
833                    }
834                }
835            }
836
837            // Bending energy regularisation: penalise second derivatives of
838            // control-point grid using finite differences
839            let bend = self.bending_energy(&ctrl_dy, &ctrl_dx, gr, gc);
840            let total_energy = ssd + self.config.regularisation * bend;
841            energy_history.push(total_energy);
842
843            // Add regularisation gradient (Laplacian of control points)
844            for r in 1..gr - 1 {
845                for c in 1..gc - 1 {
846                    let lap_dy = ctrl_dy[[r - 1, c]] - 2.0 * ctrl_dy[[r, c]]
847                        + ctrl_dy[[r + 1, c]]
848                        + ctrl_dy[[r, c - 1]]
849                        - 2.0 * ctrl_dy[[r, c]]
850                        + ctrl_dy[[r, c + 1]];
851                    let lap_dx = ctrl_dx[[r - 1, c]] - 2.0 * ctrl_dx[[r, c]]
852                        + ctrl_dx[[r + 1, c]]
853                        + ctrl_dx[[r, c - 1]]
854                        - 2.0 * ctrl_dx[[r, c]]
855                        + ctrl_dx[[r, c + 1]];
856                    grad_dy[[r, c]] -= self.config.regularisation * lap_dy;
857                    grad_dx[[r, c]] -= self.config.regularisation * lap_dx;
858                }
859            }
860
861            // Update control points
862            for r in 0..gr {
863                for c in 0..gc {
864                    ctrl_dy[[r, c]] -= self.config.step_size * grad_dy[[r, c]];
865                    ctrl_dx[[r, c]] -= self.config.step_size * grad_dx[[r, c]];
866                }
867            }
868
869            let rel_change = (prev_energy - total_energy).abs() / (prev_energy.abs() + 1e-12);
870            if rel_change < self.config.convergence_threshold {
871                converged = true;
872                let final_disp = self.ctrl_to_dense(&ctrl_dy, &ctrl_dx, rows, cols);
873                return Ok(FfdResult {
874                    field: final_disp,
875                    ctrl_dy,
876                    ctrl_dx,
877                    energy_history,
878                    iterations: iter + 1,
879                    converged,
880                });
881            }
882            prev_energy = total_energy;
883        }
884
885        let final_disp = self.ctrl_to_dense(&ctrl_dy, &ctrl_dx, rows, cols);
886        Ok(FfdResult {
887            field: final_disp,
888            ctrl_dy,
889            ctrl_dx,
890            energy_history,
891            iterations: self.config.max_iterations,
892            converged,
893        })
894    }
895
896    /// Evaluate the dense displacement field by bilinear interpolation of
897    /// B-spline control points.
898    fn ctrl_to_dense(
899        &self,
900        ctrl_dy: &Array2<f64>,
901        ctrl_dx: &Array2<f64>,
902        rows: usize,
903        cols: usize,
904    ) -> DisplacementField {
905        let [gr, gc] = self.config.grid_size;
906        let mut disp = DisplacementField::zeros_2d(rows, cols);
907        for r in 0..rows {
908            for c in 0..cols {
909                let t_r = r as f64 / rows as f64 * (gr - 1) as f64;
910                let t_c = c as f64 / cols as f64 * (gc - 1) as f64;
911                let pr = (t_r.floor() as isize).clamp(0, gr as isize - 1) as usize;
912                let pc = (t_c.floor() as isize).clamp(0, gc as isize - 1) as usize;
913                let fr = t_r - pr as f64;
914                let fc = t_c - pc as f64;
915
916                let mut dy = 0.0;
917                let mut dx = 0.0;
918                for dr in 0..2_usize {
919                    for dc in 0..2_usize {
920                        let nrr = (pr + dr).min(gr - 1);
921                        let ncc = (pc + dc).min(gc - 1);
922                        let wr = if dr == 0 { 1.0 - fr } else { fr };
923                        let wc = if dc == 0 { 1.0 - fc } else { fc };
924                        dy += wr * wc * ctrl_dy[[nrr, ncc]];
925                        dx += wr * wc * ctrl_dx[[nrr, ncc]];
926                    }
927                }
928                let base = (r * cols + c) * 2;
929                disp.field[base] = dy;
930                disp.field[base + 1] = dx;
931            }
932        }
933        disp
934    }
935
936    /// Compute the bending energy of the control-point grid using second
937    /// finite differences.
938    fn bending_energy(&self, dy: &Array2<f64>, dx: &Array2<f64>, gr: usize, gc: usize) -> f64 {
939        let mut energy = 0.0;
940        for r in 1..gr.saturating_sub(1) {
941            for c in 1..gc.saturating_sub(1) {
942                let d2y_rr = dy[[r - 1, c]] - 2.0 * dy[[r, c]] + dy[[r + 1, c]];
943                let d2y_cc = dy[[r, c - 1]] - 2.0 * dy[[r, c]] + dy[[r, c + 1]];
944                let d2x_rr = dx[[r - 1, c]] - 2.0 * dx[[r, c]] + dx[[r + 1, c]];
945                let d2x_cc = dx[[r, c - 1]] - 2.0 * dx[[r, c]] + dx[[r, c + 1]];
946                energy += d2y_rr * d2y_rr + d2y_cc * d2y_cc + d2x_rr * d2x_rr + d2x_cc * d2x_cc;
947            }
948        }
949        energy
950    }
951}
952
953// ─── Helpers ──────────────────────────────────────────────────────────────────
954
955/// Warp a 2D image by a displacement field using bilinear interpolation.
956///
957/// Returns a flat `Vec<f64>` with shape `rows × cols` (row-major).
958fn warp_image_2d(image: &Array2<f64>, field: &DisplacementField) -> Vec<f64> {
959    let rows = field.dims[0];
960    let cols = field.dims[1];
961    let mut out = vec![0.0_f64; rows * cols];
962    for r in 0..rows {
963        for c in 0..cols {
964            let base = (r * cols + c) * 2;
965            let dy = field.field[base];
966            let dx = field.field[base + 1];
967            let src_r = r as f64 - dy;
968            let src_c = c as f64 - dx;
969            out[r * cols + c] = bilinear_interpolate(image, src_r, src_c);
970        }
971    }
972    out
973}
974
975/// Bilinear interpolation on a 2D array at fractional coordinates `(r, c)`.
976///
977/// Boundary pixels are replicated (clamp-to-edge).
978fn bilinear_interpolate(image: &Array2<f64>, r: f64, c: f64) -> f64 {
979    let rows = image.nrows();
980    let cols = image.ncols();
981    let r0 = r.floor() as isize;
982    let c0 = c.floor() as isize;
983    let fr = r - r.floor();
984    let fc = c - c.floor();
985
986    let clamp_r = |v: isize| v.clamp(0, rows as isize - 1) as usize;
987    let clamp_c = |v: isize| v.clamp(0, cols as isize - 1) as usize;
988
989    let r0u = clamp_r(r0);
990    let r1u = clamp_r(r0 + 1);
991    let c0u = clamp_c(c0);
992    let c1u = clamp_c(c0 + 1);
993
994    let v00 = image[[r0u, c0u]];
995    let v01 = image[[r0u, c1u]];
996    let v10 = image[[r1u, c0u]];
997    let v11 = image[[r1u, c1u]];
998
999    v00 * (1.0 - fr) * (1.0 - fc) + v01 * (1.0 - fr) * fc + v10 * fr * (1.0 - fc) + v11 * fr * fc
1000}
1001
1002/// Bilinearly sample a 2D displacement field at fractional coordinates.
1003///
1004/// Returns `[dy, dx]`; out-of-bounds coordinates are clamped.
1005fn bilinear_sample_2d_field(field: &DisplacementField, r: f64, c: f64) -> [f64; 2] {
1006    let rows = field.dims[0];
1007    let cols = field.dims[1];
1008    let r0 = r.floor() as isize;
1009    let c0 = c.floor() as isize;
1010    let fr = r - r.floor();
1011    let fc = c - c.floor();
1012
1013    let clamp_r = |v: isize| v.clamp(0, rows as isize - 1) as usize;
1014    let clamp_c = |v: isize| v.clamp(0, cols as isize - 1) as usize;
1015
1016    let corners = [
1017        (clamp_r(r0), clamp_c(c0)),
1018        (clamp_r(r0), clamp_c(c0 + 1)),
1019        (clamp_r(r0 + 1), clamp_c(c0)),
1020        (clamp_r(r0 + 1), clamp_c(c0 + 1)),
1021    ];
1022    let weights = [
1023        (1.0 - fr) * (1.0 - fc),
1024        (1.0 - fr) * fc,
1025        fr * (1.0 - fc),
1026        fr * fc,
1027    ];
1028
1029    let mut dy = 0.0_f64;
1030    let mut dx = 0.0_f64;
1031    for (idx, &(cr, cc)) in corners.iter().enumerate() {
1032        let base = (cr * cols + cc) * 2;
1033        dy += weights[idx] * field.field[base];
1034        dx += weights[idx] * field.field[base + 1];
1035    }
1036    [dy, dx]
1037}
1038
1039/// Apply an affine transform to a single point.
1040fn apply_affine_point(aff: &AffineTransform2D, r: f64, c: f64) -> (f64, f64) {
1041    let m = &aff.matrix;
1042    let nr = m[[0, 0]] * r + m[[0, 1]] * c + m[[0, 2]];
1043    let nc = m[[1, 0]] * r + m[[1, 1]] * c + m[[1, 2]];
1044    (nr, nc)
1045}
1046
1047/// Apply a rigid transform to a single point.
1048fn apply_rigid_point(rig: &RigidTransform2D, r: f64, c: f64) -> (f64, f64) {
1049    let cos_a = rig.angle.cos();
1050    let sin_a = rig.angle.sin();
1051    let nr = cos_a * r - sin_a * c + rig.ty;
1052    let nc = sin_a * r + cos_a * c + rig.tx;
1053    (nr, nc)
1054}
1055
1056/// Separable 1D Gaussian smoothing applied row-wise then column-wise to a flat
1057/// 2D buffer of size `rows × cols`.
1058fn gaussian_smooth_1d_separable(buf: &mut Vec<f64>, rows: usize, cols: usize, sigma: f64) {
1059    let radius = (3.0 * sigma).ceil() as usize;
1060    let kernel: Vec<f64> = {
1061        let two_sig2 = 2.0 * sigma * sigma;
1062        let k: Vec<f64> = (0..=radius)
1063            .flat_map(|i| {
1064                if i == 0 {
1065                    vec![(-(0_f64.powi(2)) / two_sig2).exp()]
1066                } else {
1067                    let v = (-(i as f64).powi(2) / two_sig2).exp();
1068                    vec![v, v]
1069                }
1070            })
1071            .collect();
1072        // Build symmetric kernel [-radius..0..+radius]
1073        let mut full = Vec::with_capacity(2 * radius + 1);
1074        for i in (1..=radius).rev() {
1075            full.push((-(i as f64).powi(2) / two_sig2).exp());
1076        }
1077        full.push(0.0_f64.exp()); // centre
1078        for i in 1..=radius {
1079            full.push((-(i as f64).powi(2) / two_sig2).exp());
1080        }
1081        let sum: f64 = full.iter().sum();
1082        let _ = k; // suppress warning
1083        full.iter().map(|v| v / sum).collect()
1084    };
1085
1086    let klen = kernel.len();
1087    let krad = klen / 2;
1088
1089    // Row-wise pass
1090    let mut tmp = buf.clone();
1091    for r in 0..rows {
1092        for c in 0..cols {
1093            let mut acc = 0.0;
1094            for (ki, &kv) in kernel.iter().enumerate() {
1095                let sc = c as isize + ki as isize - krad as isize;
1096                let sc_clamped = sc.clamp(0, cols as isize - 1) as usize;
1097                acc += kv * buf[r * cols + sc_clamped];
1098            }
1099            tmp[r * cols + c] = acc;
1100        }
1101    }
1102
1103    // Column-wise pass
1104    for r in 0..rows {
1105        for c in 0..cols {
1106            let mut acc = 0.0;
1107            for (ki, &kv) in kernel.iter().enumerate() {
1108                let sr = r as isize + ki as isize - krad as isize;
1109                let sr_clamped = sr.clamp(0, rows as isize - 1) as usize;
1110                acc += kv * tmp[sr_clamped * cols + c];
1111            }
1112            buf[r * cols + c] = acc;
1113        }
1114    }
1115}
1116
1117// ─── Unit tests ───────────────────────────────────────────────────────────────
1118
1119#[cfg(test)]
1120mod tests {
1121    use super::*;
1122    use scirs2_core::ndarray::Array2;
1123
1124    fn make_test_image(rows: usize, cols: usize, offset: f64) -> Array2<f64> {
1125        let mut img = Array2::<f64>::zeros((rows, cols));
1126        for r in 0..rows {
1127            for c in 0..cols {
1128                img[[r, c]] = ((r as f64 + offset).sin() + (c as f64).cos()) * 50.0 + 128.0;
1129            }
1130        }
1131        img
1132    }
1133
1134    #[test]
1135    fn test_displacement_field_create_and_access() {
1136        let mut df = DisplacementField::zeros_2d(10, 10);
1137        df.set_2d(3, 4, 1.5, -2.0)
1138            .expect("set_2d should succeed for valid coordinates");
1139        let [dy, dx] = df
1140            .get_2d(3, 4)
1141            .expect("get_2d should succeed for valid coordinates");
1142        assert!((dy - 1.5).abs() < 1e-10);
1143        assert!((dx + 2.0).abs() < 1e-10);
1144    }
1145
1146    #[test]
1147    fn test_displacement_field_compose_identity() {
1148        let a = DisplacementField::zeros_2d(8, 8);
1149        let b = DisplacementField::zeros_2d(8, 8);
1150        let composed = a
1151            .compose_2d(&b)
1152            .expect("compose_2d should succeed with identical-size fields");
1153        assert!(composed.rms_magnitude() < 1e-10);
1154    }
1155
1156    #[test]
1157    fn test_jacobian_determinant_identity() {
1158        let field = DisplacementField::zeros_2d(10, 10);
1159        let det = JacobianDeterminant::compute_2d(&field)
1160            .expect("compute_2d should succeed on identity field");
1161        // Identity deformation → all determinants should be 1
1162        for v in det.iter() {
1163            assert!((v - 1.0).abs() < 1e-8, "Expected det≈1, got {}", v);
1164        }
1165    }
1166
1167    #[test]
1168    fn test_jacobian_folding_fraction_zero_for_identity() {
1169        let field = DisplacementField::zeros_2d(10, 10);
1170        let frac = JacobianDeterminant::folding_fraction_2d(&field)
1171            .expect("folding_fraction_2d should succeed on identity field");
1172        assert!(frac < 1e-10);
1173    }
1174
1175    #[test]
1176    fn test_demons_diffeo_smoke() {
1177        let fixed = make_test_image(16, 16, 0.0);
1178        let moving = make_test_image(16, 16, 0.3);
1179        let reg = DemonsDiffeo::new();
1180        let result = reg
1181            .register(&fixed, &moving)
1182            .expect("DemonsDiffeo register should succeed on valid images");
1183        assert!(result.iterations > 0);
1184        // The field should have some non-zero displacements
1185        let _ = result.field;
1186    }
1187
1188    #[test]
1189    fn test_fluid_registration_smoke() {
1190        let fixed = make_test_image(16, 16, 0.0);
1191        let moving = make_test_image(16, 16, 0.3);
1192        let reg = FluidRegistration::new();
1193        let result = reg
1194            .register(&fixed, &moving)
1195            .expect("FluidRegistration register should succeed on valid images");
1196        assert!(result.iterations > 0);
1197    }
1198
1199    #[test]
1200    fn test_ffd_registration_smoke() {
1201        let fixed = make_test_image(16, 16, 0.0);
1202        let moving = make_test_image(16, 16, 0.3);
1203        let reg = FreeFormDeformation::new();
1204        let result = reg
1205            .register(&fixed, &moving)
1206            .expect("FreeFormDeformation register should succeed on valid images");
1207        assert!(result.iterations > 0);
1208    }
1209
1210    #[test]
1211    fn test_composite_transform_identity() {
1212        let t = CompositeTransform::identity();
1213        let (nr, nc) = t
1214            .apply_to_point(5.0, 7.0)
1215            .expect("apply_to_point should succeed for identity transform");
1216        assert!((nr - 5.0).abs() < 1e-10);
1217        assert!((nc - 7.0).abs() < 1e-10);
1218    }
1219
1220    #[test]
1221    fn test_gaussian_smooth_does_not_panic() {
1222        let mut df = DisplacementField::zeros_2d(8, 8);
1223        df.set_2d(4, 4, 10.0, -5.0)
1224            .expect("set_2d should succeed for valid coordinates");
1225        df.gaussian_smooth_2d(1.0)
1226            .expect("gaussian_smooth_2d should succeed with sigma=1");
1227        // After smoothing the peak should be reduced
1228        let [dy, _dx] = df
1229            .get_2d(4, 4)
1230            .expect("get_2d should succeed for valid coordinates");
1231        assert!(dy < 10.0);
1232    }
1233}