Skip to main content

sidereon_core/astro/math/
linear.rs

1//! Deterministic small linear-algebra kernels.
2//!
3//! These routines keep scalar operation order explicit for parity-sensitive
4//! GNSS callers. When pivot tie-breaking or accumulation order matters, the
5//! variant name states the policy instead of hiding it in a local copy.
6
7use crate::astro::tolerances::PIVOT_EPSILON;
8use crate::validate;
9
10#[derive(Debug, Default, Clone)]
11pub struct FlatLinearScratch {
12    rows: Vec<f64>,
13    x: Vec<f64>,
14}
15
16#[derive(Debug, Default, Clone)]
17pub struct FlatNormalSolveScratch {
18    a: Vec<f64>,
19    b: Vec<f64>,
20    x: Vec<f64>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
24pub enum LinearError {
25    #[error("invalid linear algebra {field}: {reason}")]
26    InvalidInput {
27        field: &'static str,
28        reason: &'static str,
29    },
30}
31
32#[allow(clippy::needless_range_loop)]
33pub fn solve_linear_first_tie(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
34    let n = validate_dense_system(a, b)?;
35    let mut rows: Vec<Vec<f64>> = a
36        .iter()
37        .zip(b)
38        .map(|(row, &bi)| {
39            let mut r = row.clone();
40            r.push(bi);
41            r
42        })
43        .collect();
44
45    for col in 0..n {
46        let mut pivot_row = col;
47        let mut pivot_abs = rows[col][col].abs();
48        for idx in (col + 1)..n {
49            let v = rows[idx][col].abs();
50            if v > pivot_abs {
51                pivot_abs = v;
52                pivot_row = idx;
53            }
54        }
55        if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
56            return None;
57        }
58        rows.swap(col, pivot_row);
59
60        let pivot = rows[col].clone();
61        let pivot_value = pivot[col];
62        for idx in (col + 1)..n {
63            let factor = rows[idx][col] / pivot_value;
64            for j in 0..=n {
65                rows[idx][j] -= factor * pivot[j];
66            }
67        }
68    }
69
70    let mut x = vec![0.0; n];
71    for i in (0..n).rev() {
72        let mut known = 0.0;
73        for j in (i + 1)..n {
74            known += rows[i][j] * x[j];
75        }
76        x[i] = (rows[i][n] - known) / rows[i][i];
77    }
78    validate::finite_slice(&x, "solution").ok()?;
79    Some(x)
80}
81
82#[allow(clippy::needless_range_loop)]
83pub fn solve_linear_last_tie(mut a: Vec<Vec<f64>>, b: Vec<f64>) -> Option<Vec<f64>> {
84    let n = validate_dense_system(&a, &b)?;
85    for (row, bi) in a.iter_mut().zip(b) {
86        row.push(bi);
87    }
88    for col in 0..n {
89        let (pivot_row, pivot_abs) = (col..n)
90            .map(|idx| (idx, a[idx][col].abs()))
91            .max_by(|lhs, rhs| lhs.1.total_cmp(&rhs.1))
92            .unwrap();
93        if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
94            return None;
95        }
96        a.swap(col, pivot_row);
97        let pivot = a[col].clone();
98        let pivot_value = pivot[col];
99        for row in a.iter_mut().take(n).skip(col + 1) {
100            let factor = row[col] / pivot_value;
101            for j in col..=n {
102                row[j] -= factor * pivot[j];
103            }
104        }
105    }
106    let mut x = vec![0.0; n];
107    for i in (0..n).rev() {
108        let tail_sum: f64 = ((i + 1)..n).map(|j| a[i][j] * x[j]).sum();
109        x[i] = (a[i][n] - tail_sum) / a[i][i];
110    }
111    validate::finite_slice(&x, "solution").ok()?;
112    Some(x)
113}
114
115pub fn invert_matrix_first_tie(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
116    let n = a.len();
117    if n == 0 {
118        return None;
119    }
120    let mut columns: Vec<Vec<f64>> = Vec::with_capacity(n);
121    for col in 0..n {
122        let mut e = vec![0.0; n];
123        e[col] = 1.0;
124        columns.push(solve_linear_first_tie(a, &e)?);
125    }
126    Some(
127        (0..n)
128            .map(|i| (0..n).map(|j| columns[j][i]).collect())
129            .collect(),
130    )
131}
132
133pub fn invert_matrix_last_tie(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
134    let n = a.len();
135    let mut columns = Vec::with_capacity(n);
136    for col in 0..n {
137        let unit = (0..n)
138            .map(|idx| if idx == col { 1.0 } else { 0.0 })
139            .collect();
140        columns.push(solve_linear_last_tie(a.to_vec(), unit)?);
141    }
142    transpose(&columns)
143}
144
145pub fn solve_matrix_last_tie(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
146    let columns = transpose(b)?;
147    let mut solved_columns = Vec::with_capacity(columns.len());
148    for col in columns {
149        solved_columns.push(solve_linear_last_tie(a.to_vec(), col)?);
150    }
151    transpose(&solved_columns)
152}
153
154pub fn normal_equations_weighted<'a, I>(rows: I, n: usize) -> Option<(Vec<Vec<f64>>, Vec<f64>)>
155where
156    I: IntoIterator<Item = (&'a [f64], f64, f64)>,
157{
158    if n == 0 {
159        return None;
160    }
161    let mut ata = vec![vec![0.0; n]; n];
162    let mut aty = vec![0.0; n];
163    for (row_h, row_y, row_weight) in rows {
164        if row_h.len() != n {
165            return None;
166        }
167        validate::finite_slice(row_h, "normal row").ok()?;
168        validate::finite(row_y, "normal residual").ok()?;
169        validate::finite(row_weight, "normal weight").ok()?;
170        let h: Vec<f64> = row_h.iter().map(|v| v * row_weight).collect();
171        let y = row_y * row_weight;
172        for i in 0..n {
173            aty[i] += h[i] * y;
174            for j in 0..n {
175                ata[i][j] += h[i] * h[j];
176            }
177        }
178    }
179    for row in &ata {
180        validate::finite_slice(row, "normal matrix").ok()?;
181    }
182    validate::finite_slice(&aty, "normal rhs").ok()?;
183    Some((ata, aty))
184}
185
186pub fn matrix_sub(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
187    let (rows, cols) = validate_same_shape(a, b)?;
188    let out: Vec<Vec<f64>> = a
189        .iter()
190        .zip(b)
191        .map(|(row_a, row_b)| row_a.iter().zip(row_b).map(|(x, y)| x - y).collect())
192        .collect();
193    debug_assert_eq!(out.len(), rows);
194    debug_assert!(out.iter().all(|row| row.len() == cols));
195    for row in &out {
196        validate::finite_slice(row, "matrix difference").ok()?;
197    }
198    Some(out)
199}
200
201pub fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
202    let b_t = transpose(b)?;
203    let rows = a.len();
204    let shared = b_t.first()?.len();
205    if rows == 0 || shared == 0 {
206        return None;
207    }
208    for row in a {
209        if row.len() != shared {
210            return None;
211        }
212        validate::finite_slice(row, "matrix").ok()?;
213    }
214    let out: Vec<Vec<f64>> = a
215        .iter()
216        .map(|row| {
217            b_t.iter()
218                .map(|col| row.iter().zip(col).fold(0.0, |acc, (x, y)| acc + x * y))
219                .collect()
220        })
221        .collect();
222    for row in &out {
223        validate::finite_slice(row, "matrix product").ok()?;
224    }
225    Some(out)
226}
227
228pub fn transpose(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
229    let cols = matrix.first()?.len();
230    if cols == 0 {
231        return None;
232    }
233    for row in matrix {
234        if row.len() != cols {
235            return None;
236        }
237        validate::finite_slice(row, "matrix").ok()?;
238    }
239    Some(
240        (0..cols)
241            .map(|col| matrix.iter().map(|row| row[col]).collect())
242            .collect(),
243    )
244}
245
246pub fn invert_flat_first_tie_into(
247    a: &[f64],
248    n: usize,
249    out: &mut Vec<f64>,
250    scratch: &mut FlatLinearScratch,
251) -> Option<()> {
252    validate_flat_square(a, n, "matrix")?;
253    out.resize(n * n, 0.0);
254    scratch.rows.resize(n * (n + 1), 0.0);
255    scratch.x.resize(n, 0.0);
256
257    for col in 0..n {
258        for i in 0..n {
259            let src = i * n;
260            let dst = i * (n + 1);
261            scratch.rows[dst..(dst + n)].copy_from_slice(&a[src..(src + n)]);
262            scratch.rows[dst + n] = if i == col { 1.0 } else { 0.0 };
263        }
264        solve_augmented_flat_first_tie_in_place(&mut scratch.rows, n, &mut scratch.x)?;
265        for i in 0..n {
266            out[i * n + col] = scratch.x[i];
267        }
268    }
269
270    Some(())
271}
272
273pub fn solve_matrix_flat_first_tie_into(
274    a: &[f64],
275    n: usize,
276    b: &[f64],
277    cols: usize,
278    out: &mut Vec<f64>,
279    scratch: &mut FlatLinearScratch,
280) -> Option<()> {
281    validate_flat_square(a, n, "matrix")?;
282    if cols == 0 || b.len() != n.checked_mul(cols)? {
283        return None;
284    }
285    validate::finite_slice(b, "rhs").ok()?;
286    out.resize(n.checked_mul(cols)?, 0.0);
287    scratch.rows.resize(n * (n + 1), 0.0);
288    scratch.x.resize(n, 0.0);
289
290    for col in 0..cols {
291        for i in 0..n {
292            let src = i * n;
293            let dst = i * (n + 1);
294            scratch.rows[dst..(dst + n)].copy_from_slice(&a[src..(src + n)]);
295            scratch.rows[dst + n] = b[i * cols + col];
296        }
297        solve_augmented_flat_first_tie_in_place(&mut scratch.rows, n, &mut scratch.x)?;
298        for i in 0..n {
299            out[i * cols + col] = scratch.x[i];
300        }
301    }
302    Some(())
303}
304
305#[allow(clippy::needless_range_loop)]
306pub fn solve_augmented_flat_first_tie_in_place(
307    rows: &mut [f64],
308    n: usize,
309    x: &mut [f64],
310) -> Option<()> {
311    let stride = n + 1;
312    if n == 0 || rows.len() != n.checked_mul(stride)? || x.len() != n {
313        return None;
314    }
315    validate::finite_slice(rows, "augmented matrix").ok()?;
316
317    for col in 0..n {
318        let mut pivot_row = col;
319        let mut pivot_abs = rows[col * stride + col].abs();
320        for idx in (col + 1)..n {
321            let v = rows[idx * stride + col].abs();
322            if v > pivot_abs {
323                pivot_abs = v;
324                pivot_row = idx;
325            }
326        }
327        if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
328            return None;
329        }
330        if pivot_row != col {
331            for j in 0..=n {
332                rows.swap(col * stride + j, pivot_row * stride + j);
333            }
334        }
335
336        let pivot_value = rows[col * stride + col];
337        for idx in (col + 1)..n {
338            let factor = rows[idx * stride + col] / pivot_value;
339            for j in 0..=n {
340                rows[idx * stride + j] -= factor * rows[col * stride + j];
341            }
342        }
343    }
344
345    for i in (0..n).rev() {
346        let mut known = 0.0;
347        for j in (i + 1)..n {
348            known += rows[i * stride + j] * x[j];
349        }
350        x[i] = (rows[i * stride + n] - known) / rows[i * stride + i];
351    }
352
353    validate::finite_slice(x, "solution").ok()?;
354    Some(())
355}
356
357pub fn solve_flat_normal_first_tie(lambda: &[f64], eta: &[f64]) -> Option<Vec<f64>> {
358    let mut scratch = FlatNormalSolveScratch::default();
359    solve_flat_normal_first_tie_into(lambda, eta, &mut scratch).map(<[f64]>::to_vec)
360}
361
362#[allow(clippy::needless_range_loop)]
363pub fn solve_flat_normal_first_tie_into<'a>(
364    lambda: &[f64],
365    eta: &[f64],
366    scratch: &'a mut FlatNormalSolveScratch,
367) -> Option<&'a [f64]> {
368    let n = eta.len();
369    if n == 0 || lambda.len() != n.checked_mul(n)? {
370        return None;
371    }
372    validate::finite_slice(lambda, "normal matrix").ok()?;
373    validate::finite_slice(eta, "normal rhs").ok()?;
374
375    scratch.a.resize(n * n, 0.0);
376    scratch.a.copy_from_slice(lambda);
377    scratch.b.resize(n, 0.0);
378    scratch.b.copy_from_slice(eta);
379
380    for k in 0..n {
381        let mut pivot = k;
382        let mut pivot_abs = scratch.a[k * n + k].abs();
383        for i in (k + 1)..n {
384            let candidate = scratch.a[i * n + k].abs();
385            if candidate > pivot_abs {
386                pivot = i;
387                pivot_abs = candidate;
388            }
389        }
390        if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
391            return None;
392        }
393        if pivot != k {
394            for j in 0..n {
395                scratch.a.swap(k * n + j, pivot * n + j);
396            }
397            scratch.b.swap(k, pivot);
398        }
399
400        let diag = scratch.a[k * n + k];
401        for i in (k + 1)..n {
402            let factor = scratch.a[i * n + k] / diag;
403            scratch.a[i * n + k] = 0.0;
404            for j in (k + 1)..n {
405                scratch.a[i * n + j] -= factor * scratch.a[k * n + j];
406            }
407            scratch.b[i] -= factor * scratch.b[k];
408        }
409    }
410
411    scratch.x.resize(n, 0.0);
412    for i in (0..n).rev() {
413        let mut known = 0.0;
414        for j in (i + 1)..n {
415            known += scratch.a[i * n + j] * scratch.x[j];
416        }
417        scratch.x[i] = (scratch.b[i] - known) / scratch.a[i * n + i];
418    }
419    validate::finite_slice(&scratch.x, "solution").ok()?;
420    Some(&scratch.x)
421}
422
423/// Reusable buffers for the owned Cholesky (square-root) solve
424/// ([`solve_flat_normal_square_root_into`]): the lower-triangular factor `L`
425/// (row-major `n x n`), the forward-substitution vector `z`, and the solution
426/// `x`. Held across solves so a steady-state iteration does not allocate.
427#[derive(Debug, Default, Clone)]
428pub struct FlatCholeskySolveScratch {
429    l: Vec<f64>,
430    z: Vec<f64>,
431    x: Vec<f64>,
432}
433
434/// Solve the symmetric positive-definite information system `Λ x = η` by an owned
435/// deterministic Cholesky (square-root) factorization `Λ = L Lᵀ`, then forward
436/// substitution `L z = η` and back substitution `Lᵀ x = z`. `lambda` is the
437/// row-major `n x n` information matrix, `eta` the length-`n` information vector.
438///
439/// The Cholesky factor `L` is the information-matrix square root, so this is the
440/// square-root-information solve. Unlike the general first-tie Gaussian
441/// elimination ([`solve_flat_normal_first_tie_into`]) it needs no pivoting: the
442/// system is SPD, so the fixed `i`/`j`/`k` reduction order (identical to
443/// [`invert_symmetric_pd`]) is the entire op-order and the result is
444/// bit-reproducible with no pivot-dependent branching. Returns `None` if `Λ` is
445/// not positive definite (a non-positive or non-finite pivot), which for a
446/// weighted least-squares normal matrix means rank-deficient geometry.
447#[allow(clippy::needless_range_loop)]
448pub fn solve_flat_normal_square_root_into<'a>(
449    lambda: &[f64],
450    eta: &[f64],
451    scratch: &'a mut FlatCholeskySolveScratch,
452) -> Option<&'a [f64]> {
453    let n = eta.len();
454    if n == 0 || lambda.len() != n.checked_mul(n)? {
455        return None;
456    }
457    validate::finite_slice(lambda, "normal matrix").ok()?;
458    validate::finite_slice(eta, "normal rhs").ok()?;
459    validate_flat_symmetric(lambda, n)?;
460    scratch.l.resize(n * n, 0.0);
461    scratch.l.fill(0.0);
462
463    // Cholesky Λ = L Lᵀ, the same factorization order as `invert_symmetric_pd`.
464    for i in 0..n {
465        for j in 0..=i {
466            let mut s = lambda[i * n + j];
467            for k in 0..j {
468                s -= scratch.l[i * n + k] * scratch.l[j * n + k];
469            }
470            if i == j {
471                #[allow(clippy::neg_cmp_op_on_partial_ord)]
472                let nonpositive_or_nan = !(s > 0.0);
473                if nonpositive_or_nan || !s.is_finite() {
474                    return None;
475                }
476                scratch.l[i * n + j] = s.sqrt();
477            } else {
478                scratch.l[i * n + j] = s / scratch.l[j * n + j];
479            }
480        }
481    }
482
483    // Forward substitution L z = η.
484    scratch.z.resize(n, 0.0);
485    for i in 0..n {
486        let mut s = eta[i];
487        for k in 0..i {
488            s -= scratch.l[i * n + k] * scratch.z[k];
489        }
490        scratch.z[i] = s / scratch.l[i * n + i];
491    }
492    validate::finite_slice(&scratch.z, "solution work vector").ok()?;
493
494    // Back substitution Lᵀ x = z.
495    scratch.x.resize(n, 0.0);
496    for i in (0..n).rev() {
497        let mut s = scratch.z[i];
498        for k in (i + 1)..n {
499            s -= scratch.l[k * n + i] * scratch.x[k];
500        }
501        scratch.x[i] = s / scratch.l[i * n + i];
502    }
503    validate::finite_slice(&scratch.x, "solution").ok()?;
504    Some(scratch.x.as_slice())
505}
506
507fn validate_flat_symmetric(matrix: &[f64], n: usize) -> Option<()> {
508    let mut scale = 1.0_f64;
509    for value in matrix {
510        scale = scale.max(value.abs());
511    }
512    let tol = symmetry_tolerance(n, scale);
513    for i in 0..n {
514        for j in (i + 1)..n {
515            if (matrix[i * n + j] - matrix[j * n + i]).abs() > tol {
516                return None;
517            }
518        }
519    }
520    Some(())
521}
522
523#[allow(clippy::needless_range_loop)]
524fn validate_rows_symmetric(matrix: &[Vec<f64>]) -> Option<()> {
525    let n = matrix.len();
526    let mut scale = 1.0_f64;
527    for row in matrix {
528        for value in row {
529            scale = scale.max(value.abs());
530        }
531    }
532    let tol = symmetry_tolerance(n, scale);
533    for i in 0..n {
534        for j in (i + 1)..n {
535            if (matrix[i][j] - matrix[j][i]).abs() > tol {
536                return None;
537            }
538        }
539    }
540    Some(())
541}
542
543fn symmetry_tolerance(n: usize, scale: f64) -> f64 {
544    128.0 * f64::EPSILON * (n.max(1) as f64) * scale.max(1.0)
545}
546
547fn validate_dense_system(a: &[Vec<f64>], b: &[f64]) -> Option<usize> {
548    let n = b.len();
549    if n == 0 || a.len() != n {
550        return None;
551    }
552    validate::finite_slice(b, "rhs").ok()?;
553    for row in a {
554        if row.len() != n {
555            return None;
556        }
557        validate::finite_slice(row, "matrix").ok()?;
558    }
559    Some(n)
560}
561
562fn validate_same_shape(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<(usize, usize)> {
563    let rows = a.len();
564    if rows == 0 || b.len() != rows {
565        return None;
566    }
567    let cols = a.first()?.len();
568    if cols == 0 {
569        return None;
570    }
571    for row in a {
572        if row.len() != cols {
573            return None;
574        }
575        validate::finite_slice(row, "matrix").ok()?;
576    }
577    for row in b {
578        if row.len() != cols {
579            return None;
580        }
581        validate::finite_slice(row, "matrix").ok()?;
582    }
583    Some((rows, cols))
584}
585
586fn validate_flat_square(a: &[f64], n: usize, field: &'static str) -> Option<()> {
587    if n == 0 || a.len() != n.checked_mul(n)? {
588        return None;
589    }
590    validate::finite_slice(a, field).ok()
591}
592
593fn map_linear_field_error(error: validate::FieldError) -> LinearError {
594    linear_invalid_input(error.field(), error.reason())
595}
596
597fn linear_invalid_input(field: &'static str, reason: &'static str) -> LinearError {
598    LinearError::InvalidInput { field, reason }
599}
600
601#[allow(clippy::needless_range_loop)]
602pub fn normal_matrix_4_weighted_column_outer(
603    rows: &[[f64; 4]],
604    weights: &[f64],
605) -> Result<[[f64; 4]; 4], LinearError> {
606    if weights.len() != rows.len() {
607        return Err(linear_invalid_input("weights", "length must match rows"));
608    }
609    validate::finite_slice(weights, "weights").map_err(map_linear_field_error)?;
610    for row in rows {
611        validate::finite_slice(row, "rows").map_err(map_linear_field_error)?;
612    }
613
614    let mut a = [[0.0_f64; 4]; 4];
615    for i in 0..4 {
616        for j in 0..4 {
617            let mut s = 0.0_f64;
618            for k in 0..rows.len() {
619                s += rows[k][i] * weights[k] * rows[k][j];
620            }
621            a[i][j] = s;
622        }
623    }
624    for row in &a {
625        validate::finite_slice(row, "normal matrix").map_err(map_linear_field_error)?;
626    }
627    Ok(a)
628}
629
630#[allow(clippy::needless_range_loop)]
631pub fn normal_matrix_4_unweighted_row_outer(rows: &[[f64; 4]]) -> [[f64; 4]; 4] {
632    let mut a = [[0.0_f64; 4]; 4];
633    for row in rows {
634        for i in 0..4 {
635            for j in 0..4 {
636                a[i][j] += row[i] * row[j];
637            }
638        }
639    }
640    a
641}
642
643pub fn mat4_vec4(m: &[[f64; 4]; 4], v: &[f64; 4]) -> [f64; 4] {
644    [
645        dot4(&m[0], v),
646        dot4(&m[1], v),
647        dot4(&m[2], v),
648        dot4(&m[3], v),
649    ]
650}
651
652pub fn dot4(row: &[f64; 4], v: &[f64; 4]) -> f64 {
653    row[0] * v[0] + row[1] * v[1] + row[2] * v[2] + row[3] * v[3]
654}
655
656pub fn det4_cofactor(a: &[[f64; 4]; 4]) -> f64 {
657    let m01 = a[2][0] * a[3][1] - a[2][1] * a[3][0];
658    let m02 = a[2][0] * a[3][2] - a[2][2] * a[3][0];
659    let m03 = a[2][0] * a[3][3] - a[2][3] * a[3][0];
660    let m12 = a[2][1] * a[3][2] - a[2][2] * a[3][1];
661    let m13 = a[2][1] * a[3][3] - a[2][3] * a[3][1];
662    let m23 = a[2][2] * a[3][3] - a[2][3] * a[3][2];
663
664    let c0 = a[1][1] * m23 - a[1][2] * m13 + a[1][3] * m12;
665    let c1 = a[1][0] * m23 - a[1][2] * m03 + a[1][3] * m02;
666    let c2 = a[1][0] * m13 - a[1][1] * m03 + a[1][3] * m01;
667    let c3 = a[1][0] * m12 - a[1][1] * m02 + a[1][2] * m01;
668
669    a[0][0] * c0 - a[0][1] * c1 + a[0][2] * c2 - a[0][3] * c3
670}
671
672pub fn minor3_of_4(a: &[[f64; 4]; 4], skip_r: usize, skip_c: usize) -> f64 {
673    let mut rows = [0_usize; 3];
674    let mut cols = [0_usize; 3];
675    let mut row_idx = 0;
676    let mut col_idx = 0;
677    for row in 0..4 {
678        if row != skip_r {
679            rows[row_idx] = row;
680            row_idx += 1;
681        }
682    }
683    for col in 0..4 {
684        if col != skip_c {
685            cols[col_idx] = col;
686            col_idx += 1;
687        }
688    }
689
690    let b00 = a[rows[0]][cols[0]];
691    let b01 = a[rows[0]][cols[1]];
692    let b02 = a[rows[0]][cols[2]];
693    let b10 = a[rows[1]][cols[0]];
694    let b11 = a[rows[1]][cols[1]];
695    let b12 = a[rows[1]][cols[2]];
696    let b20 = a[rows[2]][cols[0]];
697    let b21 = a[rows[2]][cols[1]];
698    let b22 = a[rows[2]][cols[2]];
699
700    b00 * (b11 * b22 - b12 * b21) - b01 * (b10 * b22 - b12 * b20) + b02 * (b10 * b21 - b11 * b20)
701}
702
703#[allow(clippy::needless_range_loop)]
704pub fn invert_4x4_cofactor(a: &[[f64; 4]; 4]) -> Option<[[f64; 4]; 4]> {
705    let det = det4_cofactor(a);
706    if det == 0.0 || !det.is_finite() {
707        return None;
708    }
709
710    let mut inv = [[0.0_f64; 4]; 4];
711    for j in 0..4 {
712        for i in 0..4 {
713            let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
714            inv[j][i] = sign * minor3_of_4(a, i, j) / det;
715        }
716    }
717    if inv.iter().flatten().any(|value| !value.is_finite()) {
718        return None;
719    }
720    Some(inv)
721}
722
723pub fn invert_3x3_adjugate(m: &[[f64; 3]; 3]) -> Option<[[f64; 3]; 3]> {
724    let [[a, b, c], [d, e, f], [g, h, i]] = *m;
725    let det = a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
726    if det.abs() <= PIVOT_EPSILON || !det.is_finite() {
727        return None;
728    }
729    let inv_det = 1.0 / det;
730    let inverse = [
731        [
732            (e * i - f * h) * inv_det,
733            (c * h - b * i) * inv_det,
734            (b * f - c * e) * inv_det,
735        ],
736        [
737            (f * g - d * i) * inv_det,
738            (a * i - c * g) * inv_det,
739            (c * d - a * f) * inv_det,
740        ],
741        [
742            (d * h - e * g) * inv_det,
743            (b * g - a * h) * inv_det,
744            (a * e - b * d) * inv_det,
745        ],
746    ];
747    if inverse.iter().flatten().any(|value| !value.is_finite()) {
748        return None;
749    }
750    Some(inverse)
751}
752
753#[allow(clippy::needless_range_loop)]
754pub fn invert_symmetric_pd(n: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
755    let p = n.len();
756    if p == 0 {
757        return None;
758    }
759    for row in n {
760        if row.len() != p {
761            return None;
762        }
763        validate::finite_slice(row, "matrix").ok()?;
764    }
765    validate_rows_symmetric(n)?;
766    let mut l = vec![vec![0.0_f64; p]; p];
767    for i in 0..p {
768        for j in 0..=i {
769            let mut s = n[i][j];
770            for k in 0..j {
771                s -= l[i][k] * l[j][k];
772            }
773            if i == j {
774                #[allow(clippy::neg_cmp_op_on_partial_ord)]
775                let nonpositive_or_nan = !(s > 0.0);
776                if nonpositive_or_nan || !s.is_finite() {
777                    return None;
778                }
779                l[i][j] = s.sqrt();
780            } else {
781                l[i][j] = s / l[j][j];
782            }
783        }
784    }
785
786    let mut li = vec![vec![0.0_f64; p]; p];
787    for i in 0..p {
788        li[i][i] = 1.0 / l[i][i];
789        for j in 0..i {
790            let mut s = 0.0_f64;
791            for k in j..i {
792                s -= l[i][k] * li[k][j];
793            }
794            li[i][j] = s / l[i][i];
795        }
796    }
797
798    let mut inv = vec![vec![0.0_f64; p]; p];
799    for i in 0..p {
800        for j in 0..p {
801            let mut s = 0.0_f64;
802            for k in 0..p {
803                s += li[k][i] * li[k][j];
804            }
805            inv[i][j] = s;
806        }
807    }
808    for row in &inv {
809        validate::finite_slice(row, "inverse").ok()?;
810    }
811    Some(inv)
812}
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    #[test]
819    fn first_tie_solver_inverts_known_matrix() {
820        let a = vec![vec![4.0, 7.0], vec![2.0, 6.0]];
821        let inv = invert_matrix_first_tie(&a).unwrap();
822        assert_eq!(inv[0][0].to_bits(), 0.6000000000000001f64.to_bits());
823        assert_eq!(inv[0][1].to_bits(), (-0.7000000000000001f64).to_bits());
824        assert_eq!(inv[1][0].to_bits(), (-0.2f64).to_bits());
825        assert_eq!(inv[1][1].to_bits(), 0.4f64.to_bits());
826    }
827
828    #[test]
829    fn dense_solvers_reject_nonfinite_and_bad_shapes() {
830        let good_rhs = [1.0, 2.0];
831        let ragged = vec![vec![1.0], vec![0.0, 1.0]];
832        assert!(solve_linear_first_tie(&ragged, &good_rhs).is_none());
833        assert!(solve_linear_last_tie(ragged, good_rhs.to_vec()).is_none());
834
835        let nonfinite_matrix = vec![vec![1.0, f64::NAN], vec![0.0, 1.0]];
836        assert!(solve_linear_first_tie(&nonfinite_matrix, &good_rhs).is_none());
837        assert!(solve_linear_last_tie(nonfinite_matrix, good_rhs.to_vec()).is_none());
838
839        let good_matrix = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
840        assert!(solve_linear_first_tie(&good_matrix, &[1.0, f64::INFINITY]).is_none());
841        assert!(solve_linear_first_tie(&[], &[]).is_none());
842        assert!(invert_matrix_first_tie(&[]).is_none());
843    }
844
845    #[test]
846    fn weighted_column_outer_rejects_short_weights() {
847        let rows = [[1.0, 2.0, 3.0, 4.0], [2.0, 0.0, -1.0, 1.0]];
848        assert_eq!(
849            normal_matrix_4_weighted_column_outer(&rows, &[0.5]),
850            Err(LinearError::InvalidInput {
851                field: "weights",
852                reason: "length must match rows"
853            })
854        );
855    }
856
857    #[test]
858    fn weighted_column_outer_accumulates_valid_inputs() {
859        let rows = [[1.0, 2.0, 3.0, 4.0], [2.0, 0.0, -1.0, 1.0]];
860        let weights = [0.5, 2.0];
861        assert_eq!(
862            normal_matrix_4_weighted_column_outer(&rows, &weights).unwrap(),
863            [
864                [8.5, 1.0, -2.5, 6.0],
865                [1.0, 2.0, 3.0, 4.0],
866                [-2.5, 3.0, 6.5, 4.0],
867                [6.0, 4.0, 4.0, 10.0],
868            ]
869        );
870    }
871
872    #[test]
873    fn transpose_rejects_empty_ragged_and_nonfinite_matrices() {
874        assert!(transpose(&[]).is_none());
875        assert!(transpose(&[vec![1.0], vec![]]).is_none());
876        assert!(transpose(&[vec![f64::INFINITY]]).is_none());
877    }
878
879    #[test]
880    fn normal_equations_reject_malformed_or_nonfinite_rows() {
881        let short = [1.0];
882        assert!(normal_equations_weighted([(short.as_slice(), 1.0, 1.0)], 2).is_none());
883
884        let nonfinite_row = [1.0, f64::NAN];
885        assert!(normal_equations_weighted([(nonfinite_row.as_slice(), 1.0, 1.0)], 2).is_none());
886
887        let good_row = [1.0, 2.0];
888        assert!(normal_equations_weighted([(good_row.as_slice(), f64::NAN, 1.0)], 2).is_none());
889        assert!(
890            normal_equations_weighted([(good_row.as_slice(), 1.0, f64::INFINITY)], 2).is_none()
891        );
892    }
893
894    #[test]
895    fn flat_solvers_reject_nonfinite_inputs() {
896        let mut out = Vec::new();
897        let mut scratch = FlatLinearScratch::default();
898        assert!(invert_flat_first_tie_into(&[f64::NAN], 1, &mut out, &mut scratch).is_none());
899
900        assert!(solve_flat_normal_first_tie(&[f64::NAN], &[1.0]).is_none());
901        assert!(solve_flat_normal_first_tie(&[1.0], &[f64::INFINITY]).is_none());
902
903        let mut cholesky = FlatCholeskySolveScratch::default();
904        assert!(solve_flat_normal_square_root_into(&[1.0], &[f64::NAN], &mut cholesky).is_none());
905    }
906
907    #[test]
908    fn flat_normal_solver_reports_singular() {
909        assert!(solve_flat_normal_first_tie(&[1.0, 2.0, 2.0, 4.0], &[1.0, 2.0]).is_none());
910    }
911
912    #[test]
913    fn cofactor_inverse_rejects_singular_4x4() {
914        let a = [[0.0; 4]; 4];
915        assert!(invert_4x4_cofactor(&a).is_none());
916    }
917
918    #[test]
919    fn cholesky_square_root_solves_spd_system() {
920        // Λ = [[4, 12, -16], [12, 37, -43], [-16, -43, 98]] (the classic SPD
921        // Cholesky example), η chosen so the exact solution is [1, 2, 3].
922        let lambda = [
923            4.0, 12.0, -16.0, //
924            12.0, 37.0, -43.0, //
925            -16.0, -43.0, 98.0,
926        ];
927        let eta = [
928            4.0 * 1.0 + 12.0 * 2.0 - 16.0 * 3.0,
929            12.0 * 1.0 + 37.0 * 2.0 - 43.0 * 3.0,
930            -16.0 * 1.0 - 43.0 * 2.0 + 98.0 * 3.0,
931        ];
932        let mut scratch = FlatCholeskySolveScratch::default();
933        let x = solve_flat_normal_square_root_into(&lambda, &eta, &mut scratch).unwrap();
934        for (got, want) in x.iter().zip([1.0_f64, 2.0, 3.0]) {
935            assert!((got - want).abs() < 1.0e-12, "got {got}, want {want}");
936        }
937    }
938
939    #[test]
940    fn cholesky_square_root_agrees_with_first_tie_to_roundoff() {
941        // The square-root solve and the first-tie Gaussian solve of the same SPD
942        // system must agree to roundoff: they differ only in factorization order.
943        let lambda = [
944            6.0, 2.0, 1.0, //
945            2.0, 5.0, 2.0, //
946            1.0, 2.0, 4.0,
947        ];
948        let eta = [9.0, 9.0, 7.0];
949        let mut sqrt_scratch = FlatCholeskySolveScratch::default();
950        let sqrt_x = solve_flat_normal_square_root_into(&lambda, &eta, &mut sqrt_scratch)
951            .unwrap()
952            .to_vec();
953        let first_tie_x = solve_flat_normal_first_tie(&lambda, &eta).unwrap();
954        for (s, f) in sqrt_x.iter().zip(&first_tie_x) {
955            assert!((s - f).abs() < 1.0e-12, "square-root {s} vs first-tie {f}");
956        }
957    }
958
959    #[test]
960    fn cholesky_square_root_frozen_bits() {
961        // Frozen-bits golden on an exactly-representable SPD system
962        // (Λ = L Lᵀ with L = [[2,0,0],[1,2,0],[0,0,1]]), so every factor and
963        // substitution step is exact in f64 and the solution bits are a portable
964        // constant: f64 sqrt is IEEE-754 correctly rounded, so these bits hold
965        // across platforms, not merely run-to-run on one build.
966        let lambda = [
967            4.0, 2.0, 0.0, //
968            2.0, 5.0, 0.0, //
969            0.0, 0.0, 1.0,
970        ];
971        // η = Λ·[2, 0.5, 3].
972        let eta = [9.0, 6.5, 3.0];
973        let mut scratch = FlatCholeskySolveScratch::default();
974        let x = solve_flat_normal_square_root_into(&lambda, &eta, &mut scratch).unwrap();
975        assert_eq!(x[0].to_bits(), 2.0f64.to_bits());
976        assert_eq!(x[1].to_bits(), 0.5f64.to_bits());
977        assert_eq!(x[2].to_bits(), 3.0f64.to_bits());
978    }
979
980    #[test]
981    fn cholesky_square_root_rejects_non_pd() {
982        // A singular (rank-deficient) matrix has a non-positive Cholesky pivot.
983        assert!(solve_flat_normal_square_root_into(
984            &[1.0, 2.0, 2.0, 4.0],
985            &[1.0, 2.0],
986            &mut Default::default()
987        )
988        .is_none());
989    }
990
991    #[test]
992    fn cholesky_square_root_rejects_invalid_information_geometry() {
993        let eta = [1.0, 2.0];
994        let mut scratch = FlatCholeskySolveScratch::default();
995
996        let negative_variance = [-1.0, 0.0, 0.0, 1.0];
997        assert!(
998            solve_flat_normal_square_root_into(&negative_variance, &eta, &mut scratch).is_none()
999        );
1000
1001        let asymmetric = [1.0, 0.5, 0.0, 1.0];
1002        assert!(solve_flat_normal_square_root_into(&asymmetric, &eta, &mut scratch).is_none());
1003
1004        let indefinite = [1.0, 2.0, 2.0, 1.0];
1005        assert!(solve_flat_normal_square_root_into(&indefinite, &eta, &mut scratch).is_none());
1006    }
1007
1008    #[test]
1009    fn symmetric_pd_inverse_rejects_invalid_matrix_geometry() {
1010        let negative_variance = vec![vec![-1.0, 0.0], vec![0.0, 1.0]];
1011        assert!(invert_symmetric_pd(&negative_variance).is_none());
1012
1013        let asymmetric = vec![vec![1.0, 0.5], vec![0.0, 1.0]];
1014        assert!(invert_symmetric_pd(&asymmetric).is_none());
1015
1016        let indefinite = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
1017        assert!(invert_symmetric_pd(&indefinite).is_none());
1018
1019        let overflow_inverse = vec![vec![f64::from_bits(1)]];
1020        assert!(invert_symmetric_pd(&overflow_inverse).is_none());
1021    }
1022}