1use nalgebra::{DMatrix, DVector};
31
32pub const FD_REL_STEP_2POINT: f64 = 1.4901161193847656e-8; const TRF_DEFAULT_GTOL: f64 = 1e-10;
39const TRF_DEFAULT_FTOL: f64 = 1e-8;
41const TRF_DEFAULT_XTOL: f64 = 1e-8;
43const TRF_DEFAULT_MAX_NFEV: usize = 300;
45const TRF_INITIAL_DAMPING_SCALE: f64 = 1e-3;
48
49#[derive(Debug, Clone, PartialEq)]
53pub struct FdStep {
54 pub param_index: usize,
56 pub sign_x0: f64,
59 pub h: f64,
61 pub dx: f64,
65 pub x_perturbed: DVector<f64>,
67}
68
69pub fn fd_steps(x0: &DVector<f64>, rel_step: f64) -> Result<Vec<FdStep>, SolveError> {
75 let rel_step = crate::validate::positive_step(rel_step, "rel_step").map_err(map_field_error)?;
76 fd_steps_checked(x0, rel_step)
77}
78
79fn fd_steps_checked(x0: &DVector<f64>, rel_step: f64) -> Result<Vec<FdStep>, SolveError> {
80 validate_nonempty_vector(x0, "parameters")?;
81 validate_vector(x0, "parameters")?;
82 let steps = fd_steps_unchecked(x0, rel_step);
83 for step in &steps {
84 validate_value(step.h, "fd_step")?;
85 validate_value(step.dx, "fd_step")?;
86 if step.dx == 0.0 {
87 return Err(invalid_input("fd_step", "zero"));
88 }
89 validate_vector(&step.x_perturbed, "perturbed parameters")?;
90 }
91 Ok(steps)
92}
93
94fn fd_steps_unchecked(x0: &DVector<f64>, rel_step: f64) -> Vec<FdStep> {
95 (0..x0.len())
96 .map(|i| {
97 let xi = x0[i];
98 let sign_x0 = if xi >= 0.0 { 1.0 } else { -1.0 };
99 let h = rel_step * sign_x0 * xi.abs().max(1.0);
100 let mut x_perturbed = x0.clone();
101 x_perturbed[i] = xi + h;
102 let dx = x_perturbed[i] - xi;
103 FdStep {
104 param_index: i,
105 sign_x0,
106 h,
107 dx,
108 x_perturbed,
109 }
110 })
111 .collect()
112}
113
114pub fn jacobian_2point<F>(
125 residual: F,
126 x0: &DVector<f64>,
127 f0: &DVector<f64>,
128) -> Result<DMatrix<f64>, SolveError>
129where
130 F: Fn(&DVector<f64>) -> DVector<f64>,
131{
132 jacobian_2point_checked(|x| Ok(residual(x)), x0, f0)
133}
134
135fn jacobian_2point_checked<F>(
136 residual: F,
137 x0: &DVector<f64>,
138 f0: &DVector<f64>,
139) -> Result<DMatrix<f64>, SolveError>
140where
141 F: Fn(&DVector<f64>) -> Result<DVector<f64>, SolveError>,
142{
143 validate_nonempty_vector(x0, "parameters")?;
144 validate_vector(x0, "parameters")?;
145 validate_nonempty_vector(f0, "residual")?;
146 validate_vector(f0, "residual")?;
147 let m = f0.len();
148 let n = x0.len();
149 let steps = fd_steps_checked(x0, FD_REL_STEP_2POINT)?;
150 let mut jac = DMatrix::zeros(m, n);
151 for step in &steps {
152 let f1 = residual(&step.x_perturbed)?;
153 validate_nonempty_vector(&f1, "residual")?;
154 validate_vector(&f1, "residual")?;
155 if f1.len() != m {
156 return Err(invalid_input("residual", "length mismatch"));
157 }
158 let i = step.param_index;
159 for row in 0..m {
160 jac[(row, i)] = (f1[row] - f0[row]) / step.dx;
161 }
162 }
163 validate_matrix(&jac, "jacobian")?;
164 Ok(jac)
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum Status {
171 GradientTolerance,
173 CostTolerance,
175 StepTolerance,
177 MaxEvaluations,
179}
180
181#[derive(Debug, Clone, Copy)]
183pub struct SolveOptions {
184 pub gtol: f64,
186 pub ftol: f64,
188 pub xtol: f64,
190 pub max_nfev: usize,
192}
193
194impl Default for SolveOptions {
195 fn default() -> Self {
196 Self {
198 gtol: TRF_DEFAULT_GTOL,
199 ftol: TRF_DEFAULT_FTOL,
200 xtol: TRF_DEFAULT_XTOL,
201 max_nfev: TRF_DEFAULT_MAX_NFEV,
202 }
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct LeastSquaresReport {
209 pub x: DVector<f64>,
211 pub residual: DVector<f64>,
213 pub cost: f64,
215 pub jacobian: DMatrix<f64>,
217 pub optimality_inf: f64,
219 pub iterations: usize,
221 pub status: Status,
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
233pub enum TrustRegionSolve {
234 #[default]
238 NalgebraLu,
239 OwnedGaussianFirstTie,
259}
260
261#[derive(Debug, Clone, thiserror::Error)]
263pub enum SolveError {
264 #[error("singular or rank-deficient Jacobian: no usable descent direction")]
267 SingularJacobian,
268 #[error("invalid least-squares {field}: {reason}")]
270 InvalidInput {
271 field: &'static str,
272 reason: &'static str,
273 },
274}
275
276pub fn cost(residual: &DVector<f64>) -> Result<f64, SolveError> {
278 validate_nonempty_vector(residual, "residual")?;
279 validate_vector(residual, "residual")?;
280 validate_value(0.5 * residual.dot(residual), "cost")
281}
282
283pub struct LeastSquaresProblem<F> {
288 residual: F,
289 sqrt_weights: Option<DVector<f64>>,
291 x0: DVector<f64>,
292}
293
294impl<F> LeastSquaresProblem<F>
295where
296 F: Fn(&DVector<f64>) -> DVector<f64>,
297{
298 pub fn new(residual: F, x0: DVector<f64>) -> Self {
300 Self {
301 residual,
302 sqrt_weights: None,
303 x0,
304 }
305 }
306
307 pub fn with_weights(residual: F, x0: DVector<f64>, weights: DVector<f64>) -> Self {
310 let sqrt_weights = weights.map(f64::sqrt);
311 Self {
312 residual,
313 sqrt_weights: Some(sqrt_weights),
314 x0,
315 }
316 }
317
318 fn weighted_residual(&self, x: &DVector<f64>) -> Result<DVector<f64>, SolveError> {
320 validate_nonempty_vector(x, "parameters")?;
321 validate_vector(x, "parameters")?;
322 let r = (self.residual)(x);
323 validate_nonempty_vector(&r, "residual")?;
324 validate_vector(&r, "residual")?;
325 match &self.sqrt_weights {
326 Some(sw) => {
327 validate_nonempty_vector(sw, "weights")?;
328 validate_vector(sw, "weights")?;
329 if sw.len() != r.len() {
330 return Err(invalid_input("weights", "length mismatch"));
331 }
332 let weighted = r.component_mul(sw);
333 validate_vector(&weighted, "weighted residual")?;
334 Ok(weighted)
335 }
336 None => Ok(r),
337 }
338 }
339}
340
341pub fn solve_trf<F>(
356 problem: &LeastSquaresProblem<F>,
357 opts: &SolveOptions,
358) -> Result<LeastSquaresReport, SolveError>
359where
360 F: Fn(&DVector<f64>) -> DVector<f64>,
361{
362 solve_trf_with(problem, opts, TrustRegionSolve::NalgebraLu)
363}
364
365fn solve_subproblem(
369 lhs: &DMatrix<f64>,
370 rhs: &DVector<f64>,
371 linear_solve: TrustRegionSolve,
372) -> Option<DVector<f64>> {
373 match linear_solve {
374 TrustRegionSolve::NalgebraLu => lhs.clone().lu().solve(rhs),
375 TrustRegionSolve::OwnedGaussianFirstTie => {
376 let n = rhs.len();
377 let a: Vec<Vec<f64>> = (0..n)
378 .map(|i| (0..n).map(|j| lhs[(i, j)]).collect())
379 .collect();
380 let b: Vec<f64> = rhs.iter().copied().collect();
381 crate::astro::math::linear::solve_linear_first_tie(&a, &b).map(DVector::from_vec)
382 }
383 }
384}
385
386pub fn solve_trf_with<F>(
390 problem: &LeastSquaresProblem<F>,
391 opts: &SolveOptions,
392 linear_solve: TrustRegionSolve,
393) -> Result<LeastSquaresReport, SolveError>
394where
395 F: Fn(&DVector<f64>) -> DVector<f64>,
396{
397 validate_options(opts)?;
398 let n = problem.x0.len();
399
400 let mut x = problem.x0.clone();
401 validate_nonempty_vector(&x, "initial parameters")?;
402 validate_vector(&x, "initial parameters")?;
403 let mut r = problem.weighted_residual(&x)?;
404 let mut f0 = r.clone();
405 let mut jac = jacobian_2point_checked(|p| problem.weighted_residual(p), &x, &f0)?;
406 let mut nfev = 1usize; let mut cur_cost = cost(&r)?;
408
409 let jtj0 = jac.transpose() * &jac;
411 validate_matrix(&jtj0, "normal matrix")?;
412 let mut mu = TRF_INITIAL_DAMPING_SCALE
413 * (0..n)
414 .map(|i| jtj0[(i, i)])
415 .fold(0.0_f64, f64::max)
416 .max(1.0);
417
418 let mut iterations = 0usize;
419
420 loop {
421 let jt = jac.transpose();
422 let grad = &jt * &r;
423 validate_vector(&grad, "gradient")?;
424 let optimality_inf = validate_value(grad.amax(), "optimality")?;
425
426 if optimality_inf < opts.gtol {
427 return finish(x, r, cur_cost, jac, iterations, Status::GradientTolerance);
428 }
429 if nfev >= opts.max_nfev {
430 return finish(x, r, cur_cost, jac, iterations, Status::MaxEvaluations);
431 }
432
433 let jtj = &jt * &jac;
434 validate_matrix(&jtj, "normal matrix")?;
435
436 let mut accepted = false;
438 for _ in 0..30 {
439 let mut lhs = jtj.clone();
440 for i in 0..n {
441 lhs[(i, i)] += mu;
442 }
443 let rhs = -&grad;
444 validate_matrix(&lhs, "subproblem matrix")?;
445 validate_vector(&rhs, "subproblem rhs")?;
446 let step = match solve_subproblem(&lhs, &rhs, linear_solve) {
447 Some(s) => s,
448 None => return Err(SolveError::SingularJacobian),
449 };
450 validate_vector(&step, "step")?;
451
452 let x_trial = &x + &step;
453 let r_trial = problem.weighted_residual(&x_trial)?;
454 nfev += 1;
455 let cost_trial = cost(&r_trial)?;
456
457 if cost_trial < cur_cost {
458 let cost_reduction = (cur_cost - cost_trial) / cur_cost.max(f64::MIN_POSITIVE);
460 let step_norm = step.norm();
461 let x_norm = x.norm();
462 let rel_step = step_norm / x_norm.max(f64::MIN_POSITIVE);
463
464 x = x_trial;
465 r = r_trial;
466 cur_cost = cost_trial;
467 f0 = r.clone();
468 jac = jacobian_2point_checked(|p| problem.weighted_residual(p), &x, &f0)?;
469 nfev += n; iterations += 1;
471 mu *= 0.5;
472 accepted = true;
473
474 if cost_reduction < opts.ftol {
475 return finish(x, r, cur_cost, jac, iterations, Status::CostTolerance);
476 }
477 if rel_step < opts.xtol {
478 return finish(x, r, cur_cost, jac, iterations, Status::StepTolerance);
479 }
480 break;
481 } else {
482 mu *= 2.0;
484 }
485 }
486
487 if !accepted {
488 return finish(x, r, cur_cost, jac, iterations, Status::StepTolerance);
490 }
491 }
492}
493
494fn finish(
495 x: DVector<f64>,
496 residual: DVector<f64>,
497 cost_value: f64,
498 jacobian: DMatrix<f64>,
499 iterations: usize,
500 status: Status,
501) -> Result<LeastSquaresReport, SolveError> {
502 validate_nonempty_vector(&x, "solution")?;
503 validate_vector(&x, "solution")?;
504 validate_nonempty_vector(&residual, "residual")?;
505 validate_vector(&residual, "residual")?;
506 validate_value(cost_value, "cost")?;
507 validate_matrix(&jacobian, "jacobian")?;
508 let optimality_inf = validate_value((jacobian.transpose() * &residual).amax(), "optimality")?;
509 Ok(LeastSquaresReport {
510 x,
511 residual,
512 cost: cost_value,
513 jacobian,
514 optimality_inf,
515 iterations,
516 status,
517 })
518}
519
520fn validate_value(value: f64, field: &'static str) -> Result<f64, SolveError> {
521 crate::validate::finite(value, field).map_err(map_field_error)
522}
523
524fn validate_options(opts: &SolveOptions) -> Result<(), SolveError> {
525 crate::validate::positive_step(opts.gtol, "gtol").map_err(map_field_error)?;
526 crate::validate::positive_step(opts.ftol, "ftol").map_err(map_field_error)?;
527 crate::validate::positive_step(opts.xtol, "xtol").map_err(map_field_error)?;
528 if opts.max_nfev == 0 {
529 return Err(invalid_input("max_nfev", "not positive"));
530 }
531 Ok(())
532}
533
534fn validate_nonempty_vector(vector: &DVector<f64>, field: &'static str) -> Result<(), SolveError> {
535 if vector.is_empty() {
536 Err(invalid_input(field, "empty"))
537 } else {
538 Ok(())
539 }
540}
541
542fn validate_vector(vector: &DVector<f64>, field: &'static str) -> Result<(), SolveError> {
543 crate::validate::finite_slice(vector.as_slice(), field).map_err(map_field_error)
544}
545
546fn validate_matrix(matrix: &DMatrix<f64>, field: &'static str) -> Result<(), SolveError> {
547 crate::validate::finite_slice(matrix.as_slice(), field).map_err(map_field_error)
548}
549
550fn map_field_error(error: crate::validate::FieldError) -> SolveError {
551 invalid_input(error.field(), error.reason())
552}
553
554fn invalid_input(field: &'static str, reason: &'static str) -> SolveError {
555 SolveError::InvalidInput { field, reason }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn fd_rel_step_is_sqrt_eps() {
564 assert_eq!(FD_REL_STEP_2POINT, (2.0_f64.powi(-52)).sqrt());
565 assert_eq!(FD_REL_STEP_2POINT, 2.0_f64.powi(-26));
566 }
567
568 #[test]
569 fn fd_step_sign_convention() {
570 let x0 = DVector::from_vec(vec![5.0, -2.0, 0.0]);
571 let steps = fd_steps(&x0, FD_REL_STEP_2POINT).unwrap();
572 assert_eq!(steps[0].sign_x0, 1.0);
573 assert_eq!(steps[1].sign_x0, -1.0);
574 assert_eq!(steps[2].sign_x0, 1.0); }
576
577 #[test]
578 fn fd_steps_rejects_zero_relative_step() {
579 let x0 = DVector::from_vec(vec![1.0]);
580 assert_invalid_field(fd_steps(&x0, 0.0).unwrap_err(), "rel_step");
581 }
582
583 #[test]
584 fn fd_steps_rejects_nonfinite_parameters() {
585 let x0 = DVector::from_vec(vec![1.0, f64::NAN]);
586 assert_invalid_field(fd_steps(&x0, FD_REL_STEP_2POINT).unwrap_err(), "parameters");
587 }
588
589 #[test]
590 fn jacobian_rejects_residual_length_mismatch() {
591 let x0 = DVector::from_vec(vec![1.0, 2.0]);
592 let f0 = DVector::from_vec(vec![1.0, 2.0]);
593 let residual = |_: &DVector<f64>| DVector::from_vec(vec![1.0]);
594 assert_invalid_field(jacobian_2point(residual, &x0, &f0).unwrap_err(), "residual");
595 }
596
597 #[test]
598 fn cost_rejects_nonfinite_residual() {
599 assert_invalid_field(
600 cost(&DVector::from_vec(vec![1.0, f64::INFINITY])).unwrap_err(),
601 "residual",
602 );
603 }
604
605 #[test]
606 fn exp_fit_converges() {
607 let t = vec![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0];
609 let y = vec![
610 3.0123, 2.2083, 1.6889, 1.3713, 1.0903, 0.9302, 0.8104, 0.6303,
611 ];
612 let tt = t.clone();
613 let yy = y.clone();
614 let residual = move |p: &DVector<f64>| {
615 let (a, b, c) = (p[0], p[1], p[2]);
616 DVector::from_iterator(
617 tt.len(),
618 tt.iter()
619 .zip(&yy)
620 .map(|(&tk, &yk)| a * (b * tk).exp() + c - yk),
621 )
622 };
623 let problem = LeastSquaresProblem::new(residual, DVector::from_vec(vec![5.0, -2.0, 2.0]));
624 let report = solve_trf(&problem, &SolveOptions::default()).unwrap();
625 assert!(report.cost < 1.0, "cost did not reduce: {}", report.cost);
626 }
627
628 #[test]
629 fn solve_trf_rejects_nonfinite_initial_residual() {
630 fn residual(_: &DVector<f64>) -> DVector<f64> {
631 DVector::from_element(1, f64::NAN)
632 }
633 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 0.0));
634 assert_invalid_field(
635 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
636 "residual",
637 );
638 }
639
640 #[test]
641 fn solve_trf_rejects_nonfinite_initial_cost() {
642 fn residual(_: &DVector<f64>) -> DVector<f64> {
643 DVector::from_element(1, f64::MAX)
644 }
645 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 0.0));
646 assert_invalid_field(
647 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
648 "cost",
649 );
650 }
651
652 #[test]
653 fn solve_trf_rejects_nonfinite_trial_residual_instead_of_converging() {
654 use std::cell::Cell;
655
656 let calls = Cell::new(0usize);
657 let residual = move |p: &DVector<f64>| {
658 let call = calls.get();
659 calls.set(call + 1);
660 if call >= 2 {
661 DVector::from_element(1, f64::NAN)
662 } else {
663 DVector::from_element(1, p[0])
664 }
665 };
666 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 1.0));
667 assert_invalid_field(
668 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
669 "residual",
670 );
671 }
672
673 #[test]
674 fn solve_trf_rejects_invalid_options() {
675 fn residual(p: &DVector<f64>) -> DVector<f64> {
676 DVector::from_element(1, p[0])
677 }
678 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 1.0));
679 let opts = SolveOptions {
680 gtol: f64::NAN,
681 ..SolveOptions::default()
682 };
683 assert_invalid_field(solve_trf(&problem, &opts).unwrap_err(), "gtol");
684
685 let opts = SolveOptions {
686 max_nfev: 0,
687 ..SolveOptions::default()
688 };
689 assert_invalid_field(solve_trf(&problem, &opts).unwrap_err(), "max_nfev");
690 }
691
692 #[test]
693 fn solve_trf_rejects_weight_residual_dimension_mismatch() {
694 fn residual(_: &DVector<f64>) -> DVector<f64> {
695 DVector::from_vec(vec![1.0, 2.0])
696 }
697 let problem = LeastSquaresProblem::with_weights(
698 residual,
699 DVector::from_element(1, 0.0),
700 DVector::from_vec(vec![1.0]),
701 );
702 assert_invalid_field(
703 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
704 "weights",
705 );
706 }
707
708 fn assert_invalid_field(error: SolveError, expected: &'static str) {
709 match error {
710 SolveError::InvalidInput { field, .. } => assert_eq!(field, expected),
711 other => panic!("expected invalid input for {expected}, got {other:?}"),
712 }
713 }
714
715 fn exp_fit_problem() -> LeastSquaresProblem<impl Fn(&DVector<f64>) -> DVector<f64>> {
717 let t = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0];
718 let y = [
719 3.0123, 2.2083, 1.6889, 1.3713, 1.0903, 0.9302, 0.8104, 0.6303,
720 ];
721 let residual = move |p: &DVector<f64>| {
722 let (a, b, c) = (p[0], p[1], p[2]);
723 DVector::from_iterator(
724 t.len(),
725 t.iter()
726 .zip(&y)
727 .map(|(&tk, &yk)| a * (b * tk).exp() + c - yk),
728 )
729 };
730 LeastSquaresProblem::new(residual, DVector::from_vec(vec![5.0, -2.0, 2.0]))
731 }
732
733 #[test]
742 fn owned_trf_converges_to_frozen_bits() {
743 let problem = exp_fit_problem();
744 let report = solve_trf_with(
745 &problem,
746 &SolveOptions::default(),
747 TrustRegionSolve::OwnedGaussianFirstTie,
748 )
749 .unwrap();
750 assert!(
751 report.cost < 1.0,
752 "owned cost did not reduce: {}",
753 report.cost
754 );
755 assert_eq!(report.x[0].to_bits(), 0x4003c3674cdfadef);
756 assert_eq!(report.x[1].to_bits(), 0xbfe799e0d1929220);
757 assert_eq!(report.x[2].to_bits(), 0x3fe0d5c96d9d3b35);
758
759 let again = solve_trf_with(
761 &problem,
762 &SolveOptions::default(),
763 TrustRegionSolve::OwnedGaussianFirstTie,
764 )
765 .unwrap();
766 for i in 0..3 {
767 assert_eq!(report.x[i].to_bits(), again.x[i].to_bits());
768 }
769 }
770}