tiny_solver/
parameter_block.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use nalgebra as na;
7use num_dual::DualDVec64;
8
9use crate::manifold::Manifold;
10
11#[derive(Clone)]
12pub struct ParameterBlock {
13    pub params: na::DVector<f64>,
14    pub fixed_variables: HashSet<usize>,
15    pub variable_bounds: HashMap<usize, (f64, f64)>,
16    pub manifold: Option<Arc<dyn Manifold + Sync + Send>>,
17}
18
19impl ParameterBlock {
20    pub fn from_vec(params: na::DVector<f64>) -> Self {
21        ParameterBlock {
22            params,
23            fixed_variables: HashSet::new(),
24            variable_bounds: HashMap::new(),
25            manifold: None,
26        }
27    }
28    pub fn set_manifold(&mut self, manifold: Arc<dyn Manifold + Sync + Send>) {
29        self.manifold = Some(manifold);
30    }
31    pub fn ambient_size(&self) -> usize {
32        self.params.shape().0
33    }
34    pub fn tangent_size(&self) -> usize {
35        if let Some(m) = &self.manifold {
36            m.tangent_size().get()
37        } else {
38            self.ambient_size()
39        }
40    }
41    pub fn plus_f64(&self, dx: na::DVectorView<f64>) -> na::DVector<f64> {
42        let mut new_param = na::DVector::zeros(self.ambient_size());
43        if let Some(m) = &self.manifold {
44            new_param = m.plus_f64(self.params.as_view(), dx);
45        } else {
46            self.params.add_to(&dx, &mut new_param);
47        }
48        new_param
49    }
50    pub fn plus_dual(&self, dx: na::DVectorView<DualDVec64>) -> na::DVector<DualDVec64> {
51        let mut new_param = na::DVector::zeros(self.ambient_size());
52        if let Some(m) = &self.manifold {
53            new_param = m.plus_dual(self.params.clone().cast::<DualDVec64>().as_view(), dx);
54        } else {
55            self.params.clone().cast().add_to(&dx, &mut new_param);
56        }
57        new_param
58    }
59    pub fn y_minus_f64(&self, y: na::DVectorView<f64>) -> na::DVector<f64> {
60        let mut delta_x = na::DVector::zeros(self.tangent_size());
61        if let Some(m) = &self.manifold {
62            delta_x = m.minus_f64(y, self.params.as_view());
63        } else {
64            y.sub_to(&self.params, &mut delta_x);
65        }
66        delta_x
67    }
68    pub fn y_minus_dual(&self, y: na::DVectorView<DualDVec64>) -> na::DVector<DualDVec64> {
69        let mut delta_x = na::DVector::zeros(self.tangent_size());
70        if let Some(m) = &self.manifold {
71            delta_x = m.minus_dual(y, self.params.clone().cast().as_view());
72        } else {
73            y.sub_to(&self.params.clone().cast(), &mut delta_x);
74        }
75        delta_x
76    }
77    pub fn update_params(&mut self, mut new_param: na::DVector<f64>) {
78        // bound
79        for (&idx, &(lower, upper)) in &self.variable_bounds {
80            new_param[idx] = new_param[idx].max(lower).min(upper);
81        }
82
83        // fix
84        for &index_to_fix in &self.fixed_variables {
85            new_param[index_to_fix] = self.params[index_to_fix];
86        }
87        self.params = new_param;
88    }
89}