tinympc_rs/
policy.rs

1use nalgebra::{RealField, SMatrix, Scalar, convert};
2
3/// Errors that can occur during policy setup
4#[derive(Debug, PartialEq, Clone, Copy)]
5pub enum Error {
6    /// The value of rho must be strictly positive `(rho > 0)`
7    RhoNotPositive,
8    /// The matrix `R_aug + B^T * P * B` is not invertible
9    RpBPBNotInvertible,
10    /// The resulting matrices contained non-finite elements (Inf or NaN)
11    NonFiniteValues,
12}
13
14pub trait Policy<T, const NX: usize, const NU: usize> {
15    /// Updates which policy is active by evaluating the primal and dual residuals.
16    ///
17    /// Returns: A scalar (`old_rho/new_rho`) to be applied to constraint duals in case the policy changed
18    fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T>;
19
20    /// Get a reference to the currently active policy.
21    fn get_active(&self) -> &FixedPolicy<T, NX, NU>;
22}
23
24/// Contains all pre-computed values for a given problem and value of rho.
25#[derive(Debug)]
26pub struct FixedPolicy<T, const NX: usize, const NU: usize> {
27    /// Penalty-parameter for this policy
28    pub(crate) rho: T,
29
30    /// (Negated) Infinite-time horizon LQR gain
31    pub(crate) nKlqr: SMatrix<T, NU, NX>,
32
33    /// Infinite-time horizon LQR cost-to-go
34    pub(crate) Plqr: SMatrix<T, NX, NX>,
35
36    /// Precomputed `inv(R_aug + B^T * Plqr * B)`
37    pub(crate) RpBPBi: SMatrix<T, NU, NU>,
38
39    /// Precomputed `(A - B * Klqr)^T`
40    pub(crate) AmBKt: SMatrix<T, NX, NX>,
41}
42
43impl<T, const NX: usize, const NU: usize> FixedPolicy<T, NX, NU>
44where
45    T: Scalar + RealField + Copy,
46{
47    /// Create a new `FixedPolicy`.
48    ///
49    /// # Errors
50    ///
51    /// If the calculated LQR gain is not invertible, or any of the calculated values are not normal.
52    pub fn new(
53        rho: T,
54        iters: usize,
55        A: &SMatrix<T, NX, NX>,
56        B: &SMatrix<T, NX, NU>,
57        Q: &SMatrix<T, NX, NX>,
58        R: &SMatrix<T, NU, NU>,
59        S: &SMatrix<T, NX, NU>,
60    ) -> Result<Self, Error> {
61        if !rho.is_positive() {
62            return Err(Error::RhoNotPositive);
63        }
64
65        let Q = Q.symmetric_part();
66        let R = R.symmetric_part();
67
68        // ADMM-augmented cost matrices for LQR problem
69        let Q_aug = Q + SMatrix::from_diagonal_element(rho);
70        let R_aug = R + SMatrix::from_diagonal_element(rho);
71
72        let mut Klqr = SMatrix::zeros();
73        let mut Plqr = Q_aug.clone_owned();
74
75        for _ in 0..iters {
76            Klqr = (R_aug + B.transpose() * Plqr * B)
77                .try_inverse()
78                .ok_or(Error::RpBPBNotInvertible)?
79                * (S.transpose() + B.transpose() * Plqr * A);
80            Plqr = A.transpose() * Plqr * A - A.transpose() * Plqr * B * Klqr + Q_aug;
81        }
82
83        let RpBPBi = (R_aug + B.transpose() * Plqr * B)
84            .try_inverse()
85            .ok_or(Error::RpBPBNotInvertible)?;
86        let AmBKt = (A - B * Klqr).transpose();
87        let nKlqr = -Klqr;
88
89        // If RpBPBi and AmBKt are finite, so are all the other values
90        ([].iter())
91            .chain(RpBPBi.iter())
92            .chain(AmBKt.iter())
93            .all(nalgebra::ComplexField::is_finite)
94            .then_some(FixedPolicy {
95                rho,
96                nKlqr,
97                Plqr,
98                RpBPBi,
99                AmBKt,
100            })
101            .ok_or(Error::NonFiniteValues)
102    }
103}
104
105impl<T, const NX: usize, const NU: usize> Policy<T, NX, NU> for FixedPolicy<T, NX, NU>
106where
107    T: Scalar + RealField + Copy,
108{
109    fn update_active(&mut self, _prim_residual: T, _dual_residual: T) -> Option<T> {
110        None
111    }
112
113    fn get_active(&self) -> &FixedPolicy<T, NX, NU> {
114        self
115    }
116}
117
118/// Contains an array of pre-computed values for a given problem and value of rho
119#[derive(Debug)]
120pub struct ArrayPolicy<T, const NX: usize, const NU: usize, const NUM: usize> {
121    threshold: T,
122    active_index: usize,
123    policies: [FixedPolicy<T, NX, NU>; NUM],
124}
125
126impl<T, const NX: usize, const NU: usize, const NUM: usize> ArrayPolicy<T, NX, NU, NUM>
127where
128    T: Scalar + RealField + Copy,
129{
130    /// Create a new `ArrayPolicy` with a length of `NUM` .
131    ///
132    /// # Errors
133    ///
134    /// If any of calculated LQR gain are not invertible, or any of the calculated values are not normal.
135    #[allow(clippy::too_many_arguments)]
136    pub fn new(
137        central_rho: T,
138        threshold: T,
139        factor: T,
140        iters: usize,
141        A: &SMatrix<T, NX, NX>,
142        B: &SMatrix<T, NX, NU>,
143        Q: &SMatrix<T, NX, NX>,
144        R: &SMatrix<T, NU, NU>,
145        S: &SMatrix<T, NX, NU>,
146    ) -> Result<Self, Error> {
147        let active_index = NUM / 2;
148        let policies = crate::util::try_array_from_fn(|index| {
149            let diff = index as i32 - active_index as i32;
150            let mult = factor.powf(convert(f64::from(diff)));
151            let rho = central_rho * mult;
152            FixedPolicy::new(rho, iters, A, B, Q, R, S)
153        })?;
154
155        Ok(Self {
156            threshold,
157            active_index,
158            policies,
159        })
160    }
161}
162
163impl<T, const NX: usize, const NU: usize, const NUM: usize> Policy<T, NX, NU>
164    for ArrayPolicy<T, NX, NU, NUM>
165where
166    T: Scalar + RealField + Copy,
167{
168    fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T> {
169        let mut policy = &self.policies[self.active_index];
170        let prev_rho = policy.rho;
171
172        // TODO: It seems to work better without this?
173        // Since we are using a scaled dual formulation
174        let dual_residual = dual_residual * prev_rho;
175
176        // For much larger primal residuals, increase rho
177        if prim_residual > dual_residual * self.threshold {
178            if self.active_index < NUM - 1 {
179                self.active_index += 1;
180                policy = &self.policies[self.active_index];
181            }
182        }
183        // For much larger dual residuals, decrease rho
184        else if dual_residual > prim_residual * self.threshold && self.active_index > 0 {
185            self.active_index -= 1;
186            policy = &self.policies[self.active_index];
187        }
188
189        // If the value of rho changed we must also rescale all duals
190        (prev_rho != policy.rho).then(|| prev_rho / policy.rho)
191    }
192
193    fn get_active(&self) -> &FixedPolicy<T, NX, NU> {
194        &self.policies[self.active_index]
195    }
196}