Skip to main content

wick_quaternion/
funcs.rs

1//! Quaternion functions: normalize, conjugate, slerp, rotate, etc.
2//!
3//! Quaternion uses [x, y, z, w] order (scalar last).
4
5use crate::{FunctionRegistry, QuaternionFn, QuaternionValue, Signature, Type, Value};
6use num_traits::Float;
7
8// ============================================================================
9// Conjugate
10// ============================================================================
11
12/// Quaternion conjugate: conj([x, y, z, w]) = [-x, -y, -z, w]
13pub struct Conj;
14
15impl<T, V> QuaternionFn<T, V> for Conj
16where
17    T: Float,
18    V: QuaternionValue<T>,
19{
20    fn name(&self) -> &str {
21        "conj"
22    }
23
24    fn signatures(&self) -> Vec<Signature> {
25        vec![Signature {
26            args: vec![Type::Quaternion],
27            ret: Type::Quaternion,
28        }]
29    }
30
31    fn call(&self, args: &[V]) -> V {
32        let q = args[0].as_quaternion().unwrap();
33        V::from_quaternion([-q[0], -q[1], -q[2], q[3]])
34    }
35}
36
37// ============================================================================
38// Length
39// ============================================================================
40
41/// Quaternion magnitude: length(q) = sqrt(x² + y² + z² + w²)
42pub struct Length;
43
44impl<T, V> QuaternionFn<T, V> for Length
45where
46    T: Float,
47    V: QuaternionValue<T>,
48{
49    fn name(&self) -> &str {
50        "length"
51    }
52
53    fn signatures(&self) -> Vec<Signature> {
54        vec![
55            Signature {
56                args: vec![Type::Vec3],
57                ret: Type::Scalar,
58            },
59            Signature {
60                args: vec![Type::Quaternion],
61                ret: Type::Scalar,
62            },
63        ]
64    }
65
66    fn call(&self, args: &[V]) -> V {
67        match args[0].typ() {
68            Type::Vec3 => {
69                let v = args[0].as_vec3().unwrap();
70                V::from_scalar((v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt())
71            }
72            Type::Quaternion => {
73                let q = args[0].as_quaternion().unwrap();
74                V::from_scalar((q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt())
75            }
76            _ => unreachable!(),
77        }
78    }
79}
80
81// ============================================================================
82// Normalize
83// ============================================================================
84
85/// Normalize to unit quaternion/vector
86pub struct Normalize;
87
88impl<T, V> QuaternionFn<T, V> for Normalize
89where
90    T: Float,
91    V: QuaternionValue<T>,
92{
93    fn name(&self) -> &str {
94        "normalize"
95    }
96
97    fn signatures(&self) -> Vec<Signature> {
98        vec![
99            Signature {
100                args: vec![Type::Vec3],
101                ret: Type::Vec3,
102            },
103            Signature {
104                args: vec![Type::Quaternion],
105                ret: Type::Quaternion,
106            },
107        ]
108    }
109
110    fn call(&self, args: &[V]) -> V {
111        match args[0].typ() {
112            Type::Vec3 => {
113                let v = args[0].as_vec3().unwrap();
114                let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
115                V::from_vec3([v[0] / len, v[1] / len, v[2] / len])
116            }
117            Type::Quaternion => {
118                let q = args[0].as_quaternion().unwrap();
119                let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
120                V::from_quaternion([q[0] / len, q[1] / len, q[2] / len, q[3] / len])
121            }
122            _ => unreachable!(),
123        }
124    }
125}
126
127// ============================================================================
128// Inverse
129// ============================================================================
130
131/// Quaternion inverse: inverse(q) = conj(q) / |q|²
132/// For unit quaternions, inverse = conjugate
133pub struct Inverse;
134
135impl<T, V> QuaternionFn<T, V> for Inverse
136where
137    T: Float,
138    V: QuaternionValue<T>,
139{
140    fn name(&self) -> &str {
141        "inverse"
142    }
143
144    fn signatures(&self) -> Vec<Signature> {
145        vec![Signature {
146            args: vec![Type::Quaternion],
147            ret: Type::Quaternion,
148        }]
149    }
150
151    fn call(&self, args: &[V]) -> V {
152        let q = args[0].as_quaternion().unwrap();
153        let norm_sq = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3];
154        V::from_quaternion([
155            -q[0] / norm_sq,
156            -q[1] / norm_sq,
157            -q[2] / norm_sq,
158            q[3] / norm_sq,
159        ])
160    }
161}
162
163// ============================================================================
164// Dot product
165// ============================================================================
166
167/// Dot product (4D for quaternions, 3D for vectors)
168pub struct Dot;
169
170impl<T, V> QuaternionFn<T, V> for Dot
171where
172    T: Float,
173    V: QuaternionValue<T>,
174{
175    fn name(&self) -> &str {
176        "dot"
177    }
178
179    fn signatures(&self) -> Vec<Signature> {
180        vec![
181            Signature {
182                args: vec![Type::Vec3, Type::Vec3],
183                ret: Type::Scalar,
184            },
185            Signature {
186                args: vec![Type::Quaternion, Type::Quaternion],
187                ret: Type::Scalar,
188            },
189        ]
190    }
191
192    fn call(&self, args: &[V]) -> V {
193        match (args[0].typ(), args[1].typ()) {
194            (Type::Vec3, Type::Vec3) => {
195                let a = args[0].as_vec3().unwrap();
196                let b = args[1].as_vec3().unwrap();
197                V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2])
198            }
199            (Type::Quaternion, Type::Quaternion) => {
200                let a = args[0].as_quaternion().unwrap();
201                let b = args[1].as_quaternion().unwrap();
202                V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3])
203            }
204            _ => unreachable!(),
205        }
206    }
207}
208
209// ============================================================================
210// Lerp (linear interpolation)
211// ============================================================================
212
213/// Linear interpolation (use slerp for rotations)
214pub struct Lerp;
215
216impl<T, V> QuaternionFn<T, V> for Lerp
217where
218    T: Float,
219    V: QuaternionValue<T>,
220{
221    fn name(&self) -> &str {
222        "lerp"
223    }
224
225    fn signatures(&self) -> Vec<Signature> {
226        vec![
227            Signature {
228                args: vec![Type::Vec3, Type::Vec3, Type::Scalar],
229                ret: Type::Vec3,
230            },
231            Signature {
232                args: vec![Type::Quaternion, Type::Quaternion, Type::Scalar],
233                ret: Type::Quaternion,
234            },
235        ]
236    }
237
238    fn call(&self, args: &[V]) -> V {
239        let t = args[2].as_scalar().unwrap();
240        match (args[0].typ(), args[1].typ()) {
241            (Type::Vec3, Type::Vec3) => {
242                let a = args[0].as_vec3().unwrap();
243                let b = args[1].as_vec3().unwrap();
244                V::from_vec3([
245                    a[0] + (b[0] - a[0]) * t,
246                    a[1] + (b[1] - a[1]) * t,
247                    a[2] + (b[2] - a[2]) * t,
248                ])
249            }
250            (Type::Quaternion, Type::Quaternion) => {
251                let a = args[0].as_quaternion().unwrap();
252                let b = args[1].as_quaternion().unwrap();
253                V::from_quaternion([
254                    a[0] + (b[0] - a[0]) * t,
255                    a[1] + (b[1] - a[1]) * t,
256                    a[2] + (b[2] - a[2]) * t,
257                    a[3] + (b[3] - a[3]) * t,
258                ])
259            }
260            _ => unreachable!(),
261        }
262    }
263}
264
265// ============================================================================
266// Slerp (spherical linear interpolation)
267// ============================================================================
268
269/// Spherical linear interpolation for quaternions
270pub struct Slerp;
271
272impl<T, V> QuaternionFn<T, V> for Slerp
273where
274    T: Float,
275    V: QuaternionValue<T>,
276{
277    fn name(&self) -> &str {
278        "slerp"
279    }
280
281    fn signatures(&self) -> Vec<Signature> {
282        vec![Signature {
283            args: vec![Type::Quaternion, Type::Quaternion, Type::Scalar],
284            ret: Type::Quaternion,
285        }]
286    }
287
288    fn call(&self, args: &[V]) -> V {
289        let a = args[0].as_quaternion().unwrap();
290        let b = args[1].as_quaternion().unwrap();
291        let t = args[2].as_scalar().unwrap();
292        V::from_quaternion(slerp_impl(&a, &b, t))
293    }
294}
295
296fn slerp_impl<T: Float>(a: &[T; 4], b: &[T; 4], t: T) -> [T; 4] {
297    // Compute cosine of angle between quaternions
298    let mut dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3];
299
300    // If dot < 0, negate one quaternion to take shorter path
301    let mut b = *b;
302    if dot < T::zero() {
303        b = [-b[0], -b[1], -b[2], -b[3]];
304        dot = -dot;
305    }
306
307    // Clamp dot to valid range for acos
308    let one = T::one();
309    if dot > one {
310        dot = one;
311    }
312
313    // If quaternions are very close, use linear interpolation
314    let threshold = T::from(0.9995).unwrap();
315    if dot > threshold {
316        // Linear interpolation
317        let result = [
318            a[0] + (b[0] - a[0]) * t,
319            a[1] + (b[1] - a[1]) * t,
320            a[2] + (b[2] - a[2]) * t,
321            a[3] + (b[3] - a[3]) * t,
322        ];
323        // Normalize
324        let len = (result[0] * result[0]
325            + result[1] * result[1]
326            + result[2] * result[2]
327            + result[3] * result[3])
328            .sqrt();
329        return [
330            result[0] / len,
331            result[1] / len,
332            result[2] / len,
333            result[3] / len,
334        ];
335    }
336
337    // Spherical interpolation
338    let theta = dot.acos();
339    let sin_theta = theta.sin();
340    let s0 = ((one - t) * theta).sin() / sin_theta;
341    let s1 = (t * theta).sin() / sin_theta;
342
343    [
344        a[0] * s0 + b[0] * s1,
345        a[1] * s0 + b[1] * s1,
346        a[2] * s0 + b[2] * s1,
347        a[3] * s0 + b[3] * s1,
348    ]
349}
350
351// ============================================================================
352// Axis-Angle construction
353// ============================================================================
354
355/// Create quaternion from axis and angle: axis_angle(axis, angle)
356pub struct AxisAngle;
357
358impl<T, V> QuaternionFn<T, V> for AxisAngle
359where
360    T: Float,
361    V: QuaternionValue<T>,
362{
363    fn name(&self) -> &str {
364        "axis_angle"
365    }
366
367    fn signatures(&self) -> Vec<Signature> {
368        vec![Signature {
369            args: vec![Type::Vec3, Type::Scalar],
370            ret: Type::Quaternion,
371        }]
372    }
373
374    fn call(&self, args: &[V]) -> V {
375        let axis = args[0].as_vec3().unwrap();
376        let angle = args[1].as_scalar().unwrap();
377        let half_angle = angle / T::from(2.0).unwrap();
378        let s = half_angle.sin();
379        let c = half_angle.cos();
380        // Normalize axis
381        let len = (axis[0] * axis[0] + axis[1] * axis[1] + axis[2] * axis[2]).sqrt();
382        V::from_quaternion([axis[0] / len * s, axis[1] / len * s, axis[2] / len * s, c])
383    }
384}
385
386// ============================================================================
387// Rotate vector
388// ============================================================================
389
390/// Rotate a vector by a quaternion: rotate(vec, quat)
391pub struct Rotate;
392
393impl<T, V> QuaternionFn<T, V> for Rotate
394where
395    T: Float,
396    V: QuaternionValue<T>,
397{
398    fn name(&self) -> &str {
399        "rotate"
400    }
401
402    fn signatures(&self) -> Vec<Signature> {
403        vec![Signature {
404            args: vec![Type::Vec3, Type::Quaternion],
405            ret: Type::Vec3,
406        }]
407    }
408
409    fn call(&self, args: &[V]) -> V {
410        let v = args[0].as_vec3().unwrap();
411        let q = args[1].as_quaternion().unwrap();
412        V::from_vec3(rotate_vec3_by_quat(&v, &q))
413    }
414}
415
416/// Rotate a vec3 by a quaternion using the optimized formula.
417fn rotate_vec3_by_quat<T: Float>(v: &[T; 3], q: &[T; 4]) -> [T; 3] {
418    let (qx, qy, qz, qw) = (q[0], q[1], q[2], q[3]);
419    let two = T::from(2.0).unwrap();
420
421    // t = 2 * (q_xyz × v)
422    let tx = two * (qy * v[2] - qz * v[1]);
423    let ty = two * (qz * v[0] - qx * v[2]);
424    let tz = two * (qx * v[1] - qy * v[0]);
425
426    // v' = v + w * t + (q_xyz × t)
427    [
428        v[0] + qw * tx + (qy * tz - qz * ty),
429        v[1] + qw * ty + (qz * tx - qx * tz),
430        v[2] + qw * tz + (qx * ty - qy * tx),
431    ]
432}
433
434// ============================================================================
435// Vec3 construction
436// ============================================================================
437
438/// Construct Vec3 from three scalars: vec3(x, y, z) -> Vec3
439pub struct Vec3Constructor;
440
441impl<T, V> QuaternionFn<T, V> for Vec3Constructor
442where
443    T: Float,
444    V: QuaternionValue<T>,
445{
446    fn name(&self) -> &str {
447        "vec3"
448    }
449
450    fn signatures(&self) -> Vec<Signature> {
451        vec![Signature {
452            args: vec![Type::Scalar, Type::Scalar, Type::Scalar],
453            ret: Type::Vec3,
454        }]
455    }
456
457    fn call(&self, args: &[V]) -> V {
458        let x = args[0].as_scalar().unwrap();
459        let y = args[1].as_scalar().unwrap();
460        let z = args[2].as_scalar().unwrap();
461        V::from_vec3([x, y, z])
462    }
463}
464
465// ============================================================================
466// Quaternion construction
467// ============================================================================
468
469/// Construct Quaternion from four scalars: quat(x, y, z, w) -> Quaternion
470/// Uses [x, y, z, w] order (scalar last, matching GLM/glTF convention).
471pub struct QuatConstructor;
472
473impl<T, V> QuaternionFn<T, V> for QuatConstructor
474where
475    T: Float,
476    V: QuaternionValue<T>,
477{
478    fn name(&self) -> &str {
479        "quat"
480    }
481
482    fn signatures(&self) -> Vec<Signature> {
483        vec![Signature {
484            args: vec![Type::Scalar, Type::Scalar, Type::Scalar, Type::Scalar],
485            ret: Type::Quaternion,
486        }]
487    }
488
489    fn call(&self, args: &[V]) -> V {
490        let x = args[0].as_scalar().unwrap();
491        let y = args[1].as_scalar().unwrap();
492        let z = args[2].as_scalar().unwrap();
493        let w = args[3].as_scalar().unwrap();
494        V::from_quaternion([x, y, z, w])
495    }
496}
497
498// ============================================================================
499// Registry helper
500// ============================================================================
501
502/// Register all standard quaternion functions.
503pub fn register_quaternion<T, V>(registry: &mut FunctionRegistry<T, V>)
504where
505    T: Float + 'static,
506    V: QuaternionValue<T> + 'static,
507{
508    registry.register(Conj);
509    registry.register(Length);
510    registry.register(Normalize);
511    registry.register(Inverse);
512    registry.register(Dot);
513    registry.register(Lerp);
514    registry.register(Slerp);
515    registry.register(AxisAngle);
516    registry.register(Rotate);
517    registry.register(Vec3Constructor);
518    registry.register(QuatConstructor);
519}
520
521/// Create a new registry with all standard quaternion functions.
522pub fn quaternion_registry<T: Float + std::fmt::Debug + 'static>() -> FunctionRegistry<T, Value<T>>
523{
524    let mut registry = FunctionRegistry::new();
525    register_quaternion(&mut registry);
526    registry
527}
528
529// ============================================================================
530// Tests
531// ============================================================================
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use std::collections::HashMap;
537    use wick_core::Expr;
538
539    fn approx_eq(a: f32, b: f32) -> bool {
540        (a - b).abs() < 0.0001
541    }
542
543    fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Value<f32> {
544        let expr = Expr::parse(expr).unwrap();
545        let var_map: HashMap<String, Value<f32>> = vars
546            .iter()
547            .map(|(k, v)| (k.to_string(), v.clone()))
548            .collect();
549        let registry = quaternion_registry();
550        crate::eval(expr.ast(), &var_map, &registry).unwrap()
551    }
552
553    #[test]
554    fn test_conj() {
555        let result = eval_expr("conj(q)", &[("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0]))]);
556        assert_eq!(result, Value::Quaternion([-1.0, -2.0, -3.0, 4.0]));
557    }
558
559    #[test]
560    fn test_normalize() {
561        let result = eval_expr(
562            "normalize(q)",
563            &[("q", Value::Quaternion([0.0, 0.0, 0.0, 2.0]))],
564        );
565        assert_eq!(result, Value::Quaternion([0.0, 0.0, 0.0, 1.0]));
566    }
567
568    #[test]
569    fn test_length() {
570        let result = eval_expr(
571            "length(q)",
572            &[("q", Value::Quaternion([0.0, 0.0, 3.0, 4.0]))],
573        );
574        assert_eq!(result, Value::Scalar(5.0));
575    }
576
577    #[test]
578    fn test_dot() {
579        let result = eval_expr(
580            "dot(a, b)",
581            &[
582                ("a", Value::Quaternion([1.0, 0.0, 0.0, 0.0])),
583                ("b", Value::Quaternion([1.0, 0.0, 0.0, 0.0])),
584            ],
585        );
586        assert_eq!(result, Value::Scalar(1.0));
587    }
588
589    #[test]
590    fn test_axis_angle() {
591        // 90° rotation around Z axis
592        let result = eval_expr(
593            "axis_angle(axis, angle)",
594            &[
595                ("axis", Value::Vec3([0.0, 0.0, 1.0])),
596                ("angle", Value::Scalar(std::f32::consts::FRAC_PI_2)),
597            ],
598        );
599        if let Value::Quaternion(q) = result {
600            // half angle = 45°, sin(45°) ≈ 0.707, cos(45°) ≈ 0.707
601            assert!(approx_eq(q[0], 0.0));
602            assert!(approx_eq(q[1], 0.0));
603            assert!(approx_eq(q[2], std::f32::consts::FRAC_PI_4.sin()));
604            assert!(approx_eq(q[3], std::f32::consts::FRAC_PI_4.cos()));
605        } else {
606            panic!("expected quaternion");
607        }
608    }
609
610    #[test]
611    fn test_rotate() {
612        // Rotate [1, 0, 0] by 90° around Z axis -> [0, 1, 0]
613        let half_angle = std::f32::consts::FRAC_PI_4;
614        let result = eval_expr(
615            "rotate(v, q)",
616            &[
617                ("v", Value::Vec3([1.0, 0.0, 0.0])),
618                (
619                    "q",
620                    Value::Quaternion([0.0, 0.0, half_angle.sin(), half_angle.cos()]),
621                ),
622            ],
623        );
624        if let Value::Vec3(v) = result {
625            assert!(approx_eq(v[0], 0.0));
626            assert!(approx_eq(v[1], 1.0));
627            assert!(approx_eq(v[2], 0.0));
628        } else {
629            panic!("expected vec3");
630        }
631    }
632
633    #[test]
634    fn test_slerp() {
635        // Slerp between identity and 180° rotation should give 90° at t=0.5
636        let identity = Value::Quaternion([0.0, 0.0, 0.0, 1.0]);
637        // 180° around Z = [0, 0, 1, 0]
638        let half_turn = Value::Quaternion([0.0, 0.0, 1.0, 0.0]);
639        let result = eval_expr(
640            "slerp(a, b, t)",
641            &[("a", identity), ("b", half_turn), ("t", Value::Scalar(0.5))],
642        );
643        if let Value::Quaternion(q) = result {
644            // Should be 90° rotation: [0, 0, sin(45°), cos(45°)]
645            assert!(approx_eq(q[0], 0.0));
646            assert!(approx_eq(q[1], 0.0));
647            assert!(approx_eq(q[2], std::f32::consts::FRAC_PI_4.sin()));
648            assert!(approx_eq(q[3], std::f32::consts::FRAC_PI_4.cos()));
649        } else {
650            panic!("expected quaternion");
651        }
652    }
653}