Skip to main content

sensor_fusion/
sensor_fusion_math.rs

1#![allow(clippy::inline_always)]
2#![allow(clippy::many_single_char_names)]
3
4#[cfg(feature = "simd")]
5use core::simd::{f32x4, simd_swizzle};
6
7use vqm::{Quaternion, Vector3d};
8
9/// Math functions for `SensorFusion`, using **SIMD** accelerations for `f32`.
10pub trait SensorFusionMath: Sized {
11    fn estimate_gravity(q: Quaternion<Self>) -> Vector3d<Self>;
12    fn derivative(q: Quaternion<Self>, gyro: Vector3d<Self>) -> Quaternion<Self>;
13    fn madgwick_step_acc(q: Quaternion<Self>, acc: Vector3d<Self>, max_acc_magnitude_squared: Self)
14    -> Quaternion<Self>;
15    fn madgwick_step_acc_mag(
16        q: Quaternion<Self>,
17        acc: Vector3d<Self>,
18        mag: Vector3d<Self>,
19        max_acc_magnitude_squared: Self,
20    ) -> Quaternion<Self>;
21}
22
23impl SensorFusionMath for f32 {
24    #[inline(always)]
25    fn estimate_gravity(q: Quaternion<Self>) -> Vector3d<Self> {
26        Vector3d {
27            x: 2.0 * (q.x * q.z - q.w * q.y),
28            y: 2.0 * (q.y * q.z + q.w * q.x),
29            z: q.w * q.w - q.x * q.x - q.y * q.y + q.z * q.z,
30        }
31    }
32    #[inline(always)]
33    fn derivative(q: Quaternion<Self>, gyro: Vector3d<Self>) -> Quaternion<Self> {
34        #[cfg(feature = "simd")]
35        {
36            /*// Load q: [w, x, y, z]
37            let q_v: f32x4 = unsafe { core::mem::transmute(q) };
38
39            // Load gyro: [x, y, z, padding] -> Swizzle to [0, x, y, z]
40            let g_raw: f32x4 = unsafe { core::mem::transmute(gyro) };
41            //let g_v = simd_swizzle!(g_raw, [f32x4::splat(0.0)], [4, 0, 1, 2]);
42
43            // Efficiently shift: [x, y, z, 0] -> [0, x, y, z]
44            // We rotate right and then zero out the 'w' (index 0)
45            let g_v = simd_swizzle!(g_raw, [3, 0, 1, 2]) * f32x4::from_array([0.0, 1.0, 1.0, 1.0]);
46
47            // Hamilton Product (q * g) for [w, x, y, z] layout:
48            // w_out = -x1*gx - y1*gy - z1*gz
49            // x_out =  w1*gx + y1*gz - z1*gy
50            // y_out =  w1*gy - x1*gz + z1*gx
51            // z_out =  w1*gz + x1*gy - y1*gx
52
53            // Row A: [w, w, w, w] * [0, gx, gy, gz]
54            let res = simd_swizzle!(q_v, [0, 0, 0, 0]) * g_v;
55
56            // Row B: [x, -x, -x, -x] * [gx, 0, gz, gy]
57            // (Using signs and swizzles to match the Hamilton rows)
58            let x_part = simd_swizzle!(q_v, [1, 1, 1, 1])
59                * simd_swizzle!(g_v, [1, 0, 3, 2])
60                * f32x4::from_array([-1.0, 1.0, -1.0, 1.0]);
61
62            // Row C: [y, y, -y, y] * [gy, gz, 0, gx]
63            let y_part = simd_swizzle!(q_v, [2, 2, 2, 2])
64                * simd_swizzle!(g_v, [2, 3, 0, 1])
65                * f32x4::from_array([-1.0, 1.0, 1.0, -1.0]);
66
67            // Row D: [z, z, z, -z] * [gz, gy, gx, 0]
68            let z_part = simd_swizzle!(q_v, [3, 3, 3, 3])
69                * simd_swizzle!(g_v, [3, 2, 1, 0])
70                * f32x4::from_array([-1.0, -1.0, 1.0, 1.0]);
71
72            let q_dot = (res + x_part + y_part + z_part) * f32x4::splat(0.5);
73
74            unsafe { core::mem::transmute(q_dot) }*/
75            let q_v = f32x4::from(q);
76            let g_raw = f32x4::from(gyro);
77
78            // Shift [x, y, z, pad] to [0, x, y, z] and zero the w lane
79            let g_v = simd_swizzle!(g_raw, [3, 0, 1, 2]) * f32x4::from_array([0.0, 1.0, 1.0, 1.0]);
80
81            // Parallel Hamilton Calculation
82            let w1 = simd_swizzle!(q_v, [0, 0, 0, 0]);
83            let x1 = simd_swizzle!(q_v, [1, 1, 1, 1]);
84            let y1 = simd_swizzle!(q_v, [2, 2, 2, 2]);
85            let z1 = simd_swizzle!(q_v, [3, 3, 3, 3]);
86
87            let g_w = g_v; // [0, gx, gy, gz]
88            let g_x = simd_swizzle!(g_v, [1, 0, 3, 2]); // [gx, 0, gz, gy]
89            let g_y = simd_swizzle!(g_v, [2, 3, 0, 1]); // [gy, gz, 0, gx]
90            let g_z = simd_swizzle!(g_v, [3, 2, 1, 0]); // [gz, gy, gx, 0]
91
92            let res = (w1 * g_w)
93                + (x1 * g_x * f32x4::from_array([-1.0, 1.0, -1.0, 1.0]))
94                + (y1 * g_y * f32x4::from_array([-1.0, 1.0, 1.0, -1.0]))
95                + (z1 * g_z * f32x4::from_array([-1.0, -1.0, 1.0, 1.0]));
96
97            let q_dot = res * f32x4::splat(0.5);
98            q_dot.into()
99        }
100        #[cfg(not(feature = "simd"))]
101        {
102            Quaternion {
103                w: (-q.x * gyro.x - q.y * gyro.y - q.z * gyro.z) * 0.5,
104                x: (q.w * gyro.x - q.z * gyro.y + q.y * gyro.z) * 0.5,
105                y: (q.z * gyro.x + q.w * gyro.y - q.x * gyro.z) * 0.5,
106                z: (-q.y * gyro.x + q.x * gyro.y + q.w * gyro.z) * 0.5,
107            }
108        }
109    }
110
111    /// Features of this implementation:
112    ///
113    /// 1. Parallel Throughput: Instead of 12 separate floating-point multiplications and 8 additions,
114    ///    the SIMD unit performs 3 vector multiplications and 2 vector additions.
115    ///
116    /// 2. Instruction Density: The `simd_swizzle!` maps directly to the VREV or VMOV instructions on the M33.
117    ///
118    /// 3. Register Reuse: `q_v` stays in its SIMD register the entire time.
119    ///    The compiler will likely use VFMA (Vector Fused Multiply-Add) to combine the terms,
120    ///    meaning this whole function could resolve in under 15 clock cycles.
121    ///
122    #[inline(always)]
123    fn madgwick_step_acc(
124        q: Quaternion<Self>,
125        acc: Vector3d<Self>,
126        max_acc_magnitude_squared: Self,
127    ) -> Quaternion<Self> {
128        use num_traits::Zero;
129        use vqm::SqrtMethods;
130
131        let acc_magnitude_squared = acc.norm_squared();
132        // Acceleration is an unreliable indicator of orientation when in high-g maneuvers,
133        // so exclude it from the calculation in these cases
134        let mut a = acc;
135        if acc_magnitude_squared > max_acc_magnitude_squared || acc_magnitude_squared == 0.0 {
136            a = Vector3d::zero();
137        } else {
138            a *= acc_magnitude_squared.sqrt_reciprocal();
139        }
140        #[cfg(feature = "simd")]
141        {
142            let q_v = f32x4::from(q);
143
144            // 1. Calculate wz_common
145            let wz_common = 2.0 * (q.x * q.x + q.y * q.y);
146
147            // 2. Calculate xy_common (w*w + z*z - 1.0 + 2.0*wz_common + a.z)
148            let xy_common = 2.0 * (q.w * q.w + q.z * q.z - 1.0 + 2.0 * wz_common + a.z);
149
150            // 3. Calculate common term: [w, x, y, z] * [qq, xy_common, xy_common, qq]
151            let common_scalars = f32x4::from_array([wz_common, xy_common, xy_common, wz_common]);
152            let common = q_v * common_scalars;
153
154            // 4. Calculate ax term: [y, -z, w, -x] * ax
155            // Term 2: [y, -z, w, -x] * ax
156            // Indices needed: y=2, z=3, w=0, x=1
157            let ax_q_swiz = simd_swizzle!(q_v, [2, 3, 0, 1]);
158            let ax_signs = f32x4::from_array([1.0, -1.0, 1.0, -1.0]);
159            let ax = (ax_q_swiz * ax_signs) * f32x4::splat(a.x);
160
161            // 5. Calculate ay term: [-x, -w, -z, -y] * ay
162            // Term 3: [-x, -w, -z, -y] * ay
163            // Indices needed: x=1, w=0, z=3, y=2
164            let ay_q_swiz = simd_swizzle!(q_v, [1, 0, 3, 2]);
165            let ay_signs = f32x4::splat(-1.0);
166            let ay = (ay_q_swiz * ay_signs) * f32x4::splat(a.y);
167
168            // 6. Combine: step = common + ax + ay
169            let ret_v = common + ax + ay;
170
171            ret_v.into()
172        }
173        #[cfg(not(feature = "simd"))]
174        {
175            let wz_common = 2.0 * (q.x * q.x + q.y * q.y);
176            let xy_common = 2.0 * (q.w * q.w + q.z * q.z - 1.0 + 2.0 * wz_common + a.z);
177            Quaternion {
178                w: q.w * wz_common + q.y * a.x - q.x * a.y,
179                x: q.x * xy_common - q.z * a.x - q.w * a.y,
180                y: q.y * xy_common + q.w * a.x - q.z * a.y,
181                z: q.z * wz_common - q.x * a.x - q.y * a.y,
182            }
183        }
184    }
185
186    #[inline(always)]
187    fn madgwick_step_acc_mag(
188        q: Quaternion<Self>,
189        acc: Vector3d<Self>,
190        mag: Vector3d<Self>,
191        max_acc_magnitude_squared: Self,
192    ) -> Quaternion<Self> {
193        use num_traits::Zero;
194        use vqm::SqrtMethods;
195
196        let acc_magnitude_squared = acc.norm_squared();
197        // Acceleration is an unreliable indicator of orientation when in high-g maneuvers,
198        // so exclude it from the calculation in these cases
199        let mut a = acc;
200        if acc_magnitude_squared > max_acc_magnitude_squared || acc_magnitude_squared == 0.0 {
201            a = Vector3d::zero();
202        } else {
203            a *= acc_magnitude_squared.sqrt_reciprocal();
204        }
205
206        let m = mag.normalize();
207
208        // make copies of the components of q to simplify the algebraic expressions
209        let q0 = q.w;
210        let q1 = q.x;
211        let q2 = q.y;
212        let q3 = q.z;
213        // Auxiliary variables to avoid repeated arithmetic
214        let q0q0 = q0 * q0;
215        let q0q1 = q0 * q1;
216        let q0q2 = q0 * q2;
217        let q0q3 = q0 * q3;
218        let q1q1 = q1 * q1;
219        let q1q2 = q1 * q2;
220        let q1q3 = q1 * q3;
221        let q2q2 = q2 * q2;
222        let q2q3 = q2 * q3;
223        let q3q3 = q3 * q3;
224
225        let q1q1_plus_q2q2 = q1q1 + q2q2;
226        let q2q2_plus_q3q3 = q2q2 + q3q3;
227
228        // Reference direction of Earth's magnetic field
229        let h = Vector3d {
230            x: m.x * (q0q0 + q1q1 - q2q2_plus_q3q3) + 2.0 * (m.y * (q1q2 - q0q3) + m.z * (q0q2 + q1q3)),
231            y: (m.x * (q0q3 + q1q2) + m.y * (q0q0 - q1q1 + q2q2 - q3q3) + m.z * (q2q3 - q0q1)) * 2.0,
232            z: 0.0,
233        };
234
235        let bx_bx = h.x * h.x + h.y * h.y;
236        let b = Vector3d {
237            x: bx_bx.sqrt(),
238            y: 0.0,
239            z: 2.0 * (m.x * (q1q3 - q0q2) + m.y * (q0q1 + q2q3)) + m.z * (q0q0 - q1q1_plus_q2q2 + q3q3),
240        };
241
242        let a_dash = Vector3d { x: a.x + m.x * b.z, y: a.y + m.y * b.z, z: 0.0 };
243        let bz_bz = b.z * b.z;
244        let _4bx_bz = 4.0 * b.x * b.z;
245
246        let m_bx = m * b.x;
247        let mz_bz = m.z * b.z;
248
249        let sum_squares_minus_one = q0q0 + q1q1_plus_q2q2 + q3q3 - 1.0;
250        let xy_common = sum_squares_minus_one + q1q1_plus_q2q2;
251        let yz_common = sum_squares_minus_one + q2q2_plus_q3q3;
252        let wz_common = q1q1_plus_q2q2 * (1.0 + bz_bz) + bx_bx;
253
254        // Gradient decent algorithm corrective step
255        #[allow(clippy::used_underscore_binding)]
256        Quaternion {
257            w: q0 * 2.0 * (wz_common * q2q2_plus_q3q3) - q1 * a_dash.y
258                + q2 * (a_dash.x - m_bx.z)
259                + q3 * (m_bx.y - _4bx_bz * q0q1),
260
261            x: -q0 * a_dash.y + q1 * 2.0 * (xy_common * (1.0 + 2.0 * bz_bz) + mz_bz + bx_bx * q2q2_plus_q3q3 + a.z)
262                - q2 * m_bx.y
263                - q3 * (a_dash.x + m_bx.z + _4bx_bz * (0.5 * sum_squares_minus_one + q1q1)),
264
265            y: q0 * (a_dash.x - m_bx.z) - q1 * m_bx.y
266                + q2 * 2.0 * (xy_common * (1.0 + bz_bz) + mz_bz + m_bx.x + bx_bx * yz_common + a.z)
267                - q3 * (a_dash.y + _4bx_bz * q1q2),
268
269            z: q0 * m_bx.y - q1 * (a_dash.x + m_bx.z + _4bx_bz * (0.5 * sum_squares_minus_one + q3q3)) - q2 * a_dash.y
270                + q3 * 2.0 * (wz_common + m_bx.x * yz_common),
271        }
272        .normalize()
273    }
274}
275
276impl SensorFusionMath for f64 {
277    #[inline(always)]
278    fn estimate_gravity(q: Quaternion<Self>) -> Vector3d<Self> {
279        Vector3d {
280            x: 2.0 * (q.x * q.z - q.w * q.y),
281            y: 2.0 * (q.y * q.z + q.w * q.x),
282            z: q.w * q.w - q.x * q.x - q.y * q.y + q.z * q.z,
283        }
284    }
285    #[inline(always)]
286    fn derivative(q: Quaternion<Self>, gyro_rps: Vector3d<Self>) -> Quaternion<Self> {
287        Quaternion {
288            w: (-q.x * gyro_rps.x - q.y * gyro_rps.y - q.z * gyro_rps.z) * 0.5,
289            x: (q.w * gyro_rps.x + q.y * gyro_rps.z - q.z * gyro_rps.y) * 0.5,
290            y: (q.w * gyro_rps.y - q.x * gyro_rps.z + q.z * gyro_rps.x) * 0.5,
291            z: (q.w * gyro_rps.z + q.x * gyro_rps.y - q.y * gyro_rps.x) * 0.5,
292        }
293    }
294    #[inline(always)]
295    fn madgwick_step_acc(
296        q: Quaternion<Self>,
297        acc: Vector3d<Self>,
298        max_acc_magnitude_squared: Self,
299    ) -> Quaternion<Self> {
300        use num_traits::Zero;
301        use vqm::SqrtMethods;
302
303        let acc_magnitude_squared = acc.norm_squared();
304        // Acceleration is an unreliable indicator of orientation when in high-g maneuvers,
305        // so exclude it from the calculation in these cases
306        let mut a = acc;
307        if acc_magnitude_squared > max_acc_magnitude_squared || acc_magnitude_squared == 0.0 {
308            a = Vector3d::zero();
309        } else {
310            a *= acc_magnitude_squared.sqrt_reciprocal();
311        }
312        let wz_common = 2.0 * (q.x * q.x + q.y * q.y);
313        let xy_common = 2.0 * (q.w * q.w + q.z * q.z - 1.0 + 2.0 * wz_common + a.z);
314        Quaternion {
315            w: q.w * wz_common + q.y * a.x - q.x * a.y,
316            x: q.x * xy_common - q.z * a.x - q.w * a.y,
317            y: q.y * xy_common + q.w * a.x - q.z * a.y,
318            z: q.z * wz_common - q.x * a.x - q.y * a.y,
319        }
320    }
321
322    #[inline(always)]
323    fn madgwick_step_acc_mag(
324        q: Quaternion<Self>,
325        _acc: Vector3d<Self>,
326        _mag: Vector3d<Self>,
327        _max_acc_magnitude_squared: Self,
328    ) -> Quaternion<Self> {
329        q
330    }
331}