tinympc_rs/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![allow(non_snake_case)]
3
4use nalgebra::{RealField, SMatrix, SVector, SVectorView, SVectorViewMut, Scalar, convert};
5
6pub mod constraint;
7pub mod policy;
8pub mod project;
9
10pub use constraint::Constraint;
11pub use policy::{Error, Policy};
12
13pub use project::*;
14
15mod util;
16
17pub type LtiFn<T, const NX: usize, const NU: usize> =
18    fn(SVectorViewMut<T, NX>, SVectorView<T, NX>, SVectorView<T, NU>);
19
20#[derive(Debug, PartialEq, Clone, Copy)]
21pub enum TerminationReason {
22    /// The solver converged to within the defined tolerances
23    Converged,
24    /// The solver reached the maximum number of iterations allowed
25    MaxIters,
26}
27
28#[derive(Debug)]
29pub struct Solver<
30    T,
31    POLICY: Policy<T, NX, NU>,
32    const NX: usize,
33    const NU: usize,
34    const HX: usize,
35    const HU: usize,
36> {
37    policy: POLICY,
38    state: State<T, NX, NU, HX, HU>,
39    pub config: Config<T>,
40}
41
42#[derive(Debug)]
43pub struct Config<T> {
44    /// The convergence tolerance for the primal residual (default 0.001)
45    pub prim_tol: T,
46
47    /// The convergence tolerance for the dual residual (default 0.001)
48    pub dual_tol: T,
49
50    /// Maximum iterations without converging before terminating (default 50)
51    pub max_iter: usize,
52
53    /// Number of iterations between evaluating convergence (default 5)
54    pub do_check: usize,
55
56    /// Relaxation, values `1.5-1.8` may improve convergence (default 1.0)
57    pub relaxation: T,
58}
59
60#[derive(Debug)]
61pub struct State<T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
62    // Linear state space model
63    A: SMatrix<T, NX, NX>,
64    B: SMatrix<T, NX, NU>,
65
66    // For sparse system dynamics
67    sys: Option<LtiFn<T, NX, NU>>,
68
69    // State and input tracking error predictions
70    ex: SMatrix<T, NX, HX>,
71    eu: SMatrix<T, NU, HU>,
72
73    // State tracking dynamics mismatch
74    cx: SMatrix<T, NX, HX>,
75    cp: SMatrix<T, NX, HX>,
76
77    // Linear cost matrices
78    q: SMatrix<T, NX, HX>,
79    r: SMatrix<T, NU, HU>,
80
81    // Riccati backward pass terms
82    p: SMatrix<T, NX, HX>,
83    d: SMatrix<T, NU, HU>,
84
85    // Number of iterations for latest solve
86    iter: usize,
87}
88
89pub struct Problem<
90    'a,
91    T,
92    C,
93    const NX: usize,
94    const NU: usize,
95    const HX: usize,
96    const HU: usize,
97    XProj = (),
98    UProj = (),
99> where
100    T: Scalar + RealField + Copy,
101    C: Policy<T, NX, NU>,
102    XProj: ProjectMulti<T, NX, HX>,
103    UProj: ProjectMulti<T, NU, HU>,
104{
105    mpc: &'a mut Solver<T, C, NX, NU, HX, HU>,
106    x_now: SVector<T, NX>,
107    x_ref: Option<&'a SMatrix<T, NX, HX>>,
108    u_ref: Option<&'a SMatrix<T, NU, HU>>,
109    x_con: Option<&'a mut [Constraint<T, XProj, NX, HX>]>,
110    u_con: Option<&'a mut [Constraint<T, UProj, NU, HU>]>,
111}
112
113impl<'a, T, C, XProj, UProj, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
114    Problem<'a, T, C, NX, NU, HX, HU, XProj, UProj>
115where
116    T: Scalar + RealField + Copy,
117    C: Policy<T, NX, NU>,
118    XProj: ProjectMulti<T, NX, HX>,
119    UProj: ProjectMulti<T, NU, HU>,
120{
121    /// Set the reference for state variables
122    #[must_use]
123    pub fn x_reference(mut self, x_ref: &'a SMatrix<T, NX, HX>) -> Self {
124        self.x_ref = Some(x_ref);
125        self
126    }
127
128    /// Set the reference for input variables
129    #[must_use]
130    pub fn u_reference(mut self, u_ref: &'a SMatrix<T, NU, HU>) -> Self {
131        self.u_ref = Some(u_ref);
132        self
133    }
134
135    /// Set constraints on the state variables
136    #[must_use]
137    pub fn x_constraints<Proj: ProjectMulti<T, NX, HX>>(
138        self,
139        x_con: &'a mut [Constraint<T, Proj, NX, HX>],
140    ) -> Problem<'a, T, C, NX, NU, HX, HU, Proj, UProj> {
141        Problem {
142            mpc: self.mpc,
143            x_now: self.x_now,
144            x_ref: self.x_ref,
145            u_ref: self.u_ref,
146            x_con: Some(x_con),
147            u_con: self.u_con,
148        }
149    }
150
151    /// Set constraints on the input variables
152    #[must_use]
153    pub fn u_constraints<Proj: ProjectMulti<T, NU, HU>>(
154        self,
155        u_con: &'a mut [Constraint<T, Proj, NU, HU>],
156    ) -> Problem<'a, T, C, NX, NU, HX, HU, XProj, Proj> {
157        Problem {
158            mpc: self.mpc,
159            x_now: self.x_now,
160            x_ref: self.x_ref,
161            u_ref: self.u_ref,
162            x_con: self.x_con,
163            u_con: Some(u_con),
164        }
165    }
166
167    /// Run the solver
168    #[must_use]
169    #[inline(never)]
170    pub fn solve(self) -> Solution<'a, T, NX, NU, HX, HU> {
171        self.mpc
172            .solve(self.x_now, self.x_ref, self.u_ref, self.x_con, self.u_con)
173    }
174}
175
176impl<T, C: Policy<T, NX, NU>, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
177    Solver<T, C, NX, NU, HX, HU>
178where
179    T: Scalar + RealField + Copy,
180{
181    #[must_use]
182    pub fn new(A: SMatrix<T, NX, NX>, B: SMatrix<T, NX, NU>, policy: C) -> Self {
183        // Compile-time guard against invalid horizon lengths
184        const {
185            assert!(HX > HU, "`HX` must be larger than `HU`");
186            assert!(HU > 0, "`HU` must be non-zero");
187        }
188
189        Self {
190            config: Config {
191                prim_tol: convert(1e-2),
192                dual_tol: convert(1e-2),
193                max_iter: 50,
194                do_check: 5,
195                relaxation: T::one(),
196            },
197            policy,
198            state: State {
199                A,
200                B,
201                sys: None,
202                cx: SMatrix::zeros(),
203                cp: SMatrix::zeros(),
204                q: SMatrix::zeros(),
205                r: SMatrix::zeros(),
206                p: SMatrix::zeros(),
207                d: SMatrix::zeros(),
208                ex: SMatrix::zeros(),
209                eu: SMatrix::zeros(),
210                iter: 0,
211            },
212        }
213    }
214
215    #[must_use]
216    pub fn with_sys(mut self, sys: LtiFn<T, NX, NU>) -> Self {
217        self.state.sys = Some(sys);
218        self
219    }
220
221    #[must_use]
222    pub fn initial_condition(
223        &mut self,
224        x_now: SVector<T, NX>,
225    ) -> Problem<'_, T, C, NX, NU, HX, HU> {
226        Problem {
227            mpc: self,
228            x_now,
229            x_ref: None,
230            u_ref: None,
231            x_con: None,
232            u_con: None,
233        }
234    }
235
236    #[must_use]
237    #[inline(never)]
238    pub fn solve<'a>(
239        &'a mut self,
240        x_now: SVector<T, NX>,
241        x_ref: Option<&'a SMatrix<T, NX, HX>>,
242        u_ref: Option<&'a SMatrix<T, NU, HU>>,
243        x_con: Option<&mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>]>,
244        u_con: Option<&mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>]>,
245    ) -> Solution<'a, T, NX, NU, HX, HU> {
246        let mut reason = TerminationReason::MaxIters;
247
248        // We flatten the None variant into an empty slice
249        let x_con = x_con.unwrap_or(&mut [][..]);
250        let u_con = u_con.unwrap_or(&mut [][..]);
251
252        // Set initial error state and warm-start constraints
253        self.set_initial_conditions(x_now, x_ref, u_ref);
254        self.warm_start_constraints(x_con, u_con);
255
256        let mut prim_residual = T::zero();
257        let mut dual_residual = T::zero();
258
259        self.state.iter = 0;
260        while self.state.iter < self.config.max_iter {
261            profiling::scope!("solve loop", format!("iter: {}", self.state.iter));
262
263            self.update_cost(x_con, u_con);
264
265            self.backward_pass();
266
267            self.forward_pass();
268
269            self.update_constraints(x_ref, u_ref, x_con, u_con);
270
271            if self.check_termination(&mut prim_residual, &mut dual_residual, x_con, u_con) {
272                reason = TerminationReason::Converged;
273                self.state.iter += 1;
274                break;
275            }
276
277            self.state.iter += 1;
278        }
279
280        Solution {
281            x_ref,
282            u_ref,
283            x: &self.state.ex,
284            u: &self.state.eu,
285            reason,
286            iterations: self.state.iter,
287            prim_residual,
288            dual_residual: dual_residual * self.policy.get_active().rho,
289        }
290    }
291
292    fn should_compute_residuals(&self) -> bool {
293        self.state.iter.is_multiple_of(self.config.do_check)
294    }
295
296    #[profiling::function]
297    fn set_initial_conditions(
298        &mut self,
299        x_now: SVector<T, NX>,
300        x_ref: Option<&SMatrix<T, NX, HX>>,
301        u_ref: Option<&SMatrix<T, NU, HU>>,
302    ) {
303        if let Some(x_ref) = x_ref {
304            profiling::scope!("affine state reference term");
305            x_now.sub_to(&x_ref.column(0), &mut self.state.ex.column_mut(0));
306            self.state.A.mul_to(x_ref, &mut self.state.cx);
307            for i in 0..HX - 1 {
308                let mut cx_col = self.state.cx.column_mut(i);
309                cx_col.axpy(-T::one(), &x_ref.column(i + 1), T::one());
310            }
311        } else {
312            self.state.ex.set_column(0, &x_now);
313        }
314
315        if let Some(u_ref) = u_ref {
316            profiling::scope!("affine input reference term");
317            for i in 0..HX - 1 {
318                let mut cx_col = self.state.cx.column_mut(i);
319                let u_ref_col = u_ref.column(i.min(HU - 1));
320                cx_col.gemv(-T::one(), &self.state.B, &u_ref_col, T::one());
321            }
322        }
323
324        self.update_tracking_mismatch_plqr();
325    }
326
327    fn update_tracking_mismatch_plqr(&mut self) {
328        // Note: using `sygemv` to exploit the symmetry of Plqr is actually
329        // slower than just doing a regular matrix-vector multiplication,
330        // since sygemv adds additional indexing overhead.
331        let policy = self.policy.get_active();
332        policy.Plqr.mul_to(&self.state.cx, &mut self.state.cp);
333    }
334
335    /// Shift the dual variables by one time step for more accurate hot starting
336    #[profiling::function]
337    fn warm_start_constraints(
338        &mut self,
339        x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
340        u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
341    ) {
342        for con in x_con {
343            util::shift_columns_left(&mut con.dual);
344            util::shift_columns_left(&mut con.slac);
345        }
346
347        for con in u_con {
348            util::shift_columns_left(&mut con.dual);
349            util::shift_columns_left(&mut con.slac);
350        }
351    }
352
353    /// Update linear control cost terms based on constraint violations
354    #[profiling::function]
355    fn update_cost(
356        &mut self,
357        x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
358        u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
359    ) {
360        let s = &mut self.state;
361        let c = self.policy.get_active();
362
363        // Add cost contribution for state constraint violations
364        let mut x_con_iter = x_con.iter_mut();
365        if let Some(x_con_first) = x_con_iter.next() {
366            profiling::scope!("update state cost");
367            x_con_first.set_cost(&mut s.q);
368            for x_con_next in x_con_iter {
369                x_con_next.add_cost(&mut s.q);
370            }
371            s.q.scale_mut(c.rho);
372        } else {
373            s.q = SMatrix::<T, NX, HX>::zeros();
374        }
375
376        // Add cost contribution for input constraint violations
377        let mut u_con_iter = u_con.iter_mut();
378        if let Some(u_con_first) = u_con_iter.next() {
379            profiling::scope!("update input cost");
380            u_con_first.set_cost(&mut s.r);
381            for u_con_next in u_con_iter {
382                u_con_next.add_cost(&mut s.r);
383            }
384            s.r.scale_mut(c.rho);
385        } else {
386            s.r = SMatrix::<T, NU, HU>::zeros();
387        }
388
389        // Extract ADMM cost term for Riccati terminal condition
390        s.p.set_column(HX - 1, &(s.q.column(HX - 1)));
391    }
392
393    /// Backward pass to update Ricatti variables
394    #[profiling::function]
395    fn backward_pass(&mut self) {
396        let s = &mut self.state;
397        let c = self.policy.get_active();
398
399        for i in (0..HX - 1).rev() {
400            let (mut p_now, mut p_fut) = util::column_pair_mut(&mut s.p, i, i + 1);
401            let mut r_col = s.r.column_mut(i.min(HU - 1));
402
403            // Reused calculation: [[[i+1]]] <- (p[i+1] + Plqr * w[i])
404            p_fut.axpy(T::one(), &s.cp.column(i), T::one());
405
406            // Calc: p[i] = AmBKt * [[[i+1]]] - Klqr' * r[:,u_index] + q[i]
407            p_now.gemv(T::one(), &c.AmBKt, &p_fut, T::zero());
408            p_now.gemv_tr(T::one(), &c.nKlqr, &r_col, T::one());
409            p_now.axpy(T::one(), &s.q.column(i), T::one());
410
411            if i < HU {
412                let mut d_col = s.d.column_mut(i);
413
414                // Calc: d[i] = RpBPBi * (B' * [[[i+1]]] + r[i])
415                r_col.gemv_tr(T::one(), &s.B, &p_fut, T::one());
416                d_col.gemv(T::one(), &c.RpBPBi, &r_col, T::zero());
417            }
418        }
419    }
420
421    /// Use LQR feedback policy to roll out trajectory
422    #[profiling::function]
423    fn forward_pass(&mut self) {
424        let s = &mut self.state;
425        let c = self.policy.get_active();
426
427        if let Some(system) = s.sys {
428            // Roll out trajectory up to the control horizon (HU)
429            for i in 0..HU {
430                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
431                let mut u_col = s.eu.column_mut(i);
432
433                u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
434                u_col.axpy(-T::one(), &s.d.column(i), T::one());
435
436                system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
437                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
438            }
439
440            // Roll out rest of trajectory keeping u constant
441            for i in HU..HX - 1 {
442                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
443                let u_col = s.eu.column(HU - 1);
444
445                system(ex_fut.as_view_mut(), ex_now.as_view(), u_col.as_view());
446                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
447            }
448        } else {
449            // Roll out trajectory up to the control horizon (HU)
450            for i in 0..HU {
451                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
452                let mut u_col = s.eu.column_mut(i);
453
454                // Calc: u[i] = -Klqr * ex[i] + d[i]
455                u_col.gemv(T::one(), &c.nKlqr, &ex_now, T::zero());
456                u_col.axpy(-T::one(), &s.d.column(i), T::one());
457
458                // Calc x[i+1] = A * x[i] + B * u[i] + w[i]
459                ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
460                ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
461                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
462            }
463
464            // Roll out rest of trajectory keeping u constant
465            for i in HU..HX - 1 {
466                let (ex_now, mut ex_fut) = util::column_pair_mut(&mut s.ex, i, i + 1);
467                let u_col = s.eu.column(HU - 1);
468
469                // Calc x[i+1] = A * x[i] + B * u[i] + w[i]
470                ex_fut.gemv(T::one(), &s.A, &ex_now, T::zero());
471                ex_fut.gemv(T::one(), &s.B, &u_col, T::one());
472                ex_fut.axpy(T::one(), &s.cx.column(i), T::one());
473            }
474        }
475    }
476
477    /// Project slack variables into their feasible domain and update dual variables
478    #[profiling::function]
479    fn update_constraints(
480        &mut self,
481        x_ref: Option<&SMatrix<T, NX, HX>>,
482        u_ref: Option<&SMatrix<T, NU, HU>>,
483        x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
484        u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
485    ) {
486        let compute_residuals = self.should_compute_residuals();
487        let s = &mut self.state;
488
489        let (x_points, u_points) = if self.config.relaxation == T::one() {
490            // Just use original predictions
491            (&s.ex, &s.eu)
492        } else {
493            profiling::scope!("apply relaxation to state and input");
494
495            // Use Riccati matrices to store relaxed x and u matrices.
496            s.q.copy_from(&s.ex);
497            s.r.copy_from(&s.eu);
498
499            let alpha = self.config.relaxation;
500
501            s.q.scale_mut(alpha);
502            s.r.scale_mut(alpha);
503
504            for con in x_con.iter() {
505                for (mut prim, slac) in s.q.column_iter_mut().zip(con.slac.column_iter()) {
506                    prim.axpy(T::one() - alpha, &slac, T::one());
507                }
508            }
509
510            for con in u_con.iter() {
511                for (mut prim, slac) in s.r.column_iter_mut().zip(con.slac.column_iter()) {
512                    prim.axpy(T::one() - alpha, &slac, T::one());
513                }
514            }
515
516            // Buffers now contain: x' = alpha * x + (1 - alpha) * z
517            (&s.q, &s.r)
518        };
519
520        // Use cost matrices as scratch buffers
521        let u_scratch = &mut s.d;
522        let x_scratch = &mut s.p;
523
524        for con in x_con {
525            con.constrain(compute_residuals, x_points, x_ref, x_scratch);
526        }
527
528        for con in u_con {
529            con.constrain(compute_residuals, u_points, u_ref, u_scratch);
530        }
531    }
532
533    /// Check for termination condition by evaluating residuals
534    #[profiling::function]
535    fn check_termination(
536        &mut self,
537        max_prim_residual: &mut T,
538        max_dual_residual: &mut T,
539        x_con: &mut [Constraint<T, impl ProjectMulti<T, NX, HX>, NX, HX>],
540        u_con: &mut [Constraint<T, impl ProjectMulti<T, NU, HU>, NU, HU>],
541    ) -> bool {
542        let c = self.policy.get_active();
543        let cfg = &self.config;
544
545        if !self.should_compute_residuals() {
546            return false;
547        }
548
549        *max_prim_residual = T::zero();
550        *max_dual_residual = T::zero();
551
552        for con in x_con.iter() {
553            *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
554            *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
555        }
556
557        for con in u_con.iter() {
558            *max_prim_residual = (*max_prim_residual).max(con.max_prim_residual);
559            *max_dual_residual = (*max_dual_residual).max(con.max_dual_residual);
560        }
561
562        let terminate =
563            *max_prim_residual < cfg.prim_tol && *max_dual_residual * c.rho < cfg.dual_tol;
564
565        // Try to adapt rho
566        if !terminate
567            && let Some(scalar) = self
568                .policy
569                .update_active(*max_prim_residual, *max_dual_residual)
570        {
571            profiling::scope!("policy updated, rescale all dual variables");
572
573            self.update_tracking_mismatch_plqr();
574
575            for con in x_con.iter_mut() {
576                con.rescale_dual(scalar);
577            }
578
579            for con in u_con.iter_mut() {
580                con.rescale_dual(scalar);
581            }
582        }
583
584        terminate
585    }
586}
587
588pub struct Solution<'a, T, const NX: usize, const NU: usize, const HX: usize, const HU: usize> {
589    x_ref: Option<&'a SMatrix<T, NX, HX>>,
590    u_ref: Option<&'a SMatrix<T, NU, HU>>,
591    x: &'a SMatrix<T, NX, HX>,
592    u: &'a SMatrix<T, NU, HU>,
593    pub reason: TerminationReason,
594    pub iterations: usize,
595    pub prim_residual: T,
596    pub dual_residual: T,
597}
598
599impl<T: RealField + Copy, const NX: usize, const NU: usize, const HX: usize, const HU: usize>
600    Solution<'_, T, NX, NU, HX, HU>
601{
602    /// Get the predicticted states for the `at` index
603    pub fn x_prediction(&self, at: usize) -> SVector<T, NX> {
604        if let Some(x_ref) = self.x_ref.as_ref() {
605            self.x.column(at) + x_ref.column(at)
606        } else {
607            self.x.column(at).clone_owned()
608        }
609    }
610
611    /// Get the predicticted input for the `at` index
612    pub fn u_prediction(&self, at: usize) -> SVector<T, NU> {
613        if let Some(u_ref) = self.u_ref.as_ref() {
614            self.u.column(at) + u_ref.column(at)
615        } else {
616            self.u.column(at).clone_owned()
617        }
618    }
619
620    /// Get the predictiction of states for this solution
621    pub fn x_prediction_full(&self) -> SMatrix<T, NX, HX> {
622        if let Some(x_ref) = self.x_ref.as_ref() {
623            self.x + *x_ref
624        } else {
625            self.x.clone_owned()
626        }
627    }
628
629    /// Get the predictiction of inputs for this solution
630    pub fn u_prediction_full(&self) -> SMatrix<T, NU, HU> {
631        if let Some(u_ref) = self.u_ref.as_ref() {
632            self.u + *u_ref
633        } else {
634            self.u.clone_owned()
635        }
636    }
637
638    /// Get the current contron input to be applied
639    pub fn u_now(&self) -> SVector<T, NU> {
640        self.u_prediction(0)
641    }
642}