scirs2_transform/alignment/procrustes.rs
1//! Procrustes analysis for aligning geometric configurations.
2//!
3//! ## Overview
4//!
5//! Procrustes analysis finds the optimal orthogonal transformation (rotation and
6//! optionally reflection and scaling) that maps one matrix onto another in the
7//! Frobenius-norm sense.
8//!
9//! ### Orthogonal Procrustes Problem
10//!
11//! Given matrices **A** (n × d) and **B** (n × d), find:
12//!
13//! ```text
14//! min_{R: Rᵀ R = I} ||s · A R + 1 tᵀ − B||_F
15//! ```
16//!
17//! **Solution via SVD** of Bᵀ A = U Σ Vᵀ:
18//! - R = V Uᵀ (or V diag(1,…,det(VUᵀ)) Uᵀ to prevent reflections)
19//! - Optimal scale s = trace(Σ) / ||A||_F² (when centering and scaling enabled)
20//!
21//! ### Generalized Procrustes Analysis
22//!
23//! Aligns multiple matrices to a common mean (consensus) shape via iterative
24//! pairwise Procrustes alignment, similar to the GPA algorithm of Gower (1975).
25//!
26//! ## References
27//!
28//! - Schönemann (1966): A generalized solution of the orthogonal Procrustes problem
29//! - Gower (1975): Generalized Procrustes analysis
30//! - Golub & Van Loan (1996): Matrix Computations, §12.4
31
32use scirs2_core::ndarray::{Array1, Array2, Axis};
33
34use crate::error::{Result, TransformError};
35
36// ---------------------------------------------------------------------------
37// Configuration
38// ---------------------------------------------------------------------------
39
40/// Configuration for Procrustes alignment.
41#[derive(Debug, Clone)]
42pub struct ProcrustesConfig {
43 /// Allow reflections (orthogonal group O(d)) in addition to rotations SO(d).
44 /// Default: `false` (rotation only, det(R) = +1).
45 pub allow_reflection: bool,
46 /// Find the optimal isotropic scale factor.
47 /// Default: `true`.
48 pub scaling: bool,
49 /// Center both matrices before solving.
50 /// Default: `true`.
51 pub centering: bool,
52}
53
54impl Default for ProcrustesConfig {
55 fn default() -> Self {
56 Self {
57 allow_reflection: false,
58 scaling: true,
59 centering: true,
60 }
61 }
62}
63
64// ---------------------------------------------------------------------------
65// Result
66// ---------------------------------------------------------------------------
67
68/// Result of a Procrustes alignment.
69#[derive(Debug, Clone)]
70pub struct ProcrustesResult {
71 /// Optimal orthogonal rotation matrix R (d × d), with det(R) = +1 unless
72 /// `allow_reflection = true`.
73 pub rotation: Array2<f64>,
74 /// Optimal isotropic scale factor s (1.0 when `scaling = false`).
75 pub scale: f64,
76 /// Translation vector t (d-dimensional) applied *after* rotation.
77 pub translation: Array1<f64>,
78 /// Frobenius-norm residual ‖s·A·R + 1·tᵀ − B‖_F after alignment.
79 pub disparity: f64,
80 /// Aligned version of A: s·(A_centred · R) + centroid_B.
81 pub transformed: Array2<f64>,
82}
83
84// ---------------------------------------------------------------------------
85// Orthogonal Procrustes
86// ---------------------------------------------------------------------------
87
88/// Solve the orthogonal Procrustes problem.
89///
90/// Finds the best-fitting orthogonal transformation (rotation, optional scale,
91/// and translation) mapping **A** onto **B**:
92///
93/// ```text
94/// min_{R: Rᵀ R = I, s > 0, t} ||s · A R + 1 tᵀ − B||_F
95/// ```
96///
97/// # Arguments
98/// * `a` – Source matrix (n × d).
99/// * `b` – Target matrix (n × d).
100/// * `config` – Alignment options.
101///
102/// # Errors
103/// Returns [`TransformError::InvalidInput`] when shapes are incompatible, or
104/// [`TransformError::ComputationError`] on numerical failure.
105///
106/// # Example
107/// ```rust
108/// use scirs2_transform::alignment::procrustes::{orthogonal_procrustes, ProcrustesConfig};
109/// use scirs2_core::ndarray::array;
110///
111/// // A 90° rotation of a simple triangle
112/// let a = array![[1.0_f64, 0.0], [0.0, 1.0], [0.0, 0.0]];
113/// let b = array![[0.0_f64, 1.0], [-1.0, 0.0], [0.0, 0.0]];
114/// let config = ProcrustesConfig { scaling: false, ..Default::default() };
115/// let result = orthogonal_procrustes(&a, &b, &config).expect("should succeed");
116/// assert!(result.disparity < 1e-6);
117/// ```
118pub fn orthogonal_procrustes(
119 a: &Array2<f64>,
120 b: &Array2<f64>,
121 config: &ProcrustesConfig,
122) -> Result<ProcrustesResult> {
123 let (n, d) = a.dim();
124 if b.dim() != (n, d) {
125 return Err(TransformError::InvalidInput(format!(
126 "Shape mismatch: A is ({n}×{d}) but B is ({}×{})",
127 b.nrows(),
128 b.ncols()
129 )));
130 }
131 if n == 0 || d == 0 {
132 return Err(TransformError::InvalidInput(
133 "Matrices must be non-empty".to_string(),
134 ));
135 }
136
137 // ----------------------------------------------------------------
138 // 1. Center both matrices
139 // ----------------------------------------------------------------
140 let centroid_a: Array1<f64> = if config.centering {
141 a.mean_axis(Axis(0)).ok_or_else(|| {
142 TransformError::ComputationError("Failed to compute centroid of A".to_string())
143 })?
144 } else {
145 Array1::zeros(d)
146 };
147
148 let centroid_b: Array1<f64> = if config.centering {
149 b.mean_axis(Axis(0)).ok_or_else(|| {
150 TransformError::ComputationError("Failed to compute centroid of B".to_string())
151 })?
152 } else {
153 Array1::zeros(d)
154 };
155
156 // Centered matrices
157 let a_c: Array2<f64> = a - ¢roid_a.view().insert_axis(Axis(0));
158 let b_c: Array2<f64> = b - ¢roid_b.view().insert_axis(Axis(0));
159
160 // ----------------------------------------------------------------
161 // 2. Frobenius norm of centered A
162 // ----------------------------------------------------------------
163 let norm_a_sq: f64 = a_c.iter().map(|&x| x * x).sum();
164
165 if norm_a_sq < f64::EPSILON {
166 // A is (approximately) zero — can't define rotation; return identity
167 let rotation = Array2::eye(d);
168 let translation = centroid_b.clone();
169 let zeros_plus_cb: Array2<f64> =
170 Array2::from_shape_fn((n, d), |_| 0.0) + ¢roid_b.view().insert_axis(Axis(0));
171 let disparity = b_c.iter().map(|&x| x * x).sum::<f64>().sqrt();
172 return Ok(ProcrustesResult {
173 rotation,
174 scale: 1.0,
175 translation,
176 disparity,
177 transformed: zeros_plus_cb,
178 });
179 }
180
181 // ----------------------------------------------------------------
182 // 3. Compute M = Bᵀ A (d × d)
183 // The Procrustes solution uses SVD of M = B_cᵀ A_c
184 // ----------------------------------------------------------------
185 let m = b_c.t().dot(&a_c); // d × d
186
187 // ----------------------------------------------------------------
188 // 4. SVD of M: M = U Σ Vᵀ using Jacobi SVD
189 // ----------------------------------------------------------------
190 let (u_mat, sigma_vec, vt_mat) = jacobi_svd_square(&m)?;
191 // u_mat : d×d, sigma_vec: d, vt_mat: d×d (rows are right singular vectors)
192 // So M = U diag(σ) Vᵀ
193
194 // ----------------------------------------------------------------
195 // 5. Construct candidate R = V Uᵀ
196 // ----------------------------------------------------------------
197 let v_mat = vt_mat.t().to_owned(); // V: d×d (columns are right singular vectors)
198 let ut_mat = u_mat.t().to_owned(); // Uᵀ: d×d
199 let mut r = v_mat.dot(&ut_mat); // R = V Uᵀ
200
201 // ----------------------------------------------------------------
202 // 6. Enforce det(R) = +1 if reflections are not allowed
203 // ----------------------------------------------------------------
204 if !config.allow_reflection {
205 let det_r = mat_det(&r);
206 if det_r < 0.0 {
207 // Flip sign of the last column of V (associated with smallest σ)
208 // so that det(R) = +1: R = V diag(1,…,1,−1) Uᵀ
209 let mut v_adj = v_mat.clone();
210 for row in 0..d {
211 v_adj[[row, d - 1]] *= -1.0;
212 }
213 r = v_adj.dot(&ut_mat);
214 }
215 }
216
217 // ----------------------------------------------------------------
218 // 7. Optimal scale s = trace(Σ_adj) / ‖A_c‖²_F
219 // Σ_adj accounts for the possible sign-flip of the last singular value.
220 // ----------------------------------------------------------------
221 let (scale, _sigma_trace) = if config.scaling {
222 let sigma_sum_raw: f64 = sigma_vec.iter().sum();
223 // If we flipped the last singular value to fix det:
224 let det_r = mat_det(&r);
225 let sigma_adj = if !config.allow_reflection && det_r > 0.0 {
226 // Check if we needed to flip (by comparing with raw det before flip)
227 // The raw sigma_sum is correct if we flipped, need to subtract 2*sigma_last
228 // But since `r` is already the corrected rotation, we re-check det
229 // The correction happened above: if original det < 0, we flipped.
230 // We always stored corrected `r`, so compare det of corrected r.
231 // If det(r) = +1, no flip was needed OR flip was applied.
232 // Easier: just recompute via V Uᵀ to see if flip happened.
233 let r_uncorrected = v_mat.dot(&ut_mat);
234 let det_uncorrected = mat_det(&r_uncorrected);
235 if det_uncorrected < 0.0 && !config.allow_reflection {
236 // Flip was applied → adjusted sigma
237 sigma_sum_raw - 2.0 * sigma_vec[d - 1]
238 } else {
239 sigma_sum_raw
240 }
241 } else {
242 sigma_sum_raw
243 };
244 let s = (sigma_adj / norm_a_sq).max(0.0);
245 (s, sigma_adj)
246 } else {
247 (1.0, sigma_vec.iter().sum::<f64>())
248 };
249
250 // ----------------------------------------------------------------
251 // 8. Translation: t = centroid_B − s · (centroid_A · R)
252 // ----------------------------------------------------------------
253 let ca_r: Array1<f64> = centroid_a
254 .view()
255 .insert_axis(Axis(0))
256 .dot(&r)
257 .row(0)
258 .to_owned();
259 let translation: Array1<f64> = ¢roid_b - &(ca_r * scale);
260
261 // ----------------------------------------------------------------
262 // 9. Apply transformation: T(A) = s · A_c · R + centroid_B
263 // ----------------------------------------------------------------
264 let a_c_r = a_c.dot(&r);
265 let transformed: Array2<f64> = a_c_r * scale + ¢roid_b.view().insert_axis(Axis(0));
266
267 // ----------------------------------------------------------------
268 // 10. Disparity = ‖T(A) − B‖_F
269 // ----------------------------------------------------------------
270 let diff = &transformed - b;
271 let disparity: f64 = diff.iter().map(|&x| x * x).sum::<f64>().sqrt();
272
273 Ok(ProcrustesResult {
274 rotation: r,
275 scale,
276 translation,
277 disparity,
278 transformed,
279 })
280}
281
282// ---------------------------------------------------------------------------
283// Generalized Procrustes Analysis
284// ---------------------------------------------------------------------------
285
286/// Generalized Procrustes Analysis (GPA): align multiple matrices to a common mean.
287///
288/// Iteratively aligns each matrix to the current consensus (mean) shape using
289/// [`orthogonal_procrustes`] until convergence or `max_iter` is reached.
290///
291/// # Arguments
292/// * `matrices` – Slice of matrices, each (n × d), representing the same n landmarks.
293/// * `max_iter` – Maximum number of GPA sweeps. Default suggestion: 100.
294/// * `tol` – Convergence tolerance on the total disparity change. Default: 1e-8.
295///
296/// # Returns
297/// One [`ProcrustesResult`] per input matrix (aligned to consensus).
298///
299/// # Errors
300/// Returns an error if fewer than 2 matrices are provided or shapes differ.
301pub fn generalized_procrustes(
302 matrices: &[Array2<f64>],
303 max_iter: usize,
304 tol: f64,
305) -> Result<Vec<ProcrustesResult>> {
306 let k = matrices.len();
307 if k < 2 {
308 return Err(TransformError::InvalidInput(
309 "Generalized Procrustes requires at least 2 matrices".to_string(),
310 ));
311 }
312
313 let (n, d) = matrices[0].dim();
314 for (idx, m) in matrices.iter().enumerate() {
315 if m.dim() != (n, d) {
316 return Err(TransformError::InvalidInput(format!(
317 "Matrix {idx} has shape ({},{}) but expected ({n},{d})",
318 m.nrows(),
319 m.ncols()
320 )));
321 }
322 }
323
324 let config = ProcrustesConfig {
325 allow_reflection: false,
326 scaling: true,
327 centering: true,
328 };
329
330 // Initialise: copies of original matrices as "aligned" versions
331 let mut aligned: Vec<Array2<f64>> = matrices.to_vec();
332
333 let mut prev_disparity = f64::INFINITY;
334
335 for _iter in 0..max_iter {
336 // Compute consensus (mean shape)
337 let consensus = compute_mean_shape(&aligned);
338
339 // Align each matrix to the consensus
340 let mut total_disparity = 0.0_f64;
341 for m in aligned.iter_mut() {
342 let result = orthogonal_procrustes(m, &consensus, &config)?;
343 total_disparity += result.disparity;
344 *m = result.transformed;
345 }
346
347 // Check convergence
348 let change = (prev_disparity - total_disparity).abs();
349 prev_disparity = total_disparity;
350 if change < tol {
351 break;
352 }
353 }
354
355 // Final pass: compute ProcrustesResult for each original matrix against consensus
356 let consensus = compute_mean_shape(&aligned);
357 let mut results = Vec::with_capacity(k);
358 for orig in matrices.iter() {
359 let result = orthogonal_procrustes(orig, &consensus, &config)?;
360 results.push(result);
361 }
362
363 Ok(results)
364}
365
366// ---------------------------------------------------------------------------
367// Internal helpers
368// ---------------------------------------------------------------------------
369
370/// Compute the element-wise mean of a collection of matrices.
371fn compute_mean_shape(matrices: &[Array2<f64>]) -> Array2<f64> {
372 let k = matrices.len() as f64;
373 let (n, d) = matrices[0].dim();
374 let mut mean = Array2::<f64>::zeros((n, d));
375 for m in matrices {
376 mean = mean + m;
377 }
378 mean / k
379}
380
381/// Jacobi one-sided SVD for a square d×d matrix.
382///
383/// Computes M = U Σ Vᵀ using Givens rotations on Mᵀ M (Golub-Reinsch variant).
384/// Returns (U, σ, Vᵀ) where Vᵀ has rows that are the right singular vectors.
385fn jacobi_svd_square(m: &Array2<f64>) -> Result<(Array2<f64>, Vec<f64>, Array2<f64>)> {
386 let d = m.nrows();
387 if m.ncols() != d {
388 return Err(TransformError::ComputationError(
389 "jacobi_svd_square requires square matrix".to_string(),
390 ));
391 }
392 if d == 0 {
393 return Err(TransformError::ComputationError(
394 "jacobi_svd_square requires non-empty matrix".to_string(),
395 ));
396 }
397
398 // Work on B = Mᵀ M (symmetric PSD), accumulate V
399 let mut b = m.t().dot(m); // d×d
400 let mut v = Array2::<f64>::eye(d);
401
402 let max_sweeps = 200;
403 let eps = 1e-14_f64;
404
405 for _ in 0..max_sweeps {
406 let mut converged = true;
407 for p in 0..d {
408 for q in (p + 1)..d {
409 let bpq = b[[p, q]];
410 if bpq.abs() < eps * (b[[p, p]].abs().max(b[[q, q]].abs()).max(1.0)) {
411 continue;
412 }
413 converged = false;
414
415 // 2×2 Jacobi rotation to zero b[p,q]
416 let bpp = b[[p, p]];
417 let bqq = b[[q, q]];
418 let tau = (bqq - bpp) / (2.0 * bpq);
419 let t = if tau >= 0.0 {
420 1.0 / (tau + (1.0 + tau * tau).sqrt())
421 } else {
422 1.0 / (tau - (1.0 + tau * tau).sqrt())
423 };
424 let c = 1.0 / (1.0 + t * t).sqrt();
425 let s = t * c;
426
427 // Update diagonal first
428 b[[p, p]] = bpp - t * bpq;
429 b[[q, q]] = bqq + t * bpq;
430 b[[p, q]] = 0.0;
431 b[[q, p]] = 0.0;
432
433 // Update off-diagonal elements
434 for i in 0..d {
435 if i != p && i != q {
436 let bip = b[[i, p]];
437 let biq = b[[i, q]];
438 b[[i, p]] = c * bip - s * biq;
439 b[[i, q]] = s * bip + c * biq;
440 b[[p, i]] = b[[i, p]];
441 b[[q, i]] = b[[i, q]];
442 }
443 }
444
445 // Accumulate V: V ← V J_{pq}
446 for i in 0..d {
447 let vip = v[[i, p]];
448 let viq = v[[i, q]];
449 v[[i, p]] = c * vip - s * viq;
450 v[[i, q]] = s * vip + c * viq;
451 }
452 }
453 }
454 if converged {
455 break;
456 }
457 }
458
459 // Singular values = sqrt of diagonal of B (clamped to ≥ 0)
460 let mut sigma: Vec<f64> = (0..d).map(|i| b[[i, i]].max(0.0).sqrt()).collect();
461
462 // Sort singular values in descending order (and permute V accordingly)
463 let mut order: Vec<usize> = (0..d).collect();
464 order.sort_by(|&i, &j| {
465 sigma[j]
466 .partial_cmp(&sigma[i])
467 .unwrap_or(std::cmp::Ordering::Equal)
468 });
469
470 let sigma_sorted: Vec<f64> = order.iter().map(|&i| sigma[i]).collect();
471 let v_sorted: Array2<f64> = {
472 let mut vs = Array2::<f64>::zeros((d, d));
473 for (new_col, &old_col) in order.iter().enumerate() {
474 for row in 0..d {
475 vs[[row, new_col]] = v[[row, old_col]];
476 }
477 }
478 vs
479 };
480 sigma = sigma_sorted;
481
482 // Compute U = M V Σ^{-1}: columns u_i = M v_i / σ_i
483 let mv = m.dot(&v_sorted);
484 let mut u = Array2::<f64>::zeros((d, d));
485 for i in 0..d {
486 let si = sigma[i];
487 if si > eps {
488 for r in 0..d {
489 u[[r, i]] = mv[[r, i]] / si;
490 }
491 } else {
492 // Zero singular value: u_i will be filled by Gram-Schmidt if needed
493 // For Procrustes purposes (d ≤ typically ~100), just leave as zero
494 // and we handle the det-fixing step separately.
495 }
496 }
497
498 // Orthogonalize U columns for zero singular values via Gram-Schmidt
499 orthogonalize_columns(&mut u);
500
501 let vt = v_sorted.t().to_owned(); // Vᵀ: rows are right singular vectors
502 Ok((u, sigma, vt))
503}
504
505/// Gram-Schmidt orthogonalization of matrix columns (in-place).
506/// Only processes columns that are nearly zero.
507fn orthogonalize_columns(m: &mut Array2<f64>) {
508 let (r, c) = m.dim();
509 let eps = 1e-12_f64;
510
511 for j in 0..c {
512 // Check if column j is near-zero
513 let norm_sq: f64 = (0..r).map(|i| m[[i, j]] * m[[i, j]]).sum();
514 if norm_sq > eps {
515 // Normalize it
516 let norm = norm_sq.sqrt();
517 for i in 0..r {
518 m[[i, j]] /= norm;
519 }
520 // Make subsequent columns orthogonal to this one
521 for k in (j + 1)..c {
522 let dot: f64 = (0..r).map(|i| m[[i, j]] * m[[i, k]]).sum();
523 for i in 0..r {
524 let mij = m[[i, j]];
525 m[[i, k]] -= dot * mij;
526 }
527 }
528 } else {
529 // Find an arbitrary unit vector orthogonal to all previous columns
530 for candidate in 0..r {
531 let mut v = vec![0.0f64; r];
532 v[candidate] = 1.0;
533 // Orthogonalize against all previous columns
534 for k in 0..j {
535 let dot: f64 = (0..r).map(|i| m[[i, k]] * v[i]).sum();
536 for i in 0..r {
537 let mik = m[[i, k]];
538 v[i] -= dot * mik;
539 }
540 }
541 let vnorm_sq: f64 = v.iter().map(|&x| x * x).sum();
542 if vnorm_sq > eps {
543 let vnorm = vnorm_sq.sqrt();
544 for i in 0..r {
545 m[[i, j]] = v[i] / vnorm;
546 }
547 break;
548 }
549 }
550 }
551 }
552}
553
554/// Compute the determinant of a square matrix via Gaussian elimination.
555pub(crate) fn mat_det(m: &Array2<f64>) -> f64 {
556 let d = m.nrows();
557 if d == 1 {
558 return m[[0, 0]];
559 }
560 if d == 2 {
561 return m[[0, 0]] * m[[1, 1]] - m[[0, 1]] * m[[1, 0]];
562 }
563 if d == 3 {
564 return m[[0, 0]] * (m[[1, 1]] * m[[2, 2]] - m[[1, 2]] * m[[2, 1]])
565 - m[[0, 1]] * (m[[1, 0]] * m[[2, 2]] - m[[1, 2]] * m[[2, 0]])
566 + m[[0, 2]] * (m[[1, 0]] * m[[2, 1]] - m[[1, 1]] * m[[2, 0]]);
567 }
568
569 // General case: LU with partial pivoting
570 let mut a = m.to_owned();
571 let mut sign = 1.0_f64;
572
573 for col in 0..d {
574 let mut max_val = a[[col, col]].abs();
575 let mut max_row = col;
576 for row in (col + 1)..d {
577 if a[[row, col]].abs() > max_val {
578 max_val = a[[row, col]].abs();
579 max_row = row;
580 }
581 }
582 if max_val < 1e-15 {
583 return 0.0;
584 }
585 if max_row != col {
586 for c in 0..d {
587 let tmp = a[[col, c]];
588 a[[col, c]] = a[[max_row, c]];
589 a[[max_row, c]] = tmp;
590 }
591 sign *= -1.0;
592 }
593 let pivot = a[[col, col]];
594 for row in (col + 1)..d {
595 let factor = a[[row, col]] / pivot;
596 for c in col..d {
597 let v = a[[col, c]];
598 a[[row, c]] -= factor * v;
599 }
600 }
601 }
602
603 let diag_prod: f64 = (0..d).map(|i| a[[i, i]]).product();
604 sign * diag_prod
605}
606
607// ---------------------------------------------------------------------------
608// Tests
609// ---------------------------------------------------------------------------
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use scirs2_core::ndarray::{array, Array2};
615
616 const TOL: f64 = 1e-5;
617
618 // Helper: 2D rotation matrix
619 fn rot2(angle_rad: f64) -> Array2<f64> {
620 let c = angle_rad.cos();
621 let s = angle_rad.sin();
622 array![[c, -s], [s, c]]
623 }
624
625 // ------------------------------------------------------------------
626 // Rotation-only alignment
627 // ------------------------------------------------------------------
628
629 #[test]
630 fn test_procrustes_rotation() {
631 // Rotate a 3-point configuration by 45° and recover rotation
632 let a = array![[1.0_f64, 0.0], [0.0, 1.0], [-1.0, 0.0]];
633 let angle = std::f64::consts::FRAC_PI_4;
634 let r_true = rot2(angle);
635 let b = a.dot(&r_true);
636
637 let config = ProcrustesConfig {
638 allow_reflection: false,
639 scaling: false,
640 centering: true,
641 };
642 let result = orthogonal_procrustes(&a, &b, &config).expect("procrustes ok");
643 assert!(
644 result.disparity < TOL,
645 "residual should be near 0, got {}",
646 result.disparity
647 );
648 }
649
650 #[test]
651 fn test_procrustes_no_reflection() {
652 // When a reflection is the optimal map and allow_reflection=false,
653 // we should get det(R) = +1
654 let a = array![[1.0_f64, 0.0], [0.0, 1.0], [0.0, 0.0]];
655 // Apply a reflection (det = -1): flip y-axis
656 let b: Array2<f64> = array![[1.0_f64, 0.0], [0.0, -1.0], [0.0, 0.0]];
657
658 let config = ProcrustesConfig {
659 allow_reflection: false,
660 scaling: false,
661 centering: false,
662 };
663 let result = orthogonal_procrustes(&a, &b, &config).expect("procrustes ok");
664 let det = mat_det(&result.rotation);
665 assert!((det - 1.0).abs() < TOL, "det(R) should be +1, got {det}");
666 }
667
668 #[test]
669 fn test_procrustes_scale_translation() {
670 // Apply scale 2.0 and translation [3, -1], then recover
671 let a = array![[0.0_f64, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]];
672 let scale_true = 2.0_f64;
673 let translation = array![3.0_f64, -1.0];
674 let b: Array2<f64> = &a * scale_true + &translation.view().insert_axis(Axis(0));
675
676 let config = ProcrustesConfig::default();
677 let result = orthogonal_procrustes(&a, &b, &config).expect("procrustes ok");
678 assert!(
679 result.disparity < TOL,
680 "residual should be near 0, got {}",
681 result.disparity
682 );
683 assert!(
684 (result.scale - scale_true).abs() < TOL,
685 "scale should be {scale_true}, got {}",
686 result.scale
687 );
688 }
689
690 #[test]
691 fn test_procrustes_identity() {
692 // Aligning A to itself should give identity rotation and zero residual
693 let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
694 let config = ProcrustesConfig::default();
695 let result = orthogonal_procrustes(&a, &a, &config).expect("procrustes ok");
696 assert!(
697 result.disparity < TOL,
698 "residual for A→A should be 0, got {}",
699 result.disparity
700 );
701 }
702
703 #[test]
704 fn test_procrustes_shape_mismatch_error() {
705 let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
706 let b = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0]];
707 let config = ProcrustesConfig::default();
708 let result = orthogonal_procrustes(&a, &b, &config);
709 assert!(result.is_err(), "mismatched shapes should produce an error");
710 }
711
712 // ------------------------------------------------------------------
713 // Generalized Procrustes
714 // ------------------------------------------------------------------
715
716 #[test]
717 fn test_generalized_procrustes() {
718 // Create 4 rotated versions of the same square
719 let base = array![[1.0_f64, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
720
721 let angles = [0.0_f64, 0.3, 0.7, 1.2];
722 let matrices: Vec<Array2<f64>> = angles.iter().map(|&a| base.dot(&rot2(a))).collect();
723
724 let results = generalized_procrustes(&matrices, 100, 1e-8).expect("GPA should converge");
725 assert_eq!(results.len(), matrices.len());
726
727 // Each result should have reasonably small disparity
728 for (i, r) in results.iter().enumerate() {
729 assert!(
730 r.disparity < 1.0,
731 "GPA result {i} disparity {:.4} should be small",
732 r.disparity
733 );
734 }
735 }
736
737 #[test]
738 fn test_generalized_procrustes_too_few_matrices() {
739 let m = array![[1.0_f64, 0.0]];
740 let result = generalized_procrustes(&[m], 100, 1e-8);
741 assert!(result.is_err(), "single matrix should error");
742 }
743
744 #[test]
745 fn test_generalized_procrustes_shape_mismatch() {
746 let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
747 let b = array![[1.0_f64, 0.0, 0.0]]; // different ncols
748 let result = generalized_procrustes(&[a, b], 100, 1e-8);
749 assert!(result.is_err(), "shape mismatch should error");
750 }
751
752 // ------------------------------------------------------------------
753 // Determinant helper
754 // ------------------------------------------------------------------
755
756 #[test]
757 fn test_det_2x2() {
758 let m = array![[3.0_f64, 1.0], [5.0, 2.0]];
759 let det = mat_det(&m);
760 assert!((det - 1.0).abs() < 1e-12, "2x2 det should be 1, got {det}");
761 }
762
763 #[test]
764 fn test_det_3x3() {
765 let m = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]];
766 let det = mat_det(&m);
767 assert!((det - (-3.0)).abs() < 1e-10, "det should be -3, got {det}");
768 }
769}