Skip to main content

sensor_fusion/
sensor_fusion_math.rs

1#![allow(clippy::many_single_char_names)]
2
3#[cfg(feature = "simd")]
4use core::simd::{f32x4, simd_swizzle};
5
6use vqm::{Quaternion, Vector3d};
7
8pub trait SensorFusionMath: Sized {
9    fn estimate_gravity(q: Quaternion<Self>) -> Vector3d<Self>;
10    fn derivative(q: Quaternion<Self>, gyro: Vector3d<Self>) -> Quaternion<Self>;
11    fn madgwick_step_acc(q: Quaternion<Self>, acc: Vector3d<Self>, max_acc_magnitude_squared: Self)
12    -> Quaternion<Self>;
13    fn madgwick_step_acc_mag(
14        q: Quaternion<Self>,
15        acc: Vector3d<Self>,
16        mag: Vector3d<Self>,
17        max_acc_magnitude_squared: Self,
18    ) -> Quaternion<Self>;
19}
20
21impl SensorFusionMath for f32 {
22    #[inline(always)]
23    fn estimate_gravity(q: Quaternion<Self>) -> Vector3d<Self> {
24        Vector3d {
25            x: 2.0 * (q.x * q.z - q.w * q.y),
26            y: 2.0 * (q.y * q.z + q.w * q.x),
27            z: q.w * q.w - q.x * q.x - q.y * q.y + q.z * q.z,
28        }
29    }
30    #[inline(always)]
31    fn derivative(q: Quaternion<Self>, gyro: Vector3d<Self>) -> Quaternion<Self> {
32        #[cfg(feature = "simd")]
33        {
34            /*// Load q: [w, x, y, z]
35            let q_v: f32x4 = unsafe { core::mem::transmute(q) };
36
37            // Load gyro: [x, y, z, padding] -> Swizzle to [0, x, y, z]
38            let g_raw: f32x4 = unsafe { core::mem::transmute(gyro) };
39            //let g_v = simd_swizzle!(g_raw, [f32x4::splat(0.0)], [4, 0, 1, 2]);
40
41            // Efficiently shift: [x, y, z, 0] -> [0, x, y, z]
42            // We rotate right and then zero out the 'w' (index 0)
43            let g_v = simd_swizzle!(g_raw, [3, 0, 1, 2]) * f32x4::from_array([0.0, 1.0, 1.0, 1.0]);
44
45            // Hamilton Product (q * g) for [w, x, y, z] layout:
46            // w_out = -x1*gx - y1*gy - z1*gz
47            // x_out =  w1*gx + y1*gz - z1*gy
48            // y_out =  w1*gy - x1*gz + z1*gx
49            // z_out =  w1*gz + x1*gy - y1*gx
50
51            // Row A: [w, w, w, w] * [0, gx, gy, gz]
52            let res = simd_swizzle!(q_v, [0, 0, 0, 0]) * g_v;
53
54            // Row B: [x, -x, -x, -x] * [gx, 0, gz, gy]
55            // (Using signs and swizzles to match the Hamilton rows)
56            let x_part = simd_swizzle!(q_v, [1, 1, 1, 1])
57                * simd_swizzle!(g_v, [1, 0, 3, 2])
58                * f32x4::from_array([-1.0, 1.0, -1.0, 1.0]);
59
60            // Row C: [y, y, -y, y] * [gy, gz, 0, gx]
61            let y_part = simd_swizzle!(q_v, [2, 2, 2, 2])
62                * simd_swizzle!(g_v, [2, 3, 0, 1])
63                * f32x4::from_array([-1.0, 1.0, 1.0, -1.0]);
64
65            // Row D: [z, z, z, -z] * [gz, gy, gx, 0]
66            let z_part = simd_swizzle!(q_v, [3, 3, 3, 3])
67                * simd_swizzle!(g_v, [3, 2, 1, 0])
68                * f32x4::from_array([-1.0, -1.0, 1.0, 1.0]);
69
70            let q_dot = (res + x_part + y_part + z_part) * f32x4::splat(0.5);
71
72            unsafe { core::mem::transmute(q_dot) }*/
73            let q_v = f32x4::from(q);
74            let g_raw = f32x4::from(gyro);
75
76            // Shift [x, y, z, pad] to [0, x, y, z] and zero the w lane
77            let g_v = simd_swizzle!(g_raw, [3, 0, 1, 2]) * f32x4::from_array([0.0, 1.0, 1.0, 1.0]);
78
79            // Parallel Hamilton Calculation
80            let w1 = simd_swizzle!(q_v, [0, 0, 0, 0]);
81            let x1 = simd_swizzle!(q_v, [1, 1, 1, 1]);
82            let y1 = simd_swizzle!(q_v, [2, 2, 2, 2]);
83            let z1 = simd_swizzle!(q_v, [3, 3, 3, 3]);
84
85            let g_w = g_v; // [0, gx, gy, gz]
86            let g_x = simd_swizzle!(g_v, [1, 0, 3, 2]); // [gx, 0, gz, gy]
87            let g_y = simd_swizzle!(g_v, [2, 3, 0, 1]); // [gy, gz, 0, gx]
88            let g_z = simd_swizzle!(g_v, [3, 2, 1, 0]); // [gz, gy, gx, 0]
89
90            let res = (w1 * g_w)
91                + (x1 * g_x * f32x4::from_array([-1.0, 1.0, -1.0, 1.0]))
92                + (y1 * g_y * f32x4::from_array([-1.0, 1.0, 1.0, -1.0]))
93                + (z1 * g_z * f32x4::from_array([-1.0, -1.0, 1.0, 1.0]));
94
95            let q_dot = res * f32x4::splat(0.5);
96            q_dot.into()
97        }
98        #[cfg(not(feature = "simd"))]
99        {
100            Quaternion {
101                w: (-q.x * gyro.x - q.y * gyro.y - q.z * gyro.z) * 0.5,
102                x: (q.w * gyro.x - q.z * gyro.y + q.y * gyro.z) * 0.5,
103                y: (q.z * gyro.x + q.w * gyro.y - q.x * gyro.z) * 0.5,
104                z: (-q.y * gyro.x + q.x * gyro.y + q.w * gyro.z) * 0.5,
105            }
106        }
107    }
108
109    /// Features of this implementation:
110    ///
111    /// 1. Parallel Throughput: Instead of 12 separate floating-point multiplications and 8 additions,
112    ///    the SIMD unit performs 3 vector multiplications and 2 vector additions.
113    ///
114    /// 2. Instruction Density: The `simd_swizzle!` maps directly to the VREV or VMOV instructions on the M33.
115    ///
116    /// 3. Register Reuse: `q_v` stays in its SIMD register the entire time.
117    ///    The compiler will likely use VFMA (Vector Fused Multiply-Add) to combine the terms,
118    ///    meaning this whole function could resolve in under 15 clock cycles.
119    ///
120    #[inline(always)]
121    fn madgwick_step_acc(
122        q: Quaternion<Self>,
123        acc: Vector3d<Self>,
124        max_acc_magnitude_squared: Self,
125    ) -> Quaternion<Self> {
126        use num_traits::Zero;
127        use vqm::SqrtMethods;
128
129        let acc_magnitude_squared = acc.norm_squared();
130        // Acceleration is an unreliable indicator of orientation when in high-g maneuvers,
131        // so exclude it from the calculation in these cases
132        let mut a = acc;
133        if acc_magnitude_squared > max_acc_magnitude_squared || acc_magnitude_squared == 0.0 {
134            a = Vector3d::zero();
135        } else {
136            a *= acc_magnitude_squared.sqrt_reciprocal();
137        }
138        #[cfg(feature = "simd")]
139        {
140            let q_v = f32x4::from(q);
141
142            // 1. Calculate wz_common
143            let wz_common = 2.0 * (q.x * q.x + q.y * q.y);
144
145            // 2. Calculate xy_common (w*w + z*z - 1.0 + 2.0*wz_common + a.z)
146            let xy_common = 2.0 * (q.w * q.w + q.z * q.z - 1.0 + 2.0 * wz_common + a.z);
147
148            // 3. Calculate common term: [w, x, y, z] * [qq, xy_common, xy_common, qq]
149            let common_scalars = f32x4::from_array([wz_common, xy_common, xy_common, wz_common]);
150            let common = q_v * common_scalars;
151
152            // 4. Calculate ax term: [y, -z, w, -x] * ax
153            // Term 2: [y, -z, w, -x] * ax
154            // Indices needed: y=2, z=3, w=0, x=1
155            let ax_q_swiz = simd_swizzle!(q_v, [2, 3, 0, 1]);
156            let ax_signs = f32x4::from_array([1.0, -1.0, 1.0, -1.0]);
157            let ax = (ax_q_swiz * ax_signs) * f32x4::splat(a.x);
158
159            // 5. Calculate ay term: [-x, -w, -z, -y] * ay
160            // Term 3: [-x, -w, -z, -y] * ay
161            // Indices needed: x=1, w=0, z=3, y=2
162            let ay_q_swiz = simd_swizzle!(q_v, [1, 0, 3, 2]);
163            let ay_signs = f32x4::splat(-1.0);
164            let ay = (ay_q_swiz * ay_signs) * f32x4::splat(a.y);
165
166            // 6. Combine: step = common + ax + ay
167            let ret_v = common + ax + ay;
168
169            ret_v.into()
170        }
171        #[cfg(not(feature = "simd"))]
172        {
173            let wz_common = 2.0 * (q.x * q.x + q.y * q.y);
174            let xy_common = 2.0 * (q.w * q.w + q.z * q.z - 1.0 + 2.0 * wz_common + a.z);
175            Quaternion {
176                w: q.w * wz_common + q.y * a.x - q.x * a.y,
177                x: q.x * xy_common - q.z * a.x - q.w * a.y,
178                y: q.y * xy_common + q.w * a.x - q.z * a.y,
179                z: q.z * wz_common - q.x * a.x - q.y * a.y,
180            }
181        }
182    }
183
184    #[inline(always)]
185    fn madgwick_step_acc_mag(
186        q: Quaternion<Self>,
187        acc: Vector3d<Self>,
188        mag: Vector3d<Self>,
189        max_acc_magnitude_squared: Self,
190    ) -> Quaternion<Self> {
191        use num_traits::Zero;
192        use vqm::SqrtMethods;
193
194        let acc_magnitude_squared = acc.norm_squared();
195        // Acceleration is an unreliable indicator of orientation when in high-g maneuvers,
196        // so exclude it from the calculation in these cases
197        let mut a = acc;
198        if acc_magnitude_squared > max_acc_magnitude_squared || acc_magnitude_squared == 0.0 {
199            a = Vector3d::zero();
200        } else {
201            a *= acc_magnitude_squared.sqrt_reciprocal();
202        }
203
204        let m = mag.normalized();
205
206        // make copies of the components of q to simplify the algebraic expressions
207        let q0 = q.w;
208        let q1 = q.x;
209        let q2 = q.y;
210        let q3 = q.z;
211        // Auxiliary variables to avoid repeated arithmetic
212        let q0q0 = q0 * q0;
213        let q0q1 = q0 * q1;
214        let q0q2 = q0 * q2;
215        let q0q3 = q0 * q3;
216        let q1q1 = q1 * q1;
217        let q1q2 = q1 * q2;
218        let q1q3 = q1 * q3;
219        let q2q2 = q2 * q2;
220        let q2q3 = q2 * q3;
221        let q3q3 = q3 * q3;
222
223        let q1q1_plus_q2q2 = q1q1 + q2q2;
224        let q2q2_plus_q3q3 = q2q2 + q3q3;
225
226        // Reference direction of Earth's magnetic field
227        let h = Vector3d {
228            x: m.x * (q0q0 + q1q1 - q2q2_plus_q3q3) + 2.0 * (m.y * (q1q2 - q0q3) + m.z * (q0q2 + q1q3)),
229            y: (m.x * (q0q3 + q1q2) + m.y * (q0q0 - q1q1 + q2q2 - q3q3) + m.z * (q2q3 - q0q1)) * 2.0,
230            z: 0.0,
231        };
232
233        let bx_bx = h.x * h.x + h.y * h.y;
234        let b = Vector3d {
235            x: bx_bx.sqrt(),
236            y: 0.0,
237            z: 2.0 * (m.x * (q1q3 - q0q2) + m.y * (q0q1 + q2q3)) + m.z * (q0q0 - q1q1_plus_q2q2 + q3q3),
238        };
239
240        let a_dash = Vector3d { x: a.x + m.x * b.z, y: a.y + m.y * b.z, z: 0.0 };
241        let bz_bz = b.z * b.z;
242        let _4bx_bz = 4.0 * b.x * b.z;
243
244        let m_bx = m * b.x;
245        let mz_bz = m.z * b.z;
246
247        let sum_squares_minus_one = q0q0 + q1q1_plus_q2q2 + q3q3 - 1.0;
248        let xy_common = sum_squares_minus_one + q1q1_plus_q2q2;
249        let yz_common = sum_squares_minus_one + q2q2_plus_q3q3;
250        let wz_common = q1q1_plus_q2q2 * (1.0 + bz_bz) + bx_bx;
251
252        // Gradient decent algorithm corrective step
253        #[allow(clippy::used_underscore_binding)]
254        Quaternion {
255            w: q0 * 2.0 * (wz_common * q2q2_plus_q3q3) - q1 * a_dash.y
256                + q2 * (a_dash.x - m_bx.z)
257                + q3 * (m_bx.y - _4bx_bz * q0q1),
258
259            x: -q0 * a_dash.y + q1 * 2.0 * (xy_common * (1.0 + 2.0 * bz_bz) + mz_bz + bx_bx * q2q2_plus_q3q3 + a.z)
260                - q2 * m_bx.y
261                - q3 * (a_dash.x + m_bx.z + _4bx_bz * (0.5 * sum_squares_minus_one + q1q1)),
262
263            y: q0 * (a_dash.x - m_bx.z) - q1 * m_bx.y
264                + q2 * 2.0 * (xy_common * (1.0 + bz_bz) + mz_bz + m_bx.x + bx_bx * yz_common + a.z)
265                - q3 * (a_dash.y + _4bx_bz * q1q2),
266
267            z: q0 * m_bx.y - q1 * (a_dash.x + m_bx.z + _4bx_bz * (0.5 * sum_squares_minus_one + q3q3)) - q2 * a_dash.y
268                + q3 * 2.0 * (wz_common + m_bx.x * yz_common),
269        }
270        .normalized()
271    }
272}
273
274impl SensorFusionMath for f64 {
275    #[inline(always)]
276    fn estimate_gravity(q: Quaternion<Self>) -> Vector3d<Self> {
277        Vector3d {
278            x: 2.0 * (q.x * q.z - q.w * q.y),
279            y: 2.0 * (q.y * q.z + q.w * q.x),
280            z: q.w * q.w - q.x * q.x - q.y * q.y + q.z * q.z,
281        }
282    }
283    #[inline(always)]
284    fn derivative(q: Quaternion<Self>, gyro_rps: Vector3d<Self>) -> Quaternion<Self> {
285        Quaternion {
286            w: (-q.x * gyro_rps.x - q.y * gyro_rps.y - q.z * gyro_rps.z) * 0.5,
287            x: (q.w * gyro_rps.x + q.y * gyro_rps.z - q.z * gyro_rps.y) * 0.5,
288            y: (q.w * gyro_rps.y - q.x * gyro_rps.z + q.z * gyro_rps.x) * 0.5,
289            z: (q.w * gyro_rps.z + q.x * gyro_rps.y - q.y * gyro_rps.x) * 0.5,
290        }
291    }
292    #[inline(always)]
293    fn madgwick_step_acc(
294        q: Quaternion<Self>,
295        acc: Vector3d<Self>,
296        max_acc_magnitude_squared: Self,
297    ) -> Quaternion<Self> {
298        use num_traits::Zero;
299        use vqm::SqrtMethods;
300
301        let acc_magnitude_squared = acc.norm_squared();
302        // Acceleration is an unreliable indicator of orientation when in high-g maneuvers,
303        // so exclude it from the calculation in these cases
304        let mut a = acc;
305        if acc_magnitude_squared > max_acc_magnitude_squared || acc_magnitude_squared == 0.0 {
306            a = Vector3d::zero();
307        } else {
308            a *= acc_magnitude_squared.sqrt_reciprocal();
309        }
310        let wz_common = 2.0 * (q.x * q.x + q.y * q.y);
311        let xy_common = 2.0 * (q.w * q.w + q.z * q.z - 1.0 + 2.0 * wz_common + a.z);
312        Quaternion {
313            w: q.w * wz_common + q.y * a.x - q.x * a.y,
314            x: q.x * xy_common - q.z * a.x - q.w * a.y,
315            y: q.y * xy_common + q.w * a.x - q.z * a.y,
316            z: q.z * wz_common - q.x * a.x - q.y * a.y,
317        }
318    }
319
320    #[inline(always)]
321    fn madgwick_step_acc_mag(
322        q: Quaternion<Self>,
323        _acc: Vector3d<Self>,
324        _mag: Vector3d<Self>,
325        _max_acc_magnitude_squared: Self,
326    ) -> Quaternion<Self> {
327        q
328    }
329}