tensor_rs/
quaternion.rs

1use std::ops::Div;
2
3#[derive(PartialEq, Debug)]
4pub struct Quaternion<T> {
5    d: (T, T, T, T),
6}
7
8impl<T> Default for Quaternion<T>
9where T: num_traits::Float {
10    fn default() -> Self {
11        Quaternion {
12            d: (T::one(), T::zero(), T::zero(), T::zero())
13        }
14    }
15}
16
17impl<T> Div<T> for Quaternion<T>
18where T: num_traits::Float {
19    type Output = Self;
20    
21    fn div(self, rhs: T) -> Self::Output {
22        Quaternion {
23            d: (self.d.0/rhs, self.d.1/rhs, self.d.2/rhs, self.d.3/rhs)
24        }
25    }
26}
27
28impl<T> Quaternion<T>
29where T: num_traits::Float {
30
31    pub fn new(a: T, b: T, c: T, d: T) -> Self {
32        Quaternion {
33            d: (a, b, c, d)
34        }
35    }
36
37    pub fn scalar_part(&self) -> T {
38        self.d.0
39    }
40
41    pub fn vector_part(&self) -> (T, T, T) {
42        (self.d.1, self.d.2, self.d.3)
43    }
44
45    pub fn conjugate(&self) -> Self {
46        Quaternion {
47            d: (self.d.0, -self.d.1, -self.d.2, -self.d.3)
48        }
49    }
50
51    pub fn dot(&self, o: &Quaternion<T>) -> T {
52        self.d.0*o.d.0
53            + self.d.1*o.d.1
54            + self.d.2*o.d.2
55            + self.d.3*o.d.3
56    }
57
58    pub fn len(&self) -> T {
59        T::sqrt(self.dot(self))
60    }
61
62    pub fn norm(&self) -> T {
63        self.len()
64    }
65
66    pub fn inverse(&self) -> Self {
67        self.conjugate()/self.dot(self)
68    }
69
70    /// Make it a unit quaternion
71    pub fn normalize(&self) -> Self {
72        let n = self.norm();
73        
74        Quaternion {
75            d: (self.d.0/n, self.d.1/n, self.d.2/n, self.d.3/n)
76        }
77    }
78
79    /// Quaternion multiplication
80    pub fn qm(&self, o: &Quaternion<T>) -> Self {
81        Quaternion {
82            d: (self.d.0*o.d.0 - self.d.1*o.d.1 - self.d.2*o.d.2 - self.d.3*o.d.3,
83                self.d.0*o.d.1 + self.d.1*o.d.0 + self.d.2*o.d.3 - self.d.3*o.d.2,
84                self.d.0*o.d.2 - self.d.1*o.d.3 + self.d.2*o.d.0 + self.d.3*o.d.1,
85                self.d.0*o.d.3 + self.d.1*o.d.2 - self.d.2*o.d.1 + self.d.3*o.d.0,)
86        }
87    }
88
89    /// Create a quaternion ready to apply to vector for rotation.
90    pub fn rotation_around_axis(axis: (T, T, T), theta: T) -> Self {
91        let a = T::cos(theta/(T::one() + T::one()));
92        let coef = T::sin(theta/(T::one() + T::one()));
93        let norm = T::sqrt(axis.0*axis.0 + axis.1*axis.1 + axis.2*axis.2);
94        
95        Quaternion {
96            d: (a, coef*axis.0/norm,
97                coef*axis.1/norm,
98                coef*axis.2/norm)
99        }
100    }
101
102    pub fn rotate_around_x(theta: T) -> Self {
103        Self::rotation_around_axis((T::one(), T::zero(), T::zero()), theta)
104    }
105
106    pub fn rotate_around_y(theta: T) -> Self {
107        Self::rotation_around_axis((T::zero(), T::one(), T::zero()), theta)
108    }
109
110    pub fn rotate_around_z(theta: T) -> Self {
111        Self::rotation_around_axis((T::zero(), T::zero(), T::one()), theta)
112    }
113
114    /// Apply unit quaternion to 3d vector for rotation.
115    pub fn apply_rotation(&self, v: (T, T, T)) -> (T, T, T) {
116        if self.norm() != T::one() {
117            println!("Apply a non unit quaternion for rotation!");
118        }
119        
120        let x = Quaternion {
121            d: (T::zero(), v.0, v.1, v.2)
122        };
123        let xp = self.qm(&x).qm(&self.conjugate());
124        (xp.d.1, xp.d.2, xp.d.3)
125    }
126
127    pub fn rotate_around_axis(axis: (T, T, T), theta: T, v: (T, T, T)) -> (T, T, T) {
128        let q = Self::rotation_around_axis(axis, theta);
129        q.apply_rotation(v)
130    }
131
132    pub fn unit_exp(&self, t: T) -> Self {
133        if self.norm() != T::one() {
134            println!("unit_exp needs unit quaternion!");
135        }
136
137        let omega = T::acos(self.d.0);
138
139        if T::sin(omega) == T::zero() {
140            return Self::default();
141        }
142        
143        let i = self.d.1/T::sin(omega);
144        let j = self.d.2/T::sin(omega);
145        let k = self.d.3/T::sin(omega);
146        
147        Quaternion {
148            d: (T::cos(t*omega), T::sin(t*omega)*i,
149                T::sin(t*omega)*j, T::sin(t*omega)*k)
150        }
151    }
152
153    pub fn slerp(p: &Self, q: &Self, t: T) -> Self {
154        if p.norm() != T::one() || q.norm() != T::one() {
155            println!("slerp need unit quaternion!");
156        }
157
158        let p1 = p.normalize();
159        let q1 = q.normalize();
160
161        p1.qm(&p1.inverse().qm(&q1).unit_exp(t))
162    }
163
164    
165    
166}
167
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_qm() {
175        let a = Quaternion::<f64>::new(1., 2., 3., 4.);
176        let b = Quaternion::<f64>::new(2., 3., 4., 5.);
177
178        let c = a.qm(&b);
179        assert_eq!(c, Quaternion::<f64>::new(-36., 6., 12., 12.,))
180    }
181
182    #[test]
183    fn test_rotate_around_axis() {
184        let a = Quaternion::<f64>::rotate_around_x(3.1415/2.);
185        let v = (0., 0., 1.);
186        let r = a.apply_rotation(v);
187
188        assert!(f64::abs(r.0-0.) + f64::abs(r.1 + 1.) + f64::abs(r.2-0.) < 0.001);
189    }
190
191    #[test]
192    fn test_unit_exp() {
193        let b = Quaternion::<f64>::default();
194        let b1 = b.unit_exp(1.);
195        assert_eq!(b1, Quaternion::<f64>::new(1., 0., 0., 0.));
196
197        let b = Quaternion::<f64>::new(0.5, 0.5, 0.5, 0.5);
198        let b1 = b.unit_exp(0.);
199        assert_eq!(b1, Quaternion::<f64>::new(1., 0., 0., 0.));
200
201        //let b = Quaternion::<f64>::new(0.5, 0.5, 0.5, 0.5);
202        //let b1 = b.unit_exp(0.5);
203        //assert_eq!(b1, Quaternion::<f64>::new(1., 0., 0., 0.));
204
205        //let b = Quaternion::<f64>::new(0.5, 0.5, 0.5, 0.5);
206        //let b1 = b.unit_exp(1.);
207        //assert_eq!(b1, Quaternion::<f64>::new(0.5, 0.5, 0.5, 0.5));
208        
209    }
210
211    #[test]
212    fn test_slerp() {
213        let a = Quaternion::<f64>::rotate_around_x(3.1415/2.);
214        let b = Quaternion::<f64>::default();
215        let c = Quaternion::<f64>::slerp(&a, &b, 0.5);
216
217        let v = (0., 0., 1.);
218        let r = c.apply_rotation(v);
219
220        assert_eq!(r, (0.0, -0.7070904020014415, 0.7071231599922606));
221
222        //assert_eq!(c, Quaternion::<f64>::default());
223    }
224}