1use crate::{FunctionRegistry, QuaternionFn, QuaternionValue, Signature, Type, Value};
6use num_traits::Float;
7
8pub struct Conj;
14
15impl<T, V> QuaternionFn<T, V> for Conj
16where
17 T: Float,
18 V: QuaternionValue<T>,
19{
20 fn name(&self) -> &str {
21 "conj"
22 }
23
24 fn signatures(&self) -> Vec<Signature> {
25 vec![Signature {
26 args: vec![Type::Quaternion],
27 ret: Type::Quaternion,
28 }]
29 }
30
31 fn call(&self, args: &[V]) -> V {
32 let q = args[0].as_quaternion().unwrap();
33 V::from_quaternion([-q[0], -q[1], -q[2], q[3]])
34 }
35}
36
37pub struct Length;
43
44impl<T, V> QuaternionFn<T, V> for Length
45where
46 T: Float,
47 V: QuaternionValue<T>,
48{
49 fn name(&self) -> &str {
50 "length"
51 }
52
53 fn signatures(&self) -> Vec<Signature> {
54 vec![
55 Signature {
56 args: vec![Type::Vec3],
57 ret: Type::Scalar,
58 },
59 Signature {
60 args: vec![Type::Quaternion],
61 ret: Type::Scalar,
62 },
63 ]
64 }
65
66 fn call(&self, args: &[V]) -> V {
67 match args[0].typ() {
68 Type::Vec3 => {
69 let v = args[0].as_vec3().unwrap();
70 V::from_scalar((v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt())
71 }
72 Type::Quaternion => {
73 let q = args[0].as_quaternion().unwrap();
74 V::from_scalar((q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt())
75 }
76 _ => unreachable!(),
77 }
78 }
79}
80
81pub struct Normalize;
87
88impl<T, V> QuaternionFn<T, V> for Normalize
89where
90 T: Float,
91 V: QuaternionValue<T>,
92{
93 fn name(&self) -> &str {
94 "normalize"
95 }
96
97 fn signatures(&self) -> Vec<Signature> {
98 vec![
99 Signature {
100 args: vec![Type::Vec3],
101 ret: Type::Vec3,
102 },
103 Signature {
104 args: vec![Type::Quaternion],
105 ret: Type::Quaternion,
106 },
107 ]
108 }
109
110 fn call(&self, args: &[V]) -> V {
111 match args[0].typ() {
112 Type::Vec3 => {
113 let v = args[0].as_vec3().unwrap();
114 let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
115 V::from_vec3([v[0] / len, v[1] / len, v[2] / len])
116 }
117 Type::Quaternion => {
118 let q = args[0].as_quaternion().unwrap();
119 let len = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
120 V::from_quaternion([q[0] / len, q[1] / len, q[2] / len, q[3] / len])
121 }
122 _ => unreachable!(),
123 }
124 }
125}
126
127pub struct Inverse;
134
135impl<T, V> QuaternionFn<T, V> for Inverse
136where
137 T: Float,
138 V: QuaternionValue<T>,
139{
140 fn name(&self) -> &str {
141 "inverse"
142 }
143
144 fn signatures(&self) -> Vec<Signature> {
145 vec![Signature {
146 args: vec![Type::Quaternion],
147 ret: Type::Quaternion,
148 }]
149 }
150
151 fn call(&self, args: &[V]) -> V {
152 let q = args[0].as_quaternion().unwrap();
153 let norm_sq = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3];
154 V::from_quaternion([
155 -q[0] / norm_sq,
156 -q[1] / norm_sq,
157 -q[2] / norm_sq,
158 q[3] / norm_sq,
159 ])
160 }
161}
162
163pub struct Dot;
169
170impl<T, V> QuaternionFn<T, V> for Dot
171where
172 T: Float,
173 V: QuaternionValue<T>,
174{
175 fn name(&self) -> &str {
176 "dot"
177 }
178
179 fn signatures(&self) -> Vec<Signature> {
180 vec![
181 Signature {
182 args: vec![Type::Vec3, Type::Vec3],
183 ret: Type::Scalar,
184 },
185 Signature {
186 args: vec![Type::Quaternion, Type::Quaternion],
187 ret: Type::Scalar,
188 },
189 ]
190 }
191
192 fn call(&self, args: &[V]) -> V {
193 match (args[0].typ(), args[1].typ()) {
194 (Type::Vec3, Type::Vec3) => {
195 let a = args[0].as_vec3().unwrap();
196 let b = args[1].as_vec3().unwrap();
197 V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2])
198 }
199 (Type::Quaternion, Type::Quaternion) => {
200 let a = args[0].as_quaternion().unwrap();
201 let b = args[1].as_quaternion().unwrap();
202 V::from_scalar(a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3])
203 }
204 _ => unreachable!(),
205 }
206 }
207}
208
209pub struct Lerp;
215
216impl<T, V> QuaternionFn<T, V> for Lerp
217where
218 T: Float,
219 V: QuaternionValue<T>,
220{
221 fn name(&self) -> &str {
222 "lerp"
223 }
224
225 fn signatures(&self) -> Vec<Signature> {
226 vec![
227 Signature {
228 args: vec![Type::Vec3, Type::Vec3, Type::Scalar],
229 ret: Type::Vec3,
230 },
231 Signature {
232 args: vec![Type::Quaternion, Type::Quaternion, Type::Scalar],
233 ret: Type::Quaternion,
234 },
235 ]
236 }
237
238 fn call(&self, args: &[V]) -> V {
239 let t = args[2].as_scalar().unwrap();
240 match (args[0].typ(), args[1].typ()) {
241 (Type::Vec3, Type::Vec3) => {
242 let a = args[0].as_vec3().unwrap();
243 let b = args[1].as_vec3().unwrap();
244 V::from_vec3([
245 a[0] + (b[0] - a[0]) * t,
246 a[1] + (b[1] - a[1]) * t,
247 a[2] + (b[2] - a[2]) * t,
248 ])
249 }
250 (Type::Quaternion, Type::Quaternion) => {
251 let a = args[0].as_quaternion().unwrap();
252 let b = args[1].as_quaternion().unwrap();
253 V::from_quaternion([
254 a[0] + (b[0] - a[0]) * t,
255 a[1] + (b[1] - a[1]) * t,
256 a[2] + (b[2] - a[2]) * t,
257 a[3] + (b[3] - a[3]) * t,
258 ])
259 }
260 _ => unreachable!(),
261 }
262 }
263}
264
265pub struct Slerp;
271
272impl<T, V> QuaternionFn<T, V> for Slerp
273where
274 T: Float,
275 V: QuaternionValue<T>,
276{
277 fn name(&self) -> &str {
278 "slerp"
279 }
280
281 fn signatures(&self) -> Vec<Signature> {
282 vec![Signature {
283 args: vec![Type::Quaternion, Type::Quaternion, Type::Scalar],
284 ret: Type::Quaternion,
285 }]
286 }
287
288 fn call(&self, args: &[V]) -> V {
289 let a = args[0].as_quaternion().unwrap();
290 let b = args[1].as_quaternion().unwrap();
291 let t = args[2].as_scalar().unwrap();
292 V::from_quaternion(slerp_impl(&a, &b, t))
293 }
294}
295
296fn slerp_impl<T: Float>(a: &[T; 4], b: &[T; 4], t: T) -> [T; 4] {
297 let mut dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3];
299
300 let mut b = *b;
302 if dot < T::zero() {
303 b = [-b[0], -b[1], -b[2], -b[3]];
304 dot = -dot;
305 }
306
307 let one = T::one();
309 if dot > one {
310 dot = one;
311 }
312
313 let threshold = T::from(0.9995).unwrap();
315 if dot > threshold {
316 let result = [
318 a[0] + (b[0] - a[0]) * t,
319 a[1] + (b[1] - a[1]) * t,
320 a[2] + (b[2] - a[2]) * t,
321 a[3] + (b[3] - a[3]) * t,
322 ];
323 let len = (result[0] * result[0]
325 + result[1] * result[1]
326 + result[2] * result[2]
327 + result[3] * result[3])
328 .sqrt();
329 return [
330 result[0] / len,
331 result[1] / len,
332 result[2] / len,
333 result[3] / len,
334 ];
335 }
336
337 let theta = dot.acos();
339 let sin_theta = theta.sin();
340 let s0 = ((one - t) * theta).sin() / sin_theta;
341 let s1 = (t * theta).sin() / sin_theta;
342
343 [
344 a[0] * s0 + b[0] * s1,
345 a[1] * s0 + b[1] * s1,
346 a[2] * s0 + b[2] * s1,
347 a[3] * s0 + b[3] * s1,
348 ]
349}
350
351pub struct AxisAngle;
357
358impl<T, V> QuaternionFn<T, V> for AxisAngle
359where
360 T: Float,
361 V: QuaternionValue<T>,
362{
363 fn name(&self) -> &str {
364 "axis_angle"
365 }
366
367 fn signatures(&self) -> Vec<Signature> {
368 vec![Signature {
369 args: vec![Type::Vec3, Type::Scalar],
370 ret: Type::Quaternion,
371 }]
372 }
373
374 fn call(&self, args: &[V]) -> V {
375 let axis = args[0].as_vec3().unwrap();
376 let angle = args[1].as_scalar().unwrap();
377 let half_angle = angle / T::from(2.0).unwrap();
378 let s = half_angle.sin();
379 let c = half_angle.cos();
380 let len = (axis[0] * axis[0] + axis[1] * axis[1] + axis[2] * axis[2]).sqrt();
382 V::from_quaternion([axis[0] / len * s, axis[1] / len * s, axis[2] / len * s, c])
383 }
384}
385
386pub struct Rotate;
392
393impl<T, V> QuaternionFn<T, V> for Rotate
394where
395 T: Float,
396 V: QuaternionValue<T>,
397{
398 fn name(&self) -> &str {
399 "rotate"
400 }
401
402 fn signatures(&self) -> Vec<Signature> {
403 vec![Signature {
404 args: vec![Type::Vec3, Type::Quaternion],
405 ret: Type::Vec3,
406 }]
407 }
408
409 fn call(&self, args: &[V]) -> V {
410 let v = args[0].as_vec3().unwrap();
411 let q = args[1].as_quaternion().unwrap();
412 V::from_vec3(rotate_vec3_by_quat(&v, &q))
413 }
414}
415
416fn rotate_vec3_by_quat<T: Float>(v: &[T; 3], q: &[T; 4]) -> [T; 3] {
418 let (qx, qy, qz, qw) = (q[0], q[1], q[2], q[3]);
419 let two = T::from(2.0).unwrap();
420
421 let tx = two * (qy * v[2] - qz * v[1]);
423 let ty = two * (qz * v[0] - qx * v[2]);
424 let tz = two * (qx * v[1] - qy * v[0]);
425
426 [
428 v[0] + qw * tx + (qy * tz - qz * ty),
429 v[1] + qw * ty + (qz * tx - qx * tz),
430 v[2] + qw * tz + (qx * ty - qy * tx),
431 ]
432}
433
434pub struct Vec3Constructor;
440
441impl<T, V> QuaternionFn<T, V> for Vec3Constructor
442where
443 T: Float,
444 V: QuaternionValue<T>,
445{
446 fn name(&self) -> &str {
447 "vec3"
448 }
449
450 fn signatures(&self) -> Vec<Signature> {
451 vec![Signature {
452 args: vec![Type::Scalar, Type::Scalar, Type::Scalar],
453 ret: Type::Vec3,
454 }]
455 }
456
457 fn call(&self, args: &[V]) -> V {
458 let x = args[0].as_scalar().unwrap();
459 let y = args[1].as_scalar().unwrap();
460 let z = args[2].as_scalar().unwrap();
461 V::from_vec3([x, y, z])
462 }
463}
464
465pub struct QuatConstructor;
472
473impl<T, V> QuaternionFn<T, V> for QuatConstructor
474where
475 T: Float,
476 V: QuaternionValue<T>,
477{
478 fn name(&self) -> &str {
479 "quat"
480 }
481
482 fn signatures(&self) -> Vec<Signature> {
483 vec![Signature {
484 args: vec![Type::Scalar, Type::Scalar, Type::Scalar, Type::Scalar],
485 ret: Type::Quaternion,
486 }]
487 }
488
489 fn call(&self, args: &[V]) -> V {
490 let x = args[0].as_scalar().unwrap();
491 let y = args[1].as_scalar().unwrap();
492 let z = args[2].as_scalar().unwrap();
493 let w = args[3].as_scalar().unwrap();
494 V::from_quaternion([x, y, z, w])
495 }
496}
497
498pub fn register_quaternion<T, V>(registry: &mut FunctionRegistry<T, V>)
504where
505 T: Float + 'static,
506 V: QuaternionValue<T> + 'static,
507{
508 registry.register(Conj);
509 registry.register(Length);
510 registry.register(Normalize);
511 registry.register(Inverse);
512 registry.register(Dot);
513 registry.register(Lerp);
514 registry.register(Slerp);
515 registry.register(AxisAngle);
516 registry.register(Rotate);
517 registry.register(Vec3Constructor);
518 registry.register(QuatConstructor);
519}
520
521pub fn quaternion_registry<T: Float + std::fmt::Debug + 'static>() -> FunctionRegistry<T, Value<T>>
523{
524 let mut registry = FunctionRegistry::new();
525 register_quaternion(&mut registry);
526 registry
527}
528
529#[cfg(test)]
534mod tests {
535 use super::*;
536 use std::collections::HashMap;
537 use wick_core::Expr;
538
539 fn approx_eq(a: f32, b: f32) -> bool {
540 (a - b).abs() < 0.0001
541 }
542
543 fn eval_expr(expr: &str, vars: &[(&str, Value<f32>)]) -> Value<f32> {
544 let expr = Expr::parse(expr).unwrap();
545 let var_map: HashMap<String, Value<f32>> = vars
546 .iter()
547 .map(|(k, v)| (k.to_string(), v.clone()))
548 .collect();
549 let registry = quaternion_registry();
550 crate::eval(expr.ast(), &var_map, ®istry).unwrap()
551 }
552
553 #[test]
554 fn test_conj() {
555 let result = eval_expr("conj(q)", &[("q", Value::Quaternion([1.0, 2.0, 3.0, 4.0]))]);
556 assert_eq!(result, Value::Quaternion([-1.0, -2.0, -3.0, 4.0]));
557 }
558
559 #[test]
560 fn test_normalize() {
561 let result = eval_expr(
562 "normalize(q)",
563 &[("q", Value::Quaternion([0.0, 0.0, 0.0, 2.0]))],
564 );
565 assert_eq!(result, Value::Quaternion([0.0, 0.0, 0.0, 1.0]));
566 }
567
568 #[test]
569 fn test_length() {
570 let result = eval_expr(
571 "length(q)",
572 &[("q", Value::Quaternion([0.0, 0.0, 3.0, 4.0]))],
573 );
574 assert_eq!(result, Value::Scalar(5.0));
575 }
576
577 #[test]
578 fn test_dot() {
579 let result = eval_expr(
580 "dot(a, b)",
581 &[
582 ("a", Value::Quaternion([1.0, 0.0, 0.0, 0.0])),
583 ("b", Value::Quaternion([1.0, 0.0, 0.0, 0.0])),
584 ],
585 );
586 assert_eq!(result, Value::Scalar(1.0));
587 }
588
589 #[test]
590 fn test_axis_angle() {
591 let result = eval_expr(
593 "axis_angle(axis, angle)",
594 &[
595 ("axis", Value::Vec3([0.0, 0.0, 1.0])),
596 ("angle", Value::Scalar(std::f32::consts::FRAC_PI_2)),
597 ],
598 );
599 if let Value::Quaternion(q) = result {
600 assert!(approx_eq(q[0], 0.0));
602 assert!(approx_eq(q[1], 0.0));
603 assert!(approx_eq(q[2], std::f32::consts::FRAC_PI_4.sin()));
604 assert!(approx_eq(q[3], std::f32::consts::FRAC_PI_4.cos()));
605 } else {
606 panic!("expected quaternion");
607 }
608 }
609
610 #[test]
611 fn test_rotate() {
612 let half_angle = std::f32::consts::FRAC_PI_4;
614 let result = eval_expr(
615 "rotate(v, q)",
616 &[
617 ("v", Value::Vec3([1.0, 0.0, 0.0])),
618 (
619 "q",
620 Value::Quaternion([0.0, 0.0, half_angle.sin(), half_angle.cos()]),
621 ),
622 ],
623 );
624 if let Value::Vec3(v) = result {
625 assert!(approx_eq(v[0], 0.0));
626 assert!(approx_eq(v[1], 1.0));
627 assert!(approx_eq(v[2], 0.0));
628 } else {
629 panic!("expected vec3");
630 }
631 }
632
633 #[test]
634 fn test_slerp() {
635 let identity = Value::Quaternion([0.0, 0.0, 0.0, 1.0]);
637 let half_turn = Value::Quaternion([0.0, 0.0, 1.0, 0.0]);
639 let result = eval_expr(
640 "slerp(a, b, t)",
641 &[("a", identity), ("b", half_turn), ("t", Value::Scalar(0.5))],
642 );
643 if let Value::Quaternion(q) = result {
644 assert!(approx_eq(q[0], 0.0));
646 assert!(approx_eq(q[1], 0.0));
647 assert!(approx_eq(q[2], std::f32::consts::FRAC_PI_4.sin()));
648 assert!(approx_eq(q[3], std::f32::consts::FRAC_PI_4.cos()));
649 } else {
650 panic!("expected quaternion");
651 }
652 }
653}