sable_core/math/
quat.rs

1//! Quaternion type for rotations.
2
3use super::{EPSILON, Float, Mat4, Vec3, approx_eq, lerp};
4
5/// A quaternion for representing rotations.
6///
7/// Stored as (x, y, z, w) where w is the scalar component.
8#[derive(Debug, Clone, Copy, PartialEq)]
9#[repr(C)]
10pub struct Quat {
11    /// X component (vector part).
12    pub x: Float,
13    /// Y component (vector part).
14    pub y: Float,
15    /// Z component (vector part).
16    pub z: Float,
17    /// W component (scalar part).
18    pub w: Float,
19}
20
21impl Default for Quat {
22    fn default() -> Self {
23        Self::IDENTITY
24    }
25}
26
27impl Quat {
28    /// Identity quaternion (no rotation).
29    pub const IDENTITY: Self = Self {
30        x: 0.0,
31        y: 0.0,
32        z: 0.0,
33        w: 1.0,
34    };
35
36    /// Creates a quaternion from components.
37    #[inline]
38    #[must_use]
39    pub const fn new(x: Float, y: Float, z: Float, w: Float) -> Self {
40        Self { x, y, z, w }
41    }
42
43    /// Creates a quaternion from an axis and angle.
44    #[inline]
45    #[must_use]
46    pub fn from_axis_angle(axis: Vec3, angle: Float) -> Self {
47        let axis = axis.normalize();
48        let half_angle = angle / 2.0;
49        let (s, c) = half_angle.sin_cos();
50        Self {
51            x: axis.x * s,
52            y: axis.y * s,
53            z: axis.z * s,
54            w: c,
55        }
56    }
57
58    /// Creates a quaternion from Euler angles (XYZ order).
59    #[inline]
60    #[must_use]
61    pub fn from_euler_xyz(x: Float, y: Float, z: Float) -> Self {
62        let (sx, cx) = (x / 2.0).sin_cos();
63        let (sy, cy) = (y / 2.0).sin_cos();
64        let (sz, cz) = (z / 2.0).sin_cos();
65
66        Self {
67            x: sx * cy * cz + cx * sy * sz,
68            y: cx * sy * cz - sx * cy * sz,
69            z: cx * cy * sz + sx * sy * cz,
70            w: cx * cy * cz - sx * sy * sz,
71        }
72    }
73
74    /// Creates a quaternion from a rotation matrix.
75    #[must_use]
76    pub fn from_mat4(m: &Mat4) -> Self {
77        let trace = m.cols[0].x + m.cols[1].y + m.cols[2].z;
78
79        if trace > 0.0 {
80            let s = (trace + 1.0).sqrt() * 2.0;
81            Self {
82                w: 0.25 * s,
83                x: (m.cols[1].z - m.cols[2].y) / s,
84                y: (m.cols[2].x - m.cols[0].z) / s,
85                z: (m.cols[0].y - m.cols[1].x) / s,
86            }
87        } else if m.cols[0].x > m.cols[1].y && m.cols[0].x > m.cols[2].z {
88            let s = (1.0 + m.cols[0].x - m.cols[1].y - m.cols[2].z).sqrt() * 2.0;
89            Self {
90                w: (m.cols[1].z - m.cols[2].y) / s,
91                x: 0.25 * s,
92                y: (m.cols[1].x + m.cols[0].y) / s,
93                z: (m.cols[2].x + m.cols[0].z) / s,
94            }
95        } else if m.cols[1].y > m.cols[2].z {
96            let s = (1.0 + m.cols[1].y - m.cols[0].x - m.cols[2].z).sqrt() * 2.0;
97            Self {
98                w: (m.cols[2].x - m.cols[0].z) / s,
99                x: (m.cols[1].x + m.cols[0].y) / s,
100                y: 0.25 * s,
101                z: (m.cols[2].y + m.cols[1].z) / s,
102            }
103        } else {
104            let s = (1.0 + m.cols[2].z - m.cols[0].x - m.cols[1].y).sqrt() * 2.0;
105            Self {
106                w: (m.cols[0].y - m.cols[1].x) / s,
107                x: (m.cols[2].x + m.cols[0].z) / s,
108                y: (m.cols[2].y + m.cols[1].z) / s,
109                z: 0.25 * s,
110            }
111        }
112    }
113
114    /// Converts to a rotation matrix.
115    #[must_use]
116    pub fn to_mat4(self) -> Mat4 {
117        let x2 = self.x + self.x;
118        let y2 = self.y + self.y;
119        let z2 = self.z + self.z;
120
121        let xx = self.x * x2;
122        let xy = self.x * y2;
123        let xz = self.x * z2;
124        let yy = self.y * y2;
125        let yz = self.y * z2;
126        let zz = self.z * z2;
127        let wx = self.w * x2;
128        let wy = self.w * y2;
129        let wz = self.w * z2;
130
131        Mat4::from_cols(
132            super::Vec4::new(1.0 - (yy + zz), xy + wz, xz - wy, 0.0),
133            super::Vec4::new(xy - wz, 1.0 - (xx + zz), yz + wx, 0.0),
134            super::Vec4::new(xz + wy, yz - wx, 1.0 - (xx + yy), 0.0),
135            super::Vec4::W,
136        )
137    }
138
139    /// Returns the conjugate (inverse for unit quaternions).
140    #[inline]
141    #[must_use]
142    pub fn conjugate(self) -> Self {
143        Self {
144            x: -self.x,
145            y: -self.y,
146            z: -self.z,
147            w: self.w,
148        }
149    }
150
151    /// Computes the dot product.
152    #[inline]
153    #[must_use]
154    pub fn dot(self, other: Self) -> Float {
155        self.x * other.x + self.y * other.y + self.z * other.z + self.w * other.w
156    }
157
158    /// Computes the squared length.
159    #[inline]
160    #[must_use]
161    pub fn length_squared(self) -> Float {
162        self.dot(self)
163    }
164
165    /// Computes the length.
166    #[inline]
167    #[must_use]
168    pub fn length(self) -> Float {
169        self.length_squared().sqrt()
170    }
171
172    /// Returns a normalized quaternion.
173    #[inline]
174    #[must_use]
175    pub fn normalize(self) -> Self {
176        let len = self.length();
177        if len > EPSILON {
178            Self {
179                x: self.x / len,
180                y: self.y / len,
181                z: self.z / len,
182                w: self.w / len,
183            }
184        } else {
185            Self::IDENTITY
186        }
187    }
188
189    /// Returns the inverse.
190    #[inline]
191    #[must_use]
192    pub fn inverse(self) -> Self {
193        let len_sq = self.length_squared();
194        if len_sq > EPSILON {
195            let inv_len_sq = 1.0 / len_sq;
196            Self {
197                x: -self.x * inv_len_sq,
198                y: -self.y * inv_len_sq,
199                z: -self.z * inv_len_sq,
200                w: self.w * inv_len_sq,
201            }
202        } else {
203            Self::IDENTITY
204        }
205    }
206
207    /// Rotates a vector by this quaternion.
208    #[inline]
209    #[must_use]
210    pub fn rotate(self, v: Vec3) -> Vec3 {
211        let qv = Vec3::new(self.x, self.y, self.z);
212        let uv = qv.cross(v);
213        let uuv = qv.cross(uv);
214        v + (uv * self.w + uuv) * 2.0
215    }
216
217    /// Linear interpolation between quaternions.
218    ///
219    /// Note: For most cases, prefer `slerp` for smooth interpolation.
220    #[inline]
221    #[must_use]
222    pub fn lerp(self, other: Self, t: Float) -> Self {
223        Self {
224            x: lerp(self.x, other.x, t),
225            y: lerp(self.y, other.y, t),
226            z: lerp(self.z, other.z, t),
227            w: lerp(self.w, other.w, t),
228        }
229        .normalize()
230    }
231
232    /// Normalized linear interpolation (faster than slerp, nearly as good).
233    #[inline]
234    #[must_use]
235    pub fn nlerp(self, mut other: Self, t: Float) -> Self {
236        // Ensure we take the short path
237        if self.dot(other) < 0.0 {
238            other = Self {
239                x: -other.x,
240                y: -other.y,
241                z: -other.z,
242                w: -other.w,
243            };
244        }
245        self.lerp(other, t)
246    }
247
248    /// Spherical linear interpolation between quaternions.
249    #[must_use]
250    pub fn slerp(self, mut other: Self, t: Float) -> Self {
251        let mut dot = self.dot(other);
252
253        // Ensure we take the short path
254        if dot < 0.0 {
255            other = Self {
256                x: -other.x,
257                y: -other.y,
258                z: -other.z,
259                w: -other.w,
260            };
261            dot = -dot;
262        }
263
264        // If quaternions are very close, use lerp to avoid division by zero
265        if dot > 0.9995 {
266            return self.lerp(other, t);
267        }
268
269        let theta_0 = dot.acos();
270        let theta = theta_0 * t;
271        let sin_theta = theta.sin();
272        let sin_theta_0 = theta_0.sin();
273
274        let s0 = theta.cos() - dot * sin_theta / sin_theta_0;
275        let s1 = sin_theta / sin_theta_0;
276
277        Self {
278            x: self.x * s0 + other.x * s1,
279            y: self.y * s0 + other.y * s1,
280            z: self.z * s0 + other.z * s1,
281            w: self.w * s0 + other.w * s1,
282        }
283    }
284
285    /// Converts to Euler angles (XYZ order) in radians.
286    #[must_use]
287    pub fn to_euler_xyz(self) -> (Float, Float, Float) {
288        let sinr_cosp = 2.0 * (self.w * self.x + self.y * self.z);
289        let cosr_cosp = 1.0 - 2.0 * (self.x * self.x + self.y * self.y);
290        let x = sinr_cosp.atan2(cosr_cosp);
291
292        let sinp = 2.0 * (self.w * self.y - self.z * self.x);
293        let y = if sinp.abs() >= 1.0 {
294            (super::PI / 2.0).copysign(sinp)
295        } else {
296            sinp.asin()
297        };
298
299        let siny_cosp = 2.0 * (self.w * self.z + self.x * self.y);
300        let cosy_cosp = 1.0 - 2.0 * (self.y * self.y + self.z * self.z);
301        let z = siny_cosp.atan2(cosy_cosp);
302
303        (x, y, z)
304    }
305
306    /// Returns the axis and angle of this rotation.
307    #[must_use]
308    pub fn to_axis_angle(self) -> (Vec3, Float) {
309        let angle = 2.0 * self.w.acos();
310        let s = (1.0 - self.w * self.w).sqrt();
311
312        if s < EPSILON {
313            (Vec3::X, angle)
314        } else {
315            (Vec3::new(self.x / s, self.y / s, self.z / s), angle)
316        }
317    }
318
319    /// Checks if this quaternion is approximately equal to another.
320    #[inline]
321    #[must_use]
322    pub fn approx_eq(self, other: Self) -> bool {
323        // Quaternions q and -q represent the same rotation
324        let dot = self.dot(other).abs();
325        approx_eq(dot, 1.0)
326    }
327}
328
329impl std::ops::Mul for Quat {
330    type Output = Self;
331
332    #[inline]
333    fn mul(self, other: Self) -> Self {
334        Self {
335            x: self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y,
336            y: self.w * other.y - self.x * other.z + self.y * other.w + self.z * other.x,
337            z: self.w * other.z + self.x * other.y - self.y * other.x + self.z * other.w,
338            w: self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z,
339        }
340    }
341}
342
343impl std::ops::Mul<Vec3> for Quat {
344    type Output = Vec3;
345
346    #[inline]
347    fn mul(self, v: Vec3) -> Vec3 {
348        self.rotate(v)
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::super::{Float, PI};
355    use super::*;
356
357    #[test]
358    fn test_identity() {
359        let q = Quat::IDENTITY;
360        assert!(q.approx_eq(Quat::default()));
361        assert!(approx_eq(q.length(), 1.0));
362    }
363
364    #[test]
365    fn test_from_axis_angle() {
366        let q = Quat::from_axis_angle(Vec3::Z, PI / 2.0);
367        let v = q.rotate(Vec3::X);
368        assert!(v.approx_eq(Vec3::Y));
369    }
370
371    #[test]
372    fn test_from_euler_xyz() {
373        let q = Quat::from_euler_xyz(0.0, 0.0, PI / 2.0);
374        let v = q.rotate(Vec3::X);
375        assert!(v.approx_eq(Vec3::Y));
376    }
377
378    #[test]
379    fn test_to_euler_xyz_basic() {
380        // Test that euler angles are extracted correctly for simple cases
381        let q = Quat::from_axis_angle(Vec3::Z, PI / 4.0);
382        let (x, y, z) = q.to_euler_xyz();
383        // Z rotation should be PI/4, others should be ~0
384        assert!(approx_eq(x, 0.0));
385        assert!(approx_eq(y, 0.0));
386        assert!(approx_eq(z, PI / 4.0));
387    }
388
389    #[test]
390    fn test_to_mat4_roundtrip() {
391        let q = Quat::from_axis_angle(Vec3::new(1.0, 1.0, 1.0).normalize(), 0.5);
392        let m = q.to_mat4();
393        let q2 = Quat::from_mat4(&m);
394        assert!(q.approx_eq(q2));
395    }
396
397    #[test]
398    fn test_conjugate() {
399        let q = Quat::new(1.0, 2.0, 3.0, 4.0);
400        let c = q.conjugate();
401        assert!(approx_eq(c.x, -1.0));
402        assert!(approx_eq(c.y, -2.0));
403        assert!(approx_eq(c.z, -3.0));
404        assert!(approx_eq(c.w, 4.0));
405    }
406
407    #[test]
408    fn test_normalize() {
409        let q = Quat::new(1.0, 2.0, 3.0, 4.0);
410        let n = q.normalize();
411        assert!(approx_eq(n.length(), 1.0));
412    }
413
414    #[test]
415    fn test_inverse() {
416        let q = Quat::from_axis_angle(Vec3::Y, 0.5);
417        let inv = q.inverse();
418        let result = q * inv;
419        assert!(result.approx_eq(Quat::IDENTITY));
420    }
421
422    #[test]
423    fn test_mul() {
424        let q1 = Quat::from_axis_angle(Vec3::Z, PI / 2.0);
425        let q2 = Quat::from_axis_angle(Vec3::Z, PI / 2.0);
426        let result = q1 * q2;
427        // Two 90-degree rotations = 180-degree rotation
428        let v = result.rotate(Vec3::X);
429        assert!(v.approx_eq(Vec3::NEG_X));
430    }
431
432    #[test]
433    fn test_lerp() {
434        let q1 = Quat::IDENTITY;
435        let q2 = Quat::from_axis_angle(Vec3::Z, PI / 2.0);
436        let mid = q1.lerp(q2, 0.5);
437        // At 50% between 0 and 90 degrees, we should be somewhere in between
438        let v = mid.rotate(Vec3::X);
439        // Just verify it's between X and Y
440        assert!(v.x > 0.0 && v.x < 1.0);
441        assert!(v.y > 0.0 && v.y < 1.0);
442        assert!(approx_eq(v.length(), 1.0));
443    }
444
445    #[test]
446    fn test_nlerp() {
447        let q1 = Quat::IDENTITY;
448        let q2 = Quat::from_axis_angle(Vec3::Z, PI / 2.0);
449        let mid = q1.nlerp(q2, 0.5);
450        let v = mid.rotate(Vec3::X);
451        // Should be roughly halfway between X and Y
452        assert!(v.x > 0.0 && v.x < 1.0);
453        assert!(v.y > 0.0 && v.y < 1.0);
454        assert!(approx_eq(v.length(), 1.0));
455    }
456
457    #[test]
458    fn test_slerp() {
459        let q1 = Quat::IDENTITY;
460        let q2 = Quat::from_axis_angle(Vec3::Z, PI / 2.0);
461        let mid = q1.slerp(q2, 0.5);
462        let v = mid.rotate(Vec3::X);
463        // At exactly 50% of a 90-degree rotation, we should be at 45 degrees
464        let sqrt_half = (0.5 as Float).sqrt();
465        let expected = Vec3::new(sqrt_half, sqrt_half, 0.0);
466        assert!(v.approx_eq(expected));
467    }
468
469    #[test]
470    fn test_slerp_endpoints() {
471        let q1 = Quat::from_axis_angle(Vec3::X, 0.5);
472        let q2 = Quat::from_axis_angle(Vec3::Y, 1.0);
473
474        let start = q1.slerp(q2, 0.0);
475        let end = q1.slerp(q2, 1.0);
476
477        assert!(start.approx_eq(q1));
478        assert!(end.approx_eq(q2));
479    }
480
481    #[test]
482    fn test_to_axis_angle() {
483        let original_axis = Vec3::Y;
484        let original_angle = 0.5;
485        let q = Quat::from_axis_angle(original_axis, original_angle);
486        let (axis, angle) = q.to_axis_angle();
487        assert!(axis.approx_eq(original_axis));
488        assert!(approx_eq(angle, original_angle));
489    }
490
491    #[test]
492    fn test_rotate() {
493        let q = Quat::from_axis_angle(Vec3::Y, PI / 2.0);
494        let v = q * Vec3::X;
495        assert!(v.approx_eq(Vec3::NEG_Z));
496    }
497}