1use nalgebra::{RealField, SMatrix, Scalar, convert};
2
3#[derive(Debug, PartialEq, Clone, Copy)]
5pub enum Error {
6 RhoNotPositive,
8 RpBPBNotInvertible,
10 NonFiniteValues,
12}
13
14pub trait Policy<T, const NX: usize, const NU: usize> {
15 fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T>;
19
20 fn get_active(&self) -> &FixedPolicy<T, NX, NU>;
22}
23
24#[derive(Debug)]
26pub struct FixedPolicy<T, const NX: usize, const NU: usize> {
27 pub(crate) rho: T,
29
30 pub(crate) nKlqr: SMatrix<T, NU, NX>,
32
33 pub(crate) Plqr: SMatrix<T, NX, NX>,
35
36 pub(crate) RpBPBi: SMatrix<T, NU, NU>,
38
39 pub(crate) AmBKt: SMatrix<T, NX, NX>,
41}
42
43impl<T, const NX: usize, const NU: usize> FixedPolicy<T, NX, NU>
44where
45 T: Scalar + RealField + Copy,
46{
47 pub fn new(
53 rho: T,
54 iters: usize,
55 A: &SMatrix<T, NX, NX>,
56 B: &SMatrix<T, NX, NU>,
57 Q: &SMatrix<T, NX, NX>,
58 R: &SMatrix<T, NU, NU>,
59 S: &SMatrix<T, NX, NU>,
60 ) -> Result<Self, Error> {
61 if !rho.is_positive() {
62 return Err(Error::RhoNotPositive);
63 }
64
65 let Q = Q.symmetric_part();
66 let R = R.symmetric_part();
67
68 let Q_aug = Q + SMatrix::from_diagonal_element(rho);
70 let R_aug = R + SMatrix::from_diagonal_element(rho);
71
72 let mut Klqr = SMatrix::zeros();
73 let mut Plqr = Q_aug.clone_owned();
74
75 for _ in 0..iters {
76 Klqr = (R_aug + B.transpose() * Plqr * B)
77 .try_inverse()
78 .ok_or(Error::RpBPBNotInvertible)?
79 * (S.transpose() + B.transpose() * Plqr * A);
80 Plqr = A.transpose() * Plqr * A - A.transpose() * Plqr * B * Klqr + Q_aug;
81 }
82
83 let RpBPBi = (R_aug + B.transpose() * Plqr * B)
84 .try_inverse()
85 .ok_or(Error::RpBPBNotInvertible)?;
86 let AmBKt = (A - B * Klqr).transpose();
87 let nKlqr = -Klqr;
88
89 ([].iter())
91 .chain(RpBPBi.iter())
92 .chain(AmBKt.iter())
93 .all(nalgebra::ComplexField::is_finite)
94 .then_some(FixedPolicy {
95 rho,
96 nKlqr,
97 Plqr,
98 RpBPBi,
99 AmBKt,
100 })
101 .ok_or(Error::NonFiniteValues)
102 }
103}
104
105impl<T, const NX: usize, const NU: usize> Policy<T, NX, NU> for FixedPolicy<T, NX, NU>
106where
107 T: Scalar + RealField + Copy,
108{
109 fn update_active(&mut self, _prim_residual: T, _dual_residual: T) -> Option<T> {
110 None
111 }
112
113 fn get_active(&self) -> &FixedPolicy<T, NX, NU> {
114 self
115 }
116}
117
118#[derive(Debug)]
120pub struct ArrayPolicy<T, const NX: usize, const NU: usize, const NUM: usize> {
121 threshold: T,
122 active_index: usize,
123 policies: [FixedPolicy<T, NX, NU>; NUM],
124}
125
126impl<T, const NX: usize, const NU: usize, const NUM: usize> ArrayPolicy<T, NX, NU, NUM>
127where
128 T: Scalar + RealField + Copy,
129{
130 #[allow(clippy::too_many_arguments)]
136 pub fn new(
137 central_rho: T,
138 threshold: T,
139 factor: T,
140 iters: usize,
141 A: &SMatrix<T, NX, NX>,
142 B: &SMatrix<T, NX, NU>,
143 Q: &SMatrix<T, NX, NX>,
144 R: &SMatrix<T, NU, NU>,
145 S: &SMatrix<T, NX, NU>,
146 ) -> Result<Self, Error> {
147 let active_index = NUM / 2;
148 let policies = crate::util::try_array_from_fn(|index| {
149 let diff = index as i32 - active_index as i32;
150 let mult = factor.powf(convert(f64::from(diff)));
151 let rho = central_rho * mult;
152 FixedPolicy::new(rho, iters, A, B, Q, R, S)
153 })?;
154
155 Ok(Self {
156 threshold,
157 active_index,
158 policies,
159 })
160 }
161}
162
163impl<T, const NX: usize, const NU: usize, const NUM: usize> Policy<T, NX, NU>
164 for ArrayPolicy<T, NX, NU, NUM>
165where
166 T: Scalar + RealField + Copy,
167{
168 fn update_active(&mut self, prim_residual: T, dual_residual: T) -> Option<T> {
169 let mut policy = &self.policies[self.active_index];
170 let prev_rho = policy.rho;
171
172 let dual_residual = dual_residual * prev_rho;
175
176 if prim_residual > dual_residual * self.threshold {
178 if self.active_index < NUM - 1 {
179 self.active_index += 1;
180 policy = &self.policies[self.active_index];
181 }
182 }
183 else if dual_residual > prim_residual * self.threshold && self.active_index > 0 {
185 self.active_index -= 1;
186 policy = &self.policies[self.active_index];
187 }
188
189 (prev_rho != policy.rho).then(|| prev_rho / policy.rho)
191 }
192
193 fn get_active(&self) -> &FixedPolicy<T, NX, NU> {
194 &self.policies[self.active_index]
195 }
196}