strapdown/
filter.rs

1//! Inertial Navigation Filters
2//!
3//! This module contains implementations of various inertial navigation filters, including
4//! Kalman filters and particle filters. These filters are used to estimate the state of a
5//! strapdown inertial navigation system based on IMU measurements and other sensor data.
6//! The filters use the strapdown equations (provided by the StrapdownState) to propagate
7//! the state in the local level frame.
8//!
9//! Currently, this module contains an implementation of a full-state Unscented Kalman Filter
10//! (UKF) and a full-state particle filter. For completeness, an Extended Kalman Filter (EKF)
11//! should be included, however, a strapdown EKF INS is typically implemented as an error state
12//! filter, which would require a slightly different architecture.
13//!
14//! Contained in this module is also a simple standard position measurement model for both
15//! the UKF and particle filter. This model is used to update the state based on position
16//! measurements in the local level frame (i.e. a GPS fix).
17use crate::earth::METERS_TO_DEGREES;
18use crate::linalg::{matrix_square_root, robust_spd_solve, symmetrize};
19use crate::{IMUData, StrapdownState, forward, wrap_to_2pi};
20
21use std::fmt::Debug;
22
23use nalgebra::{DMatrix, DVector, Rotation3};
24use rand;
25// ==== Measurement Models ========================================================
26// Below is a set of generic measurement models for the UKF and particle filter.
27// These models provide a vec of "expected measurements" based on the sigma points
28// location. When used in a filter, you should simultaneously iterate through the
29// list of sigma points or particles and the expected measurements in order to
30// calculate the innovation matrix or the particle weighting.
31// ================================================================================
32/// Generic measurement model trait for all filters
33pub trait MeasurementModel {
34    /// Get the dimensionality of the measurement vector.
35    fn get_dimension(&self) -> usize;
36    /// Get the measurement vector
37    fn get_vector(&self) -> DVector<f64>;
38    /// Get the measurement noise covariance matrix
39    fn get_noise(&self) -> DMatrix<f64>;
40    /// Get the measurement sigma points, performs the mapping between the state space
41    /// and the measurement space.
42    fn get_sigma_points(&self, state_sigma_points: &DMatrix<f64>) -> DMatrix<f64>;
43}
44/// GPS position measurement model
45#[derive(Clone, Debug, Default)]
46pub struct GPSPositionMeasurement {
47    // <-- Check this model for degree/radian consistency
48    /// latitude in degrees
49    pub latitude: f64,
50    /// longitude in degrees
51    pub longitude: f64,
52    /// altitude in meters
53    pub altitude: f64,
54    /// noise standard deviation in meters
55    pub horizontal_noise_std: f64,
56    /// vertical noise standard deviation in meters
57    pub vertical_noise_std: f64,
58}
59impl MeasurementModel for GPSPositionMeasurement {
60    fn get_dimension(&self) -> usize {
61        3 // latitude, longitude, altitude
62    }
63    fn get_vector(&self) -> DVector<f64> {
64        DVector::from_vec(vec![
65            self.latitude.to_radians(),
66            self.longitude.to_radians(),
67            self.altitude,
68        ])
69    }
70    fn get_noise(&self) -> DMatrix<f64> {
71        DMatrix::from_diagonal(&DVector::from_vec(vec![
72            (self.horizontal_noise_std * METERS_TO_DEGREES).powi(2),
73            (self.horizontal_noise_std * METERS_TO_DEGREES).powi(2),
74            self.vertical_noise_std.powi(2),
75        ]))
76    }
77    fn get_sigma_points(&self, state_sigma_points: &DMatrix<f64>) -> DMatrix<f64> {
78        let mut measurement_sigma_points = DMatrix::<f64>::zeros(3, state_sigma_points.ncols());
79        for (i, sigma_point) in state_sigma_points.column_iter().enumerate() {
80            measurement_sigma_points[(0, i)] = sigma_point[0];
81            measurement_sigma_points[(1, i)] = sigma_point[1];
82            measurement_sigma_points[(2, i)] = sigma_point[2];
83        }
84        measurement_sigma_points
85    }
86}
87/// GPS Velocity measurement model
88#[derive(Clone, Debug, Default)]
89pub struct GPSVelocityMeasurement {
90    /// Northward velocity in m/s
91    pub northward_velocity: f64,
92    /// Eastward velocity in m/s
93    pub eastward_velocity: f64,
94    /// Downward velocity in m/s
95    pub downward_velocity: f64,
96    /// noise standard deviation in m/s
97    pub horizontal_noise_std: f64,
98    /// vertical noise standard deviation in m/s
99    pub vertical_noise_std: f64,
100}
101impl MeasurementModel for GPSVelocityMeasurement {
102    fn get_dimension(&self) -> usize {
103        3 // northward, eastward, downward velocity
104    }
105    fn get_vector(&self) -> DVector<f64> {
106        DVector::from_vec(vec![
107            self.northward_velocity,
108            self.eastward_velocity,
109            self.downward_velocity,
110        ])
111    }
112    fn get_noise(&self) -> DMatrix<f64> {
113        DMatrix::from_diagonal(&DVector::from_vec(vec![
114            self.horizontal_noise_std.powi(2),
115            self.horizontal_noise_std.powi(2),
116            self.vertical_noise_std.powi(2),
117        ]))
118    }
119    fn get_sigma_points(&self, state_sigma_points: &DMatrix<f64>) -> DMatrix<f64> {
120        let mut measurement_sigma_points = DMatrix::<f64>::zeros(3, state_sigma_points.ncols());
121        for (i, sigma_point) in state_sigma_points.column_iter().enumerate() {
122            measurement_sigma_points[(0, i)] = sigma_point[3];
123            measurement_sigma_points[(1, i)] = sigma_point[4];
124            measurement_sigma_points[(2, i)] = sigma_point[5];
125        }
126        measurement_sigma_points
127    }
128}
129
130/// GPS Position and Velocity measurement model
131#[derive(Clone, Debug, Default)]
132pub struct GPSPositionAndVelocityMeasurement {
133    /// latitude in degrees
134    pub latitude: f64,
135    /// longitude in degrees
136    pub longitude: f64,
137    /// altitude in meters
138    pub altitude: f64,
139    /// Northward velocity in m/s
140    pub northward_velocity: f64,
141    /// Eastward velocity in m/s
142    pub eastward_velocity: f64,
143    /// Downward velocity in m/s
144    // pub downward_velocity: f64, // GPS speed measurements do not typically provide vertical velocity
145    /// noise standard deviation in meters for position
146    pub horizontal_noise_std: f64,
147    /// vertical noise standard deviation in meters for position
148    pub vertical_noise_std: f64,
149    /// noise standard deviation in m/s for velocity
150    pub velocity_noise_std: f64,
151}
152impl MeasurementModel for GPSPositionAndVelocityMeasurement {
153    fn get_dimension(&self) -> usize {
154        5 // latitude, longitude, altitude, northward velocity, eastward velocity
155    }
156    fn get_vector(&self) -> DVector<f64> {
157        DVector::from_vec(vec![
158            self.latitude.to_radians(),
159            self.longitude.to_radians(),
160            self.altitude,
161            self.northward_velocity,
162            self.eastward_velocity,
163        ])
164    }
165    fn get_noise(&self) -> DMatrix<f64> {
166        DMatrix::from_diagonal(&DVector::from_vec(vec![
167            (self.horizontal_noise_std * METERS_TO_DEGREES).powi(2),
168            (self.horizontal_noise_std * METERS_TO_DEGREES).powi(2),
169            self.vertical_noise_std.powi(2),
170            self.velocity_noise_std.powi(2),
171            self.velocity_noise_std.powi(2),
172        ]))
173    }
174    fn get_sigma_points(&self, state_sigma_points: &DMatrix<f64>) -> DMatrix<f64> {
175        let mut measurement_sigma_points = DMatrix::<f64>::zeros(5, state_sigma_points.ncols());
176        for (i, sigma_point) in state_sigma_points.column_iter().enumerate() {
177            measurement_sigma_points[(0, i)] = sigma_point[0];
178            measurement_sigma_points[(1, i)] = sigma_point[1];
179            measurement_sigma_points[(2, i)] = sigma_point[2];
180            measurement_sigma_points[(3, i)] = sigma_point[3];
181            measurement_sigma_points[(4, i)] = sigma_point[4];
182        }
183        measurement_sigma_points
184    }
185}
186
187/// A relative relative altitude measurement derived from barometric pressure.
188/// Note that this measurement model is an altitude measurement derived from
189/// a barometric altimeter and not a direct calculation of altitude from the
190/// barometric pressure.
191#[derive(Clone, Debug, Default)]
192pub struct RelativeAltitudeMeasurement {
193    /// Measured relative altitude in meters
194    pub relative_altitude: f64,
195    /// Reference pressure in Pa
196    pub reference_altitude: f64,
197}
198impl MeasurementModel for RelativeAltitudeMeasurement {
199    fn get_dimension(&self) -> usize {
200        1 // relative altitude
201    }
202    fn get_vector(&self) -> DVector<f64> {
203        DVector::from_vec(vec![self.relative_altitude + self.reference_altitude])
204    }
205    fn get_noise(&self) -> DMatrix<f64> {
206        DMatrix::from_diagonal(&DVector::from_vec(vec![5.0])) // 1 mm noise
207    }
208    fn get_sigma_points(&self, state_sigma_points: &DMatrix<f64>) -> DMatrix<f64> {
209        let mut measurement_sigma_points =
210            DMatrix::<f64>::zeros(self.get_dimension(), state_sigma_points.ncols());
211        for (i, sigma_point) in state_sigma_points.column_iter().enumerate() {
212            measurement_sigma_points[(0, i)] = sigma_point[2];
213        }
214        measurement_sigma_points
215    }
216}
217#[derive(Clone, Debug, Default)]
218pub struct GravityAnomalyMeasurement {
219    // Placeholder
220}
221
222#[derive(Clone, Debug, Default)]
223pub struct MagneticAnomalyMeasurement {}
224
225/// Basic strapdown state parameters for the UKF and particle filter initialization.
226#[derive(Clone, Debug, Default)]
227pub struct StrapdownParams {
228    pub latitude: f64,
229    pub longitude: f64,
230    pub altitude: f64,
231    pub northward_velocity: f64,
232    pub eastward_velocity: f64,
233    pub downward_velocity: f64,
234    pub roll: f64,
235    pub pitch: f64,
236    pub yaw: f64,
237    pub in_degrees: bool,
238}
239/// Strapdown Unscented Kalman Filter Inertial Navigation Filter
240///
241/// This filter uses the Unscented Kalman Filter (UKF) algorithm to estimate the state of a
242/// strapdown inertial navigation system. It uses the strapdown equations to propagate the state
243/// in the local level frame based on IMU measurements in the body frame. The filter also uses
244/// a generic measurement model to update the state based on measurements in the local level frame.
245///
246/// Because of the generic nature of both the UKF and this toolbox, the filter requires the user to
247/// implement the measurement model(s). The measurement model must calculate the measurement sigma points
248/// ($\mathcal{Z} = h(\mathcal{X})$) and the measurement noise matrix ($R$) for the filter. Some basic
249/// GNSS-based are provided in this module (position, velocity, position and velocity, barometric altitude).
250/// In a given scenario's implementation, the user should then call these measurement models. Please see the
251/// `sim` module for a reference implementation of a full state UKF INS with a position and velocity GPS-based
252/// measurement model and barometric altitude measurement model.
253///
254/// Note that, internally, angles are always stored in radians (both for the attitude and the position),
255/// however, the user can choose to convert them to degrees when retrieving the state vector and the UKF
256/// and underlying strapdown state can be constructed from data in degrees by using the boolean `in_degrees`
257/// toggle where applicable. Generally speaking, the design of this crate is such that methods that expect
258/// a WGS84 coordinate (e.g. latitude or longitude) will expect the value in degrees, whereas trigonometric
259/// functions (e.g. sine, cosine, tangent) will expect the value in radians.
260pub struct UKF {
261    mean_state: DVector<f64>,
262    covariance: DMatrix<f64>,
263    process_noise: DMatrix<f64>,
264    lambda: f64,
265    state_size: usize,
266    weights_mean: DVector<f64>,
267    weights_cov: DVector<f64>,
268}
269impl Debug for UKF {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        f.debug_struct("UKF")
272            .field("mean_state", &self.mean_state)
273            .field("covariance", &self.covariance)
274            .field("process_noise", &self.process_noise)
275            .field("lambda", &self.lambda)
276            .field("state_size", &self.state_size)
277            .finish()
278    }
279}
280impl UKF {
281    /// Creates a new UKF with the given initial state, biases, covariance, process noise,
282    /// any additional other states, and UKF hyper parameters.
283    ///
284    /// # Arguments
285    /// * `position` - The initial position of the strapdown state.
286    /// * `velocity` - The initial velocity of the strapdown state.
287    /// * `attitude` - The initial attitude of the strapdown state.
288    /// * `imu_biases` - The initial IMU biases.
289    /// * `other_states` - Any additional states the filter is estimating (ex: measurement or sensor bias).
290    /// * `covariance_diagonal` - The initial covariance diagonal.
291    /// * `process_noise_diagonal` - The process noise diagonal.
292    /// * `alpha` - The alpha parameter for the UKF.
293    /// * `beta` - The beta parameter for the UKF.
294    /// * `kappa` - The kappa parameter for the UKF.
295    /// * `in_degrees` - Whether the input vectors are in degrees or radians.
296    ///
297    /// # Returns
298    /// * A new UKF struct.
299    pub fn new(
300        strapdown_state: StrapdownParams,
301        imu_biases: Vec<f64>,
302        other_states: Option<Vec<f64>>,
303        covariance_diagonal: Vec<f64>,
304        process_noise: DMatrix<f64>,
305        alpha: f64,
306        beta: f64,
307        kappa: f64,
308    ) -> UKF {
309        assert!(
310            process_noise.nrows() == process_noise.ncols(),
311            "Process noise matrix must be square"
312        );
313        let mut mean = if strapdown_state.in_degrees {
314            vec![
315                strapdown_state.latitude.to_radians(),
316                strapdown_state.longitude.to_radians(),
317                strapdown_state.altitude,
318                strapdown_state.northward_velocity,
319                strapdown_state.eastward_velocity,
320                strapdown_state.downward_velocity,
321                strapdown_state.roll,
322                strapdown_state.pitch,
323                strapdown_state.yaw,
324            ]
325        } else {
326            vec![
327                strapdown_state.latitude,
328                strapdown_state.longitude,
329                strapdown_state.altitude,
330                strapdown_state.northward_velocity,
331                strapdown_state.eastward_velocity,
332                strapdown_state.downward_velocity,
333                strapdown_state.roll,
334                strapdown_state.pitch,
335                strapdown_state.yaw,
336            ]
337        };
338        mean.extend(imu_biases);
339        if let Some(ref other_states) = other_states {
340            mean.extend(other_states.iter().cloned());
341        }
342        assert!(
343            mean.len() >= 15,
344            "Expected a canonical state vector of at least 15 states (position, velocity, attitude, imu biases)"
345        );
346        assert!(
347            mean.len() == covariance_diagonal.len(),
348            "{}",
349            &format!(
350                "Mean vector and covariance diagonal must be of the same size (mean: {}, covariance_diagonal: {})",
351                mean.len(),
352                covariance_diagonal.len()
353            )
354        );
355        let state_size = mean.len();
356        let mean_state = DVector::from_vec(mean);
357        let covariance = DMatrix::<f64>::from_diagonal(&DVector::from_vec(covariance_diagonal));
358        assert!(
359            covariance.shape() == (state_size, state_size),
360            "Covariance matrix must be square"
361        );
362        assert!(
363            covariance.shape() == process_noise.shape(),
364            "Covariance and process noise must be of the same size"
365        );
366        let lambda = alpha * alpha * (state_size as f64 + kappa) - state_size as f64;
367        let mut weights_mean = DVector::zeros(2 * state_size + 1);
368        let mut weights_cov = DVector::zeros(2 * state_size + 1);
369        weights_mean[0] = lambda / (state_size as f64 + lambda);
370        weights_cov[0] = lambda / (state_size as f64 + lambda) + (1.0 - alpha * alpha + beta);
371        for i in 1..(2 * state_size + 1) {
372            let w = 1.0 / (2.0 * (state_size as f64 + lambda));
373            weights_mean[i] = w;
374            weights_cov[i] = w;
375        }
376        UKF {
377            mean_state,
378            covariance,
379            process_noise,
380            lambda,
381            state_size,
382            weights_mean,
383            weights_cov,
384        }
385    }
386    /// Predicts the state using the strapdown equations and IMU measurements.
387    ///
388    /// The IMU measurements are used to update the strapdown state in the local level frame.
389    /// The IMU measurements are assumed to be in the body frame.
390    ///
391    /// # Arguments
392    /// * `imu_data` - The IMU measurements to propagate the state with (e.g. relative accelerations (m/s^2) and angular rates (rad/s)).
393    /// * `dt` - The time step for the propagation.
394    ///
395    /// # Returns
396    /// * none
397    pub fn predict(&mut self, imu_data: IMUData, dt: f64) {
398        // Propagate the strapdown state using the strapdown equations
399        let mut sigma_points = self.get_sigma_points();
400        for i in 0..sigma_points.ncols() {
401            let mut sigma_point_vec = sigma_points.column(i).clone_owned();
402            let mut state = StrapdownState {
403                latitude: sigma_point_vec[0],
404                longitude: sigma_point_vec[1],
405                altitude: sigma_point_vec[2],
406                velocity_north: sigma_point_vec[3],
407                velocity_east: sigma_point_vec[4],
408                velocity_down: sigma_point_vec[5],
409                attitude: Rotation3::from_euler_angles(
410                    sigma_point_vec[6],
411                    sigma_point_vec[7],
412                    sigma_point_vec[8],
413                ),
414                coordinate_convention: true,
415            };
416            // println!("propagating: lat {}  lon {}", state.latitude.to_degrees(), state.longitude.to_degrees());
417            forward(&mut state, imu_data, dt);
418            // Update the sigma point with the new state
419            sigma_point_vec[0] = state.latitude;
420            sigma_point_vec[1] = state.longitude;
421            sigma_point_vec[2] = state.altitude;
422            sigma_point_vec[3] = state.velocity_north;
423            sigma_point_vec[4] = state.velocity_east;
424            sigma_point_vec[5] = state.velocity_down;
425            sigma_point_vec[6] = state.attitude.euler_angles().0; // Roll
426            sigma_point_vec[7] = state.attitude.euler_angles().1; // Pitch
427            sigma_point_vec[8] = state.attitude.euler_angles().2; // Yaw
428            sigma_points.set_column(i, &sigma_point_vec);
429        }
430        // Update the mean state as mu_bar
431        let mut mu_bar = DVector::<f64>::zeros(self.state_size);
432        for (i, sigma_point) in sigma_points.column_iter().enumerate() {
433            mu_bar += self.weights_mean[i] * sigma_point;
434        }
435        // Update the covariance as P_bar
436        let mut p_bar = DMatrix::<f64>::zeros(self.state_size, self.state_size);
437        for (i, sigma_point) in sigma_points.column_iter().enumerate() {
438            let diff = sigma_point - &mu_bar;
439            p_bar += self.weights_cov[i] * &diff * &diff.transpose();
440        }
441        // Add process noise to the covariance
442        p_bar += &self.process_noise;
443        // Update the mean state and covariance
444        self.mean_state = mu_bar;
445        self.covariance = symmetrize(&p_bar);
446    }
447    /// Get the UKF mean state.
448    pub fn get_mean(&self) -> DVector<f64> {
449        self.mean_state.clone()
450    }
451    /// Get the UKF covariance.
452    pub fn get_covariance(&self) -> DMatrix<f64> {
453        self.covariance.clone()
454    }
455    /// Convert a Vec<SigmaPoint> to a DMatrix<f64>
456    pub fn get_sigma_points(&self) -> DMatrix<f64> {
457        let p = (self.state_size as f64 + self.lambda) * self.covariance.clone();
458        let sqrt_p = matrix_square_root(&p);
459        let mu = self.mean_state.clone();
460        let mut pts = DMatrix::<f64>::zeros(self.state_size, 2 * self.state_size + 1);
461        pts.column_mut(0).copy_from(&mu);
462        for i in 0..sqrt_p.ncols() {
463            pts.column_mut(i + 1).copy_from(&(&mu + sqrt_p.column(i)));
464            pts.column_mut(i + 1 + self.state_size)
465                .copy_from(&(&mu - sqrt_p.column(i)));
466        }
467        pts
468    }
469    /// Perform the Kalman measurement update step.
470    ///
471    /// This method updates the state and covariance based on the measurement and measurement
472    /// sigma points. The measurement model is specific to a given implementation of the UKF
473    /// and must be provided by the user as the model determines the shape and quantities of
474    /// the measurement vector and the measurement sigma points. Measurement models should be
475    /// implemented as traits and applied to the UKF as needed.
476    ///
477    /// This module contains some standard GNSS-aided measurements models (`position_measurement_model`,
478    /// `velocity_measurement_model`, and `position_and_velocity_measurement_model`) that can be
479    /// used. See the `sim` module for a canonical example of a GPS-aided INS implementation
480    /// that uses these models.
481    ///
482    /// **Note**: Canonical INS implementations use a position measurement model. Typically,
483    /// position is reported in _degrees_ for latitude and longitude, and in meters for altitude.
484    /// Internally, the UKF stores the latitude and longitude in _radians_, and the measurement models make no
485    /// assumptions about the units of the position measurements. However, the user should
486    /// ensure that the provided measurement to this function is in the same units as the
487    /// measurement model.
488    ///
489    /// # Arguments
490    /// * `measurement` - The measurement vector to update the state with.
491    /// * `measurement_sigma_points` - The measurement sigma points to use for the update.
492    pub fn update<M: MeasurementModel>(&mut self, measurement: M) {
493        let measurement_sigma_points = measurement.get_sigma_points(&self.get_sigma_points());
494        // Calculate expected measurement
495        let mut z_hat = DVector::<f64>::zeros(measurement.get_dimension());
496        for (i, sigma_point) in measurement_sigma_points.column_iter().enumerate() {
497            z_hat += self.weights_mean[i] * sigma_point;
498        }
499        // Calculate innovation matrix S
500        let mut s = DMatrix::<f64>::zeros(measurement.get_dimension(), measurement.get_dimension());
501        //for i in 0..measurement_sigma_points.len() {
502        for (i, sigma_point) in measurement_sigma_points.column_iter().enumerate() {
503            let diff = sigma_point - &z_hat;
504            s += self.weights_cov[i] * &diff * &diff.transpose();
505        }
506        s += measurement.get_noise();
507        // Calculate the cross-covariance
508        let sigma_points = self.get_sigma_points();
509        let mut cross_covariance =
510            DMatrix::<f64>::zeros(self.state_size, measurement.get_dimension());
511        for (i, measurement_sigma_point) in measurement_sigma_points.column_iter().enumerate() {
512            let measurement_diff = measurement_sigma_point - &z_hat;
513            let state_diff = sigma_points.column(i) - &self.mean_state;
514            cross_covariance += self.weights_cov[i] * state_diff * measurement_diff.transpose();
515        }
516        // // Calculate the Kalman gain
517        // let s_inv = match s.clone().try_inverse() {
518        //     Some(inv) => inv,
519        //     None => panic!("Innovation matrix is singular"),
520        // };
521        // let k = &cross_covariance * &s_inv;
522        // // check that the kalman gain and measurement diff are compatible to multiply
523        // if k.ncols() != measurement.get_dimension() {
524        //     panic!("Kalman gain and measurement differential are not compatible");
525        // }
526        // K = P_xz * S^{-1} without forming S^{-1}
527        let k = self.robust_kalman_gain(&cross_covariance, &s);
528        // Update the mean and covariance
529        self.mean_state += &k * (measurement.get_vector() - &z_hat);
530        // wrap attitude angles to 2pi
531        // TODO: #30 Refactor attitude angles to use a more robust representation
532        self.mean_state[6] = wrap_to_2pi(self.mean_state[6]);
533        self.mean_state[7] = wrap_to_2pi(self.mean_state[7]);
534        self.mean_state[8] = wrap_to_2pi(self.mean_state[8]);
535        self.covariance -= &k * &s * &k.transpose();
536        // Re-symmetrize to fight round-off
537        self.covariance = 0.5 * (&self.covariance + self.covariance.transpose());
538    }
539    fn robust_kalman_gain(
540        &mut self,
541        cross_covariance: &DMatrix<f64>,
542        s: &DMatrix<f64>,
543    ) -> DMatrix<f64> {
544        // Solve S Kᵀ = P_xzᵀ  => K = (S^{-1} P_xz)ᵀ
545        let kt = robust_spd_solve(&symmetrize(s), &cross_covariance.transpose());
546        kt.transpose()
547    }
548}
549#[derive(Clone, Debug, Default)]
550pub struct Particle {
551    /// The strapdown state of the particle
552    pub nav_state: StrapdownState,
553    /// The weight of the particle
554    pub weight: f64,
555}
556
557/// Particle filter for strapdown inertial navigation
558///
559/// This filter uses a particle filter algorithm to estimate the state of a strapdown inertial navigation system.
560/// Similarly to the UKF, it uses thin wrappers around the StrapdownState's forward function to propagate the state.
561/// The particle filter is a little more generic in implementation than the UKF, as all it fundamentally is is a set
562/// of particles and several related functions to propagate, update, and resample the particles.
563pub struct ParticleFilter {
564    /// The particles in the particle filter
565    pub particles: Vec<Particle>,
566}
567impl ParticleFilter {
568    /// Create a new particle filter with the given particles
569    ///
570    /// # Arguments
571    /// * `particles` - The particles to use for the particle filter.
572    pub fn new(particles: Vec<Particle>) -> Self {
573        ParticleFilter { particles }
574    }
575    /// Propagate all particles forward using the strapdown equations
576    ///
577    /// # Arguments
578    /// * `imu_data` - The IMU measurements to propagate the particles with.
579    pub fn propagate(&mut self, imu_data: &IMUData, dt: f64) {
580        for particle in &mut self.particles {
581            //particle.forward(*imu_data, dt, None);
582            forward(&mut particle.nav_state, *imu_data, dt);
583        }
584    }
585    /// Update the weights of the particles based on a measurement
586    ///
587    /// Generic measurement update function for the particle filter. This function requires the user to provide
588    /// a measurement vector and a list of expected measurements for each particle. This list of expected measurements
589    /// is the result of a measurement model that is specific to the filter implementation. This model determines
590    /// the shape and quantities of the measurement vector and the expected measurements sigma points. This module
591    /// contains some standard GNSS-aided measurements models (`position_measurement_model`,
592    /// `velocity_measurement_model`, and `position_and_velocity_measurement_model`) that can be used.
593    ///
594    /// **Note**: Canonical INS implementations use a position measurement model. Typically,
595    /// position is reported in _degrees_ for latitude and longitude, and in meters for altitude.
596    /// Internally, the particle filter stores the latitude and longitude in _radians_, and the measurement models
597    /// make no assumptions about the units of the position measurements. However, the user should
598    /// ensure that the provided measurement to this function is in the same units as the
599    /// measurement model.
600    pub fn update(&mut self, measurement: &DVector<f64>, expected_measurements: &[DVector<f64>]) {
601        assert_eq!(self.particles.len(), expected_measurements.len());
602        let mut weights = Vec::with_capacity(self.particles.len());
603        for expected in expected_measurements.iter() {
604            // Calculate the Mahalanobis distance
605            let diff = measurement - expected;
606            let weight = (-0.5 * diff.transpose() * diff).exp().sum(); //TODO: #22 modify this to use any and/or a user specified probability distribution
607            weights.push(weight);
608        }
609        // self.set_weights(weights.as_slice());
610        self.normalize_weights();
611    }
612    /// Set the weights of the particles (e.g., after a measurement update)
613    ///
614    /// # Arguments
615    /// * `weights` - The weights to set for the particles.
616    pub fn set_weights(&mut self, weights: &[f64]) {
617        assert_eq!(weights.len(), self.particles.len());
618        for (particle, &w) in self.particles.iter_mut().zip(weights.iter()) {
619            particle.weight = w;
620        }
621    }
622    /// Normalize the weights of the particles. This is typically done after a measurement update
623    /// to ensure that the weights sum to 1.0 and can be treated like a probability distribution.
624    pub fn normalize_weights(&mut self) {
625        let sum: f64 = self.particles.iter().map(|p| p.weight).sum();
626        if sum > 0.0 {
627            for particle in &mut self.particles {
628                particle.weight /= sum;
629            }
630        }
631    }
632    /// Residual resampling (systematic resampling)
633    pub fn residual_resample(&mut self) {
634        let n = self.particles.len();
635        let mut new_particles = Vec::with_capacity(n);
636        let weights: Vec<f64> = self.particles.iter().map(|p| p.weight).collect();
637        let mut num_copies = vec![0usize; n];
638        let mut residual: Vec<f64> = vec![0.0; n];
639        //let mut total_residual: f64 = 0.0;
640        // Integer part
641        for (i, &w) in weights.iter().enumerate() {
642            let copies = (w * n as f64).floor() as usize;
643            num_copies[i] = copies;
644            residual[i] = w * n as f64 - copies as f64;
645            //total_residual += residual[i];
646        }
647        // Copy integer part
648        for (i, &copies) in num_copies.iter().enumerate() {
649            for _ in 0..copies {
650                new_particles.push(self.particles[i].clone());
651            }
652        }
653        // Residual part
654        let residual_particles = n - new_particles.len();
655        if residual_particles > 0 {
656            // Normalize residuals
657            let sum_residual: f64 = residual.iter().sum();
658            let mut positions = Vec::with_capacity(residual_particles);
659            let step = sum_residual / residual_particles as f64;
660            let mut u = rand::random::<f64>() * step;
661            for _ in 0..residual_particles {
662                positions.push(u);
663                u += step;
664            }
665            let mut i = 0;
666            let mut j = 0;
667            let mut cumsum = residual[0];
668            while j < residual_particles {
669                while positions[j] > cumsum {
670                    i += 1;
671                    cumsum += residual[i];
672                }
673                new_particles.push(self.particles[i].clone());
674                j += 1;
675            }
676        }
677        // Reset weights
678        let uniform_weight = 1.0 / n as f64;
679        for particle in &mut new_particles {
680            particle.weight = uniform_weight;
681        }
682        self.particles = new_particles;
683    }
684}
685/// Tests
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use crate::earth;
690    use assert_approx_eq::assert_approx_eq;
691    use nalgebra::Vector3;
692
693    const IMU_BIASES: [f64; 6] = [0.0; 6];
694    const N: usize = 15;
695    const COVARIANCE_DIAGONAL: [f64; N] = [1e-9; N];
696    const PROCESS_NOISE_DIAGONAL: [f64; N] = [1e-9; N];
697
698    const ALPHA: f64 = 1e-3;
699    const BETA: f64 = 2.0;
700    const KAPPA: f64 = 0.0;
701    const UKF_PARAMS: StrapdownParams = StrapdownParams {
702        latitude: 0.0,
703        longitude: 0.0,
704        altitude: 0.0,
705        northward_velocity: 0.0,
706        eastward_velocity: 0.0,
707        downward_velocity: 0.0,
708        roll: 0.0,
709        pitch: 0.0,
710        yaw: 0.0,
711        in_degrees: false,
712    };
713
714    #[test]
715    fn ukf_construction() {
716        let measurement_bias = vec![0.0; 3]; // Example measurement bias
717        let ukf = UKF::new(
718            UKF_PARAMS,
719            IMU_BIASES.to_vec(),
720            Some(measurement_bias.clone()),
721            vec![1e-3; 18],
722            DMatrix::from_diagonal(&DVector::from_vec(vec![1e-3; 18])),
723            ALPHA,
724            BETA,
725            KAPPA,
726        );
727        assert_eq!(ukf.mean_state.len(), 18);
728        let wms = ukf.weights_mean;
729        let wcs = ukf.weights_cov;
730        assert_eq!(wms.len(), (2 * ukf.state_size) + 1);
731        assert_eq!(wcs.len(), (2 * ukf.state_size) + 1);
732        // Check that the weights are correct
733        let lambda = ALPHA.powi(2) * (18.0 + KAPPA) - 18.0;
734        assert_eq!(lambda, ukf.lambda);
735        let wm_0 = lambda / (18.0 + lambda);
736        let wc_0 = wm_0 + (1.0 - ALPHA.powi(2)) + BETA;
737        let w_i = 1.0 / (2.0 * (18.0 + lambda));
738        assert_approx_eq!(wms[0], wm_0, 1e-6);
739        assert_approx_eq!(wcs[0], wc_0, 1e-6);
740        for i in 1..wms.len() {
741            assert_approx_eq!(wms[i], w_i, 1e-6);
742            assert_approx_eq!(wcs[i], w_i, 1e-6);
743        }
744    }
745    #[test]
746    fn ukf_get_sigma_points() {
747        let ukf = UKF::new(
748            UKF_PARAMS,
749            IMU_BIASES.to_vec(),
750            None,
751            COVARIANCE_DIAGONAL.to_vec(),
752            DMatrix::from_diagonal(&DVector::from_vec(PROCESS_NOISE_DIAGONAL.to_vec())),
753            ALPHA,
754            BETA,
755            KAPPA,
756        );
757        let sigma_points = ukf.get_sigma_points();
758        assert_eq!(sigma_points.ncols(), (2 * ukf.state_size) + 1);
759
760        let mu = ukf.get_sigma_points() * ukf.weights_mean;
761        assert_eq!(mu.nrows(), ukf.state_size);
762        assert_eq!(mu.ncols(), 1);
763        assert_approx_eq!(mu[0], 0.0, 1e-6);
764        assert_approx_eq!(mu[1], 0.0, 1e-6);
765        assert_approx_eq!(mu[2], 0.0, 1e-6);
766        assert_approx_eq!(mu[3], 0.0, 1e-6);
767        assert_approx_eq!(mu[4], 0.0, 1e-6);
768        assert_approx_eq!(mu[5], 0.0, 1e-6);
769        assert_approx_eq!(mu[6], 0.0, 1e-6);
770        assert_approx_eq!(mu[7], 0.0, 1e-6);
771        assert_approx_eq!(mu[8], 0.0, 1e-6);
772    }
773    #[test]
774    fn ukf_propagate() {
775        let mut ukf = UKF::new(
776            UKF_PARAMS,
777            vec![0.0; 6],
778            None,         //Some(measurement_bias.clone()),
779            vec![0.0; N], // Absolute certainty use for testing the process
780            DMatrix::from_diagonal(&DVector::from_vec(PROCESS_NOISE_DIAGONAL.to_vec())),
781            1e-3,
782            2.0,
783            0.0,
784        );
785        let dt = 1.0;
786        let imu_data = IMUData {
787            accel: Vector3::new(0.0, 0.0, earth::gravity(&0.0, &0.0)),
788            gyro: Vector3::new(0.0, 0.0, 0.0), // No rotation
789        };
790        ukf.predict(imu_data, dt);
791        assert!(
792            ukf.mean_state.len() == 15 //+ measurement_bias.len()
793        );
794        let measurement = GPSPositionMeasurement {
795            latitude: 0.0,
796            longitude: 0.0,
797            altitude: 0.0,
798            horizontal_noise_std: 1e-3,
799            vertical_noise_std: 1e-3,
800        };
801        ukf.update(measurement);
802        // Check that the state has not changed
803        assert_approx_eq!(ukf.mean_state[0], 0.0, 1e-3);
804        assert_approx_eq!(ukf.mean_state[1], 0.0, 1e-3);
805        assert_approx_eq!(ukf.mean_state[2], 0.0, 0.1);
806        assert_approx_eq!(ukf.mean_state[3], 0.0, 0.1);
807        assert_approx_eq!(ukf.mean_state[4], 0.0, 0.1);
808        assert_approx_eq!(ukf.mean_state[5], 0.0, 0.1);
809    }
810    //#[test]
811    //fn ukf_debug() {
812    //    let imu_biases = vec![0.0, 0.0, 0.0];
813    //    let measurement_bias = vec![1.0, 1.0, 1.0];
814    //    let n = 9 + imu_biases.len() + measurement_bias.len();
815    //    let covariance_diagonal = vec![1e-3; n];
816    //    let process_noise_diagonal = vec![1e-3; n];
817    //    let alpha = 1e-3;
818    //    let beta = 2.0;
819    //    let kappa = 1e-3;
820    //    let ukf_params = StrapdownParams {
821    //        latitude: 0.0,
822    //        longitude: 0.0,
823    //        altitude: 0.0,
824    //        northward_velocity: 0.0,
825    //        eastward_velocity: 0.0,
826    //        downward_velocity: 0.0,
827    //        roll: 0.0,
828    //        pitch: 0.0,
829    //        yaw: 0.0,
830    //        in_degrees: false,
831    //    };
832    //    let ukf = UKF::new(
833    //        ukf_params,
834    //        imu_biases.clone(),
835    //        Some(measurement_bias.clone()),
836    //        covariance_diagonal,
837    //        DMatrix::from_diagonal(&DVector::from_vec(process_noise_diagonal)),
838    //        alpha,
839    //        beta,
840    //        kappa,
841    //    );
842    //    let debug_str = format!("{:?}", ukf);
843    //    assert!(debug_str.contains("mean_state"));
844    //}
845    //#[test]
846    //fn test_ukf_hover() {
847    //     let imu_data = IMUData::new_from_vec(vec![0.0, 0.0, earth::gravity(&0.0, &0.0)], vec![0.0, 0.0, 0.0]);
848    //     let position = vec![0.0, 0.0, 0.0];
849    //     let velocity = [0.0, 0.0, 0.0];
850    //     let attitude = [0.0, 0.0, 0.0];
851    //     let imu_biases = vec![0.0, 0.0, 0.0];
852    //     let measurement_bias = vec![0.0, 0.0, 0.0];
853    //     let covariance_diagonal = vec![1e-9; 9 + imu_biases.len() + measurement_bias.len()];
854    //     let process_noise_diagonal = vec![1e-9; 9 + imu_biases.len() + measurement_bias.len()];
855    //     let alpha = 1e-3;
856    //     let beta = 2.0;
857    //     let kappa = 0.0;
858    //     let ukf_params = StrapdownParams {
859    //         latitude: position[0],
860    //         longitude: position[1],
861    //         altitude: position[2],
862    //         northward_velocity: velocity[0],
863    //         eastward_velocity: velocity[1],
864    //         downward_velocity: velocity[2],
865    //         roll: attitude[0],
866    //         pitch: attitude[1],
867    //         yaw: attitude[2],
868    //         in_degrees: false,
869    //     };
870    //     let mut ukf = UKF::new(
871    //         ukf_params,
872    //         imu_biases.clone(),
873    //         Some(measurement_bias.clone()),
874    //         covariance_diagonal,
875    //         DMatrix::from_diagonal(&DVector::from_vec(process_noise_diagonal)),
876    //         alpha,
877    //         beta,
878    //         kappa,
879    //     );
880    //     let dt = 1.0;
881    //     let measurement_sigma_points = ukf.position_measurement_model(true);
882    //     let measurement_noise = ukf.position_measurement_noise(true);
883    //     let measurement = DVector::from_vec(position.clone());
884    //     for _i in 0..60 {
885    //         ukf.predict(&imu_data, dt);
886    //         ukf.update(&measurement, &measurement_sigma_points, &measurement_noise);
887    //     }
888    //     assert_approx_eq!(ukf.mean_state[0], position[0], 1e-3);
889    //     assert_approx_eq!(ukf.mean_state[1], position[1], 1e-3);
890    //     assert_approx_eq!(ukf.mean_state[2], position[2], 0.01);
891    //     assert_approx_eq!(ukf.mean_state[3], velocity[0], 0.01);
892    //     assert_approx_eq!(ukf.mean_state[4], velocity[1], 0.01);
893    //     assert_approx_eq!(ukf.mean_state[5], velocity[2], 0.01);
894    // }
895}