scirs2_spatial/transform/
rotation_spline.rs

1//! RotationSpline for smooth interpolation between multiple rotations
2//!
3//! This module provides a `RotationSpline` class that allows for smooth interpolation
4//! between multiple rotations, creating a continuous curve in rotation space.
5
6use crate::error::{SpatialError, SpatialResult};
7use crate::transform::{Rotation, Slerp};
8use scirs2_core::ndarray::{array, Array1};
9
10// Helper function to create an array from values
11#[allow(dead_code)]
12fn euler_array(x: f64, y: f64, z: f64) -> Array1<f64> {
13    array![x, y, z]
14}
15
16// Helper function to create a rotation from Euler angles
17#[allow(dead_code)]
18fn rotation_from_euler(x: f64, y: f64, z: f64, convention: &str) -> SpatialResult<Rotation> {
19    let angles = euler_array(x, y, z);
20    let angles_view = angles.view();
21    Rotation::from_euler(&angles_view, convention)
22}
23
24/// RotationSpline provides smooth interpolation between multiple rotations.
25///
26/// A rotation spline allows for smooth interpolation between a sequence of rotations,
27/// creating a continuous curve in rotation space. It can be used to create smooth
28/// camera paths, character animations, or any other application requiring smooth
29/// rotation transitions.
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_spatial::transform::{Rotation, RotationSpline};
35/// use scirs2_core::ndarray::array;
36/// use std::f64::consts::PI;
37///
38/// // Create some rotations
39/// let rotations = vec![
40///     Rotation::identity(),
41///     Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz").unwrap(),
42///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
43/// ];
44///
45/// // Create times at which these rotations occur
46/// let times = vec![0.0, 0.5, 1.0];
47///
48/// // Create a rotation spline
49/// let spline = RotationSpline::new(&rotations, &times).unwrap();
50///
51/// // Get the interpolated rotation at t=0.25 (between the first two rotations)
52/// let rot_25 = spline.interpolate(0.25);
53///
54/// // Get the interpolated rotation at t=0.75 (between the second two rotations)
55/// let rot_75 = spline.interpolate(0.75);
56/// ```
57#[derive(Clone, Debug)]
58pub struct RotationSpline {
59    /// Sequence of rotations
60    rotations: Vec<Rotation>,
61    /// Times at which these rotations occur
62    times: Vec<f64>,
63    /// Cached velocities for natural cubic spline interpolation
64    velocities: Option<Vec<Array1<f64>>>,
65    /// Type of interpolation to use ("slerp" or "cubic")
66    interpolation_type: String,
67}
68
69impl RotationSpline {
70    /// Create a new rotation spline from a sequence of rotations and times
71    ///
72    /// # Arguments
73    ///
74    /// * `rotations` - A sequence of rotations
75    /// * `times` - The times at which these rotations occur
76    ///
77    /// # Returns
78    ///
79    /// A `SpatialResult` containing the RotationSpline if valid, or an error if invalid
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
85    /// use scirs2_core::ndarray::array;
86    /// use std::f64::consts::PI;
87    ///
88    /// let rotations = vec![
89    ///     Rotation::identity(),
90    ///     Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz").unwrap(),
91    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
92    /// ];
93    /// let times = vec![0.0, 1.0, 2.0];
94    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
95    /// ```
96    pub fn new(rotations: &[Rotation], times: &[f64]) -> SpatialResult<Self> {
97        if rotations.is_empty() {
98            return Err(SpatialError::ValueError("Rotations cannot be empty".into()));
99        }
100
101        if times.is_empty() {
102            return Err(SpatialError::ValueError("Times cannot be empty".into()));
103        }
104
105        if rotations.len() != times.len() {
106            return Err(SpatialError::ValueError(format!(
107                "Number of _rotations ({}) must match number of times ({})",
108                rotations.len(),
109                times.len()
110            )));
111        }
112
113        // Check if times are strictly increasing
114        for i in 1..times.len() {
115            if times[i] <= times[i - 1] {
116                return Err(SpatialError::ValueError(format!(
117                    "Times must be strictly increasing, but times[{}] = {} <= times[{}] = {}",
118                    i,
119                    times[i],
120                    i - 1,
121                    times[i - 1]
122                )));
123            }
124        }
125
126        // Make a copy of the _rotations and times
127        let rotations = rotations.to_vec();
128        let times = times.to_vec();
129
130        Ok(RotationSpline {
131            rotations,
132            times,
133            velocities: None,
134            interpolation_type: "slerp".to_string(),
135        })
136    }
137
138    /// Set the interpolation type for the rotation spline
139    ///
140    /// # Arguments
141    ///
142    /// * `_interptype` - The interpolation type ("slerp" or "cubic")
143    ///
144    /// # Returns
145    ///
146    /// A `SpatialResult` containing nothing if successful, or an error if the interpolation type is invalid
147    ///
148    /// # Examples
149    ///
150    /// ```
151    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
152    /// use scirs2_core::ndarray::array;
153    /// use std::f64::consts::PI;
154    ///
155    /// let rotations = vec![
156    ///     Rotation::identity(),
157    ///     Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz").unwrap(),
158    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
159    /// ];
160    /// let times = vec![0.0, 1.0, 2.0];
161    /// let mut spline = RotationSpline::new(&rotations, &times).unwrap();
162    ///
163    /// // Set the interpolation type to cubic (natural cubic spline)
164    /// spline.set_interpolation_type("cubic").unwrap();
165    /// ```
166    pub fn set_interpolation_type(&mut self, _interptype: &str) -> SpatialResult<()> {
167        match _interptype.to_lowercase().as_str() {
168            "slerp" => {
169                self.interpolation_type = "slerp".to_string();
170                self.velocities = None;
171                Ok(())
172            }
173            "cubic" => {
174                self.interpolation_type = "cubic".to_string();
175                // Compute velocities for cubic interpolation if needed
176                self.compute_velocities();
177                Ok(())
178            }
179            _ => Err(SpatialError::ValueError(format!(
180                "Invalid interpolation _type: {_interptype}. Must be 'slerp' or 'cubic'"
181            ))),
182        }
183    }
184
185    /// Compute velocities for natural cubic spline interpolation
186    fn compute_velocities(&mut self) {
187        if self.velocities.is_some() {
188            return; // Already computed
189        }
190
191        let n = self.times.len();
192        if n <= 2 {
193            // For 2 or fewer points, use zero velocities
194            let mut vels = Vec::with_capacity(n);
195            for _ in 0..n {
196                vels.push(Array1::zeros(3));
197            }
198            self.velocities = Some(vels);
199            return;
200        }
201
202        // Convert rotations to rotation vectors (axis-angle representation)
203        let mut rotvecs = Vec::with_capacity(n);
204        for rot in &self.rotations {
205            rotvecs.push(rot.as_rotvec());
206        }
207
208        // Compute velocities using finite differences and natural boundary conditions
209        let mut vels = Vec::with_capacity(n);
210
211        // For endpoints, we'll use one-sided differences
212        // For internal points, we'll use centered differences
213        for i in 0..n {
214            let vel = if i == 0 {
215                // Forward difference for the first point
216                let dt = self.times[1] - self.times[0];
217                (&rotvecs[1] - &rotvecs[0]) / dt
218            } else if i == n - 1 {
219                // Backward difference for the last point
220                let dt = self.times[n - 1] - self.times[n - 2];
221                (&rotvecs[n - 1] - &rotvecs[n - 2]) / dt
222            } else {
223                // Centered difference for internal points
224                let dt_prev = self.times[i] - self.times[i - 1];
225                let dt_next = self.times[i + 1] - self.times[i];
226
227                // Use weighted average based on time intervals
228                let vel_prev = (&rotvecs[i] - &rotvecs[i - 1]) / dt_prev;
229                let vel_next = (&rotvecs[i + 1] - &rotvecs[i]) / dt_next;
230
231                // Weighted average
232                let weight_prev = dt_next / (dt_prev + dt_next);
233                let weight_next = dt_prev / (dt_prev + dt_next);
234                &vel_prev * weight_prev + &vel_next * weight_next
235            };
236
237            vels.push(vel);
238        }
239
240        self.velocities = Some(vels);
241    }
242
243    /// Compute the second derivatives for natural cubic spline interpolation
244    #[allow(dead_code)]
245    fn compute_natural_spline_second_derivatives(&self, values: &[f64]) -> Vec<f64> {
246        let n = values.len();
247        if n <= 2 {
248            return vec![0.0; n];
249        }
250
251        // Set up the tridiagonal system for natural cubic spline
252        // The system is in the form: A * x = b
253        // where A is a tridiagonal matrix, x is the second derivatives we're solving for,
254        // and b is the right-hand side of the system
255
256        // Allocate arrays for the diagonals of the tridiagonal matrix
257        let mut a = vec![0.0; n - 2]; // Lower diagonal
258        let mut b = vec![0.0; n - 2]; // Main diagonal
259        let mut c = vec![0.0; n - 2]; // Upper diagonal
260        let mut d = vec![0.0; n - 2]; // Right-hand side
261
262        // Set up the tridiagonal system
263        for i in 0..n - 2 {
264            let h_i = self.times[i + 1] - self.times[i];
265            let h_ip1 = self.times[i + 2] - self.times[i + 1];
266
267            a[i] = h_i;
268            b[i] = 2.0 * (h_i + h_ip1);
269            c[i] = h_ip1;
270
271            let fd_i = (values[i + 1] - values[i]) / h_i;
272            let fd_ip1 = (values[i + 2] - values[i + 1]) / h_ip1;
273            d[i] = 6.0 * (fd_ip1 - fd_i);
274        }
275
276        // Solve the tridiagonal system using the Thomas algorithm
277        let mut x = vec![0.0; n - 2];
278        self.solve_tridiagonal(&a, &b, &c, &d, &mut x);
279
280        // The second derivatives at the endpoints are set to zero (natural spline)
281        let mut second_derivs = vec![0.0; n];
282        second_derivs[1..((n - 2) + 1)].copy_from_slice(&x[..(n - 2)]);
283
284        second_derivs
285    }
286
287    /// Solve a tridiagonal system using the Thomas algorithm
288    #[allow(dead_code)]
289    fn solve_tridiagonal(
290        &self,
291        a: &[f64],     // Lower diagonal
292        b: &[f64],     // Main diagonal
293        c: &[f64],     // Upper diagonal
294        d: &[f64],     // Right-hand side
295        x: &mut [f64], // Solution vector
296    ) {
297        let n = x.len();
298        if n == 0 {
299            return;
300        }
301
302        // Forward sweep
303        let mut c_prime = vec![0.0; n];
304        let mut d_prime = vec![0.0; n];
305
306        c_prime[0] = c[0] / b[0];
307        d_prime[0] = d[0] / b[0];
308
309        for i in 1..n {
310            let m = b[i] - a[i - 1] * c_prime[i - 1];
311            c_prime[i] = if i < n - 1 { c[i] / m } else { 0.0 };
312            d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / m;
313        }
314
315        // Back substitution
316        x[n - 1] = d_prime[n - 1];
317        for i in (0..n - 1).rev() {
318            x[i] = d_prime[i] - c_prime[i] * x[i + 1];
319        }
320    }
321
322    /// Interpolate the rotation spline at a given time
323    ///
324    /// # Arguments
325    ///
326    /// * `t` - The time at which to interpolate
327    ///
328    /// # Returns
329    ///
330    /// The interpolated rotation
331    ///
332    /// # Examples
333    ///
334    /// ```
335    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
336    /// use scirs2_core::ndarray::array;
337    /// use std::f64::consts::PI;
338    ///
339    /// let rotations = vec![
340    ///     Rotation::identity(),
341    ///     Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz").unwrap(),
342    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
343    /// ];
344    /// let times = vec![0.0, 1.0, 2.0];
345    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
346    ///
347    /// // Interpolate at t=0.5 (halfway between the first two rotations)
348    /// let rot_half = spline.interpolate(0.5);
349    /// ```
350    pub fn interpolate(&self, t: f64) -> Rotation {
351        let n = self.times.len();
352
353        // Handle boundary cases
354        if t <= self.times[0] {
355            return self.rotations[0].clone();
356        }
357        if t >= self.times[n - 1] {
358            return self.rotations[n - 1].clone();
359        }
360
361        // Find the segment containing t
362        let mut idx = 0;
363        for i in 0..n - 1 {
364            if t >= self.times[i] && t < self.times[i + 1] {
365                idx = i;
366                break;
367            }
368        }
369
370        // Interpolate within the segment based on interpolation type
371        match self.interpolation_type.as_str() {
372            "slerp" => self.interpolate_slerp(t, idx),
373            "cubic" => self.interpolate_cubic(t, idx),
374            _ => self.interpolate_slerp(t, idx), // Default to slerp
375        }
376    }
377
378    /// Interpolate the rotation spline at a given time using Slerp
379    fn interpolate_slerp(&self, t: f64, idx: usize) -> Rotation {
380        let t0 = self.times[idx];
381        let t1 = self.times[idx + 1];
382        let normalized_t = (t - t0) / (t1 - t0);
383
384        // Create a Slerp between the two rotations
385        let slerp =
386            Slerp::new(self.rotations[idx].clone(), self.rotations[idx + 1].clone()).unwrap();
387
388        slerp.interpolate(normalized_t)
389    }
390
391    /// Interpolate the rotation spline at a given time using cubic spline
392    fn interpolate_cubic(&self, t: f64, idx: usize) -> Rotation {
393        // Ensure velocities are computed
394        if self.velocities.is_none() {
395            let mut mutable_self = self.clone();
396            mutable_self.compute_velocities();
397            return mutable_self.interpolate_cubic(t, idx);
398        }
399
400        let t0 = self.times[idx];
401        let t1 = self.times[idx + 1];
402        let dt = t1 - t0;
403        let normalized_t = (t - t0) / dt;
404
405        let rot0 = &self.rotations[idx];
406        let rot1 = &self.rotations[idx + 1];
407
408        // Convert rotations to rotation vectors
409        let rotvec0 = rot0.as_rotvec();
410        let rotvec1 = rot1.as_rotvec();
411
412        // Get velocities
413        let velocities = self.velocities.as_ref().unwrap();
414        let vel0 = &velocities[idx];
415        let vel1 = &velocities[idx + 1];
416
417        // Use Hermite cubic interpolation formula
418        // h(t) = (2t³ - 3t² + 1)p0 + (t³ - 2t² + t)m0 + (-2t³ + 3t²)p1 + (t³ - t²)m1
419        // where p0, p1 are the start and end values, m0, m1 are the scaled tangents
420        let t2 = normalized_t * normalized_t;
421        let t3 = t2 * normalized_t;
422
423        // Hermite basis functions
424        let h00 = 2.0 * t3 - 3.0 * t2 + 1.0;
425        let h10 = t3 - 2.0 * t2 + normalized_t;
426        let h01 = -2.0 * t3 + 3.0 * t2;
427        let h11 = t3 - t2;
428
429        // Compute the interpolated rotation vector
430        let mut result = rotvec0 * h00;
431        result = &result + &(vel0 * dt * h10);
432        result = &result + &(rotvec1 * h01);
433        result = &result + &(vel1 * dt * h11);
434
435        // Convert back to rotation
436        Rotation::from_rotvec(&result.view()).unwrap()
437    }
438
439    /// Get the times at which the rotations are defined
440    ///
441    /// # Returns
442    ///
443    /// A reference to the times vector
444    ///
445    /// # Examples
446    ///
447    /// ```
448    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
449    /// use scirs2_core::ndarray::array;
450    ///
451    /// let rotations = vec![
452    ///     Rotation::identity(),
453    ///     Rotation::identity(),
454    /// ];
455    /// let times = vec![0.0, 1.0];
456    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
457    ///
458    /// let retrieved_times = spline.times();
459    /// assert_eq!(retrieved_times, &vec![0.0, 1.0]);
460    /// ```
461    pub fn times(&self) -> &Vec<f64> {
462        &self.times
463    }
464
465    /// Get the rotations that define the spline
466    ///
467    /// # Returns
468    ///
469    /// A reference to the rotations vector
470    ///
471    /// # Examples
472    ///
473    /// ```
474    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
475    /// use scirs2_core::ndarray::array;
476    ///
477    /// let rotations = vec![
478    ///     Rotation::identity(),
479    ///     Rotation::identity(),
480    /// ];
481    /// let times = vec![0.0, 1.0];
482    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
483    ///
484    /// let retrieved_rotations = spline.rotations();
485    /// assert_eq!(retrieved_rotations.len(), 2);
486    /// ```
487    pub fn rotations(&self) -> &Vec<Rotation> {
488        &self.rotations
489    }
490
491    /// Generate evenly spaced samples from the rotation spline
492    ///
493    /// # Arguments
494    ///
495    /// * `n` - The number of samples to generate
496    ///
497    /// # Returns
498    ///
499    /// A vector of sampled rotations and the corresponding times
500    ///
501    /// # Examples
502    ///
503    /// ```
504    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
505    /// use scirs2_core::ndarray::array;
506    /// use std::f64::consts::PI;
507    ///
508    /// let rotations = vec![
509    ///     Rotation::identity(),
510    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
511    /// ];
512    /// let times = vec![0.0, 1.0];
513    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
514    ///
515    /// // Generate 5 samples from the spline
516    /// let (sample_times, sample_rotations) = spline.sample(5);
517    /// assert_eq!(sample_times.len(), 5);
518    /// assert_eq!(sample_rotations.len(), 5);
519    /// ```
520    pub fn sample(&self, n: usize) -> (Vec<f64>, Vec<Rotation>) {
521        if n <= 1 {
522            return (vec![self.times[0]], vec![self.rotations[0].clone()]);
523        }
524
525        let t_min = self.times[0];
526        let t_max = self.times[self.times.len() - 1];
527
528        let mut sampled_times = Vec::with_capacity(n);
529        let mut sampled_rotations = Vec::with_capacity(n);
530
531        for i in 0..n {
532            let t = t_min + (t_max - t_min) * (i as f64 / (n - 1) as f64);
533            sampled_times.push(t);
534            sampled_rotations.push(self.interpolate(t));
535        }
536
537        (sampled_times, sampled_rotations)
538    }
539
540    /// Create a new rotation spline from key rotations at specific times
541    ///
542    /// This is equivalent to the regular constructor but with a more explicit name.
543    ///
544    /// # Arguments
545    ///
546    /// * `key_rots` - The key rotations
547    /// * `keytimes` - The times at which these key rotations occur
548    ///
549    /// # Returns
550    ///
551    /// A `SpatialResult` containing the RotationSpline if valid, or an error if invalid
552    ///
553    /// # Examples
554    ///
555    /// ```
556    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
557    /// use scirs2_core::ndarray::array;
558    /// use std::f64::consts::PI;
559    ///
560    /// let key_rots = vec![
561    ///     Rotation::identity(),
562    ///     Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz").unwrap(),
563    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
564    /// ];
565    /// let keytimes = vec![0.0, 1.0, 2.0];
566    ///
567    /// let spline = RotationSpline::from_key_rotations(&key_rots, &keytimes).unwrap();
568    /// ```
569    pub fn from_key_rotations(_key_rots: &[Rotation], keytimes: &[f64]) -> SpatialResult<Self> {
570        Self::new(_key_rots, keytimes)
571    }
572
573    /// Get the current interpolation type
574    ///
575    /// # Returns
576    ///
577    /// The current interpolation type ("slerp" or "cubic")
578    ///
579    /// # Examples
580    ///
581    /// ```
582    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
583    /// use scirs2_core::ndarray::array;
584    ///
585    /// let rotations = vec![
586    ///     Rotation::identity(),
587    ///     Rotation::identity(),
588    /// ];
589    /// let times = vec![0.0, 1.0];
590    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
591    ///
592    /// assert_eq!(spline.interpolation_type(), "slerp");
593    /// ```
594    pub fn interpolation_type(&self) -> &'_ str {
595        &self.interpolation_type
596    }
597
598    /// Calculate the angular velocity at a specific time
599    ///
600    /// # Arguments
601    ///
602    /// * `t` - The time at which to calculate the angular velocity
603    ///
604    /// # Returns
605    ///
606    /// The angular velocity as a 3-element array
607    ///
608    /// # Examples
609    ///
610    /// ```
611    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
612    /// use scirs2_core::ndarray::array;
613    /// use std::f64::consts::PI;
614    ///
615    /// let rotations = vec![
616    ///     Rotation::identity(),
617    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
618    /// ];
619    /// let times = vec![0.0, 1.0];
620    /// let spline = RotationSpline::new(&rotations, &times).unwrap();
621    ///
622    /// // Calculate angular velocity at t=0.5
623    /// let velocity = spline.angular_velocity(0.5);
624    /// // Should be approximately [0, 0, PI]
625    /// ```
626    pub fn angular_velocity(&self, t: f64) -> SpatialResult<Array1<f64>> {
627        let n = self.times.len();
628
629        // Handle boundary cases
630        if t <= self.times[0] || t >= self.times[n - 1] {
631            return Ok(Array1::zeros(3));
632        }
633
634        // Find the segment containing t
635        let mut idx = 0;
636        for i in 0..n - 1 {
637            if t >= self.times[i] && t < self.times[i + 1] {
638                idx = i;
639                break;
640            }
641        }
642
643        // Calculate angular velocity based on interpolation type
644        match self.interpolation_type.as_str() {
645            "slerp" => self.angular_velocity_slerp(t, idx),
646            "cubic" => Ok(self.angular_velocity_cubic(t, idx)),
647            _ => self.angular_velocity_slerp(t, idx), // Default to slerp
648        }
649    }
650
651    /// Calculate angular velocity using Slerp interpolation
652    fn angular_velocity_slerp(&self, t: f64, idx: usize) -> SpatialResult<Array1<f64>> {
653        let t0 = self.times[idx];
654        let t1 = self.times[idx + 1];
655        let dt = t1 - t0;
656        let normalized_t = (t - t0) / dt;
657
658        // Get rotations at the endpoints of the segment
659        let r0 = &self.rotations[idx];
660        let r1 = &self.rotations[idx + 1];
661
662        // Calculate the delta rotation from r0 to r1
663        let delta_rot = r0.inv().compose(r1);
664
665        // Convert to axis-angle representation via rotation vector
666        let rotvec = delta_rot.as_rotvec();
667        let angle = (rotvec.dot(&rotvec)).sqrt();
668        let axis = if angle > 1e-10 {
669            &rotvec / angle
670        } else {
671            Array1::zeros(3)
672        };
673
674        // For slerp, the angular velocity is constant and equals angle/dt along the axis
675        // The angular velocity vector in the current frame is:
676        // ω = (angle / dt) * axis
677
678        // However, we need to transform this to the frame at time t
679        // First interpolate to get the rotation at time t
680        let slerp = Slerp::new(r0.clone(), r1.clone()).unwrap();
681        let rot_t = slerp.interpolate(normalized_t);
682
683        // The angular velocity in the global frame is the axis scaled by angular rate
684        let angular_rate = angle / dt;
685        let omega_global = axis * angular_rate;
686
687        // Transform to the body frame at time t
688        // ω_body = R(t)^T * ω_global
689        rot_t.inv().apply(&omega_global.view())
690    }
691
692    /// Calculate angular velocity using cubic spline interpolation
693    fn angular_velocity_cubic(&self, t: f64, idx: usize) -> Array1<f64> {
694        // Ensure velocities are computed
695        if self.velocities.is_none() {
696            let mut mutable_self = self.clone();
697            mutable_self.compute_velocities();
698            return mutable_self.angular_velocity_cubic(t, idx);
699        }
700
701        let t0 = self.times[idx];
702        let t1 = self.times[idx + 1];
703        let dt = t1 - t0;
704        let normalized_t = (t - t0) / dt;
705
706        let rot0 = &self.rotations[idx];
707        let rot1 = &self.rotations[idx + 1];
708
709        // Convert rotations to rotation vectors
710        let rotvec0 = rot0.as_rotvec();
711        let rotvec1 = rot1.as_rotvec();
712
713        // Get velocities
714        let velocities = self.velocities.as_ref().unwrap();
715        let vel0 = &velocities[idx];
716        let vel1 = &velocities[idx + 1];
717
718        // Derivatives of Hermite basis functions
719        let dh00_dt = (6.0 * normalized_t.powi(2) - 6.0 * normalized_t) / dt;
720        let dh10_dt = (3.0 * normalized_t.powi(2) - 4.0 * normalized_t + 1.0) / dt;
721        let dh01_dt = (-6.0 * normalized_t.powi(2) + 6.0 * normalized_t) / dt;
722        let dh11_dt = (3.0 * normalized_t.powi(2) - 2.0 * normalized_t) / dt;
723
724        // Compute derivative of rotation vector interpolation
725        let mut d_rotvec_dt = &rotvec0 * dh00_dt;
726        d_rotvec_dt = &d_rotvec_dt + &(vel0 * dt * dh10_dt);
727        d_rotvec_dt = &d_rotvec_dt + &(&rotvec1 * dh01_dt);
728        d_rotvec_dt = &d_rotvec_dt + &(vel1 * dt * dh11_dt);
729
730        // The derivative gives us the angular velocity in the rotation vector space
731        // This is already the angular velocity we want
732        d_rotvec_dt
733    }
734
735    /// Calculate the angular acceleration at a specific time
736    ///
737    /// # Arguments
738    ///
739    /// * `t` - The time at which to calculate the angular acceleration
740    ///
741    /// # Returns
742    ///
743    /// The angular acceleration as a 3-element array
744    ///
745    /// # Examples
746    ///
747    /// ```
748    /// use scirs2_spatial::transform::{Rotation, RotationSpline};
749    /// use scirs2_core::ndarray::array;
750    /// use std::f64::consts::PI;
751    ///
752    /// let rotations = vec![
753    ///     Rotation::identity(),
754    ///     Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
755    ///     Rotation::identity(),
756    /// ];
757    /// let times = vec![0.0, 1.0, 2.0];
758    /// let mut spline = RotationSpline::new(&rotations, &times).unwrap();
759    ///
760    /// // Set to cubic interpolation for non-zero acceleration
761    /// spline.set_interpolation_type("cubic").unwrap();
762    ///
763    /// // Calculate angular acceleration at t=0.5
764    /// let acceleration = spline.angular_acceleration(0.5);
765    /// ```
766    pub fn angular_acceleration(&self, t: f64) -> Array1<f64> {
767        // Cubic interpolation is needed for meaningful acceleration
768        if self.interpolation_type != "cubic" {
769            return Array1::zeros(3); // Slerp has constant velocity, so acceleration is zero
770        }
771
772        let n = self.times.len();
773
774        // Handle boundary cases
775        if t <= self.times[0] || t >= self.times[n - 1] {
776            return Array1::zeros(3);
777        }
778
779        // Find the segment containing t
780        let mut idx = 0;
781        for i in 0..n - 1 {
782            if t >= self.times[i] && t < self.times[i + 1] {
783                idx = i;
784                break;
785            }
786        }
787
788        // Calculate angular acceleration
789        self.angular_acceleration_cubic(t, idx)
790    }
791
792    /// Calculate angular acceleration using cubic spline interpolation
793    fn angular_acceleration_cubic(&self, t: f64, idx: usize) -> Array1<f64> {
794        // Ensure velocities are computed
795        if self.velocities.is_none() {
796            let mut mutable_self = self.clone();
797            mutable_self.compute_velocities();
798            return mutable_self.angular_acceleration_cubic(t, idx);
799        }
800
801        let t0 = self.times[idx];
802        let t1 = self.times[idx + 1];
803        let dt = t1 - t0;
804        let normalized_t = (t - t0) / dt;
805
806        let rot0 = &self.rotations[idx];
807        let rot1 = &self.rotations[idx + 1];
808
809        // Convert rotations to rotation vectors
810        let rotvec0 = rot0.as_rotvec();
811        let rotvec1 = rot1.as_rotvec();
812
813        // Get velocities
814        let velocities = self.velocities.as_ref().unwrap();
815        let vel0 = &velocities[idx];
816        let vel1 = &velocities[idx + 1];
817
818        // Second derivatives of Hermite basis functions
819        let d2h00_dt2 = (12.0 * normalized_t - 6.0) / (dt * dt);
820        let d2h10_dt2 = (6.0 * normalized_t - 4.0) / (dt * dt);
821        let d2h01_dt2 = (-12.0 * normalized_t + 6.0) / (dt * dt);
822        let d2h11_dt2 = (6.0 * normalized_t - 2.0) / (dt * dt);
823
824        // Compute second derivative of rotation vector interpolation
825        let mut d2_rotvec_dt2 = &rotvec0 * d2h00_dt2;
826        d2_rotvec_dt2 = &d2_rotvec_dt2 + &(vel0 * dt * d2h10_dt2);
827        d2_rotvec_dt2 = &d2_rotvec_dt2 + &(&rotvec1 * d2h01_dt2);
828        d2_rotvec_dt2 = &d2_rotvec_dt2 + &(vel1 * dt * d2h11_dt2);
829
830        // This gives us the angular acceleration
831        d2_rotvec_dt2
832    }
833}
834
835#[cfg(test)]
836mod tests {
837    use super::*;
838    use approx::assert_relative_eq;
839    use std::f64::consts::PI;
840
841    #[test]
842    fn test_rotation_spline_creation() {
843        let rotations = vec![
844            Rotation::identity(),
845            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").unwrap(),
846            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
847        ];
848        let times = vec![0.0, 1.0, 2.0];
849
850        let spline = RotationSpline::new(&rotations, &times).unwrap();
851
852        assert_eq!(spline.rotations().len(), 3);
853        assert_eq!(spline.times().len(), 3);
854        assert_eq!(spline.interpolation_type(), "slerp");
855    }
856
857    #[test]
858    fn test_rotation_spline_interpolation_endpoints() {
859        let rotations = vec![
860            Rotation::identity(),
861            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").unwrap(),
862            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
863        ];
864        let times = vec![0.0, 1.0, 2.0];
865
866        let spline = RotationSpline::new(&rotations, &times).unwrap();
867
868        // Test at endpoints
869        let interp_start = spline.interpolate(0.0);
870        let interp_end = spline.interpolate(2.0);
871
872        // Should match the first and last rotations
873        assert_eq!(interp_start.as_quat(), rotations[0].as_quat());
874        assert_eq!(interp_end.as_quat(), rotations[2].as_quat());
875
876        // Test beyond endpoints (should clamp)
877        let before_start = spline.interpolate(-1.0);
878        let after_end = spline.interpolate(3.0);
879
880        assert_eq!(before_start.as_quat(), rotations[0].as_quat());
881        assert_eq!(after_end.as_quat(), rotations[2].as_quat());
882    }
883
884    #[test]
885    fn test_rotation_spline_interpolation_midpoints() {
886        let rotations = vec![
887            Rotation::identity(),
888            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").unwrap(),
889            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
890        ];
891        let times = vec![0.0, 1.0, 2.0];
892
893        let spline = RotationSpline::new(&rotations, &times).unwrap();
894
895        // Test at midpoints
896        let interp_mid1 = spline.interpolate(0.5);
897        let interp_mid2 = spline.interpolate(1.5);
898
899        // Apply to a test point
900        let test_point = array![1.0, 0.0, 0.0];
901
902        // Verify interpolation results
903        let rotated_mid1 = interp_mid1.apply(&test_point.view()).unwrap();
904        let rotated_mid2 = interp_mid2.apply(&test_point.view()).unwrap();
905
906        // At t=0.5 (between identity and 90-degree rotation), should be approximately 45 degrees
907        assert_relative_eq!(rotated_mid1[0], 2.0_f64.sqrt() / 2.0, epsilon = 1e-3);
908        assert_relative_eq!(rotated_mid1[1], 2.0_f64.sqrt() / 2.0, epsilon = 1e-3);
909        assert_relative_eq!(rotated_mid1[2], 0.0, epsilon = 1e-3);
910
911        // At t=1.5 (between 90 and 180 degrees), should be approximately 135 degrees
912        assert_relative_eq!(rotated_mid2[0], -2.0_f64.sqrt() / 2.0, epsilon = 1e-3);
913        assert_relative_eq!(rotated_mid2[1], 2.0_f64.sqrt() / 2.0, epsilon = 1e-3);
914        assert_relative_eq!(rotated_mid2[2], 0.0, epsilon = 1e-3);
915    }
916
917    #[test]
918    fn test_rotation_spline_sampling() {
919        let rotations = vec![
920            Rotation::identity(),
921            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
922        ];
923        let times = vec![0.0, 1.0];
924
925        let spline = RotationSpline::new(&rotations, &times).unwrap();
926
927        // Sample 5 points
928        let (sample_times, sample_rotations) = spline.sample(5);
929
930        assert_eq!(sample_times.len(), 5);
931        assert_eq!(sample_rotations.len(), 5);
932
933        // Check if times are evenly spaced
934        assert_relative_eq!(sample_times[0], 0.0, epsilon = 1e-10);
935        assert_relative_eq!(sample_times[1], 0.25, epsilon = 1e-10);
936        assert_relative_eq!(sample_times[2], 0.5, epsilon = 1e-10);
937        assert_relative_eq!(sample_times[3], 0.75, epsilon = 1e-10);
938        assert_relative_eq!(sample_times[4], 1.0, epsilon = 1e-10);
939
940        // Check if rotations are correct
941        let point = array![1.0, 0.0, 0.0];
942
943        // At t=0.0, should be identity
944        let rot0 = &sample_rotations[0];
945        let rotated0 = rot0.apply(&point.view()).unwrap();
946        assert_relative_eq!(rotated0[0], 1.0, epsilon = 1e-10);
947        assert_relative_eq!(rotated0[1], 0.0, epsilon = 1e-10);
948
949        // At t=0.5, should be 90-degree rotation
950        let rot2 = &sample_rotations[2];
951        let rotated2 = rot2.apply(&point.view()).unwrap();
952        assert_relative_eq!(rotated2[0], 0.0, epsilon = 1e-3);
953        assert_relative_eq!(rotated2[1], 1.0, epsilon = 1e-3);
954        assert_relative_eq!(rotated2[2], 0.0, epsilon = 1e-3);
955
956        // At t=1.0, should be 180-degree rotation
957        let rot4 = &sample_rotations[4];
958        let rotated4 = rot4.apply(&point.view()).unwrap();
959        assert_relative_eq!(rotated4[0], -1.0, epsilon = 1e-10);
960        assert_relative_eq!(rotated4[1], 0.0, epsilon = 1e-10);
961        assert_relative_eq!(rotated4[2], 0.0, epsilon = 1e-10);
962    }
963
964    #[test]
965    fn test_rotation_spline_errors() {
966        // Empty rotations
967        let result = RotationSpline::new(&[], &[0.0]);
968        assert!(result.is_err());
969
970        // Empty times
971        let rotations = vec![Rotation::identity()];
972        let result = RotationSpline::new(&rotations, &[]);
973        assert!(result.is_err());
974
975        // Mismatched lengths
976        let rotations = vec![Rotation::identity(), Rotation::identity()];
977        let times = vec![0.0];
978        let result = RotationSpline::new(&rotations, &times);
979        assert!(result.is_err());
980
981        // Non-increasing times
982        let rotations = vec![Rotation::identity(), Rotation::identity()];
983        let times = vec![1.0, 0.0];
984        let result = RotationSpline::new(&rotations, &times);
985        assert!(result.is_err());
986
987        // Equal times
988        let rotations = vec![Rotation::identity(), Rotation::identity()];
989        let times = vec![0.0, 0.0];
990        let result = RotationSpline::new(&rotations, &times);
991        assert!(result.is_err());
992
993        // Invalid interpolation type
994        let rotations = vec![Rotation::identity(), Rotation::identity()];
995        let times = vec![0.0, 1.0];
996        let mut spline = RotationSpline::new(&rotations, &times).unwrap();
997        let result = spline.set_interpolation_type("invalid");
998        assert!(result.is_err());
999    }
1000
1001    #[test]
1002    fn test_interpolation_types() {
1003        let rotations = vec![
1004            Rotation::identity(),
1005            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").unwrap(),
1006            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
1007        ];
1008        let times = vec![0.0, 1.0, 2.0];
1009
1010        let mut spline = RotationSpline::new(&rotations, &times).unwrap();
1011
1012        // Default should be slerp
1013        assert_eq!(spline.interpolation_type(), "slerp");
1014
1015        // Change to cubic
1016        spline.set_interpolation_type("cubic").unwrap();
1017        assert_eq!(spline.interpolation_type(), "cubic");
1018
1019        // Check that velocities are computed
1020        assert!(spline.velocities.is_some());
1021
1022        // Change back to slerp
1023        spline.set_interpolation_type("slerp").unwrap();
1024        assert_eq!(spline.interpolation_type(), "slerp");
1025
1026        // Velocities should be cleared
1027        assert!(spline.velocities.is_none());
1028    }
1029
1030    #[test]
1031    fn test_angular_velocity() {
1032        let rotations = vec![
1033            Rotation::identity(),
1034            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
1035        ];
1036        let times = vec![0.0, 1.0];
1037
1038        let spline = RotationSpline::new(&rotations, &times).unwrap();
1039
1040        // Angular velocity should be constant for slerp
1041        let velocity = spline.angular_velocity(0.5).unwrap();
1042
1043        // For a rotation from identity to 180 degrees around z-axis over 1 second,
1044        // the angular velocity should be approximately [0, 0, π]
1045        assert_relative_eq!(velocity[0], 0.0, epsilon = 1e-3);
1046        assert_relative_eq!(velocity[1], 0.0, epsilon = 1e-3);
1047        assert_relative_eq!(velocity[2], PI, epsilon = 1e-3);
1048
1049        // Velocity should be the same at any point in the segment
1050        let velocity_25 = spline.angular_velocity(0.25).unwrap();
1051        let velocity_75 = spline.angular_velocity(0.75).unwrap();
1052
1053        assert_relative_eq!(velocity_25[0], velocity[0], epsilon = 1e-10);
1054        assert_relative_eq!(velocity_25[1], velocity[1], epsilon = 1e-10);
1055        assert_relative_eq!(velocity_25[2], velocity[2], epsilon = 1e-10);
1056
1057        assert_relative_eq!(velocity_75[0], velocity[0], epsilon = 1e-10);
1058        assert_relative_eq!(velocity_75[1], velocity[1], epsilon = 1e-10);
1059        assert_relative_eq!(velocity_75[2], velocity[2], epsilon = 1e-10);
1060    }
1061
1062    #[test]
1063    fn test_cubic_interpolation() {
1064        let rotations = vec![
1065            Rotation::identity(),
1066            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").unwrap(),
1067            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
1068        ];
1069        let times = vec![0.0, 1.0, 2.0];
1070
1071        let mut spline = RotationSpline::new(&rotations, &times).unwrap();
1072
1073        // Set to cubic interpolation
1074        spline.set_interpolation_type("cubic").unwrap();
1075
1076        // Test at endpoints, should match original rotations
1077        let rot_0 = spline.interpolate(0.0);
1078        let rot_1 = spline.interpolate(1.0);
1079        let rot_2 = spline.interpolate(2.0);
1080
1081        let test_point = array![1.0, 0.0, 0.0];
1082
1083        // Check that endpoints match original rotations
1084        let rotated_0 = rot_0.apply(&test_point.view()).unwrap();
1085        let expected_0 = rotations[0].apply(&test_point.view()).unwrap();
1086        assert_relative_eq!(rotated_0[0], expected_0[0], epsilon = 1e-10);
1087        assert_relative_eq!(rotated_0[1], expected_0[1], epsilon = 1e-10);
1088        assert_relative_eq!(rotated_0[2], expected_0[2], epsilon = 1e-10);
1089
1090        let rotated_1 = rot_1.apply(&test_point.view()).unwrap();
1091        let expected_1 = rotations[1].apply(&test_point.view()).unwrap();
1092        assert_relative_eq!(rotated_1[0], expected_1[0], epsilon = 1e-10);
1093        assert_relative_eq!(rotated_1[1], expected_1[1], epsilon = 1e-10);
1094        assert_relative_eq!(rotated_1[2], expected_1[2], epsilon = 1e-10);
1095
1096        let rotated_2 = rot_2.apply(&test_point.view()).unwrap();
1097        let expected_2 = rotations[2].apply(&test_point.view()).unwrap();
1098        assert_relative_eq!(rotated_2[0], expected_2[0], epsilon = 1e-10);
1099        assert_relative_eq!(rotated_2[1], expected_2[1], epsilon = 1e-10);
1100        assert_relative_eq!(rotated_2[2], expected_2[2], epsilon = 1e-10);
1101
1102        // Test midpoints - cubic interpolation should be smoother than slerp
1103        // but still interpolate the key rotations
1104        let rot_05 = spline.interpolate(0.5);
1105        let rot_15 = spline.interpolate(1.5);
1106
1107        // Verify that interpolated rotations are valid
1108        let rotated_05 = rot_05.apply(&test_point.view()).unwrap();
1109        let rotated_15 = rot_15.apply(&test_point.view()).unwrap();
1110
1111        // Check that the results are normalized
1112        let norm_05 = (rotated_05.dot(&rotated_05)).sqrt();
1113        let norm_15 = (rotated_15.dot(&rotated_15)).sqrt();
1114        assert_relative_eq!(norm_05, 1.0, epsilon = 1e-10);
1115        assert_relative_eq!(norm_15, 1.0, epsilon = 1e-10);
1116    }
1117
1118    #[test]
1119    fn test_angular_acceleration() {
1120        let rotations = vec![
1121            Rotation::identity(),
1122            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").unwrap(),
1123            Rotation::from_euler(&array![0.0, 0.0, PI].view(), "xyz").unwrap(),
1124        ];
1125        let times = vec![0.0, 1.0, 2.0];
1126
1127        let mut spline = RotationSpline::new(&rotations, &times).unwrap();
1128
1129        // Slerp should have zero acceleration
1130        let accel_slerp = spline.angular_acceleration(0.5);
1131        assert_relative_eq!(accel_slerp[0], 0.0, epsilon = 1e-10);
1132        assert_relative_eq!(accel_slerp[1], 0.0, epsilon = 1e-10);
1133        assert_relative_eq!(accel_slerp[2], 0.0, epsilon = 1e-10);
1134
1135        // Set to cubic interpolation
1136        spline.set_interpolation_type("cubic").unwrap();
1137
1138        // Cubic should have non-zero acceleration
1139        let _accel_cubic = spline.angular_acceleration(0.5);
1140
1141        // For linear rotation sequence, acceleration might still be close to zero
1142        // Let's create a more complex rotation sequence
1143        let complex_rotations = vec![
1144            Rotation::identity(),
1145            {
1146                let angles = array![PI / 2.0, 0.0, 0.0];
1147                Rotation::from_euler(&angles.view(), "xyz").unwrap()
1148            },
1149            {
1150                let angles = array![PI / 2.0, PI / 2.0, 0.0];
1151                Rotation::from_euler(&angles.view(), "xyz").unwrap()
1152            },
1153        ];
1154        let complex_times = vec![0.0, 1.0, 2.0];
1155
1156        let mut complex_spline = RotationSpline::new(&complex_rotations, &complex_times).unwrap();
1157        complex_spline.set_interpolation_type("cubic").unwrap();
1158
1159        let complex_accel = complex_spline.angular_acceleration(0.5);
1160
1161        // For non-linear rotation sequences, acceleration should be non-zero
1162        let magnitude = (complex_accel.dot(&complex_accel)).sqrt();
1163        assert!(magnitude > 1e-6); // Should have meaningful acceleration
1164    }
1165}