tiny_solver/
parameter_block.rs1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use nalgebra as na;
7use num_dual::DualDVec64;
8
9use crate::manifold::Manifold;
10
11#[derive(Clone)]
12pub struct ParameterBlock {
13 pub params: na::DVector<f64>,
14 pub fixed_variables: HashSet<usize>,
15 pub variable_bounds: HashMap<usize, (f64, f64)>,
16 pub manifold: Option<Arc<dyn Manifold + Sync + Send>>,
17}
18
19impl ParameterBlock {
20 pub fn from_vec(params: na::DVector<f64>) -> Self {
21 ParameterBlock {
22 params,
23 fixed_variables: HashSet::new(),
24 variable_bounds: HashMap::new(),
25 manifold: None,
26 }
27 }
28 pub fn set_manifold(&mut self, manifold: Arc<dyn Manifold + Sync + Send>) {
29 self.manifold = Some(manifold);
30 }
31 pub fn ambient_size(&self) -> usize {
32 self.params.shape().0
33 }
34 pub fn tangent_size(&self) -> usize {
35 if let Some(m) = &self.manifold {
36 m.tangent_size().get()
37 } else {
38 self.ambient_size()
39 }
40 }
41 pub fn plus_f64(&self, dx: na::DVectorView<f64>) -> na::DVector<f64> {
42 let mut new_param = na::DVector::zeros(self.ambient_size());
43 if let Some(m) = &self.manifold {
44 new_param = m.plus_f64(self.params.as_view(), dx);
45 } else {
46 self.params.add_to(&dx, &mut new_param);
47 }
48 new_param
49 }
50 pub fn plus_dual(&self, dx: na::DVectorView<DualDVec64>) -> na::DVector<DualDVec64> {
51 let mut new_param = na::DVector::zeros(self.ambient_size());
52 if let Some(m) = &self.manifold {
53 new_param = m.plus_dual(self.params.clone().cast::<DualDVec64>().as_view(), dx);
54 } else {
55 self.params.clone().cast().add_to(&dx, &mut new_param);
56 }
57 new_param
58 }
59 pub fn y_minus_f64(&self, y: na::DVectorView<f64>) -> na::DVector<f64> {
60 let mut delta_x = na::DVector::zeros(self.tangent_size());
61 if let Some(m) = &self.manifold {
62 delta_x = m.minus_f64(y, self.params.as_view());
63 } else {
64 y.sub_to(&self.params, &mut delta_x);
65 }
66 delta_x
67 }
68 pub fn y_minus_dual(&self, y: na::DVectorView<DualDVec64>) -> na::DVector<DualDVec64> {
69 let mut delta_x = na::DVector::zeros(self.tangent_size());
70 if let Some(m) = &self.manifold {
71 delta_x = m.minus_dual(y, self.params.clone().cast().as_view());
72 } else {
73 y.sub_to(&self.params.clone().cast(), &mut delta_x);
74 }
75 delta_x
76 }
77 pub fn update_params(&mut self, mut new_param: na::DVector<f64>) {
78 for (&idx, &(lower, upper)) in &self.variable_bounds {
80 new_param[idx] = new_param[idx].max(lower).min(upper);
81 }
82
83 for &index_to_fix in &self.fixed_variables {
85 new_param[index_to_fix] = self.params[index_to_fix];
86 }
87 self.params = new_param;
88 }
89}