use std::ops::Mul;
use nalgebra as na;
use num_dual;
pub struct SolverParameters {
gradient_threshold: f64,
relative_step_threshold: f64,
error_threshold: f64,
initial_scale_factor: f64,
max_iterations: usize,
}
impl SolverParameters {
pub fn defualt() -> SolverParameters {
SolverParameters {
gradient_threshold: 1e-16,
relative_step_threshold: 1e-16,
error_threshold: 1e-16,
initial_scale_factor: 1e-3,
max_iterations: 100,
}
}
}
#[derive(PartialEq, Debug)]
enum SolverStatus {
Running,
GradientTooSmall, RelativeStepSizeTooSmall, ErrorTooSmall, HitMaxIterations,
}
#[derive(Debug)]
pub struct ProblemResult {
error_magnitude: f64, gradient_magnitude: f64, num_failed_linear_solves: usize,
iterations: usize,
status: SolverStatus,
}
impl ProblemResult {
fn new() -> ProblemResult {
ProblemResult {
error_magnitude: 0.0,
gradient_magnitude: 0.0,
num_failed_linear_solves: 0,
iterations: 0,
status: SolverStatus::Running,
}
}
}
pub trait TinySolver<const NUM_PARAMETERS: usize, const NUM_RESIDUALS: usize> {
fn cost_function(
_params: na::SVector<num_dual::DualSVec64<NUM_PARAMETERS>, NUM_PARAMETERS>,
) -> na::SVector<num_dual::DualSVec64<NUM_PARAMETERS>, NUM_RESIDUALS>;
fn solve_inplace(params: &mut na::SVector<f64, NUM_PARAMETERS>) -> ProblemResult {
let solver_params = SolverParameters::defualt();
let mut result = ProblemResult::new();
let mut u: f64 = 0.0;
let mut v = 2;
let mut residual = na::SMatrix::<f64, NUM_RESIDUALS, 1>::zeros();
let mut gradient = na::SMatrix::<f64, NUM_PARAMETERS, 1>::zeros();
let mut jac: na::SMatrix<f64, NUM_RESIDUALS, NUM_PARAMETERS>;
for step in 0..solver_params.max_iterations {
result.iterations = step + 1;
(residual, jac) = num_dual::jacobian(Self::cost_function, params.clone());
gradient = jac.transpose().mul(-residual);
let jtj = jac.transpose().mul(jac);
let max_gradient = gradient.abs().max();
if max_gradient < solver_params.gradient_threshold {
println!("gradient too small. {}", max_gradient);
result.status = SolverStatus::GradientTooSmall;
break;
} else if residual.norm() < solver_params.error_threshold {
result.status = SolverStatus::ErrorTooSmall;
break;
}
if step == 0 {
u = solver_params.initial_scale_factor * jtj.diagonal().max();
v = 2;
}
let mut jtj_augmented = na::DMatrix::<f64>::zeros(NUM_PARAMETERS, NUM_PARAMETERS);
jtj_augmented.copy_from(&jtj);
jtj_augmented.set_diagonal(&jtj_augmented.diagonal().add_scalar(u));
let dx = na::linalg::LU::new(jtj_augmented.clone())
.solve(&gradient)
.unwrap();
let solution: na::SMatrix<f64, NUM_PARAMETERS, 1> = jtj_augmented.fixed_view(0, 0) * dx;
let solved = (solution - gradient).abs().min() < solver_params.error_threshold;
if solved {
if dx.norm() < solver_params.relative_step_threshold * params.norm() {
result.status = SolverStatus::RelativeStepSizeTooSmall;
break;
}
let param_new = *params + dx;
let residual_new =
Self::cost_function(param_new.map(num_dual::DualSVec64::from_re)).map(|x| x.re);
let rho: f64 = (residual.norm_squared() - residual_new.norm_squared())
/ dx.dot(&(u * dx + gradient));
if rho > 0.0 {
*params = param_new;
let tmp: f64 = 2.0 * rho - 1.0;
u = u * (1.0_f64 / 3.0).max(1.0 - tmp.powi(3));
v = 2;
continue;
}
} else {
result.num_failed_linear_solves += 1;
println!("fail {}", solution - gradient);
}
u *= v as f64;
v *= 2;
}
if result.status == SolverStatus::Running {
result.status = SolverStatus::HitMaxIterations;
}
result.error_magnitude = residual.norm();
result.gradient_magnitude = gradient.norm();
result
}
}