1use nalgebra::{RealField, SMatrix, Scalar, convert};
2
3#[derive(Debug, PartialEq, Clone, Copy)]
5pub enum Error {
6 RhoNotPositive,
8 RpBPBNotInvertible,
10 NonFiniteValues,
12}
13
14pub trait Cache<T, const NX: usize, const NU: usize> {
15 fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T>;
19
20 fn get_active(&self) -> &SingleCache<T, NX, NU>;
22}
23
24#[derive(Debug)]
26pub struct SingleCache<T, const NX: usize, const NU: usize> {
27 pub(crate) rho: T,
29
30 pub(crate) nKlqr: SMatrix<T, NU, NX>,
32
33 pub(crate) Plqr: SMatrix<T, NX, NX>,
35
36 pub(crate) RpBPBi: SMatrix<T, NU, NU>,
38
39 pub(crate) AmBKt: SMatrix<T, NX, NX>,
41}
42
43impl<T, const NX: usize, const NU: usize> SingleCache<T, NX, NU>
44where
45 T: Scalar + RealField + Copy,
46{
47 pub fn new(
48 rho: T,
49 iters: usize,
50 A: &SMatrix<T, NX, NX>,
51 B: &SMatrix<T, NX, NU>,
52 Q: &SMatrix<T, NX, NX>,
53 R: &SMatrix<T, NU, NU>,
54 S: &SMatrix<T, NX, NU>,
55 ) -> Result<Self, Error> {
56 if !rho.is_positive() {
57 return Err(Error::RhoNotPositive);
58 }
59
60 let Q = Q.symmetric_part();
61 let R = R.symmetric_part();
62
63 let Q_aug = Q + SMatrix::from_diagonal_element(rho);
65 let R_aug = R + SMatrix::from_diagonal_element(rho);
66
67 let mut Klqr = SMatrix::zeros();
68 let mut Plqr = Q_aug.clone_owned();
69
70 for _ in 0..iters {
71 Klqr = (R_aug + B.transpose() * Plqr * B)
72 .try_inverse()
73 .ok_or(Error::RpBPBNotInvertible)?
74 * (S.transpose() + B.transpose() * Plqr * A);
75 Plqr = A.transpose() * Plqr * A - A.transpose() * Plqr * B * Klqr + Q_aug;
76 }
77
78 let RpBPBi = (R_aug + B.transpose() * Plqr * B)
79 .try_inverse()
80 .ok_or(Error::RpBPBNotInvertible)?;
81 let AmBKt = (A - B * Klqr).transpose();
82 let nKlqr = -Klqr;
83
84 ([].iter())
86 .chain(RpBPBi.iter())
87 .chain(AmBKt.iter())
88 .all(|x| x.is_finite())
89 .then_some(SingleCache {
90 rho,
91 nKlqr,
92 Plqr,
93 RpBPBi,
94 AmBKt,
95 })
96 .ok_or(Error::NonFiniteValues)
97 }
98}
99
100impl<T, const NX: usize, const NU: usize> Cache<T, NX, NU> for SingleCache<T, NX, NU>
101where
102 T: Scalar + RealField + Copy,
103{
104 fn update_active(&mut self, _prim_residual: T, _dual_residual: T) -> Option<T> {
105 None
106 }
107
108 fn get_active(&self) -> &SingleCache<T, NX, NU> {
109 self
110 }
111}
112
113#[derive(Debug)]
115pub struct ArrayCache<T, const NX: usize, const NU: usize, const NUM: usize> {
116 threshold: T,
117 active_index: usize,
118 caches: [SingleCache<T, NX, NU>; NUM],
119}
120
121impl<T, const NX: usize, const NU: usize, const NUM: usize> ArrayCache<T, NX, NU, NUM>
122where
123 T: Scalar + RealField + Copy,
124{
125 pub fn new(
126 central_rho: T,
127 threshold: T,
128 factor: T,
129 iters: usize,
130 A: &SMatrix<T, NX, NX>,
131 B: &SMatrix<T, NX, NU>,
132 Q: &SMatrix<T, NX, NX>,
133 R: &SMatrix<T, NU, NU>,
134 S: &SMatrix<T, NX, NU>,
135 ) -> Result<Self, Error> {
136 let active_index = NUM / 2;
137 let caches = crate::util::try_array_from_fn(|index| {
138 let diff = index as i32 - active_index as i32;
139 let mult = factor.powf(convert(diff as f64));
140 let rho = central_rho * mult;
141 SingleCache::new(rho, iters, A, B, Q, R, S)
142 })?;
143
144 Ok(Self {
145 threshold,
146 active_index,
147 caches,
148 })
149 }
150}
151
152impl<T, const NX: usize, const NU: usize, const NUM: usize> Cache<T, NX, NU>
153 for ArrayCache<T, NX, NU, NUM>
154where
155 T: Scalar + RealField + Copy,
156{
157 #[inline(always)]
158 fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T> {
159 let mut cache = &self.caches[self.active_index];
160 let prev_rho = cache.rho;
161
162 let dual_residual = dual_residual * prev_rho;
165
166 if prim_residual > dual_residual * self.threshold {
168 if self.active_index < NUM - 1 {
169 self.active_index += 1;
170 cache = &self.caches[self.active_index];
171 }
172 }
173 else if dual_residual > prim_residual * self.threshold {
175 if self.active_index > 0 {
176 self.active_index -= 1;
177 cache = &self.caches[self.active_index];
178 }
179 }
180
181 (prev_rho != cache.rho).then(|| prev_rho / cache.rho)
183 }
184
185 #[inline(always)]
186 fn get_active(&self) -> &SingleCache<T, NX, NU> {
187 &self.caches[self.active_index]
188 }
189}