strapdown/linalg.rs
1//! Linear algebra helpers for robust covariance square roots.
2//!
3//! Public API:
4//! pub fn matrix_square_root(matrix: &DMatrix<f64>) -> DMatrix<f64>
5//!
6//! Internal pipeline (each step isolated for testing):
7//! - symmetrize()
8//! - chol_sqrt()
9//! - chol_sqrt_with_jitter()
10//! - evd_symmetric_sqrt_with_floor()
11//!
12//! Strategy:
13//! 1) Symmetrize P ← 0.5 (P + Pᵀ)
14//! 2) Cholesky
15//! 3) Jittered Cholesky (geometric ramp)
16//! 4) Symmetric EVD with eigenvalue floor → S = U * sqrt(Λ⁺) * Uᵀ
17
18use nalgebra::DMatrix;
19use nalgebra::linalg::{Cholesky, SymmetricEigen};
20
21/// Compute a robust symmetric square root `S` such that approximately `matrix ≈ S * Sᵀ`.
22///
23/// Attempts Cholesky decomposition first (yielding L such that matrix = L * L^T).
24/// If Cholesky fails (e.g., matrix is not positive definite), it attempts to compute
25/// the square root using eigenvalue decomposition (S = V * sqrt(D) * V^T).
26///
27/// # Arguments
28/// * `matrix` - The DMatrix<f64> to find the square root of. It's assumed to be symmetric and square.
29///
30/// # Returns
31/// * `Some(DMatrix<f64>)` containing a matrix square root.
32/// The result from Cholesky is lower triangular. The result from eigenvalue decomposition is symmetric.
33/// In both cases, if the result is `M`, then `matrix` approx `M * M.transpose()`.
34/// * `None` if the matrix is not square or another fundamental issue prevents computation (though
35/// this implementation tries to be robust for positive semi-definite cases).
36pub fn matrix_square_root(matrix: &DMatrix<f64>) -> DMatrix<f64> {
37 assert!(
38 matrix.is_square(),
39 "matrix_square_root: matrix must be square"
40 );
41
42 // Tunable guards (conservative defaults for double precision INS scales)
43 const INITIAL_JITTER: f64 = 1e-12;
44 const MAX_JITTER: f64 = 1e-6;
45 const MAX_TRIES: usize = 6;
46 const EIGEN_FLOOR: f64 = 1e-12;
47
48 // 1) Symmetrize to kill round-off asymmetry
49 let p = symmetrize(matrix);
50
51 // 2) Cholesky (fast path)
52 if let Some(s) = chol_sqrt(&p) {
53 return s;
54 }
55
56 // 3) Jittered Cholesky
57 if let Some(s) = chol_sqrt_with_jitter(&p, INITIAL_JITTER, MAX_JITTER, MAX_TRIES) {
58 return s;
59 }
60
61 // 4) EVD fallback with eigenvalue floor — symmetric square root
62 return evd_symmetric_sqrt_with_floor(&p, EIGEN_FLOOR);
63}
64/// Symmetrize a matrix: P ← 0.5 (P + Pᵀ)
65///
66/// Simple matrix symmetrization function that reduces round-off errors associated
67/// with floating point arithmetic.
68///
69/// # Arguments
70/// * `m` - the matrix to symmetrize
71///
72/// # Returns
73/// A symmetrized version of the input matrix.
74#[inline]
75pub fn symmetrize(m: &DMatrix<f64>) -> DMatrix<f64> {
76 0.5 * (m + m.transpose())
77}
78/// Plain Cholesky square root
79///
80/// Cholesky factorization that returns L such that P ≈ L Lᵀ, or None if it fails.
81/// This is a quick way to initially attempt to calculate a matrix square root.
82///
83/// # Arguments
84/// * ``p` - the matrix to factor
85///
86/// # Returns
87/// A lower triangular matrix L such that P ≈ L Lᵀ, or None if it fails.
88fn chol_sqrt(p: &DMatrix<f64>) -> Option<DMatrix<f64>> {
89 Cholesky::new(p.clone()).map(|ch| ch.l().into_owned())
90}
91/// Cholesky with diagonal jitter (geometric ramp). Returns None if all tries fail.
92///
93/// Perform Cholesky decomposition with a jittered diagonal on a geometric ramp up.
94/// Returns None if all tries fail.
95fn chol_sqrt_with_jitter(
96 p: &DMatrix<f64>,
97 initial_jitter: f64,
98 max_jitter: f64,
99 max_tries: usize,
100) -> Option<DMatrix<f64>> {
101 let n = p.nrows();
102 let mut jitter = initial_jitter;
103 for _ in 0..max_tries {
104 let mut pj = p.clone();
105 for i in 0..n {
106 pj[(i, i)] += jitter;
107 }
108 if let Some(ch) = Cholesky::new(pj) {
109 return Some(ch.l().into_owned());
110 }
111 jitter *= 10.0;
112 if jitter > max_jitter {
113 break;
114 }
115 }
116 None
117}
118
119/// Symmetric EVD square root with eigenvalue flooring:
120/// S = U * sqrt(max(λ, floor)) * Uᵀ
121fn evd_symmetric_sqrt_with_floor(p: &DMatrix<f64>, floor: f64) -> DMatrix<f64> {
122 let se = SymmetricEigen::new(p.clone());
123 let mut lambdas = se.eigenvalues;
124 let u = se.eigenvectors;
125
126 for i in 0..lambdas.len() {
127 if lambdas[i] < floor {
128 lambdas[i] = floor;
129 }
130 }
131
132 let sqrt_vals = lambdas.map(|l| l.sqrt());
133 let sigma_half = DMatrix::<f64>::from_diagonal(&sqrt_vals);
134 &u * sigma_half * u.transpose()
135}
136
137#[derive(Debug, Clone, Copy)]
138pub struct SolveOptions {
139 pub initial_jitter: f64, // e.g., 1e-12
140 pub max_jitter: f64, // e.g., 1e-6
141 pub max_tries: usize, // e.g., 6
142}
143
144impl Default for SolveOptions {
145 fn default() -> Self {
146 Self {
147 initial_jitter: 1e-12,
148 max_jitter: 1e-6,
149 max_tries: 6,
150 }
151 }
152}
153/// Solve A X = B for SPD-ish A via Cholesky, with jitter retries.
154/// Returns None if all attempts fail.
155pub fn chol_solve_spd(
156 a: &DMatrix<f64>,
157 b: &DMatrix<f64>,
158 opt: SolveOptions,
159) -> Option<DMatrix<f64>> {
160 assert!(a.is_square(), "chol_solve_spd: A must be square");
161 assert_eq!(a.nrows(), b.nrows(), "chol_solve_spd: A and B incompatible");
162
163 // Symmetrize first (SPD drift is common).
164 let a_sym = symmetrize(a);
165
166 // Try plain Cholesky
167 if let Some(ch) = Cholesky::new(a_sym.clone()) {
168 return Some(ch.solve(b));
169 }
170
171 // Jitter ramp
172 let n = a_sym.nrows();
173 let mut jitter = opt.initial_jitter;
174 for _ in 0..opt.max_tries {
175 let mut a_j = a_sym.clone();
176 for i in 0..n {
177 a_j[(i, i)] += jitter;
178 }
179 if let Some(ch) = Cholesky::new(a_j) {
180 return Some(ch.solve(b));
181 }
182 jitter *= 10.0;
183 if jitter > opt.max_jitter {
184 break;
185 }
186 }
187 None
188}
189
190/// Robust SPD solve with sane defaults:
191/// - Cholesky + jitter (preferred)
192/// - Last resort: explicit inverse
193pub fn robust_spd_solve(a: &DMatrix<f64>, b: &DMatrix<f64>) -> DMatrix<f64> {
194 if let Some(x) = chol_solve_spd(a, b, SolveOptions::default()) {
195 x
196 } else if let Some(inv) = symmetrize(a).try_inverse() {
197 &inv * b
198 } else {
199 panic!("robust_spd_solve: A is not invertible (even after jitter).");
200 }
201}
202
203/* =============================== Tests ==================================== */
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn approx_eq(a: &DMatrix<f64>, b: &DMatrix<f64>, tol: f64) -> bool {
210 if a.shape() != b.shape() {
211 return false;
212 }
213 let mut max_abs = 0.0f64;
214 for i in 0..a.nrows() {
215 for j in 0..a.ncols() {
216 max_abs = max_abs.max((a[(i, j)] - b[(i, j)]).abs());
217 }
218 }
219 max_abs <= tol
220 }
221
222 #[test]
223 fn t_symmetrize() {
224 let m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 0.0, 3.0]);
225 let s = symmetrize(&m);
226 let s_expected = DMatrix::from_row_slice(2, 2, &[1.0, 1.0, 1.0, 3.0]);
227 assert!(approx_eq(&s, &s_expected, 1e-15));
228 }
229
230 #[test]
231 fn t_chol_sqrt_spd() {
232 // P = A Aᵀ is SPD
233 let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 0.5, 0.0, 1.0, -1.0, 0.0, 0.0, 0.2]);
234 let p = &a * a.transpose();
235 let s = chol_sqrt(&p).expect("Cholesky should succeed for SPD");
236 let back = &s * s.transpose();
237 assert!(approx_eq(&back, &p, 1e-12));
238 }
239
240 #[test]
241 fn t_chol_sqrt_with_jitter() {
242 // Nudge diagonal a hair negative to break plain Cholesky
243 let a = DMatrix::from_row_slice(3, 3, &[1.0, 0.2, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 1.0]);
244 let mut p = &a * a.transpose();
245 p[(2, 2)] -= 1e-10;
246
247 //assert!(chol_sqrt(&p).is_none(), "plain Cholesky should fail here");
248 let s =
249 chol_sqrt_with_jitter(&p, 1e-12, 1e-6, 6).expect("jittered Cholesky should succeed");
250 let back = &s * s.transpose();
251 let p_sym = symmetrize(&p);
252 assert!(approx_eq(&back, &p_sym, 1e-8));
253 }
254
255 #[test]
256 fn t_evd_floor() {
257 // Make P symmetric but with a negative eigenvalue, EVD should floor it.
258 let p = DMatrix::from_row_slice(2, 2, &[0.0, 1.0, 1.0, 0.0]); // eigenvalues {+1, -1}
259 let s = evd_symmetric_sqrt_with_floor(&p, 1e-12);
260 let back = &s * s.transpose();
261 // back should be PSD and close to symmetrized p with floor effects
262 let p_sym = symmetrize(&p);
263 assert_eq!(back.nrows(), p_sym.nrows());
264 assert_eq!(back.ncols(), p_sym.ncols());
265 // sanity: back is symmetric
266 assert!(approx_eq(&back, &back.transpose(), 1e-14));
267 }
268
269 #[test]
270 fn t_public_identity() {
271 let i = DMatrix::<f64>::identity(4, 4);
272 let s = matrix_square_root(&i);
273 assert!(approx_eq(&s, &i, 1e-14));
274 let back = &s * s.transpose();
275 assert!(approx_eq(&back, &i, 1e-12));
276 }
277
278 #[test]
279 fn t_public_nearly_spd() {
280 let a = DMatrix::from_row_slice(3, 3, &[1.0, 0.1, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 1.0]);
281 let mut p = &a * a.transpose();
282 p[(2, 2)] -= 1e-10;
283 p[(0, 2)] += 1e-12; // asymmetry
284
285 let s = matrix_square_root(&p);
286 let back = &s * s.transpose();
287 let p_sym = symmetrize(&p);
288 assert!(approx_eq(&back, &p_sym, 1e-8));
289 }
290
291 #[test]
292 #[should_panic]
293 fn t_public_non_square_panics() {
294 let m = DMatrix::<f64>::zeros(3, 2);
295 let _ = matrix_square_root(&m);
296 }
297}
298
299// ============ OLD ====================================
300
301// Calculates a square root of a symmetric matrix.
302//
303// Attempts Cholesky decomposition first (yielding L such that matrix = L * L^T).
304// If Cholesky fails (e.g., matrix is not positive definite), it attempts to compute
305// the square root using eigenvalue decomposition (S = V * sqrt(D) * V^T).
306// For eigenvalue decomposition, eigenvalues are clamped to be non-negative.
307//
308// # Arguments
309// * `matrix` - The DMatrix<f64> to find the square root of. It's assumed to be symmetric and square.
310//
311// # Returns
312// * `Some(DMatrix<f64>)` containing a matrix square root.
313// The result from Cholesky is lower triangular. The result from eigenvalue decomposition is symmetric.
314// In both cases, if the result is `M`, then `matrix` approx `M * M.transpose()`.
315// * `None` if the matrix is not square or another fundamental issue prevents computation (though
316// this implementation tries to be robust for positive semi-definite cases).
317//pub fn matrix_square_root(matrix: &DMatrix<f64>) -> DMatrix<f64> {
318// if !matrix.is_square() {
319// panic!("Error: Matrix must be square to compute square root.");
320// }
321// // Attempt Cholesky decomposition (yields L where matrix = L * L^T)
322// // Cholesky requires the matrix to be symmetric positive definite.
323// match cholesky_pass(matrix) {
324// Some(chol_l) => {
325// return chol_l;
326// }
327// None => {
328// //println!("Cholesky decomposition failed. Attempting eigenvalue decomposition.");
329// }
330// }
331// // If Cholesky failed, we try eigenvalue decomposition.
332// match eigenvalue_pass(matrix) {
333// Some(eigen_sqrt) => eigen_sqrt,
334// None => {
335// panic!(
336// "Cholesky and Eigenvalue decomposition failed. No valid square root found for the covariance matrix: \n {:?}",
337// matrix
338// );
339// }
340// }
341//}
342// Attempts to compute the matrix square root using Cholesky decomposition.
343//
344// This method is only applicable to symmetric positive definite matrices.
345// If successful, it returns the lower triangular matrix `L` such that `matrix = L * L.transpose()`.
346//
347// When the computation _fails_ (e.g., the matrix is not positive definite or not square),
348// a None value is returned instead of panicking, permitting the public API to proceed to the
349// next method.
350//
351// # Arguments
352// * `matrix` - The DMatrix<f64> to find the square root of. Assumed to be symmetric and square.
353//
354// # Returns
355// * `Some(DMatrix<f64>)` containing the lower triangular Cholesky factor `L`.
356// * `None` if the matrix is not positive definite or not square.
357// fn cholesky_pass(matrix: &DMatrix<f64>) -> Option<DMatrix<f64>> {
358// if !matrix.is_square() {
359// eprintln!("Error: Matrix must be square for Cholesky decomposition.");
360// return None;
361// }
362// matrix
363// .clone()
364// .cholesky()
365// .map(|chol: Cholesky<f64, nalgebra::Dyn>| chol.l())
366// }
367// Computes a symmetric matrix square root using eigenvalue decomposition.
368//
369// This method is suitable for symmetric positive semi-definite matrices.
370// It returns a symmetric matrix `S` such that `matrix = S * S`.
371// Eigenvalues are clamped to be non-negative to handle positive semi-definite cases
372// and minor numerical inaccuracies.
373//
374// When the computation _fails_ (e.g., the matrix is not positive definite or not square),
375// a None value is returned instead of panicking, permitting the public API to proceed to the
376// next method.
377//
378// # Arguments
379// * `matrix` - The DMatrix<f64> to find the square root of. Assumed to be symmetric and square.
380//
381// # Returns
382// * `Some(DMatrix<f64>)` containing the symmetric matrix square root `S`.
383// * `None` if the matrix is not square (though this should be checked by the caller for symmetry assumptions).
384// fn eigenvalue_pass(matrix: &DMatrix<f64>) -> Option<DMatrix<f64>> {
385// if !matrix.is_square() {
386// eprintln!("Error: Matrix must be square for eigenvalue decomposition based square root.");
387// return None;
388// }
389// // For eigenvalue decomposition of a symmetric matrix,
390// // we use `symmetric_eigen`. This returns real eigenvalues and orthogonal eigenvectors.
391// let eigen_decomposition: SymmetricEigen<f64, nalgebra::Dyn> = matrix.clone().symmetric_eigen();
392// let eigenvalues = eigen_decomposition.eigenvalues;
393// let eigenvectors = eigen_decomposition.eigenvectors;
394//
395// // Check for significantly negative eigenvalues, indicating non-positive semi-definiteness.
396// // While we clamp them, a warning is useful for diagnosis.
397// if eigenvalues.iter().any(|&val| val < -1e-9) {
398// println!(
399// "Warning: Negative eigenvalues encountered during eigenvalue decomposition. The input matrix was not positive semi-definite."
400// );
401// // println!("{:?}", matrix.data);
402// // // return None;
403// }
404//
405// // Create diagonal matrix of sqrt(eigenvalues), clamping eigenvalues to be non-negative.
406// // `DMatrix::from_diagonal` takes a DVector.
407// let sqrt_eigenvalues_diag_vec = eigenvalues.map(|val| val.max(1e-9).sqrt());
408// let sqrt_eigenvalues_diag = DMatrix::from_diagonal(&sqrt_eigenvalues_diag_vec);
409//
410// // Reconstruct the square root: S = V * sqrt(D) * V^T
411// // This S will be symmetric, and S * S = matrix (or S * S^T = matrix).
412// let sqrt_m = eigenvectors.clone() * sqrt_eigenvalues_diag * eigenvectors.transpose();
413//
414// Some(sqrt_m)
415// }
416//
417// #[cfg(test)]
418// mod tests {
419// use super::*;
420// use nalgebra::DMatrix;
421// use std::sync::LazyLock;
422//
423// static BASIC_SQRT: LazyLock<DMatrix<f64>> = LazyLock::new(|| {
424// DMatrix::from_row_slice(3, 3, &[4.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0, 16.0])
425// });
426// static POSITIVE_DEFINITE: LazyLock<DMatrix<f64>> = LazyLock::new(|| {
427// DMatrix::from_row_slice(3, 3, &[4.0, 2.0, 0.0, 2.0, 9.0, 3.0, 0.0, 3.0, 16.0])
428// });
429// static POSITIVE_SEMI_DEFINITE: LazyLock<DMatrix<f64>> = LazyLock::new(|| {
430// DMatrix::from_row_slice(3, 3, &[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0])
431// });
432// static NEGATIVE_DEFINITE: LazyLock<DMatrix<f64>> = LazyLock::new(|| {
433// DMatrix::from_row_slice(3, 3, &[-4.0, 0.0, 0.0, 0.0, -9.0, 0.0, 0.0, 0.0, -16.0])
434// });
435// static NEGATIVE_SEMI_DEFINITE: LazyLock<DMatrix<f64>> = LazyLock::new(|| {
436// DMatrix::from_row_slice(3, 3, &[-1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0])
437// });
438// static NON_SQUARE: LazyLock<DMatrix<f64>> =
439// LazyLock::new(|| DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
440//
441// /// Helper function to verify if a matrix is a valid square root of another matrix.
442// /// Returns true if sqrt_matrix * sqrt_matrix.T ≈ original_matrix within tolerance.
443// fn is_valid_square_root(
444// sqrt_matrix: &DMatrix<f64>,
445// original_matrix: &DMatrix<f64>,
446// tolerance: f64,
447// ) -> bool {
448// let reconstructed = sqrt_matrix * sqrt_matrix.transpose();
449//
450// if reconstructed.nrows() != original_matrix.nrows()
451// || reconstructed.ncols() != original_matrix.ncols()
452// {
453// return false;
454// }
455//
456// for i in 0..original_matrix.nrows() {
457// for j in 0..original_matrix.ncols() {
458// if (reconstructed[(i, j)] - original_matrix[(i, j)]).abs() > tolerance {
459// return false;
460// }
461// }
462// }
463// true
464// }
465// // Test matrix square root calculation
466// #[test]
467// fn cholesky_square_root() {
468// let sqrt_matrix = matrix_square_root(&BASIC_SQRT);
469// assert!(is_valid_square_root(&sqrt_matrix, &BASIC_SQRT, 1e-9));
470// }
471// #[test]
472// fn cholesky_positive_definite() {
473// let sqrt_matrix = matrix_square_root(&POSITIVE_DEFINITE);
474// assert!(is_valid_square_root(&sqrt_matrix, &POSITIVE_DEFINITE, 1e-9));
475// }
476// #[test]
477// #[should_panic]
478// fn cholesky_negative_definite() {
479// // This should panic because the matrix is negative definite.
480// let _sqrt_matrix = matrix_square_root(&NEGATIVE_DEFINITE);
481// }
482// #[test]
483// #[should_panic]
484// fn cholesky_negative_semi_definite() {
485// // This should panic because the matrix is negative semi-definite.
486// let _sqrt_matrix = matrix_square_root(&NEGATIVE_SEMI_DEFINITE);
487// }
488// #[test]
489// #[should_panic]
490// fn cholesky_non_square() {
491// // This should panic because the matrix is not square.
492// let _sqrt_matrix = matrix_square_root(&NON_SQUARE);
493// }
494// #[test]
495// fn eigenvalue_square_root() {
496// let sqrt_matrix = matrix_square_root(&POSITIVE_SEMI_DEFINITE);
497// assert!(is_valid_square_root(
498// &sqrt_matrix,
499// &POSITIVE_SEMI_DEFINITE,
500// 1e-9
501// ));
502// }
503// #[test]
504// fn eigenvalue_positive_definite() {
505// let sqrt_matrix = matrix_square_root(&POSITIVE_DEFINITE);
506// assert!(is_valid_square_root(&sqrt_matrix, &POSITIVE_DEFINITE, 1e-9));
507// }
508// #[test]
509// fn eigenvalue_positive_semi_definite() {
510// let sqrt_matrix = matrix_square_root(&POSITIVE_SEMI_DEFINITE);
511// assert!(is_valid_square_root(
512// &sqrt_matrix,
513// &POSITIVE_SEMI_DEFINITE,
514// 1e-9
515// ));
516// }
517// #[test]
518// #[should_panic]
519// fn eigenvalue_negative_definite() {
520// // This should panic because the matrix is negative definite.
521// let _sqrt_matrix = matrix_square_root(&NEGATIVE_DEFINITE);
522// }
523// #[test]
524// #[should_panic]
525// fn eigenvalue_negative_semi_definite() {
526// // This should panic because the matrix is negative semi-definite.
527// let _sqrt_matrix = matrix_square_root(&NEGATIVE_SEMI_DEFINITE);
528// }
529// #[test]
530// #[should_panic]
531// fn eigenvalue_non_square() {
532// // This should panic because the matrix is not square.
533// let _sqrt_matrix = matrix_square_root(&NON_SQUARE);
534// }
535// #[test]
536// fn public_api_square_root() {
537// let sqrt_matrix = matrix_square_root(&POSITIVE_DEFINITE);
538// assert!(is_valid_square_root(&sqrt_matrix, &POSITIVE_DEFINITE, 1e-9));
539// let sqrt_matrix = matrix_square_root(&POSITIVE_SEMI_DEFINITE);
540// assert!(is_valid_square_root(
541// &sqrt_matrix,
542// &POSITIVE_SEMI_DEFINITE,
543// 1e-9
544// ));
545// let sqrt_matrix = matrix_square_root(&BASIC_SQRT);
546// assert!(is_valid_square_root(&sqrt_matrix, &BASIC_SQRT, 1e-9));
547// }
548// #[test]
549// #[should_panic]
550// fn public_api_negative_definite() {
551// // This should panic because the matrix is negative definite.
552// let _sqrt_matrix = matrix_square_root(&NEGATIVE_DEFINITE);
553// }
554// #[test]
555// #[should_panic]
556// fn public_api_negative_semi_definite() {
557// // This should panic because the matrix is negative semi-definite.
558// let _sqrt_matrix = matrix_square_root(&NEGATIVE_SEMI_DEFINITE);
559// }
560// #[test]
561// #[should_panic]
562// fn public_api_non_square() {
563// // This should panic because the matrix is not square.
564// let _sqrt_matrix = matrix_square_root(&NON_SQUARE);
565// }
566// }
567//