1use nalgebra::{DMatrix, DVector};
57
58pub const FD_REL_STEP_2POINT: f64 = 1.4901161193847656e-8; const TRF_DEFAULT_GTOL: f64 = 1e-10;
65const TRF_DEFAULT_FTOL: f64 = 1e-8;
67const TRF_DEFAULT_XTOL: f64 = 1e-8;
69const TRF_DEFAULT_MAX_NFEV: usize = 300;
71const TRF_INITIAL_DAMPING_SCALE: f64 = 1e-3;
74
75#[derive(Debug, Clone, PartialEq)]
79pub struct FdStep {
80 pub param_index: usize,
82 pub sign_x0: f64,
85 pub h: f64,
87 pub dx: f64,
91 pub x_perturbed: DVector<f64>,
93}
94
95pub fn fd_steps(x0: &DVector<f64>, rel_step: f64) -> Result<Vec<FdStep>, SolveError> {
101 let rel_step = crate::validate::positive_step(rel_step, "rel_step").map_err(map_field_error)?;
102 fd_steps_checked(x0, rel_step)
103}
104
105fn fd_steps_checked(x0: &DVector<f64>, rel_step: f64) -> Result<Vec<FdStep>, SolveError> {
106 validate_nonempty_vector(x0, "parameters")?;
107 validate_vector(x0, "parameters")?;
108 let steps = fd_steps_unchecked(x0, rel_step);
109 for step in &steps {
110 validate_value(step.h, "fd_step")?;
111 validate_value(step.dx, "fd_step")?;
112 if step.dx == 0.0 {
113 return Err(invalid_input("fd_step", "zero"));
114 }
115 validate_vector(&step.x_perturbed, "perturbed parameters")?;
116 }
117 Ok(steps)
118}
119
120fn fd_steps_unchecked(x0: &DVector<f64>, rel_step: f64) -> Vec<FdStep> {
121 (0..x0.len())
122 .map(|i| {
123 let xi = x0[i];
124 let sign_x0 = if xi >= 0.0 { 1.0 } else { -1.0 };
125 let h = rel_step * sign_x0 * xi.abs().max(1.0);
126 let mut x_perturbed = x0.clone();
127 x_perturbed[i] = xi + h;
128 let dx = x_perturbed[i] - xi;
129 FdStep {
130 param_index: i,
131 sign_x0,
132 h,
133 dx,
134 x_perturbed,
135 }
136 })
137 .collect()
138}
139
140pub fn jacobian_2point<F>(
151 residual: F,
152 x0: &DVector<f64>,
153 f0: &DVector<f64>,
154) -> Result<DMatrix<f64>, SolveError>
155where
156 F: Fn(&DVector<f64>) -> DVector<f64>,
157{
158 jacobian_2point_checked(|x| Ok(residual(x)), x0, f0)
159}
160
161fn jacobian_2point_checked<F>(
162 residual: F,
163 x0: &DVector<f64>,
164 f0: &DVector<f64>,
165) -> Result<DMatrix<f64>, SolveError>
166where
167 F: Fn(&DVector<f64>) -> Result<DVector<f64>, SolveError>,
168{
169 validate_nonempty_vector(x0, "parameters")?;
170 validate_vector(x0, "parameters")?;
171 validate_nonempty_vector(f0, "residual")?;
172 validate_vector(f0, "residual")?;
173 let m = f0.len();
174 let n = x0.len();
175 let steps = fd_steps_checked(x0, FD_REL_STEP_2POINT)?;
176 let mut jac = DMatrix::zeros(m, n);
177 for step in &steps {
178 let f1 = residual(&step.x_perturbed)?;
179 validate_nonempty_vector(&f1, "residual")?;
180 validate_vector(&f1, "residual")?;
181 if f1.len() != m {
182 return Err(invalid_input("residual", "length mismatch"));
183 }
184 let i = step.param_index;
185 for row in 0..m {
186 jac[(row, i)] = (f1[row] - f0[row]) / step.dx;
187 }
188 }
189 validate_matrix(&jac, "jacobian")?;
190 Ok(jac)
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum Status {
197 GradientTolerance,
199 CostTolerance,
201 StepTolerance,
203 MaxEvaluations,
205}
206
207#[derive(Debug, Clone, Copy)]
209pub struct SolveOptions {
210 pub gtol: f64,
212 pub ftol: f64,
214 pub xtol: f64,
216 pub max_nfev: usize,
218}
219
220impl Default for SolveOptions {
221 fn default() -> Self {
222 Self {
224 gtol: TRF_DEFAULT_GTOL,
225 ftol: TRF_DEFAULT_FTOL,
226 xtol: TRF_DEFAULT_XTOL,
227 max_nfev: TRF_DEFAULT_MAX_NFEV,
228 }
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct LeastSquaresReport {
235 pub x: DVector<f64>,
237 pub residual: DVector<f64>,
239 pub cost: f64,
241 pub jacobian: DMatrix<f64>,
243 pub optimality_inf: f64,
245 pub iterations: usize,
247 pub status: Status,
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
259pub enum TrustRegionSolve {
260 #[default]
264 NalgebraLu,
265 OwnedGaussianFirstTie,
285}
286
287#[derive(Debug, Clone, thiserror::Error)]
289pub enum SolveError {
290 #[error("singular or rank-deficient Jacobian: no usable descent direction")]
293 SingularJacobian,
294 #[error("invalid least-squares {field}: {reason}")]
296 InvalidInput {
297 field: &'static str,
298 reason: &'static str,
299 },
300}
301
302pub fn cost(residual: &DVector<f64>) -> Result<f64, SolveError> {
304 validate_nonempty_vector(residual, "residual")?;
305 validate_vector(residual, "residual")?;
306 validate_value(0.5 * residual.dot(residual), "cost")
307}
308
309pub struct LeastSquaresProblem<F> {
314 residual: F,
315 sqrt_weights: Option<DVector<f64>>,
317 x0: DVector<f64>,
318}
319
320impl<F> LeastSquaresProblem<F>
321where
322 F: Fn(&DVector<f64>) -> DVector<f64>,
323{
324 pub fn new(residual: F, x0: DVector<f64>) -> Self {
326 Self {
327 residual,
328 sqrt_weights: None,
329 x0,
330 }
331 }
332
333 pub fn with_weights(residual: F, x0: DVector<f64>, weights: DVector<f64>) -> Self {
336 let sqrt_weights = weights.map(f64::sqrt);
337 Self {
338 residual,
339 sqrt_weights: Some(sqrt_weights),
340 x0,
341 }
342 }
343
344 fn weighted_residual(&self, x: &DVector<f64>) -> Result<DVector<f64>, SolveError> {
346 validate_nonempty_vector(x, "parameters")?;
347 validate_vector(x, "parameters")?;
348 let r = (self.residual)(x);
349 validate_nonempty_vector(&r, "residual")?;
350 validate_vector(&r, "residual")?;
351 match &self.sqrt_weights {
352 Some(sw) => {
353 validate_nonempty_vector(sw, "weights")?;
354 validate_vector(sw, "weights")?;
355 if sw.len() != r.len() {
356 return Err(invalid_input("weights", "length mismatch"));
357 }
358 let weighted = r.component_mul(sw);
359 validate_vector(&weighted, "weighted residual")?;
360 Ok(weighted)
361 }
362 None => Ok(r),
363 }
364 }
365}
366
367pub fn solve_trf<F>(
382 problem: &LeastSquaresProblem<F>,
383 opts: &SolveOptions,
384) -> Result<LeastSquaresReport, SolveError>
385where
386 F: Fn(&DVector<f64>) -> DVector<f64>,
387{
388 solve_trf_with(problem, opts, TrustRegionSolve::NalgebraLu)
389}
390
391fn solve_subproblem(
395 lhs: &DMatrix<f64>,
396 rhs: &DVector<f64>,
397 linear_solve: TrustRegionSolve,
398) -> Option<DVector<f64>> {
399 match linear_solve {
400 TrustRegionSolve::NalgebraLu => lhs.clone().lu().solve(rhs),
401 TrustRegionSolve::OwnedGaussianFirstTie => {
402 let n = rhs.len();
403 let a: Vec<Vec<f64>> = (0..n)
404 .map(|i| (0..n).map(|j| lhs[(i, j)]).collect())
405 .collect();
406 let b: Vec<f64> = rhs.iter().copied().collect();
407 crate::astro::math::linear::solve_linear_first_tie(&a, &b).map(DVector::from_vec)
408 }
409 }
410}
411
412pub fn solve_trf_with<F>(
416 problem: &LeastSquaresProblem<F>,
417 opts: &SolveOptions,
418 linear_solve: TrustRegionSolve,
419) -> Result<LeastSquaresReport, SolveError>
420where
421 F: Fn(&DVector<f64>) -> DVector<f64>,
422{
423 validate_options(opts)?;
424 let n = problem.x0.len();
425
426 let mut x = problem.x0.clone();
427 validate_nonempty_vector(&x, "initial parameters")?;
428 validate_vector(&x, "initial parameters")?;
429 let mut r = problem.weighted_residual(&x)?;
430 let mut f0 = r.clone();
431 let mut jac = jacobian_2point_checked(|p| problem.weighted_residual(p), &x, &f0)?;
432 let mut nfev = 1usize; let mut cur_cost = cost(&r)?;
434
435 let jtj0 = jac.transpose() * &jac;
437 validate_matrix(&jtj0, "normal matrix")?;
438 let mut mu = TRF_INITIAL_DAMPING_SCALE
439 * (0..n)
440 .map(|i| jtj0[(i, i)])
441 .fold(0.0_f64, f64::max)
442 .max(1.0);
443
444 let mut iterations = 0usize;
445
446 loop {
447 let jt = jac.transpose();
448 let grad = &jt * &r;
449 validate_vector(&grad, "gradient")?;
450 let optimality_inf = validate_value(grad.amax(), "optimality")?;
451
452 if optimality_inf < opts.gtol {
453 return finish(x, r, cur_cost, jac, iterations, Status::GradientTolerance);
454 }
455 if nfev >= opts.max_nfev {
456 return finish(x, r, cur_cost, jac, iterations, Status::MaxEvaluations);
457 }
458
459 let jtj = &jt * &jac;
460 validate_matrix(&jtj, "normal matrix")?;
461
462 let mut accepted = false;
464 for _ in 0..30 {
465 let mut lhs = jtj.clone();
466 for i in 0..n {
467 lhs[(i, i)] += mu;
468 }
469 let rhs = -&grad;
470 validate_matrix(&lhs, "subproblem matrix")?;
471 validate_vector(&rhs, "subproblem rhs")?;
472 let step = match solve_subproblem(&lhs, &rhs, linear_solve) {
473 Some(s) => s,
474 None => return Err(SolveError::SingularJacobian),
475 };
476 validate_vector(&step, "step")?;
477
478 let x_trial = &x + &step;
479 let r_trial = problem.weighted_residual(&x_trial)?;
480 nfev += 1;
481 let cost_trial = cost(&r_trial)?;
482
483 if cost_trial < cur_cost {
484 let cost_reduction = (cur_cost - cost_trial) / cur_cost.max(f64::MIN_POSITIVE);
486 let step_norm = step.norm();
487 let x_norm = x.norm();
488 let rel_step = step_norm / x_norm.max(f64::MIN_POSITIVE);
489
490 x = x_trial;
491 r = r_trial;
492 cur_cost = cost_trial;
493 f0 = r.clone();
494 jac = jacobian_2point_checked(|p| problem.weighted_residual(p), &x, &f0)?;
495 nfev += n; iterations += 1;
497 mu *= 0.5;
498 accepted = true;
499
500 if cost_reduction < opts.ftol {
501 return finish(x, r, cur_cost, jac, iterations, Status::CostTolerance);
502 }
503 if rel_step < opts.xtol {
504 return finish(x, r, cur_cost, jac, iterations, Status::StepTolerance);
505 }
506 break;
507 } else {
508 mu *= 2.0;
510 }
511 }
512
513 if !accepted {
514 return finish(x, r, cur_cost, jac, iterations, Status::StepTolerance);
516 }
517 }
518}
519
520fn finish(
521 x: DVector<f64>,
522 residual: DVector<f64>,
523 cost_value: f64,
524 jacobian: DMatrix<f64>,
525 iterations: usize,
526 status: Status,
527) -> Result<LeastSquaresReport, SolveError> {
528 validate_nonempty_vector(&x, "solution")?;
529 validate_vector(&x, "solution")?;
530 validate_nonempty_vector(&residual, "residual")?;
531 validate_vector(&residual, "residual")?;
532 validate_value(cost_value, "cost")?;
533 validate_matrix(&jacobian, "jacobian")?;
534 let optimality_inf = validate_value((jacobian.transpose() * &residual).amax(), "optimality")?;
535 Ok(LeastSquaresReport {
536 x,
537 residual,
538 cost: cost_value,
539 jacobian,
540 optimality_inf,
541 iterations,
542 status,
543 })
544}
545
546fn validate_value(value: f64, field: &'static str) -> Result<f64, SolveError> {
547 crate::validate::finite(value, field).map_err(map_field_error)
548}
549
550fn validate_options(opts: &SolveOptions) -> Result<(), SolveError> {
551 crate::validate::positive_step(opts.gtol, "gtol").map_err(map_field_error)?;
552 crate::validate::positive_step(opts.ftol, "ftol").map_err(map_field_error)?;
553 crate::validate::positive_step(opts.xtol, "xtol").map_err(map_field_error)?;
554 if opts.max_nfev == 0 {
555 return Err(invalid_input("max_nfev", "not positive"));
556 }
557 Ok(())
558}
559
560fn validate_nonempty_vector(vector: &DVector<f64>, field: &'static str) -> Result<(), SolveError> {
561 if vector.is_empty() {
562 Err(invalid_input(field, "empty"))
563 } else {
564 Ok(())
565 }
566}
567
568fn validate_vector(vector: &DVector<f64>, field: &'static str) -> Result<(), SolveError> {
569 crate::validate::finite_slice(vector.as_slice(), field).map_err(map_field_error)
570}
571
572fn validate_matrix(matrix: &DMatrix<f64>, field: &'static str) -> Result<(), SolveError> {
573 crate::validate::finite_slice(matrix.as_slice(), field).map_err(map_field_error)
574}
575
576fn map_field_error(error: crate::validate::FieldError) -> SolveError {
577 invalid_input(error.field(), error.reason())
578}
579
580fn invalid_input(field: &'static str, reason: &'static str) -> SolveError {
581 SolveError::InvalidInput { field, reason }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587
588 #[test]
589 fn fd_rel_step_is_sqrt_eps() {
590 assert_eq!(FD_REL_STEP_2POINT, (2.0_f64.powi(-52)).sqrt());
591 assert_eq!(FD_REL_STEP_2POINT, 2.0_f64.powi(-26));
592 }
593
594 #[test]
595 fn fd_step_sign_convention() {
596 let x0 = DVector::from_vec(vec![5.0, -2.0, 0.0]);
597 let steps = fd_steps(&x0, FD_REL_STEP_2POINT).unwrap();
598 assert_eq!(steps[0].sign_x0, 1.0);
599 assert_eq!(steps[1].sign_x0, -1.0);
600 assert_eq!(steps[2].sign_x0, 1.0); }
602
603 #[test]
604 fn fd_steps_rejects_zero_relative_step() {
605 let x0 = DVector::from_vec(vec![1.0]);
606 assert_invalid_field(fd_steps(&x0, 0.0).unwrap_err(), "rel_step");
607 }
608
609 #[test]
610 fn fd_steps_rejects_nonfinite_parameters() {
611 let x0 = DVector::from_vec(vec![1.0, f64::NAN]);
612 assert_invalid_field(fd_steps(&x0, FD_REL_STEP_2POINT).unwrap_err(), "parameters");
613 }
614
615 #[test]
616 fn jacobian_rejects_residual_length_mismatch() {
617 let x0 = DVector::from_vec(vec![1.0, 2.0]);
618 let f0 = DVector::from_vec(vec![1.0, 2.0]);
619 let residual = |_: &DVector<f64>| DVector::from_vec(vec![1.0]);
620 assert_invalid_field(jacobian_2point(residual, &x0, &f0).unwrap_err(), "residual");
621 }
622
623 #[test]
624 fn cost_rejects_nonfinite_residual() {
625 assert_invalid_field(
626 cost(&DVector::from_vec(vec![1.0, f64::INFINITY])).unwrap_err(),
627 "residual",
628 );
629 }
630
631 #[test]
632 fn exp_fit_converges() {
633 let t = vec![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0];
635 let y = vec![
636 3.0123, 2.2083, 1.6889, 1.3713, 1.0903, 0.9302, 0.8104, 0.6303,
637 ];
638 let tt = t.clone();
639 let yy = y.clone();
640 let residual = move |p: &DVector<f64>| {
641 let (a, b, c) = (p[0], p[1], p[2]);
642 DVector::from_iterator(
643 tt.len(),
644 tt.iter()
645 .zip(&yy)
646 .map(|(&tk, &yk)| a * (b * tk).exp() + c - yk),
647 )
648 };
649 let problem = LeastSquaresProblem::new(residual, DVector::from_vec(vec![5.0, -2.0, 2.0]));
650 let report = solve_trf(&problem, &SolveOptions::default()).unwrap();
651 assert!(report.cost < 1.0, "cost did not reduce: {}", report.cost);
652 }
653
654 #[test]
655 fn solve_trf_rejects_nonfinite_initial_residual() {
656 fn residual(_: &DVector<f64>) -> DVector<f64> {
657 DVector::from_element(1, f64::NAN)
658 }
659 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 0.0));
660 assert_invalid_field(
661 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
662 "residual",
663 );
664 }
665
666 #[test]
667 fn solve_trf_rejects_nonfinite_initial_cost() {
668 fn residual(_: &DVector<f64>) -> DVector<f64> {
669 DVector::from_element(1, f64::MAX)
670 }
671 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 0.0));
672 assert_invalid_field(
673 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
674 "cost",
675 );
676 }
677
678 #[test]
679 fn solve_trf_rejects_nonfinite_trial_residual_instead_of_converging() {
680 use std::cell::Cell;
681
682 let calls = Cell::new(0usize);
683 let residual = move |p: &DVector<f64>| {
684 let call = calls.get();
685 calls.set(call + 1);
686 if call >= 2 {
687 DVector::from_element(1, f64::NAN)
688 } else {
689 DVector::from_element(1, p[0])
690 }
691 };
692 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 1.0));
693 assert_invalid_field(
694 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
695 "residual",
696 );
697 }
698
699 #[test]
700 fn solve_trf_rejects_invalid_options() {
701 fn residual(p: &DVector<f64>) -> DVector<f64> {
702 DVector::from_element(1, p[0])
703 }
704 let problem = LeastSquaresProblem::new(residual, DVector::from_element(1, 1.0));
705 let opts = SolveOptions {
706 gtol: f64::NAN,
707 ..SolveOptions::default()
708 };
709 assert_invalid_field(solve_trf(&problem, &opts).unwrap_err(), "gtol");
710
711 let opts = SolveOptions {
712 max_nfev: 0,
713 ..SolveOptions::default()
714 };
715 assert_invalid_field(solve_trf(&problem, &opts).unwrap_err(), "max_nfev");
716 }
717
718 #[test]
719 fn solve_trf_rejects_weight_residual_dimension_mismatch() {
720 fn residual(_: &DVector<f64>) -> DVector<f64> {
721 DVector::from_vec(vec![1.0, 2.0])
722 }
723 let problem = LeastSquaresProblem::with_weights(
724 residual,
725 DVector::from_element(1, 0.0),
726 DVector::from_vec(vec![1.0]),
727 );
728 assert_invalid_field(
729 solve_trf(&problem, &SolveOptions::default()).unwrap_err(),
730 "weights",
731 );
732 }
733
734 fn assert_invalid_field(error: SolveError, expected: &'static str) {
735 match error {
736 SolveError::InvalidInput { field, .. } => assert_eq!(field, expected),
737 other => panic!("expected invalid input for {expected}, got {other:?}"),
738 }
739 }
740
741 fn exp_fit_problem() -> LeastSquaresProblem<impl Fn(&DVector<f64>) -> DVector<f64>> {
743 let t = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0];
744 let y = [
745 3.0123, 2.2083, 1.6889, 1.3713, 1.0903, 0.9302, 0.8104, 0.6303,
746 ];
747 let residual = move |p: &DVector<f64>| {
748 let (a, b, c) = (p[0], p[1], p[2]);
749 DVector::from_iterator(
750 t.len(),
751 t.iter()
752 .zip(&y)
753 .map(|(&tk, &yk)| a * (b * tk).exp() + c - yk),
754 )
755 };
756 LeastSquaresProblem::new(residual, DVector::from_vec(vec![5.0, -2.0, 2.0]))
757 }
758
759 #[test]
768 fn owned_trf_converges_to_frozen_bits() {
769 let problem = exp_fit_problem();
770 let report = solve_trf_with(
771 &problem,
772 &SolveOptions::default(),
773 TrustRegionSolve::OwnedGaussianFirstTie,
774 )
775 .unwrap();
776 assert!(
777 report.cost < 1.0,
778 "owned cost did not reduce: {}",
779 report.cost
780 );
781 assert_eq!(report.x[0].to_bits(), 0x4003c3674cdfadef);
782 assert_eq!(report.x[1].to_bits(), 0xbfe799e0d1929220);
783 assert_eq!(report.x[2].to_bits(), 0x3fe0d5c96d9d3b35);
784
785 let again = solve_trf_with(
787 &problem,
788 &SolveOptions::default(),
789 TrustRegionSolve::OwnedGaussianFirstTie,
790 )
791 .unwrap();
792 for i in 0..3 {
793 assert_eq!(report.x[i].to_bits(), again.x[i].to_bits());
794 }
795 }
796}