Skip to main content

scirs2_transform/alignment/
procrustes.rs

1//! Procrustes analysis for aligning geometric configurations.
2//!
3//! ## Overview
4//!
5//! Procrustes analysis finds the optimal orthogonal transformation (rotation and
6//! optionally reflection and scaling) that maps one matrix onto another in the
7//! Frobenius-norm sense.
8//!
9//! ### Orthogonal Procrustes Problem
10//!
11//! Given matrices **A** (n × d) and **B** (n × d), find:
12//!
13//! ```text
14//! min_{R: Rᵀ R = I}  ||s · A R + 1 tᵀ − B||_F
15//! ```
16//!
17//! **Solution via SVD** of Bᵀ A = U Σ Vᵀ:
18//! - R = V Uᵀ  (or V diag(1,…,det(VUᵀ)) Uᵀ to prevent reflections)
19//! - Optimal scale s = trace(Σ) / ||A||_F²  (when centering and scaling enabled)
20//!
21//! ### Generalized Procrustes Analysis
22//!
23//! Aligns multiple matrices to a common mean (consensus) shape via iterative
24//! pairwise Procrustes alignment, similar to the GPA algorithm of Gower (1975).
25//!
26//! ## References
27//!
28//! - Schönemann (1966): A generalized solution of the orthogonal Procrustes problem
29//! - Gower (1975): Generalized Procrustes analysis
30//! - Golub & Van Loan (1996): Matrix Computations, §12.4
31
32use scirs2_core::ndarray::{Array1, Array2, Axis};
33
34use crate::error::{Result, TransformError};
35
36// ---------------------------------------------------------------------------
37// Configuration
38// ---------------------------------------------------------------------------
39
40/// Configuration for Procrustes alignment.
41#[derive(Debug, Clone)]
42pub struct ProcrustesConfig {
43    /// Allow reflections (orthogonal group O(d)) in addition to rotations SO(d).
44    /// Default: `false` (rotation only, det(R) = +1).
45    pub allow_reflection: bool,
46    /// Find the optimal isotropic scale factor.
47    /// Default: `true`.
48    pub scaling: bool,
49    /// Center both matrices before solving.
50    /// Default: `true`.
51    pub centering: bool,
52}
53
54impl Default for ProcrustesConfig {
55    fn default() -> Self {
56        Self {
57            allow_reflection: false,
58            scaling: true,
59            centering: true,
60        }
61    }
62}
63
64// ---------------------------------------------------------------------------
65// Result
66// ---------------------------------------------------------------------------
67
68/// Result of a Procrustes alignment.
69#[derive(Debug, Clone)]
70pub struct ProcrustesResult {
71    /// Optimal orthogonal rotation matrix R (d × d), with det(R) = +1 unless
72    /// `allow_reflection = true`.
73    pub rotation: Array2<f64>,
74    /// Optimal isotropic scale factor s (1.0 when `scaling = false`).
75    pub scale: f64,
76    /// Translation vector t (d-dimensional) applied *after* rotation.
77    pub translation: Array1<f64>,
78    /// Frobenius-norm residual ‖s·A·R + 1·tᵀ − B‖_F after alignment.
79    pub disparity: f64,
80    /// Aligned version of A: s·(A_centred · R) + centroid_B.
81    pub transformed: Array2<f64>,
82}
83
84// ---------------------------------------------------------------------------
85// Orthogonal Procrustes
86// ---------------------------------------------------------------------------
87
88/// Solve the orthogonal Procrustes problem.
89///
90/// Finds the best-fitting orthogonal transformation (rotation, optional scale,
91/// and translation) mapping **A** onto **B**:
92///
93/// ```text
94/// min_{R: Rᵀ R = I, s > 0, t}  ||s · A R + 1 tᵀ − B||_F
95/// ```
96///
97/// # Arguments
98/// * `a`      – Source matrix (n × d).
99/// * `b`      – Target matrix (n × d).
100/// * `config` – Alignment options.
101///
102/// # Errors
103/// Returns [`TransformError::InvalidInput`] when shapes are incompatible, or
104/// [`TransformError::ComputationError`] on numerical failure.
105///
106/// # Example
107/// ```rust
108/// use scirs2_transform::alignment::procrustes::{orthogonal_procrustes, ProcrustesConfig};
109/// use scirs2_core::ndarray::array;
110///
111/// // A 90° rotation of a simple triangle
112/// let a = array![[1.0_f64, 0.0], [0.0, 1.0], [0.0, 0.0]];
113/// let b = array![[0.0_f64, 1.0], [-1.0, 0.0], [0.0, 0.0]];
114/// let config = ProcrustesConfig { scaling: false, ..Default::default() };
115/// let result = orthogonal_procrustes(&a, &b, &config).expect("should succeed");
116/// assert!(result.disparity < 1e-6);
117/// ```
118pub fn orthogonal_procrustes(
119    a: &Array2<f64>,
120    b: &Array2<f64>,
121    config: &ProcrustesConfig,
122) -> Result<ProcrustesResult> {
123    let (n, d) = a.dim();
124    if b.dim() != (n, d) {
125        return Err(TransformError::InvalidInput(format!(
126            "Shape mismatch: A is ({n}×{d}) but B is ({}×{})",
127            b.nrows(),
128            b.ncols()
129        )));
130    }
131    if n == 0 || d == 0 {
132        return Err(TransformError::InvalidInput(
133            "Matrices must be non-empty".to_string(),
134        ));
135    }
136
137    // ----------------------------------------------------------------
138    // 1. Center both matrices
139    // ----------------------------------------------------------------
140    let centroid_a: Array1<f64> = if config.centering {
141        a.mean_axis(Axis(0)).ok_or_else(|| {
142            TransformError::ComputationError("Failed to compute centroid of A".to_string())
143        })?
144    } else {
145        Array1::zeros(d)
146    };
147
148    let centroid_b: Array1<f64> = if config.centering {
149        b.mean_axis(Axis(0)).ok_or_else(|| {
150            TransformError::ComputationError("Failed to compute centroid of B".to_string())
151        })?
152    } else {
153        Array1::zeros(d)
154    };
155
156    // Centered matrices
157    let a_c: Array2<f64> = a - &centroid_a.view().insert_axis(Axis(0));
158    let b_c: Array2<f64> = b - &centroid_b.view().insert_axis(Axis(0));
159
160    // ----------------------------------------------------------------
161    // 2. Frobenius norm of centered A
162    // ----------------------------------------------------------------
163    let norm_a_sq: f64 = a_c.iter().map(|&x| x * x).sum();
164
165    if norm_a_sq < f64::EPSILON {
166        // A is (approximately) zero — can't define rotation; return identity
167        let rotation = Array2::eye(d);
168        let translation = centroid_b.clone();
169        let zeros_plus_cb: Array2<f64> =
170            Array2::from_shape_fn((n, d), |_| 0.0) + &centroid_b.view().insert_axis(Axis(0));
171        let disparity = b_c.iter().map(|&x| x * x).sum::<f64>().sqrt();
172        return Ok(ProcrustesResult {
173            rotation,
174            scale: 1.0,
175            translation,
176            disparity,
177            transformed: zeros_plus_cb,
178        });
179    }
180
181    // ----------------------------------------------------------------
182    // 3. Compute M = Bᵀ A  (d × d)
183    //    The Procrustes solution uses SVD of M = B_cᵀ A_c
184    // ----------------------------------------------------------------
185    let m = b_c.t().dot(&a_c); // d × d
186
187    // ----------------------------------------------------------------
188    // 4. SVD of M: M = U Σ Vᵀ  using Jacobi SVD
189    // ----------------------------------------------------------------
190    let (u_mat, sigma_vec, vt_mat) = jacobi_svd_square(&m)?;
191    // u_mat : d×d,  sigma_vec: d,  vt_mat: d×d  (rows are right singular vectors)
192    // So M = U diag(σ) Vᵀ
193
194    // ----------------------------------------------------------------
195    // 5. Construct candidate R = V Uᵀ
196    // ----------------------------------------------------------------
197    let v_mat = vt_mat.t().to_owned(); // V: d×d  (columns are right singular vectors)
198    let ut_mat = u_mat.t().to_owned(); // Uᵀ: d×d
199    let mut r = v_mat.dot(&ut_mat); // R = V Uᵀ
200
201    // ----------------------------------------------------------------
202    // 6. Enforce det(R) = +1 if reflections are not allowed
203    // ----------------------------------------------------------------
204    if !config.allow_reflection {
205        let det_r = mat_det(&r);
206        if det_r < 0.0 {
207            // Flip sign of the last column of V (associated with smallest σ)
208            // so that det(R) = +1: R = V diag(1,…,1,−1) Uᵀ
209            let mut v_adj = v_mat.clone();
210            for row in 0..d {
211                v_adj[[row, d - 1]] *= -1.0;
212            }
213            r = v_adj.dot(&ut_mat);
214        }
215    }
216
217    // ----------------------------------------------------------------
218    // 7. Optimal scale  s = trace(Σ_adj) / ‖A_c‖²_F
219    //    Σ_adj accounts for the possible sign-flip of the last singular value.
220    // ----------------------------------------------------------------
221    let (scale, _sigma_trace) = if config.scaling {
222        let sigma_sum_raw: f64 = sigma_vec.iter().sum();
223        // If we flipped the last singular value to fix det:
224        let det_r = mat_det(&r);
225        let sigma_adj = if !config.allow_reflection && det_r > 0.0 {
226            // Check if we needed to flip (by comparing with raw det before flip)
227            // The raw sigma_sum is correct if we flipped, need to subtract 2*sigma_last
228            // But since `r` is already the corrected rotation, we re-check det
229            // The correction happened above: if original det < 0, we flipped.
230            // We always stored corrected `r`, so compare det of corrected r.
231            // If det(r) = +1, no flip was needed OR flip was applied.
232            // Easier: just recompute via V Uᵀ to see if flip happened.
233            let r_uncorrected = v_mat.dot(&ut_mat);
234            let det_uncorrected = mat_det(&r_uncorrected);
235            if det_uncorrected < 0.0 && !config.allow_reflection {
236                // Flip was applied → adjusted sigma
237                sigma_sum_raw - 2.0 * sigma_vec[d - 1]
238            } else {
239                sigma_sum_raw
240            }
241        } else {
242            sigma_sum_raw
243        };
244        let s = (sigma_adj / norm_a_sq).max(0.0);
245        (s, sigma_adj)
246    } else {
247        (1.0, sigma_vec.iter().sum::<f64>())
248    };
249
250    // ----------------------------------------------------------------
251    // 8. Translation: t = centroid_B − s · (centroid_A · R)
252    // ----------------------------------------------------------------
253    let ca_r: Array1<f64> = centroid_a
254        .view()
255        .insert_axis(Axis(0))
256        .dot(&r)
257        .row(0)
258        .to_owned();
259    let translation: Array1<f64> = &centroid_b - &(ca_r * scale);
260
261    // ----------------------------------------------------------------
262    // 9. Apply transformation: T(A) = s · A_c · R + centroid_B
263    // ----------------------------------------------------------------
264    let a_c_r = a_c.dot(&r);
265    let transformed: Array2<f64> = a_c_r * scale + &centroid_b.view().insert_axis(Axis(0));
266
267    // ----------------------------------------------------------------
268    // 10. Disparity = ‖T(A) − B‖_F
269    // ----------------------------------------------------------------
270    let diff = &transformed - b;
271    let disparity: f64 = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
272
273    Ok(ProcrustesResult {
274        rotation: r,
275        scale,
276        translation,
277        disparity,
278        transformed,
279    })
280}
281
282// ---------------------------------------------------------------------------
283// Generalized Procrustes Analysis
284// ---------------------------------------------------------------------------
285
286/// Generalized Procrustes Analysis (GPA): align multiple matrices to a common mean.
287///
288/// Iteratively aligns each matrix to the current consensus (mean) shape using
289/// [`orthogonal_procrustes`] until convergence or `max_iter` is reached.
290///
291/// # Arguments
292/// * `matrices` – Slice of matrices, each (n × d), representing the same n landmarks.
293/// * `max_iter` – Maximum number of GPA sweeps. Default suggestion: 100.
294/// * `tol`      – Convergence tolerance on the total disparity change. Default: 1e-8.
295///
296/// # Returns
297/// One [`ProcrustesResult`] per input matrix (aligned to consensus).
298///
299/// # Errors
300/// Returns an error if fewer than 2 matrices are provided or shapes differ.
301pub fn generalized_procrustes(
302    matrices: &[Array2<f64>],
303    max_iter: usize,
304    tol: f64,
305) -> Result<Vec<ProcrustesResult>> {
306    let k = matrices.len();
307    if k < 2 {
308        return Err(TransformError::InvalidInput(
309            "Generalized Procrustes requires at least 2 matrices".to_string(),
310        ));
311    }
312
313    let (n, d) = matrices[0].dim();
314    for (idx, m) in matrices.iter().enumerate() {
315        if m.dim() != (n, d) {
316            return Err(TransformError::InvalidInput(format!(
317                "Matrix {idx} has shape ({},{}) but expected ({n},{d})",
318                m.nrows(),
319                m.ncols()
320            )));
321        }
322    }
323
324    let config = ProcrustesConfig {
325        allow_reflection: false,
326        scaling: true,
327        centering: true,
328    };
329
330    // Initialise: copies of original matrices as "aligned" versions
331    let mut aligned: Vec<Array2<f64>> = matrices.to_vec();
332
333    let mut prev_disparity = f64::INFINITY;
334
335    for _iter in 0..max_iter {
336        // Compute consensus (mean shape)
337        let consensus = compute_mean_shape(&aligned);
338
339        // Align each matrix to the consensus
340        let mut total_disparity = 0.0_f64;
341        for m in aligned.iter_mut() {
342            let result = orthogonal_procrustes(m, &consensus, &config)?;
343            total_disparity += result.disparity;
344            *m = result.transformed;
345        }
346
347        // Check convergence
348        let change = (prev_disparity - total_disparity).abs();
349        prev_disparity = total_disparity;
350        if change < tol {
351            break;
352        }
353    }
354
355    // Final pass: compute ProcrustesResult for each original matrix against consensus
356    let consensus = compute_mean_shape(&aligned);
357    let mut results = Vec::with_capacity(k);
358    for orig in matrices.iter() {
359        let result = orthogonal_procrustes(orig, &consensus, &config)?;
360        results.push(result);
361    }
362
363    Ok(results)
364}
365
366// ---------------------------------------------------------------------------
367// Internal helpers
368// ---------------------------------------------------------------------------
369
370/// Compute the element-wise mean of a collection of matrices.
371fn compute_mean_shape(matrices: &[Array2<f64>]) -> Array2<f64> {
372    let k = matrices.len() as f64;
373    let (n, d) = matrices[0].dim();
374    let mut mean = Array2::<f64>::zeros((n, d));
375    for m in matrices {
376        mean = mean + m;
377    }
378    mean / k
379}
380
381/// Jacobi one-sided SVD for a square d×d matrix.
382///
383/// Computes M = U Σ Vᵀ using Givens rotations on Mᵀ M (Golub-Reinsch variant).
384/// Returns (U, σ, Vᵀ) where Vᵀ has rows that are the right singular vectors.
385fn jacobi_svd_square(m: &Array2<f64>) -> Result<(Array2<f64>, Vec<f64>, Array2<f64>)> {
386    let d = m.nrows();
387    if m.ncols() != d {
388        return Err(TransformError::ComputationError(
389            "jacobi_svd_square requires square matrix".to_string(),
390        ));
391    }
392    if d == 0 {
393        return Err(TransformError::ComputationError(
394            "jacobi_svd_square requires non-empty matrix".to_string(),
395        ));
396    }
397
398    // Work on B = Mᵀ M (symmetric PSD), accumulate V
399    let mut b = m.t().dot(m); // d×d
400    let mut v = Array2::<f64>::eye(d);
401
402    let max_sweeps = 200;
403    let eps = 1e-14_f64;
404
405    for _ in 0..max_sweeps {
406        let mut converged = true;
407        for p in 0..d {
408            for q in (p + 1)..d {
409                let bpq = b[[p, q]];
410                if bpq.abs() < eps * (b[[p, p]].abs().max(b[[q, q]].abs()).max(1.0)) {
411                    continue;
412                }
413                converged = false;
414
415                // 2×2 Jacobi rotation to zero b[p,q]
416                let bpp = b[[p, p]];
417                let bqq = b[[q, q]];
418                let tau = (bqq - bpp) / (2.0 * bpq);
419                let t = if tau >= 0.0 {
420                    1.0 / (tau + (1.0 + tau * tau).sqrt())
421                } else {
422                    1.0 / (tau - (1.0 + tau * tau).sqrt())
423                };
424                let c = 1.0 / (1.0 + t * t).sqrt();
425                let s = t * c;
426
427                // Update diagonal first
428                b[[p, p]] = bpp - t * bpq;
429                b[[q, q]] = bqq + t * bpq;
430                b[[p, q]] = 0.0;
431                b[[q, p]] = 0.0;
432
433                // Update off-diagonal elements
434                for i in 0..d {
435                    if i != p && i != q {
436                        let bip = b[[i, p]];
437                        let biq = b[[i, q]];
438                        b[[i, p]] = c * bip - s * biq;
439                        b[[i, q]] = s * bip + c * biq;
440                        b[[p, i]] = b[[i, p]];
441                        b[[q, i]] = b[[i, q]];
442                    }
443                }
444
445                // Accumulate V: V ← V J_{pq}
446                for i in 0..d {
447                    let vip = v[[i, p]];
448                    let viq = v[[i, q]];
449                    v[[i, p]] = c * vip - s * viq;
450                    v[[i, q]] = s * vip + c * viq;
451                }
452            }
453        }
454        if converged {
455            break;
456        }
457    }
458
459    // Singular values = sqrt of diagonal of B (clamped to ≥ 0)
460    let mut sigma: Vec<f64> = (0..d).map(|i| b[[i, i]].max(0.0).sqrt()).collect();
461
462    // Sort singular values in descending order (and permute V accordingly)
463    let mut order: Vec<usize> = (0..d).collect();
464    order.sort_by(|&i, &j| {
465        sigma[j]
466            .partial_cmp(&sigma[i])
467            .unwrap_or(std::cmp::Ordering::Equal)
468    });
469
470    let sigma_sorted: Vec<f64> = order.iter().map(|&i| sigma[i]).collect();
471    let v_sorted: Array2<f64> = {
472        let mut vs = Array2::<f64>::zeros((d, d));
473        for (new_col, &old_col) in order.iter().enumerate() {
474            for row in 0..d {
475                vs[[row, new_col]] = v[[row, old_col]];
476            }
477        }
478        vs
479    };
480    sigma = sigma_sorted;
481
482    // Compute U = M V Σ^{-1}: columns u_i = M v_i / σ_i
483    let mv = m.dot(&v_sorted);
484    let mut u = Array2::<f64>::zeros((d, d));
485    for i in 0..d {
486        let si = sigma[i];
487        if si > eps {
488            for r in 0..d {
489                u[[r, i]] = mv[[r, i]] / si;
490            }
491        } else {
492            // Zero singular value: u_i will be filled by Gram-Schmidt if needed
493            // For Procrustes purposes (d ≤ typically ~100), just leave as zero
494            // and we handle the det-fixing step separately.
495        }
496    }
497
498    // Orthogonalize U columns for zero singular values via Gram-Schmidt
499    orthogonalize_columns(&mut u);
500
501    let vt = v_sorted.t().to_owned(); // Vᵀ: rows are right singular vectors
502    Ok((u, sigma, vt))
503}
504
505/// Gram-Schmidt orthogonalization of matrix columns (in-place).
506/// Only processes columns that are nearly zero.
507fn orthogonalize_columns(m: &mut Array2<f64>) {
508    let (r, c) = m.dim();
509    let eps = 1e-12_f64;
510
511    for j in 0..c {
512        // Check if column j is near-zero
513        let norm_sq: f64 = (0..r).map(|i| m[[i, j]] * m[[i, j]]).sum();
514        if norm_sq > eps {
515            // Normalize it
516            let norm = norm_sq.sqrt();
517            for i in 0..r {
518                m[[i, j]] /= norm;
519            }
520            // Make subsequent columns orthogonal to this one
521            for k in (j + 1)..c {
522                let dot: f64 = (0..r).map(|i| m[[i, j]] * m[[i, k]]).sum();
523                for i in 0..r {
524                    let mij = m[[i, j]];
525                    m[[i, k]] -= dot * mij;
526                }
527            }
528        } else {
529            // Find an arbitrary unit vector orthogonal to all previous columns
530            for candidate in 0..r {
531                let mut v = vec![0.0f64; r];
532                v[candidate] = 1.0;
533                // Orthogonalize against all previous columns
534                for k in 0..j {
535                    let dot: f64 = (0..r).map(|i| m[[i, k]] * v[i]).sum();
536                    for i in 0..r {
537                        let mik = m[[i, k]];
538                        v[i] -= dot * mik;
539                    }
540                }
541                let vnorm_sq: f64 = v.iter().map(|&x| x * x).sum();
542                if vnorm_sq > eps {
543                    let vnorm = vnorm_sq.sqrt();
544                    for i in 0..r {
545                        m[[i, j]] = v[i] / vnorm;
546                    }
547                    break;
548                }
549            }
550        }
551    }
552}
553
554/// Compute the determinant of a square matrix via Gaussian elimination.
555pub(crate) fn mat_det(m: &Array2<f64>) -> f64 {
556    let d = m.nrows();
557    if d == 1 {
558        return m[[0, 0]];
559    }
560    if d == 2 {
561        return m[[0, 0]] * m[[1, 1]] - m[[0, 1]] * m[[1, 0]];
562    }
563    if d == 3 {
564        return m[[0, 0]] * (m[[1, 1]] * m[[2, 2]] - m[[1, 2]] * m[[2, 1]])
565            - m[[0, 1]] * (m[[1, 0]] * m[[2, 2]] - m[[1, 2]] * m[[2, 0]])
566            + m[[0, 2]] * (m[[1, 0]] * m[[2, 1]] - m[[1, 1]] * m[[2, 0]]);
567    }
568
569    // General case: LU with partial pivoting
570    let mut a = m.to_owned();
571    let mut sign = 1.0_f64;
572
573    for col in 0..d {
574        let mut max_val = a[[col, col]].abs();
575        let mut max_row = col;
576        for row in (col + 1)..d {
577            if a[[row, col]].abs() > max_val {
578                max_val = a[[row, col]].abs();
579                max_row = row;
580            }
581        }
582        if max_val < 1e-15 {
583            return 0.0;
584        }
585        if max_row != col {
586            for c in 0..d {
587                let tmp = a[[col, c]];
588                a[[col, c]] = a[[max_row, c]];
589                a[[max_row, c]] = tmp;
590            }
591            sign *= -1.0;
592        }
593        let pivot = a[[col, col]];
594        for row in (col + 1)..d {
595            let factor = a[[row, col]] / pivot;
596            for c in col..d {
597                let v = a[[col, c]];
598                a[[row, c]] -= factor * v;
599            }
600        }
601    }
602
603    let diag_prod: f64 = (0..d).map(|i| a[[i, i]]).product();
604    sign * diag_prod
605}
606
607// ---------------------------------------------------------------------------
608// Tests
609// ---------------------------------------------------------------------------
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use scirs2_core::ndarray::{array, Array2};
615
616    const TOL: f64 = 1e-5;
617
618    // Helper: 2D rotation matrix
619    fn rot2(angle_rad: f64) -> Array2<f64> {
620        let c = angle_rad.cos();
621        let s = angle_rad.sin();
622        array![[c, -s], [s, c]]
623    }
624
625    // ------------------------------------------------------------------
626    // Rotation-only alignment
627    // ------------------------------------------------------------------
628
629    #[test]
630    fn test_procrustes_rotation() {
631        // Rotate a 3-point configuration by 45° and recover rotation
632        let a = array![[1.0_f64, 0.0], [0.0, 1.0], [-1.0, 0.0]];
633        let angle = std::f64::consts::FRAC_PI_4;
634        let r_true = rot2(angle);
635        let b = a.dot(&r_true);
636
637        let config = ProcrustesConfig {
638            allow_reflection: false,
639            scaling: false,
640            centering: true,
641        };
642        let result = orthogonal_procrustes(&a, &b, &config).expect("procrustes ok");
643        assert!(
644            result.disparity < TOL,
645            "residual should be near 0, got {}",
646            result.disparity
647        );
648    }
649
650    #[test]
651    fn test_procrustes_no_reflection() {
652        // When a reflection is the optimal map and allow_reflection=false,
653        // we should get det(R) = +1
654        let a = array![[1.0_f64, 0.0], [0.0, 1.0], [0.0, 0.0]];
655        // Apply a reflection (det = -1): flip y-axis
656        let b: Array2<f64> = array![[1.0_f64, 0.0], [0.0, -1.0], [0.0, 0.0]];
657
658        let config = ProcrustesConfig {
659            allow_reflection: false,
660            scaling: false,
661            centering: false,
662        };
663        let result = orthogonal_procrustes(&a, &b, &config).expect("procrustes ok");
664        let det = mat_det(&result.rotation);
665        assert!((det - 1.0).abs() < TOL, "det(R) should be +1, got {det}");
666    }
667
668    #[test]
669    fn test_procrustes_scale_translation() {
670        // Apply scale 2.0 and translation [3, -1], then recover
671        let a = array![[0.0_f64, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]];
672        let scale_true = 2.0_f64;
673        let translation = array![3.0_f64, -1.0];
674        let b: Array2<f64> = &a * scale_true + &translation.view().insert_axis(Axis(0));
675
676        let config = ProcrustesConfig::default();
677        let result = orthogonal_procrustes(&a, &b, &config).expect("procrustes ok");
678        assert!(
679            result.disparity < TOL,
680            "residual should be near 0, got {}",
681            result.disparity
682        );
683        assert!(
684            (result.scale - scale_true).abs() < TOL,
685            "scale should be {scale_true}, got {}",
686            result.scale
687        );
688    }
689
690    #[test]
691    fn test_procrustes_identity() {
692        // Aligning A to itself should give identity rotation and zero residual
693        let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
694        let config = ProcrustesConfig::default();
695        let result = orthogonal_procrustes(&a, &a, &config).expect("procrustes ok");
696        assert!(
697            result.disparity < TOL,
698            "residual for A→A should be 0, got {}",
699            result.disparity
700        );
701    }
702
703    #[test]
704    fn test_procrustes_shape_mismatch_error() {
705        let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
706        let b = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0]];
707        let config = ProcrustesConfig::default();
708        let result = orthogonal_procrustes(&a, &b, &config);
709        assert!(result.is_err(), "mismatched shapes should produce an error");
710    }
711
712    // ------------------------------------------------------------------
713    // Generalized Procrustes
714    // ------------------------------------------------------------------
715
716    #[test]
717    fn test_generalized_procrustes() {
718        // Create 4 rotated versions of the same square
719        let base = array![[1.0_f64, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
720
721        let angles = [0.0_f64, 0.3, 0.7, 1.2];
722        let matrices: Vec<Array2<f64>> = angles.iter().map(|&a| base.dot(&rot2(a))).collect();
723
724        let results = generalized_procrustes(&matrices, 100, 1e-8).expect("GPA should converge");
725        assert_eq!(results.len(), matrices.len());
726
727        // Each result should have reasonably small disparity
728        for (i, r) in results.iter().enumerate() {
729            assert!(
730                r.disparity < 1.0,
731                "GPA result {i} disparity {:.4} should be small",
732                r.disparity
733            );
734        }
735    }
736
737    #[test]
738    fn test_generalized_procrustes_too_few_matrices() {
739        let m = array![[1.0_f64, 0.0]];
740        let result = generalized_procrustes(&[m], 100, 1e-8);
741        assert!(result.is_err(), "single matrix should error");
742    }
743
744    #[test]
745    fn test_generalized_procrustes_shape_mismatch() {
746        let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
747        let b = array![[1.0_f64, 0.0, 0.0]]; // different ncols
748        let result = generalized_procrustes(&[a, b], 100, 1e-8);
749        assert!(result.is_err(), "shape mismatch should error");
750    }
751
752    // ------------------------------------------------------------------
753    // Determinant helper
754    // ------------------------------------------------------------------
755
756    #[test]
757    fn test_det_2x2() {
758        let m = array![[3.0_f64, 1.0], [5.0, 2.0]];
759        let det = mat_det(&m);
760        assert!((det - 1.0).abs() < 1e-12, "2x2 det should be 1, got {det}");
761    }
762
763    #[test]
764    fn test_det_3x3() {
765        let m = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]];
766        let det = mat_det(&m);
767        assert!((det - (-3.0)).abs() < 1e-10, "det should be -3, got {det}");
768    }
769}