1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(non_snake_case)]
3
4use nalgebra::{RealField, SMatrix, SVector, SVectorView, SVectorViewMut, Scalar, convert};
5
6pub mod constraint;
7pub mod policy;
8pub mod project;
9
10pub use constraint::Constraint;
11pub use policy::{Error, Policy};
12
13pub use project::*;
14
15mod util;
16
17pub type LtiFn<T, const NX: usize, const NU: usize> =
18 fn(SVectorViewMut<T, NX>, SVectorView<T, NX>, SVectorView<T, NU>);
19
20#[derive(Debug, PartialEq, Clone, Copy)]
21pub enum TerminationReason {
22 Converged,
24 MaxIters,
26}
27
28#[derive(Debug)]
29pub struct Solver<
30 T,
31 POLICY: Policy<T, NX, NU>,
32 const NX: usize,
33 const NU: usize,
34 const HX: usize,
35 const HU: usize,
36> {
37 policy: POLICY,
38 state: State<T, NX, NU, HX, HU>,
39 pub config: Config<T>,
40}
41
42#[derive(Debug)]
43pub struct Config<T> {
44 pub prim_tol: T,
46
47 pub dual_tol: T,
49
50 pub max_iter: usize,
52
53 pub do_check: usize,
55
56 pub relaxation: T,
58}
59
60#[derive(Debug)]
61pub struct State<T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
62 A: SMatrix<T, NX, NX>,
64 B: SMatrix<T, NX, NU>,
65
66 sys: Option<LtiFn<T, NX, NU>>,
68
69 ex: SMatrix<T, NX, HX>,
71 eu: SMatrix<T, NU, HU>,
72
73 cx: SMatrix<T, NX, HX>,
75 cp: SMatrix<T, NX, HX>,
76
77 q: SMatrix<T, NX, HX>,
79 r: SMatrix<T, NU, HU>,
80
81 p: SMatrix<T, NX, HX>,
83 d: SMatrix<T, NU, HU>,
84
85 iter: usize,
87}
88
89pub struct Problem<
90 'a,
91 T,
92 C,
93 const NX: usize,
94 const NU: usize,
95 const HX: usize,
96 const HU: usize,
97 XProj = (),
98 UProj = (),
99> where
100 T: Scalar + RealField + Copy,
101 C: Policy<T, NX, NU>,
102 XProj: ProjectMulti<T, NX, HX>,
103 UProj: ProjectMulti<T, NU, HU>,
104{
105 mpc: &'a mut Solver<T, C, NX, NU, HX, HU>,
106 x_now: SVector<T, NX>,
107 x_ref: Option<&'a SMatrix<T, NX, HX>>,
108 u_ref: Option<&'a SMatrix<T, NU, HU>>,
109 x_con: Option<&'a mut [Constraint<T, XProj, NX, HX>]>,
110 u_con: Option<&'a mut [Constraint<T, UProj, NU, HU>]>,
111}
112
113impl<'a, T, C, XProj, UProj, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
114 Problem<'a, T, C, NX, NU, HX, HU, XProj, UProj>
115where
116 T: Scalar + RealField + Copy,
117 C: Policy<T, NX, NU>,
118 XProj: ProjectMulti<T, NX, HX>,
119 UProj: ProjectMulti<T, NU, HU>,
120{
121 #[must_use]
123 pub fn x_reference(mut self, x_ref: &'a SMatrix<T, NX, HX>) -> Self {
124 self.x_ref = Some(x_ref);
125 self
126 }
127
128 #[must_use]
130 pub fn u_reference(mut self, u_ref: &'a SMatrix<T, NU, HU>) -> Self {
131 self.u_ref = Some(u_ref);
132 self
133 }
134
135 #[must_use]
137 pub fn x_constraints<Proj: ProjectMulti<T, NX, HX>>(
138 self,
139 x_con: &'a mut [Constraint<T, Proj, NX, HX>],
140 ) -> Problem<'a, T, C, NX, NU, HX, HU, Proj, UProj> {
141 Problem {
142 mpc: self.mpc,
143 x_now: self.x_now,
144 x_ref: self.x_ref,
145 u_ref: self.u_ref,
146 x_con: Some(x_con),
147 u_con: self.u_con,
148 }
149 }
150
151 #[must_use]
153 pub fn u_constraints<Proj: ProjectMulti<T, NU, HU>>(
154 self,
155 u_con: &'a mut [Constraint<T, Proj, NU, HU>],
156 ) -> Problem<'a, T, C, NX, NU, HX, HU, XProj, Proj> {
157 Problem {
158 mpc: self.mpc,
159 x_now: self.x_now,
160 x_ref: self.x_ref,
161 u_ref: self.u_ref,
162 x_con: self.x_con,
163 u_con: Some(u_con),
164 }
165 }
166
167 #[must_use]
169 #[inline(never)]
170 pub fn solve(self) -> Solution<'a, T, NX, NU, HX, HU> {
171 self.mpc
172 .solve(self.x_now, self.x_ref, self.u_ref, self.x_con, self.u_con)
173 }
174}
175
176impl<T, C: Policy<T, NX, NU>, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
177 Solver<T, C, NX, NU, HX, HU>
178where
179 T: Scalar + RealField + Copy,
180{
181 #[must_use]
182 pub fn new(A: SMatrix<T, NX, NX>, B: SMatrix<T, NX, NU>, policy: C) -> Self {
183 const {
185 assert!(HX > HU, "`HX` must be larger than `HU`");
186 assert!(HU > 0, "`HU` must be non-zero");
187 }
188
189 Self {
190 config: Config {
191 prim_tol: convert(1e-2),
192 dual_tol: convert(1e-2),
193 max_iter: 50,
194 do_check: 5,
195 relaxation: T::one(),
196 },
197 policy,
198 state: State {
199 A,
200 B,
201 sys: None,
202 cx: SMatrix::zeros(),
203 cp: SMatrix::zeros(),
204 q: SMatrix::zeros(),
205 r: SMatrix::zeros(),
206 p: SMatrix::zeros(),
207 d: SMatrix::zeros(),
208 ex: SMatrix::zeros(),
209 eu: SMatrix::zeros(),
210 iter: 0,
211 },
212 }
213 }
214
215 #[must_use]
216 pub fn with_sys(mut self, sys: LtiFn<T, NX, NU>) -> Self {
217 self.state.sys = Some(sys);
218 self
219 }
220
221 #[must_use]
222 pub fn initial_condition(
223 &mut self,
224 x_now: SVector<T, NX>,
225 ) -> Problem<'_, T, C, NX, NU, HX, HU> {
226 Problem {
227 mpc: self,
228 x_now,
229 x_ref: None,
230 u_ref: None,
231 x_con: None,
232 u_con: None,
233 }
234 }
235
236 #[must_use]
237 #[inline(never)]
238 pub fn solve<'a>(
239 &'a mut self,
240 x_now: SVector<T, NX>,
241 x_ref: Option<&'a SMatrix<T, NX, HX>>,
242 u_ref: Option<&'a SMatrix<T, NU, HU>>,
243 x_con: Option<&mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>]>,
244 u_con: Option<&mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>]>,
245 ) -> Solution<'a, T, NX, NU, HX, HU> {
246 let mut reason = TerminationReason::MaxIters;
247
248 let x_con = x_con.unwrap_or(&mut [][..]);
250 let u_con = u_con.unwrap_or(&mut [][..]);
251
252 self.set_initial_conditions(x_now, x_ref, u_ref);
254 self.warm_start_constraints(x_con, u_con);
255
256 let mut prim_residual = T::zero();
257 let mut dual_residual = T::zero();
258
259 self.state.iter = 0;
260 while self.state.iter < self.config.max_iter {
261 profiling::scope!("solve loop", format!("iter: {}", self.state.iter));
262
263 self.update_cost(x_con, u_con);
264
265 self.backward_pass();
266
267 self.forward_pass();
268
269 self.update_constraints(x_ref, u_ref, x_con, u_con);
270
271 if self.check_termination(&mut prim_residual, &mut dual_residual, x_con, u_con) {
272 reason = TerminationReason::Converged;
273 self.state.iter += 1;
274 break;
275 }
276
277 self.state.iter += 1;
278 }
279
280 Solution {
281 x_ref,
282 u_ref,
283 x: &self.state.ex,
284 u: &self.state.eu,
285 reason,
286 iterations: self.state.iter,
287 prim_residual,
288 dual_residual: dual_residual * self.policy.get_active().rho,
289 }
290 }
291
292 fn should_compute_residuals(&self) -> bool {
293 self.state.iter.is_multiple_of(self.config.do_check)
294 }
295
296 #[profiling::function]
297 fn set_initial_conditions(
298 &mut self,
299 x_now: SVector<T, NX>,
300 x_ref: Option<&SMatrix<T, NX, HX>>,
301 u_ref: Option<&SMatrix<T, NU, HU>>,
302 ) {
303 if let Some(x_ref) = x_ref {
304 profiling::scope!("affine state reference term");
305 x_now.sub_to(&x_ref.column(0), &mut self.state.ex.column_mut(0));
306 self.state.A.mul_to(x_ref, &mut self.state.cx);
307 for i in 0..HX - 1 {
308 let mut cx_col = self.state.cx.column_mut(i);
309 cx_col.axpy(-T::one(), &x_ref.column(i + 1), T::one());
310 }
311 } else {
312 self.state.ex.set_column(0, &x_now);
313 }
314
315 if let Some(u_ref) = u_ref {
316 profiling::scope!("affine input reference term");
317 for i in 0..HX - 1 {
318 let mut cx_col = self.state.cx.column_mut(i);
319 let u_ref_col = u_ref.column(i.min(HU - 1));
320 cx_col.gemv(-T::one(), &self.state.B, &u_ref_col, T::one());
321 }
322 }
323
324 self.update_tracking_mismatch_plqr();
325 }
326
327 fn update_tracking_mismatch_plqr(&mut self) {
328 let policy = self.policy.get_active();
332 policy.Plqr.mul_to(&self.state.cx, &mut self.state.cp);
333 }
334
335 #[profiling::function]
337 fn warm_start_constraints(
338 &mut self,
339 x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
340 u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
341 ) {
342 for con in x_con {
343 util::shift_columns_left(&mut con.dual);
344 util::shift_columns_left(&mut con.slac);
345 }
346
347 for con in u_con {
348 util::shift_columns_left(&mut con.dual);
349 util::shift_columns_left(&mut con.slac);
350 }
351 }
352
353 #[profiling::function]
355 fn update_cost(
356 &mut self,
357 x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
358 u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
359 ) {
360 let s = &mut self.state;
361 let c = self.policy.get_active();
362
363 let mut x_con_iter = x_con.iter_mut();
365 if let Some(x_con_first) = x_con_iter.next() {
366 profiling::scope!("update state cost");
367 x_con_first.set_cost(&mut s.q);
368 for x_con_next in x_con_iter {
369 x_con_next.add_cost(&mut s.q);
370 }
371 s.q.scale_mut(c.rho);
372 } else {
373 s.q = SMatrix::<T, NX, HX>::zeros();
374 }
375
376 let mut u_con_iter = u_con.iter_mut();
378 if let Some(u_con_first) = u_con_iter.next() {
379 profiling::scope!("update input cost");
380 u_con_first.set_cost(&mut s.r);
381 for u_con_next in u_con_iter {
382 u_con_next.add_cost(&mut s.r);
383 }
384 s.r.scale_mut(c.rho);
385 } else {
386 s.r = SMatrix::<T, NU, HU>::zeros();
387 }
388
389 s.p.set_column(HX - 1, &(s.q.column(HX - 1)));
391 }
392
393 #[profiling::function]
395 fn backward_pass(&mut self) {
396 let s = &mut self.state;
397 let c = self.policy.get_active();
398
399 for i in (0..HX - 1).rev() {
400 let (mut p_now, mut p_fut) = util::column_pair_mut(&mut s.p, i, i + 1);
401 let mut r_col = s.r.column_mut(i.min(HU - 1));
402
403 p_fut.axpy(T::one(), &s.cp.column(i), T::one());
405
406 p_now.gemv(T::one(), &c.AmBKt, &p_fut, T::zero());
408 p_now.gemv_tr(T::one(), &c.nKlqr, &r_col, T::one());
409 p_now.axpy(T::one(), &s.q.column(i), T::one());
410
411 if i < HU {
412 let mut d_col = s.d.column_mut(i);
413
414 r_col.gemv_tr(T::one(), &s.B, &p_fut, T::one());
416 d_col.gemv(T::one(), &c.RpBPBi, &r_col, T::zero());
417 }
418 }
419 }
420
421 #[profiling::function]
423 fn forward_pass(&mut self) {
424 let s = &mut self.state;
425 let c = self.policy.get_active();
426
427 if let Some(system) = s.sys {
428 for i in 0..HU {
430 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
431 let mut u_col = s.eu.column_mut(i);
432
433 u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
434 u_col.axpy(-T::one(), &s.d.column(i), T::one());
435
436 system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
437 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
438 }
439
440 for i in HU..HX - 1 {
442 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
443 let u_col = s.eu.column(HU - 1);
444
445 system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
446 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
447 }
448 } else {
449 for i in 0..HU {
451 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
452 let mut u_col = s.eu.column_mut(i);
453
454 u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
456 u_col.axpy(-T::one(), &s.d.column(i), T::one());
457
458 ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
460 ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
461 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
462 }
463
464 for i in HU..HX - 1 {
466 let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
467 let u_col = s.eu.column(HU - 1);
468
469 ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
471 ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
472 ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
473 }
474 }
475 }
476
477 #[profiling::function]
479 fn update_constraints(
480 &mut self,
481 x_ref: Option<&SMatrix<T, NX, HX>>,
482 u_ref: Option<&SMatrix<T, NU, HU>>,
483 x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
484 u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
485 ) {
486 let compute_residuals = self.should_compute_residuals();
487 let s = &mut self.state;
488
489 let (x_points, u_points) = if self.config.relaxation == T::one() {
490 (&s.ex, &s.eu)
492 } else {
493 profiling::scope!("apply relaxation to state and input");
494
495 s.q.copy_from(&s.ex);
497 s.r.copy_from(&s.eu);
498
499 let alpha = self.config.relaxation;
500
501 s.q.scale_mut(alpha);
502 s.r.scale_mut(alpha);
503
504 for con in x_con.iter() {
505 for (mut prim, slac) in s.q.column_iter_mut().zip(con.slac.column_iter()) {
506 prim.axpy(T::one() - alpha, &slac, T::one());
507 }
508 }
509
510 for con in u_con.iter() {
511 for (mut prim, slac) in s.r.column_iter_mut().zip(con.slac.column_iter()) {
512 prim.axpy(T::one() - alpha, &slac, T::one());
513 }
514 }
515
516 (&s.q, &s.r)
518 };
519
520 let u_scratch = &mut s.d;
522 let x_scratch = &mut s.p;
523
524 for con in x_con {
525 con.constrain(compute_residuals, x_points, x_ref, x_scratch);
526 }
527
528 for con in u_con {
529 con.constrain(compute_residuals, u_points, u_ref, u_scratch);
530 }
531 }
532
533 #[profiling::function]
535 fn check_termination(
536 &mut self,
537 max_prim_residual: &mut T,
538 max_dual_residual: &mut T,
539 x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
540 u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
541 ) -> bool {
542 let c = self.policy.get_active();
543 let cfg = &self.config;
544
545 if !self.should_compute_residuals() {
546 return false;
547 }
548
549 *max_prim_residual = T::zero();
550 *max_dual_residual = T::zero();
551
552 for con in x_con.iter() {
553 *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
554 *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
555 }
556
557 for con in u_con.iter() {
558 *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
559 *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
560 }
561
562 let terminate =
563 *max_prim_residual < cfg.prim_tol && *max_dual_residual * c.rho < cfg.dual_tol;
564
565 if !terminate
567 && let Some(scalar) = self
568 .policy
569 .update_active(*max_prim_residual, *max_dual_residual)
570 {
571 profiling::scope!("policy updated, rescale all dual variables");
572
573 self.update_tracking_mismatch_plqr();
574
575 for con in x_con.iter_mut() {
576 con.rescale_dual(scalar);
577 }
578
579 for con in u_con.iter_mut() {
580 con.rescale_dual(scalar);
581 }
582 }
583
584 terminate
585 }
586}
587
588pub struct Solution<'a, T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
589 x_ref: Option<&'a SMatrix<T, NX, HX>>,
590 u_ref: Option<&'a SMatrix<T, NU, HU>>,
591 x: &'a SMatrix<T, NX, HX>,
592 u: &'a SMatrix<T, NU, HU>,
593 pub reason: TerminationReason,
594 pub iterations: usize,
595 pub prim_residual: T,
596 pub dual_residual: T,
597}
598
599impl<T: RealField + Copy, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
600 Solution<'_, T, NX, NU, HX, HU>
601{
602 pub fn x_prediction(&self, at: usize) -> SVector<T, NX> {
604 if let Some(x_ref) = self.x_ref.as_ref() {
605 self.x.column(at) + x_ref.column(at)
606 } else {
607 self.x.column(at).clone_owned()
608 }
609 }
610
611 pub fn u_prediction(&self, at: usize) -> SVector<T, NU> {
613 if let Some(u_ref) = self.u_ref.as_ref() {
614 self.u.column(at) + u_ref.column(at)
615 } else {
616 self.u.column(at).clone_owned()
617 }
618 }
619
620 pub fn x_prediction_full(&self) -> SMatrix<T, NX, HX> {
622 if let Some(x_ref) = self.x_ref.as_ref() {
623 self.x + *x_ref
624 } else {
625 self.x.clone_owned()
626 }
627 }
628
629 pub fn u_prediction_full(&self) -> SMatrix<T, NU, HU> {
631 if let Some(u_ref) = self.u_ref.as_ref() {
632 self.u + *u_ref
633 } else {
634 self.u.clone_owned()
635 }
636 }
637
638 pub fn u_now(&self) -> SVector<T, NU> {
640 self.u_prediction(0)
641 }
642}