1use crate::astro::tolerances::PIVOT_EPSILON;
8use crate::validate;
9
10#[derive(Debug, Default, Clone)]
11pub struct FlatLinearScratch {
12 rows: Vec<f64>,
13 x: Vec<f64>,
14}
15
16#[derive(Debug, Default, Clone)]
17pub struct FlatNormalSolveScratch {
18 a: Vec<f64>,
19 b: Vec<f64>,
20 x: Vec<f64>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
24pub enum LinearError {
25 #[error("invalid linear algebra {field}: {reason}")]
26 InvalidInput {
27 field: &'static str,
28 reason: &'static str,
29 },
30}
31
32#[allow(clippy::needless_range_loop)]
33pub fn solve_linear_first_tie(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
34 let n = validate_dense_system(a, b)?;
35 let mut rows: Vec<Vec<f64>> = a
36 .iter()
37 .zip(b)
38 .map(|(row, &bi)| {
39 let mut r = row.clone();
40 r.push(bi);
41 r
42 })
43 .collect();
44
45 for col in 0..n {
46 let mut pivot_row = col;
47 let mut pivot_abs = rows[col][col].abs();
48 for idx in (col + 1)..n {
49 let v = rows[idx][col].abs();
50 if v > pivot_abs {
51 pivot_abs = v;
52 pivot_row = idx;
53 }
54 }
55 if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
56 return None;
57 }
58 rows.swap(col, pivot_row);
59
60 let pivot = rows[col].clone();
61 let pivot_value = pivot[col];
62 for idx in (col + 1)..n {
63 let factor = rows[idx][col] / pivot_value;
64 for j in 0..=n {
65 rows[idx][j] -= factor * pivot[j];
66 }
67 }
68 }
69
70 let mut x = vec![0.0; n];
71 for i in (0..n).rev() {
72 let mut known = 0.0;
73 for j in (i + 1)..n {
74 known += rows[i][j] * x[j];
75 }
76 x[i] = (rows[i][n] - known) / rows[i][i];
77 }
78 validate::finite_slice(&x, "solution").ok()?;
79 Some(x)
80}
81
82#[allow(clippy::needless_range_loop)]
83pub fn solve_linear_last_tie(mut a: Vec<Vec<f64>>, b: Vec<f64>) -> Option<Vec<f64>> {
84 let n = validate_dense_system(&a, &b)?;
85 for (row, bi) in a.iter_mut().zip(b) {
86 row.push(bi);
87 }
88 for col in 0..n {
89 let (pivot_row, pivot_abs) = (col..n)
90 .map(|idx| (idx, a[idx][col].abs()))
91 .max_by(|lhs, rhs| lhs.1.total_cmp(&rhs.1))
92 .unwrap();
93 if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
94 return None;
95 }
96 a.swap(col, pivot_row);
97 let pivot = a[col].clone();
98 let pivot_value = pivot[col];
99 for row in a.iter_mut().take(n).skip(col + 1) {
100 let factor = row[col] / pivot_value;
101 for j in col..=n {
102 row[j] -= factor * pivot[j];
103 }
104 }
105 }
106 let mut x = vec![0.0; n];
107 for i in (0..n).rev() {
108 let tail_sum: f64 = ((i + 1)..n).map(|j| a[i][j] * x[j]).sum();
109 x[i] = (a[i][n] - tail_sum) / a[i][i];
110 }
111 validate::finite_slice(&x, "solution").ok()?;
112 Some(x)
113}
114
115pub fn invert_matrix_first_tie(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
116 let n = a.len();
117 if n == 0 {
118 return None;
119 }
120 let mut columns: Vec<Vec<f64>> = Vec::with_capacity(n);
121 for col in 0..n {
122 let mut e = vec![0.0; n];
123 e[col] = 1.0;
124 columns.push(solve_linear_first_tie(a, &e)?);
125 }
126 Some(
127 (0..n)
128 .map(|i| (0..n).map(|j| columns[j][i]).collect())
129 .collect(),
130 )
131}
132
133pub fn invert_matrix_last_tie(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
134 let n = a.len();
135 let mut columns = Vec::with_capacity(n);
136 for col in 0..n {
137 let unit = (0..n)
138 .map(|idx| if idx == col { 1.0 } else { 0.0 })
139 .collect();
140 columns.push(solve_linear_last_tie(a.to_vec(), unit)?);
141 }
142 transpose(&columns)
143}
144
145pub fn solve_matrix_last_tie(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
146 let columns = transpose(b)?;
147 let mut solved_columns = Vec::with_capacity(columns.len());
148 for col in columns {
149 solved_columns.push(solve_linear_last_tie(a.to_vec(), col)?);
150 }
151 transpose(&solved_columns)
152}
153
154pub fn normal_equations_weighted<'a, I>(rows: I, n: usize) -> Option<(Vec<Vec<f64>>, Vec<f64>)>
155where
156 I: IntoIterator<Item = (&'a [f64], f64, f64)>,
157{
158 if n == 0 {
159 return None;
160 }
161 let mut ata = vec![vec![0.0; n]; n];
162 let mut aty = vec![0.0; n];
163 for (row_h, row_y, row_weight) in rows {
164 if row_h.len() != n {
165 return None;
166 }
167 validate::finite_slice(row_h, "normal row").ok()?;
168 validate::finite(row_y, "normal residual").ok()?;
169 validate::finite(row_weight, "normal weight").ok()?;
170 let h: Vec<f64> = row_h.iter().map(|v| v * row_weight).collect();
171 let y = row_y * row_weight;
172 for i in 0..n {
173 aty[i] += h[i] * y;
174 for j in 0..n {
175 ata[i][j] += h[i] * h[j];
176 }
177 }
178 }
179 for row in &ata {
180 validate::finite_slice(row, "normal matrix").ok()?;
181 }
182 validate::finite_slice(&aty, "normal rhs").ok()?;
183 Some((ata, aty))
184}
185
186pub fn matrix_sub(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
187 let (rows, cols) = validate_same_shape(a, b)?;
188 let out: Vec<Vec<f64>> = a
189 .iter()
190 .zip(b)
191 .map(|(row_a, row_b)| row_a.iter().zip(row_b).map(|(x, y)| x - y).collect())
192 .collect();
193 debug_assert_eq!(out.len(), rows);
194 debug_assert!(out.iter().all(|row| row.len() == cols));
195 for row in &out {
196 validate::finite_slice(row, "matrix difference").ok()?;
197 }
198 Some(out)
199}
200
201pub fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
202 let b_t = transpose(b)?;
203 let rows = a.len();
204 let shared = b_t.first()?.len();
205 if rows == 0 || shared == 0 {
206 return None;
207 }
208 for row in a {
209 if row.len() != shared {
210 return None;
211 }
212 validate::finite_slice(row, "matrix").ok()?;
213 }
214 let out: Vec<Vec<f64>> = a
215 .iter()
216 .map(|row| {
217 b_t.iter()
218 .map(|col| row.iter().zip(col).fold(0.0, |acc, (x, y)| acc + x * y))
219 .collect()
220 })
221 .collect();
222 for row in &out {
223 validate::finite_slice(row, "matrix product").ok()?;
224 }
225 Some(out)
226}
227
228pub fn transpose(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
229 let cols = matrix.first()?.len();
230 if cols == 0 {
231 return None;
232 }
233 for row in matrix {
234 if row.len() != cols {
235 return None;
236 }
237 validate::finite_slice(row, "matrix").ok()?;
238 }
239 Some(
240 (0..cols)
241 .map(|col| matrix.iter().map(|row| row[col]).collect())
242 .collect(),
243 )
244}
245
246pub fn invert_flat_first_tie_into(
247 a: &[f64],
248 n: usize,
249 out: &mut Vec<f64>,
250 scratch: &mut FlatLinearScratch,
251) -> Option<()> {
252 validate_flat_square(a, n, "matrix")?;
253 out.resize(n * n, 0.0);
254 scratch.rows.resize(n * (n + 1), 0.0);
255 scratch.x.resize(n, 0.0);
256
257 for col in 0..n {
258 for i in 0..n {
259 let src = i * n;
260 let dst = i * (n + 1);
261 scratch.rows[dst..(dst + n)].copy_from_slice(&a[src..(src + n)]);
262 scratch.rows[dst + n] = if i == col { 1.0 } else { 0.0 };
263 }
264 solve_augmented_flat_first_tie_in_place(&mut scratch.rows, n, &mut scratch.x)?;
265 for i in 0..n {
266 out[i * n + col] = scratch.x[i];
267 }
268 }
269
270 Some(())
271}
272
273pub fn solve_matrix_flat_first_tie_into(
274 a: &[f64],
275 n: usize,
276 b: &[f64],
277 cols: usize,
278 out: &mut Vec<f64>,
279 scratch: &mut FlatLinearScratch,
280) -> Option<()> {
281 validate_flat_square(a, n, "matrix")?;
282 if cols == 0 || b.len() != n.checked_mul(cols)? {
283 return None;
284 }
285 validate::finite_slice(b, "rhs").ok()?;
286 out.resize(n.checked_mul(cols)?, 0.0);
287 scratch.rows.resize(n * (n + 1), 0.0);
288 scratch.x.resize(n, 0.0);
289
290 for col in 0..cols {
291 for i in 0..n {
292 let src = i * n;
293 let dst = i * (n + 1);
294 scratch.rows[dst..(dst + n)].copy_from_slice(&a[src..(src + n)]);
295 scratch.rows[dst + n] = b[i * cols + col];
296 }
297 solve_augmented_flat_first_tie_in_place(&mut scratch.rows, n, &mut scratch.x)?;
298 for i in 0..n {
299 out[i * cols + col] = scratch.x[i];
300 }
301 }
302 Some(())
303}
304
305#[allow(clippy::needless_range_loop)]
306pub fn solve_augmented_flat_first_tie_in_place(
307 rows: &mut [f64],
308 n: usize,
309 x: &mut [f64],
310) -> Option<()> {
311 let stride = n + 1;
312 if n == 0 || rows.len() != n.checked_mul(stride)? || x.len() != n {
313 return None;
314 }
315 validate::finite_slice(rows, "augmented matrix").ok()?;
316
317 for col in 0..n {
318 let mut pivot_row = col;
319 let mut pivot_abs = rows[col * stride + col].abs();
320 for idx in (col + 1)..n {
321 let v = rows[idx * stride + col].abs();
322 if v > pivot_abs {
323 pivot_abs = v;
324 pivot_row = idx;
325 }
326 }
327 if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
328 return None;
329 }
330 if pivot_row != col {
331 for j in 0..=n {
332 rows.swap(col * stride + j, pivot_row * stride + j);
333 }
334 }
335
336 let pivot_value = rows[col * stride + col];
337 for idx in (col + 1)..n {
338 let factor = rows[idx * stride + col] / pivot_value;
339 for j in 0..=n {
340 rows[idx * stride + j] -= factor * rows[col * stride + j];
341 }
342 }
343 }
344
345 for i in (0..n).rev() {
346 let mut known = 0.0;
347 for j in (i + 1)..n {
348 known += rows[i * stride + j] * x[j];
349 }
350 x[i] = (rows[i * stride + n] - known) / rows[i * stride + i];
351 }
352
353 validate::finite_slice(x, "solution").ok()?;
354 Some(())
355}
356
357pub fn solve_flat_normal_first_tie(lambda: &[f64], eta: &[f64]) -> Option<Vec<f64>> {
358 let mut scratch = FlatNormalSolveScratch::default();
359 solve_flat_normal_first_tie_into(lambda, eta, &mut scratch).map(<[f64]>::to_vec)
360}
361
362#[allow(clippy::needless_range_loop)]
363pub fn solve_flat_normal_first_tie_into<'a>(
364 lambda: &[f64],
365 eta: &[f64],
366 scratch: &'a mut FlatNormalSolveScratch,
367) -> Option<&'a [f64]> {
368 let n = eta.len();
369 if n == 0 || lambda.len() != n.checked_mul(n)? {
370 return None;
371 }
372 validate::finite_slice(lambda, "normal matrix").ok()?;
373 validate::finite_slice(eta, "normal rhs").ok()?;
374
375 scratch.a.resize(n * n, 0.0);
376 scratch.a.copy_from_slice(lambda);
377 scratch.b.resize(n, 0.0);
378 scratch.b.copy_from_slice(eta);
379
380 for k in 0..n {
381 let mut pivot = k;
382 let mut pivot_abs = scratch.a[k * n + k].abs();
383 for i in (k + 1)..n {
384 let candidate = scratch.a[i * n + k].abs();
385 if candidate > pivot_abs {
386 pivot = i;
387 pivot_abs = candidate;
388 }
389 }
390 if !pivot_abs.is_finite() || pivot_abs <= PIVOT_EPSILON {
391 return None;
392 }
393 if pivot != k {
394 for j in 0..n {
395 scratch.a.swap(k * n + j, pivot * n + j);
396 }
397 scratch.b.swap(k, pivot);
398 }
399
400 let diag = scratch.a[k * n + k];
401 for i in (k + 1)..n {
402 let factor = scratch.a[i * n + k] / diag;
403 scratch.a[i * n + k] = 0.0;
404 for j in (k + 1)..n {
405 scratch.a[i * n + j] -= factor * scratch.a[k * n + j];
406 }
407 scratch.b[i] -= factor * scratch.b[k];
408 }
409 }
410
411 scratch.x.resize(n, 0.0);
412 for i in (0..n).rev() {
413 let mut known = 0.0;
414 for j in (i + 1)..n {
415 known += scratch.a[i * n + j] * scratch.x[j];
416 }
417 scratch.x[i] = (scratch.b[i] - known) / scratch.a[i * n + i];
418 }
419 validate::finite_slice(&scratch.x, "solution").ok()?;
420 Some(&scratch.x)
421}
422
423#[derive(Debug, Default, Clone)]
428pub struct FlatCholeskySolveScratch {
429 l: Vec<f64>,
430 z: Vec<f64>,
431 x: Vec<f64>,
432}
433
434#[allow(clippy::needless_range_loop)]
448pub fn solve_flat_normal_square_root_into<'a>(
449 lambda: &[f64],
450 eta: &[f64],
451 scratch: &'a mut FlatCholeskySolveScratch,
452) -> Option<&'a [f64]> {
453 let n = eta.len();
454 if n == 0 || lambda.len() != n.checked_mul(n)? {
455 return None;
456 }
457 validate::finite_slice(lambda, "normal matrix").ok()?;
458 validate::finite_slice(eta, "normal rhs").ok()?;
459 validate_flat_symmetric(lambda, n)?;
460 scratch.l.resize(n * n, 0.0);
461 scratch.l.fill(0.0);
462
463 for i in 0..n {
465 for j in 0..=i {
466 let mut s = lambda[i * n + j];
467 for k in 0..j {
468 s -= scratch.l[i * n + k] * scratch.l[j * n + k];
469 }
470 if i == j {
471 #[allow(clippy::neg_cmp_op_on_partial_ord)]
472 let nonpositive_or_nan = !(s > 0.0);
473 if nonpositive_or_nan || !s.is_finite() {
474 return None;
475 }
476 scratch.l[i * n + j] = s.sqrt();
477 } else {
478 scratch.l[i * n + j] = s / scratch.l[j * n + j];
479 }
480 }
481 }
482
483 scratch.z.resize(n, 0.0);
485 for i in 0..n {
486 let mut s = eta[i];
487 for k in 0..i {
488 s -= scratch.l[i * n + k] * scratch.z[k];
489 }
490 scratch.z[i] = s / scratch.l[i * n + i];
491 }
492 validate::finite_slice(&scratch.z, "solution work vector").ok()?;
493
494 scratch.x.resize(n, 0.0);
496 for i in (0..n).rev() {
497 let mut s = scratch.z[i];
498 for k in (i + 1)..n {
499 s -= scratch.l[k * n + i] * scratch.x[k];
500 }
501 scratch.x[i] = s / scratch.l[i * n + i];
502 }
503 validate::finite_slice(&scratch.x, "solution").ok()?;
504 Some(scratch.x.as_slice())
505}
506
507fn validate_flat_symmetric(matrix: &[f64], n: usize) -> Option<()> {
508 let mut scale = 1.0_f64;
509 for value in matrix {
510 scale = scale.max(value.abs());
511 }
512 let tol = symmetry_tolerance(n, scale);
513 for i in 0..n {
514 for j in (i + 1)..n {
515 if (matrix[i * n + j] - matrix[j * n + i]).abs() > tol {
516 return None;
517 }
518 }
519 }
520 Some(())
521}
522
523#[allow(clippy::needless_range_loop)]
524fn validate_rows_symmetric(matrix: &[Vec<f64>]) -> Option<()> {
525 let n = matrix.len();
526 let mut scale = 1.0_f64;
527 for row in matrix {
528 for value in row {
529 scale = scale.max(value.abs());
530 }
531 }
532 let tol = symmetry_tolerance(n, scale);
533 for i in 0..n {
534 for j in (i + 1)..n {
535 if (matrix[i][j] - matrix[j][i]).abs() > tol {
536 return None;
537 }
538 }
539 }
540 Some(())
541}
542
543fn symmetry_tolerance(n: usize, scale: f64) -> f64 {
544 128.0 * f64::EPSILON * (n.max(1) as f64) * scale.max(1.0)
545}
546
547fn validate_dense_system(a: &[Vec<f64>], b: &[f64]) -> Option<usize> {
548 let n = b.len();
549 if n == 0 || a.len() != n {
550 return None;
551 }
552 validate::finite_slice(b, "rhs").ok()?;
553 for row in a {
554 if row.len() != n {
555 return None;
556 }
557 validate::finite_slice(row, "matrix").ok()?;
558 }
559 Some(n)
560}
561
562fn validate_same_shape(a: &[Vec<f64>], b: &[Vec<f64>]) -> Option<(usize, usize)> {
563 let rows = a.len();
564 if rows == 0 || b.len() != rows {
565 return None;
566 }
567 let cols = a.first()?.len();
568 if cols == 0 {
569 return None;
570 }
571 for row in a {
572 if row.len() != cols {
573 return None;
574 }
575 validate::finite_slice(row, "matrix").ok()?;
576 }
577 for row in b {
578 if row.len() != cols {
579 return None;
580 }
581 validate::finite_slice(row, "matrix").ok()?;
582 }
583 Some((rows, cols))
584}
585
586fn validate_flat_square(a: &[f64], n: usize, field: &'static str) -> Option<()> {
587 if n == 0 || a.len() != n.checked_mul(n)? {
588 return None;
589 }
590 validate::finite_slice(a, field).ok()
591}
592
593fn map_linear_field_error(error: validate::FieldError) -> LinearError {
594 linear_invalid_input(error.field(), error.reason())
595}
596
597fn linear_invalid_input(field: &'static str, reason: &'static str) -> LinearError {
598 LinearError::InvalidInput { field, reason }
599}
600
601#[allow(clippy::needless_range_loop)]
602pub fn normal_matrix_4_weighted_column_outer(
603 rows: &[[f64; 4]],
604 weights: &[f64],
605) -> Result<[[f64; 4]; 4], LinearError> {
606 if weights.len() != rows.len() {
607 return Err(linear_invalid_input("weights", "length must match rows"));
608 }
609 validate::finite_slice(weights, "weights").map_err(map_linear_field_error)?;
610 for row in rows {
611 validate::finite_slice(row, "rows").map_err(map_linear_field_error)?;
612 }
613
614 let mut a = [[0.0_f64; 4]; 4];
615 for i in 0..4 {
616 for j in 0..4 {
617 let mut s = 0.0_f64;
618 for k in 0..rows.len() {
619 s += rows[k][i] * weights[k] * rows[k][j];
620 }
621 a[i][j] = s;
622 }
623 }
624 for row in &a {
625 validate::finite_slice(row, "normal matrix").map_err(map_linear_field_error)?;
626 }
627 Ok(a)
628}
629
630#[allow(clippy::needless_range_loop)]
631pub fn normal_matrix_4_unweighted_row_outer(rows: &[[f64; 4]]) -> [[f64; 4]; 4] {
632 let mut a = [[0.0_f64; 4]; 4];
633 for row in rows {
634 for i in 0..4 {
635 for j in 0..4 {
636 a[i][j] += row[i] * row[j];
637 }
638 }
639 }
640 a
641}
642
643pub fn mat4_vec4(m: &[[f64; 4]; 4], v: &[f64; 4]) -> [f64; 4] {
644 [
645 dot4(&m[0], v),
646 dot4(&m[1], v),
647 dot4(&m[2], v),
648 dot4(&m[3], v),
649 ]
650}
651
652pub fn dot4(row: &[f64; 4], v: &[f64; 4]) -> f64 {
653 row[0] * v[0] + row[1] * v[1] + row[2] * v[2] + row[3] * v[3]
654}
655
656pub fn det4_cofactor(a: &[[f64; 4]; 4]) -> f64 {
657 let m01 = a[2][0] * a[3][1] - a[2][1] * a[3][0];
658 let m02 = a[2][0] * a[3][2] - a[2][2] * a[3][0];
659 let m03 = a[2][0] * a[3][3] - a[2][3] * a[3][0];
660 let m12 = a[2][1] * a[3][2] - a[2][2] * a[3][1];
661 let m13 = a[2][1] * a[3][3] - a[2][3] * a[3][1];
662 let m23 = a[2][2] * a[3][3] - a[2][3] * a[3][2];
663
664 let c0 = a[1][1] * m23 - a[1][2] * m13 + a[1][3] * m12;
665 let c1 = a[1][0] * m23 - a[1][2] * m03 + a[1][3] * m02;
666 let c2 = a[1][0] * m13 - a[1][1] * m03 + a[1][3] * m01;
667 let c3 = a[1][0] * m12 - a[1][1] * m02 + a[1][2] * m01;
668
669 a[0][0] * c0 - a[0][1] * c1 + a[0][2] * c2 - a[0][3] * c3
670}
671
672pub fn minor3_of_4(a: &[[f64; 4]; 4], skip_r: usize, skip_c: usize) -> f64 {
673 let mut rows = [0_usize; 3];
674 let mut cols = [0_usize; 3];
675 let mut row_idx = 0;
676 let mut col_idx = 0;
677 for row in 0..4 {
678 if row != skip_r {
679 rows[row_idx] = row;
680 row_idx += 1;
681 }
682 }
683 for col in 0..4 {
684 if col != skip_c {
685 cols[col_idx] = col;
686 col_idx += 1;
687 }
688 }
689
690 let b00 = a[rows[0]][cols[0]];
691 let b01 = a[rows[0]][cols[1]];
692 let b02 = a[rows[0]][cols[2]];
693 let b10 = a[rows[1]][cols[0]];
694 let b11 = a[rows[1]][cols[1]];
695 let b12 = a[rows[1]][cols[2]];
696 let b20 = a[rows[2]][cols[0]];
697 let b21 = a[rows[2]][cols[1]];
698 let b22 = a[rows[2]][cols[2]];
699
700 b00 * (b11 * b22 - b12 * b21) - b01 * (b10 * b22 - b12 * b20) + b02 * (b10 * b21 - b11 * b20)
701}
702
703#[allow(clippy::needless_range_loop)]
704pub fn invert_4x4_cofactor(a: &[[f64; 4]; 4]) -> Option<[[f64; 4]; 4]> {
705 let det = det4_cofactor(a);
706 if det == 0.0 || !det.is_finite() {
707 return None;
708 }
709
710 let mut inv = [[0.0_f64; 4]; 4];
711 for j in 0..4 {
712 for i in 0..4 {
713 let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
714 inv[j][i] = sign * minor3_of_4(a, i, j) / det;
715 }
716 }
717 if inv.iter().flatten().any(|value| !value.is_finite()) {
718 return None;
719 }
720 Some(inv)
721}
722
723pub fn invert_3x3_adjugate(m: &[[f64; 3]; 3]) -> Option<[[f64; 3]; 3]> {
724 let [[a, b, c], [d, e, f], [g, h, i]] = *m;
725 let det = a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
726 if det.abs() <= PIVOT_EPSILON || !det.is_finite() {
727 return None;
728 }
729 let inv_det = 1.0 / det;
730 let inverse = [
731 [
732 (e * i - f * h) * inv_det,
733 (c * h - b * i) * inv_det,
734 (b * f - c * e) * inv_det,
735 ],
736 [
737 (f * g - d * i) * inv_det,
738 (a * i - c * g) * inv_det,
739 (c * d - a * f) * inv_det,
740 ],
741 [
742 (d * h - e * g) * inv_det,
743 (b * g - a * h) * inv_det,
744 (a * e - b * d) * inv_det,
745 ],
746 ];
747 if inverse.iter().flatten().any(|value| !value.is_finite()) {
748 return None;
749 }
750 Some(inverse)
751}
752
753#[allow(clippy::needless_range_loop)]
754pub fn invert_symmetric_pd(n: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
755 let p = n.len();
756 if p == 0 {
757 return None;
758 }
759 for row in n {
760 if row.len() != p {
761 return None;
762 }
763 validate::finite_slice(row, "matrix").ok()?;
764 }
765 validate_rows_symmetric(n)?;
766 let mut l = vec![vec![0.0_f64; p]; p];
767 for i in 0..p {
768 for j in 0..=i {
769 let mut s = n[i][j];
770 for k in 0..j {
771 s -= l[i][k] * l[j][k];
772 }
773 if i == j {
774 #[allow(clippy::neg_cmp_op_on_partial_ord)]
775 let nonpositive_or_nan = !(s > 0.0);
776 if nonpositive_or_nan || !s.is_finite() {
777 return None;
778 }
779 l[i][j] = s.sqrt();
780 } else {
781 l[i][j] = s / l[j][j];
782 }
783 }
784 }
785
786 let mut li = vec![vec![0.0_f64; p]; p];
787 for i in 0..p {
788 li[i][i] = 1.0 / l[i][i];
789 for j in 0..i {
790 let mut s = 0.0_f64;
791 for k in j..i {
792 s -= l[i][k] * li[k][j];
793 }
794 li[i][j] = s / l[i][i];
795 }
796 }
797
798 let mut inv = vec![vec![0.0_f64; p]; p];
799 for i in 0..p {
800 for j in 0..p {
801 let mut s = 0.0_f64;
802 for k in 0..p {
803 s += li[k][i] * li[k][j];
804 }
805 inv[i][j] = s;
806 }
807 }
808 for row in &inv {
809 validate::finite_slice(row, "inverse").ok()?;
810 }
811 Some(inv)
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn first_tie_solver_inverts_known_matrix() {
820 let a = vec![vec![4.0, 7.0], vec![2.0, 6.0]];
821 let inv = invert_matrix_first_tie(&a).unwrap();
822 assert_eq!(inv[0][0].to_bits(), 0.6000000000000001f64.to_bits());
823 assert_eq!(inv[0][1].to_bits(), (-0.7000000000000001f64).to_bits());
824 assert_eq!(inv[1][0].to_bits(), (-0.2f64).to_bits());
825 assert_eq!(inv[1][1].to_bits(), 0.4f64.to_bits());
826 }
827
828 #[test]
829 fn dense_solvers_reject_nonfinite_and_bad_shapes() {
830 let good_rhs = [1.0, 2.0];
831 let ragged = vec![vec![1.0], vec![0.0, 1.0]];
832 assert!(solve_linear_first_tie(&ragged, &good_rhs).is_none());
833 assert!(solve_linear_last_tie(ragged, good_rhs.to_vec()).is_none());
834
835 let nonfinite_matrix = vec![vec![1.0, f64::NAN], vec![0.0, 1.0]];
836 assert!(solve_linear_first_tie(&nonfinite_matrix, &good_rhs).is_none());
837 assert!(solve_linear_last_tie(nonfinite_matrix, good_rhs.to_vec()).is_none());
838
839 let good_matrix = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
840 assert!(solve_linear_first_tie(&good_matrix, &[1.0, f64::INFINITY]).is_none());
841 assert!(solve_linear_first_tie(&[], &[]).is_none());
842 assert!(invert_matrix_first_tie(&[]).is_none());
843 }
844
845 #[test]
846 fn weighted_column_outer_rejects_short_weights() {
847 let rows = [[1.0, 2.0, 3.0, 4.0], [2.0, 0.0, -1.0, 1.0]];
848 assert_eq!(
849 normal_matrix_4_weighted_column_outer(&rows, &[0.5]),
850 Err(LinearError::InvalidInput {
851 field: "weights",
852 reason: "length must match rows"
853 })
854 );
855 }
856
857 #[test]
858 fn weighted_column_outer_accumulates_valid_inputs() {
859 let rows = [[1.0, 2.0, 3.0, 4.0], [2.0, 0.0, -1.0, 1.0]];
860 let weights = [0.5, 2.0];
861 assert_eq!(
862 normal_matrix_4_weighted_column_outer(&rows, &weights).unwrap(),
863 [
864 [8.5, 1.0, -2.5, 6.0],
865 [1.0, 2.0, 3.0, 4.0],
866 [-2.5, 3.0, 6.5, 4.0],
867 [6.0, 4.0, 4.0, 10.0],
868 ]
869 );
870 }
871
872 #[test]
873 fn transpose_rejects_empty_ragged_and_nonfinite_matrices() {
874 assert!(transpose(&[]).is_none());
875 assert!(transpose(&[vec![1.0], vec![]]).is_none());
876 assert!(transpose(&[vec![f64::INFINITY]]).is_none());
877 }
878
879 #[test]
880 fn normal_equations_reject_malformed_or_nonfinite_rows() {
881 let short = [1.0];
882 assert!(normal_equations_weighted([(short.as_slice(), 1.0, 1.0)], 2).is_none());
883
884 let nonfinite_row = [1.0, f64::NAN];
885 assert!(normal_equations_weighted([(nonfinite_row.as_slice(), 1.0, 1.0)], 2).is_none());
886
887 let good_row = [1.0, 2.0];
888 assert!(normal_equations_weighted([(good_row.as_slice(), f64::NAN, 1.0)], 2).is_none());
889 assert!(
890 normal_equations_weighted([(good_row.as_slice(), 1.0, f64::INFINITY)], 2).is_none()
891 );
892 }
893
894 #[test]
895 fn flat_solvers_reject_nonfinite_inputs() {
896 let mut out = Vec::new();
897 let mut scratch = FlatLinearScratch::default();
898 assert!(invert_flat_first_tie_into(&[f64::NAN], 1, &mut out, &mut scratch).is_none());
899
900 assert!(solve_flat_normal_first_tie(&[f64::NAN], &[1.0]).is_none());
901 assert!(solve_flat_normal_first_tie(&[1.0], &[f64::INFINITY]).is_none());
902
903 let mut cholesky = FlatCholeskySolveScratch::default();
904 assert!(solve_flat_normal_square_root_into(&[1.0], &[f64::NAN], &mut cholesky).is_none());
905 }
906
907 #[test]
908 fn flat_normal_solver_reports_singular() {
909 assert!(solve_flat_normal_first_tie(&[1.0, 2.0, 2.0, 4.0], &[1.0, 2.0]).is_none());
910 }
911
912 #[test]
913 fn cofactor_inverse_rejects_singular_4x4() {
914 let a = [[0.0; 4]; 4];
915 assert!(invert_4x4_cofactor(&a).is_none());
916 }
917
918 #[test]
919 fn cholesky_square_root_solves_spd_system() {
920 let lambda = [
923 4.0, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0,
926 ];
927 let eta = [
928 4.0 * 1.0 + 12.0 * 2.0 - 16.0 * 3.0,
929 12.0 * 1.0 + 37.0 * 2.0 - 43.0 * 3.0,
930 -16.0 * 1.0 - 43.0 * 2.0 + 98.0 * 3.0,
931 ];
932 let mut scratch = FlatCholeskySolveScratch::default();
933 let x = solve_flat_normal_square_root_into(&lambda, &eta, &mut scratch).unwrap();
934 for (got, want) in x.iter().zip([1.0_f64, 2.0, 3.0]) {
935 assert!((got - want).abs() < 1.0e-12, "got {got}, want {want}");
936 }
937 }
938
939 #[test]
940 fn cholesky_square_root_agrees_with_first_tie_to_roundoff() {
941 let lambda = [
944 6.0, 2.0, 1.0, 2.0, 5.0, 2.0, 1.0, 2.0, 4.0,
947 ];
948 let eta = [9.0, 9.0, 7.0];
949 let mut sqrt_scratch = FlatCholeskySolveScratch::default();
950 let sqrt_x = solve_flat_normal_square_root_into(&lambda, &eta, &mut sqrt_scratch)
951 .unwrap()
952 .to_vec();
953 let first_tie_x = solve_flat_normal_first_tie(&lambda, &eta).unwrap();
954 for (s, f) in sqrt_x.iter().zip(&first_tie_x) {
955 assert!((s - f).abs() < 1.0e-12, "square-root {s} vs first-tie {f}");
956 }
957 }
958
959 #[test]
960 fn cholesky_square_root_frozen_bits() {
961 let lambda = [
967 4.0, 2.0, 0.0, 2.0, 5.0, 0.0, 0.0, 0.0, 1.0,
970 ];
971 let eta = [9.0, 6.5, 3.0];
973 let mut scratch = FlatCholeskySolveScratch::default();
974 let x = solve_flat_normal_square_root_into(&lambda, &eta, &mut scratch).unwrap();
975 assert_eq!(x[0].to_bits(), 2.0f64.to_bits());
976 assert_eq!(x[1].to_bits(), 0.5f64.to_bits());
977 assert_eq!(x[2].to_bits(), 3.0f64.to_bits());
978 }
979
980 #[test]
981 fn cholesky_square_root_rejects_non_pd() {
982 assert!(solve_flat_normal_square_root_into(
984 &[1.0, 2.0, 2.0, 4.0],
985 &[1.0, 2.0],
986 &mut Default::default()
987 )
988 .is_none());
989 }
990
991 #[test]
992 fn cholesky_square_root_rejects_invalid_information_geometry() {
993 let eta = [1.0, 2.0];
994 let mut scratch = FlatCholeskySolveScratch::default();
995
996 let negative_variance = [-1.0, 0.0, 0.0, 1.0];
997 assert!(
998 solve_flat_normal_square_root_into(&negative_variance, &eta, &mut scratch).is_none()
999 );
1000
1001 let asymmetric = [1.0, 0.5, 0.0, 1.0];
1002 assert!(solve_flat_normal_square_root_into(&asymmetric, &eta, &mut scratch).is_none());
1003
1004 let indefinite = [1.0, 2.0, 2.0, 1.0];
1005 assert!(solve_flat_normal_square_root_into(&indefinite, &eta, &mut scratch).is_none());
1006 }
1007
1008 #[test]
1009 fn symmetric_pd_inverse_rejects_invalid_matrix_geometry() {
1010 let negative_variance = vec![vec![-1.0, 0.0], vec![0.0, 1.0]];
1011 assert!(invert_symmetric_pd(&negative_variance).is_none());
1012
1013 let asymmetric = vec![vec![1.0, 0.5], vec![0.0, 1.0]];
1014 assert!(invert_symmetric_pd(&asymmetric).is_none());
1015
1016 let indefinite = vec![vec![1.0, 2.0], vec![2.0, 1.0]];
1017 assert!(invert_symmetric_pd(&indefinite).is_none());
1018
1019 let overflow_inverse = vec![vec![f64::from_bits(1)]];
1020 assert!(invert_symmetric_pd(&overflow_inverse).is_none());
1021 }
1022}