Skip to main content

scirs2_spatial/transform/
rigid_transform.rs

1//! RigidTransform class for combined rotation and translation
2//!
3//! This module provides a `RigidTransform` class that represents a rigid transformation
4//! in 3D space, combining a rotation and translation.
5
6use crate::error::{SpatialError, SpatialResult};
7use crate::transform::Rotation;
8use scirs2_core::ndarray::{array, Array1, Array2, ArrayView1, ArrayView2};
9
10// Helper function to create an array from values
11#[allow(dead_code)]
12fn euler_array(x: f64, y: f64, z: f64) -> Array1<f64> {
13    array![x, y, z]
14}
15
16// Helper function to create a rotation from Euler angles
17#[allow(dead_code)]
18fn rotation_from_euler(x: f64, y: f64, z: f64, convention: &str) -> SpatialResult<Rotation> {
19    let angles = euler_array(x, y, z);
20    let angles_view = angles.view();
21    Rotation::from_euler(&angles_view, convention)
22}
23
24/// RigidTransform represents a rigid transformation in 3D space.
25///
26/// A rigid transformation is a combination of a rotation and a translation.
27/// It preserves the distance between any two points and the orientation of objects.
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_spatial::transform::{Rotation, RigidTransform};
33/// use scirs2_core::ndarray::array;
34/// use std::f64::consts::PI;
35///
36/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
37/// // Create a rotation around Z and a translation
38/// let rotation = Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz")?;
39/// let translation = array![1.0, 2.0, 3.0];
40///
41/// // Create a rigid transform from rotation and translation
42/// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view())?;
43///
44/// // Apply the transform to a point
45/// let point = array![0.0, 0.0, 0.0];
46/// let transformed = transform.apply(&point.view());
47/// // Should be [1.0, 2.0, 3.0] (just the translation for the origin)
48///
49/// // Another point
50/// let point2 = array![1.0, 0.0, 0.0];
51/// let transformed2 = transform.apply(&point2.view());
52/// // Should be [1.0, 3.0, 3.0] (rotated then translated)
53/// # Ok(())
54/// # }
55/// ```
56#[derive(Clone, Debug)]
57pub struct RigidTransform {
58    /// The rotation component
59    rotation: Rotation,
60    /// The translation component
61    translation: Array1<f64>,
62}
63
64impl RigidTransform {
65    /// Create a new rigid transform from a rotation and translation
66    ///
67    /// # Arguments
68    ///
69    /// * `rotation` - The rotation component
70    /// * `translation` - The translation vector (3D)
71    ///
72    /// # Returns
73    ///
74    /// A `SpatialResult` containing the rigid transform if valid, or an error if invalid
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
80    /// use scirs2_core::ndarray::array;
81    ///
82    /// let rotation = Rotation::identity();
83    /// let translation = array![1.0, 2.0, 3.0];
84    /// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view()).expect("Operation failed");
85    /// ```
86    pub fn from_rotation_and_translation(
87        rotation: Rotation,
88        translation: &ArrayView1<f64>,
89    ) -> SpatialResult<Self> {
90        if translation.len() != 3 {
91            return Err(SpatialError::DimensionError(format!(
92                "Translation must have 3 elements, got {}",
93                translation.len()
94            )));
95        }
96
97        Ok(RigidTransform {
98            rotation,
99            translation: translation.to_owned(),
100        })
101    }
102
103    /// Create a rigid transform from a 4x4 transformation matrix
104    ///
105    /// # Arguments
106    ///
107    /// * `matrix` - A 4x4 transformation matrix in homogeneous coordinates
108    ///
109    /// # Returns
110    ///
111    /// A `SpatialResult` containing the rigid transform if valid, or an error if invalid
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use scirs2_spatial::transform::RigidTransform;
117    /// use scirs2_core::ndarray::array;
118    ///
119    /// // Create a transformation matrix for translation by [1, 2, 3]
120    /// let matrix = array![
121    ///     [1.0, 0.0, 0.0, 1.0],
122    ///     [0.0, 1.0, 0.0, 2.0],
123    ///     [0.0, 0.0, 1.0, 3.0],
124    ///     [0.0, 0.0, 0.0, 1.0]
125    /// ];
126    /// let transform = RigidTransform::from_matrix(&matrix.view()).expect("Operation failed");
127    /// ```
128    pub fn from_matrix(matrix: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
129        if matrix.shape() != [4, 4] {
130            return Err(SpatialError::DimensionError(format!(
131                "Matrix must be 4x4, got {:?}",
132                matrix.shape()
133            )));
134        }
135
136        // Check the last row is [0, 0, 0, 1]
137        for i in 0..3 {
138            if (matrix[[3, i]] - 0.0).abs() > 1e-10 {
139                return Err(SpatialError::ValueError(
140                    "Last row of matrix must be [0, 0, 0, 1]".into(),
141                ));
142            }
143        }
144        if (matrix[[3, 3]] - 1.0).abs() > 1e-10 {
145            return Err(SpatialError::ValueError(
146                "Last row of matrix must be [0, 0, 0, 1]".into(),
147            ));
148        }
149
150        // Extract the rotation part (3x3 upper-left submatrix)
151        let mut rotation_matrix = Array2::<f64>::zeros((3, 3));
152        for i in 0..3 {
153            for j in 0..3 {
154                rotation_matrix[[i, j]] = matrix[[i, j]];
155            }
156        }
157
158        // Extract the translation part (right column, first 3 elements)
159        let mut translation = Array1::<f64>::zeros(3);
160        for i in 0..3 {
161            translation[i] = matrix[[i, 3]];
162        }
163
164        // Create rotation from the extracted matrix
165        let rotation = Rotation::from_matrix(&rotation_matrix.view())?;
166
167        Ok(RigidTransform {
168            rotation,
169            translation,
170        })
171    }
172
173    /// Convert the rigid transform to a 4x4 matrix in homogeneous coordinates
174    ///
175    /// # Returns
176    ///
177    /// A 4x4 transformation matrix
178    ///
179    /// # Examples
180    ///
181    /// ```
182    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
183    /// use scirs2_core::ndarray::array;
184    ///
185    /// let rotation = Rotation::identity();
186    /// let translation = array![1.0, 2.0, 3.0];
187    /// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view()).expect("Operation failed");
188    /// let matrix = transform.as_matrix();
189    /// // Should be a 4x4 identity matrix with the last column containing the translation
190    /// ```
191    pub fn as_matrix(&self) -> Array2<f64> {
192        let mut matrix = Array2::<f64>::zeros((4, 4));
193
194        // Set the rotation part
195        let rotation_matrix = self.rotation.as_matrix();
196        for i in 0..3 {
197            for j in 0..3 {
198                matrix[[i, j]] = rotation_matrix[[i, j]];
199            }
200        }
201
202        // Set the translation part
203        for i in 0..3 {
204            matrix[[i, 3]] = self.translation[i];
205        }
206
207        // Set the homogeneous coordinate part
208        matrix[[3, 3]] = 1.0;
209
210        matrix
211    }
212
213    /// Get the rotation component of the rigid transform
214    ///
215    /// # Returns
216    ///
217    /// The rotation component
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
223    /// use scirs2_core::ndarray::array;
224    ///
225    /// let rotation = Rotation::identity();
226    /// let translation = array![1.0, 2.0, 3.0];
227    /// let transform = RigidTransform::from_rotation_and_translation(rotation.clone(), &translation.view()).expect("Operation failed");
228    /// let retrieved_rotation = transform.rotation();
229    /// ```
230    pub fn rotation(&self) -> &Rotation {
231        &self.rotation
232    }
233
234    /// Get the translation component of the rigid transform
235    ///
236    /// # Returns
237    ///
238    /// The translation vector
239    ///
240    /// # Examples
241    ///
242    /// ```
243    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
244    /// use scirs2_core::ndarray::array;
245    ///
246    /// let rotation = Rotation::identity();
247    /// let translation = array![1.0, 2.0, 3.0];
248    /// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view()).expect("Operation failed");
249    /// let retrieved_translation = transform.translation();
250    /// ```
251    pub fn translation(&self) -> &Array1<f64> {
252        &self.translation
253    }
254
255    /// Apply the rigid transform to a point or vector
256    ///
257    /// # Arguments
258    ///
259    /// * `point` - A 3D point or vector to transform
260    ///
261    /// # Returns
262    ///
263    /// The transformed point or vector
264    ///
265    /// # Examples
266    ///
267    /// ```
268    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
269    /// use scirs2_core::ndarray::array;
270    /// use std::f64::consts::PI;
271    ///
272    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
273    /// let rotation = Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz")?;
274    /// let translation = array![1.0, 2.0, 3.0];
275    /// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view())?;
276    /// let point = array![1.0, 0.0, 0.0];
277    /// let transformed = transform.apply(&point.view())?;
278    /// // Should be [1.0, 3.0, 3.0] (rotated then translated)
279    /// # Ok(())
280    /// # }
281    /// ```
282    pub fn apply(&self, point: &ArrayView1<f64>) -> SpatialResult<Array1<f64>> {
283        if point.len() != 3 {
284            return Err(SpatialError::DimensionError(
285                "Point must have 3 elements".to_string(),
286            ));
287        }
288
289        // Apply rotation then translation
290        let rotated = self.rotation.apply(point)?;
291        Ok(rotated + &self.translation)
292    }
293
294    /// Apply the rigid transform to multiple points
295    ///
296    /// # Arguments
297    ///
298    /// * `points` - A 2D array of points (each row is a 3D point)
299    ///
300    /// # Returns
301    ///
302    /// A 2D array of transformed points
303    ///
304    /// # Examples
305    ///
306    /// ```
307    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
308    /// use scirs2_core::ndarray::array;
309    ///
310    /// let rotation = Rotation::identity();
311    /// let translation = array![1.0, 2.0, 3.0];
312    /// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view()).expect("Operation failed");
313    /// let points = array![[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]];
314    /// let transformed = transform.apply_multiple(&points.view());
315    /// ```
316    pub fn apply_multiple(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array2<f64>> {
317        if points.ncols() != 3 {
318            return Err(SpatialError::DimensionError(
319                "Each point must have 3 elements".to_string(),
320            ));
321        }
322
323        let npoints = points.nrows();
324        let mut result = Array2::<f64>::zeros((npoints, 3));
325
326        for i in 0..npoints {
327            let point = points.row(i);
328            let transformed = self.apply(&point)?;
329            for j in 0..3 {
330                result[[i, j]] = transformed[j];
331            }
332        }
333
334        Ok(result)
335    }
336
337    /// Get the inverse of the rigid transform
338    ///
339    /// # Returns
340    ///
341    /// A new RigidTransform that is the inverse of this one
342    ///
343    /// # Examples
344    ///
345    /// ```
346    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
347    /// use scirs2_core::ndarray::array;
348    ///
349    /// let rotation = Rotation::identity();
350    /// let translation = array![1.0, 2.0, 3.0];
351    /// let transform = RigidTransform::from_rotation_and_translation(rotation, &translation.view()).expect("Operation failed");
352    /// let inverse = transform.inv();
353    /// ```
354    pub fn inv(&self) -> SpatialResult<RigidTransform> {
355        // Inverse of a rigid transform: R^-1, -R^-1 * t
356        let inv_rotation = self.rotation.inv();
357        let inv_translation = -inv_rotation.apply(&self.translation.view())?;
358
359        Ok(RigidTransform {
360            rotation: inv_rotation,
361            translation: inv_translation,
362        })
363    }
364
365    /// Compose this rigid transform with another (apply the other transform after this one)
366    ///
367    /// # Arguments
368    ///
369    /// * `other` - The other rigid transform to combine with
370    ///
371    /// # Returns
372    ///
373    /// A new RigidTransform that represents the composition
374    ///
375    /// # Examples
376    ///
377    /// ```
378    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
379    /// use scirs2_core::ndarray::array;
380    ///
381    /// let t1 = RigidTransform::from_rotation_and_translation(
382    ///     Rotation::identity(),
383    ///     &array![1.0, 0.0, 0.0].view()
384    /// ).expect("Operation failed");
385    /// let t2 = RigidTransform::from_rotation_and_translation(
386    ///     Rotation::identity(),
387    ///     &array![0.0, 1.0, 0.0].view()
388    /// ).expect("Operation failed");
389    /// let combined = t1.compose(&t2);
390    /// // Should have a translation of [1.0, 1.0, 0.0]
391    /// ```
392    pub fn compose(&self, other: &RigidTransform) -> SpatialResult<RigidTransform> {
393        // Compose rotations: T2(T1(p)) = R2*(R1*p + t1) + t2
394        // Rotation: R2 * R1 = other.rotation applied after self.rotation
395        let rotation = other.rotation.compose(&self.rotation);
396
397        // Translation: R2*t1 + t2
398        let rotated_trans = other.rotation.apply(&self.translation.view())?;
399        let translation = &rotated_trans + &other.translation;
400
401        Ok(RigidTransform {
402            rotation,
403            translation,
404        })
405    }
406
407    /// Create an identity rigid transform (no rotation, no translation)
408    ///
409    /// # Returns
410    ///
411    /// A new RigidTransform that represents identity
412    ///
413    /// # Examples
414    ///
415    /// ```
416    /// use scirs2_spatial::transform::RigidTransform;
417    /// use scirs2_core::ndarray::array;
418    ///
419    /// let identity = RigidTransform::identity();
420    /// let point = array![1.0, 2.0, 3.0];
421    /// let transformed = identity.apply(&point.view());
422    /// // Should still be [1.0, 2.0, 3.0]
423    /// ```
424    pub fn identity() -> RigidTransform {
425        RigidTransform {
426            rotation: Rotation::from_quat(&array![1.0, 0.0, 0.0, 0.0].view())
427                .expect("Operation failed"),
428            translation: Array1::<f64>::zeros(3),
429        }
430    }
431
432    /// Create a rigid transform that only has a translation component
433    ///
434    /// # Arguments
435    ///
436    /// * `translation` - The translation vector
437    ///
438    /// # Returns
439    ///
440    /// A new RigidTransform with no rotation
441    ///
442    /// # Examples
443    ///
444    /// ```
445    /// use scirs2_spatial::transform::RigidTransform;
446    /// use scirs2_core::ndarray::array;
447    ///
448    /// let transform = RigidTransform::from_translation(&array![1.0, 2.0, 3.0].view()).expect("Operation failed");
449    /// let point = array![0.0, 0.0, 0.0];
450    /// let transformed = transform.apply(&point.view());
451    /// // Should be [1.0, 2.0, 3.0]
452    /// ```
453    pub fn from_translation(translation: &ArrayView1<f64>) -> SpatialResult<RigidTransform> {
454        if translation.len() != 3 {
455            return Err(SpatialError::DimensionError(format!(
456                "Translation must have 3 elements, got {}",
457                translation.len()
458            )));
459        }
460
461        Ok(RigidTransform {
462            rotation: Rotation::from_quat(&array![1.0, 0.0, 0.0, 0.0].view())
463                .expect("Operation failed"),
464            translation: translation.to_owned(),
465        })
466    }
467
468    /// Create a rigid transform that only has a rotation component
469    ///
470    /// # Arguments
471    ///
472    /// * `rotation` - The rotation component
473    ///
474    /// # Returns
475    ///
476    /// A new RigidTransform with no translation
477    ///
478    /// # Examples
479    ///
480    /// ```
481    /// use scirs2_spatial::transform::{Rotation, RigidTransform};
482    /// use scirs2_core::ndarray::array;
483    /// use std::f64::consts::PI;
484    ///
485    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
486    /// let rotation = Rotation::from_euler(&array![0.0, 0.0, PI/2.0].view(), "xyz")?;
487    /// let transform = RigidTransform::from_rotation(rotation);
488    /// let point = array![1.0, 0.0, 0.0];
489    /// let transformed = transform.apply(&point.view())?;
490    /// // Should be [0.0, 1.0, 0.0]
491    /// # Ok(())
492    /// # }
493    /// ```
494    pub fn from_rotation(rotation: Rotation) -> RigidTransform {
495        RigidTransform {
496            rotation,
497            translation: Array1::<f64>::zeros(3),
498        }
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use approx::assert_relative_eq;
506    use std::f64::consts::PI;
507
508    #[test]
509    fn test_rigid_transform_identity() {
510        let identity = RigidTransform::identity();
511        let point = array![1.0, 2.0, 3.0];
512        let transformed = identity.apply(&point.view()).expect("Operation failed");
513
514        assert_relative_eq!(transformed[0], point[0], epsilon = 1e-10);
515        assert_relative_eq!(transformed[1], point[1], epsilon = 1e-10);
516        assert_relative_eq!(transformed[2], point[2], epsilon = 1e-10);
517    }
518
519    #[test]
520    fn test_rigid_transform_translation_only() {
521        let translation = array![1.0, 2.0, 3.0];
522        let transform =
523            RigidTransform::from_translation(&translation.view()).expect("Operation failed");
524
525        let point = array![0.0, 0.0, 0.0];
526        let transformed = transform.apply(&point.view()).expect("Operation failed");
527
528        assert_relative_eq!(transformed[0], translation[0], epsilon = 1e-10);
529        assert_relative_eq!(transformed[1], translation[1], epsilon = 1e-10);
530        assert_relative_eq!(transformed[2], translation[2], epsilon = 1e-10);
531    }
532
533    #[test]
534    fn test_rigid_transform_rotation_only() {
535        // 90 degrees rotation around Z axis
536        let rotation = rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed");
537        let transform = RigidTransform::from_rotation(rotation);
538
539        let point = array![1.0, 0.0, 0.0];
540        let transformed = transform.apply(&point.view()).expect("Operation failed");
541
542        // 90 degrees rotation around Z axis of [1, 0, 0] should give [0, 1, 0]
543        assert_relative_eq!(transformed[0], 0.0, epsilon = 1e-10);
544        assert_relative_eq!(transformed[1], 1.0, epsilon = 1e-10);
545        assert_relative_eq!(transformed[2], 0.0, epsilon = 1e-10);
546    }
547
548    #[test]
549    fn test_rigid_transform_rotation_and_translation() {
550        // 90 degrees rotation around Z axis and translation by [1, 2, 3]
551        let rotation = rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed");
552        let translation = array![1.0, 2.0, 3.0];
553        let transform =
554            RigidTransform::from_rotation_and_translation(rotation, &translation.view())
555                .expect("Operation failed");
556
557        let point = array![1.0, 0.0, 0.0];
558        let transformed = transform.apply(&point.view()).expect("Operation failed");
559
560        // 90 degrees rotation around Z axis of [1, 0, 0] should give [0, 1, 0]
561        // Then translate by [1, 2, 3] to get [1, 3, 3]
562        assert_relative_eq!(transformed[0], 1.0, epsilon = 1e-10);
563        assert_relative_eq!(transformed[1], 3.0, epsilon = 1e-10);
564        assert_relative_eq!(transformed[2], 3.0, epsilon = 1e-10);
565    }
566
567    #[test]
568    fn test_rigid_transform_from_matrix() {
569        let matrix = array![
570            [0.0, -1.0, 0.0, 1.0],
571            [1.0, 0.0, 0.0, 2.0],
572            [0.0, 0.0, 1.0, 3.0],
573            [0.0, 0.0, 0.0, 1.0]
574        ];
575        let transform = RigidTransform::from_matrix(&matrix.view()).expect("Operation failed");
576
577        let point = array![1.0, 0.0, 0.0];
578        let transformed = transform.apply(&point.view()).expect("Operation failed");
579
580        // This matrix represents a 90-degree rotation around Z and translation by [1, 2, 3]
581        // So [1, 0, 0] -> [0, 1, 0] -> [1, 3, 3]
582        assert_relative_eq!(transformed[0], 1.0, epsilon = 1e-10);
583        assert_relative_eq!(transformed[1], 3.0, epsilon = 1e-10);
584        assert_relative_eq!(transformed[2], 3.0, epsilon = 1e-10);
585    }
586
587    #[test]
588    fn test_rigid_transform_as_matrix() {
589        // Create a transform and verify its matrix representation
590        let rotation = rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed");
591        let translation = array![1.0, 2.0, 3.0];
592        let transform =
593            RigidTransform::from_rotation_and_translation(rotation, &translation.view())
594                .expect("Operation failed");
595
596        let matrix = transform.as_matrix();
597
598        // Check the rotation part (90-degree rotation around Z)
599        assert_relative_eq!(matrix[[0, 0]], 0.0, epsilon = 1e-10);
600        assert_relative_eq!(matrix[[0, 1]], -1.0, epsilon = 1e-10);
601        assert_relative_eq!(matrix[[0, 2]], 0.0, epsilon = 1e-10);
602        assert_relative_eq!(matrix[[1, 0]], 1.0, epsilon = 1e-10);
603        assert_relative_eq!(matrix[[1, 1]], 0.0, epsilon = 1e-10);
604        assert_relative_eq!(matrix[[1, 2]], 0.0, epsilon = 1e-10);
605        assert_relative_eq!(matrix[[2, 0]], 0.0, epsilon = 1e-10);
606        assert_relative_eq!(matrix[[2, 1]], 0.0, epsilon = 1e-10);
607        assert_relative_eq!(matrix[[2, 2]], 1.0, epsilon = 1e-10);
608
609        // Check the translation part
610        assert_relative_eq!(matrix[[0, 3]], 1.0, epsilon = 1e-10);
611        assert_relative_eq!(matrix[[1, 3]], 2.0, epsilon = 1e-10);
612        assert_relative_eq!(matrix[[2, 3]], 3.0, epsilon = 1e-10);
613
614        // Check the homogeneous row
615        assert_relative_eq!(matrix[[3, 0]], 0.0, epsilon = 1e-10);
616        assert_relative_eq!(matrix[[3, 1]], 0.0, epsilon = 1e-10);
617        assert_relative_eq!(matrix[[3, 2]], 0.0, epsilon = 1e-10);
618        assert_relative_eq!(matrix[[3, 3]], 1.0, epsilon = 1e-10);
619    }
620
621    #[test]
622    fn test_rigid_transform_inverse() {
623        // Create a transform and verify its inverse
624        let rotation = rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed");
625        let translation = array![1.0, 2.0, 3.0];
626        let transform =
627            RigidTransform::from_rotation_and_translation(rotation, &translation.view())
628                .expect("Operation failed");
629
630        let inverse = transform.inv().expect("Operation failed");
631
632        // Apply transform and then its inverse to a point
633        let point = array![1.0, 2.0, 3.0];
634        let transformed = transform.apply(&point.view()).expect("Operation failed");
635        let back = inverse
636            .apply(&transformed.view())
637            .expect("Operation failed");
638
639        // Should get back to the original point
640        assert_relative_eq!(back[0], point[0], epsilon = 1e-10);
641        assert_relative_eq!(back[1], point[1], epsilon = 1e-10);
642        assert_relative_eq!(back[2], point[2], epsilon = 1e-10);
643    }
644
645    #[test]
646    fn test_rigid_transform_composition() {
647        // Create two transforms and compose them
648        let t1 = RigidTransform::from_rotation_and_translation(
649            rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed"),
650            &array![1.0, 0.0, 0.0].view(),
651        )
652        .expect("Operation failed");
653
654        let t2 = RigidTransform::from_rotation_and_translation(
655            rotation_from_euler(PI / 2.0, 0.0, 0.0, "xyz").expect("Operation failed"),
656            &array![0.0, 1.0, 0.0].view(),
657        )
658        .expect("Operation failed");
659
660        let composed = t1.compose(&t2).expect("Operation failed");
661
662        // Apply the composed transform to a point
663        let point = array![1.0, 0.0, 0.0];
664        let transformed = composed.apply(&point.view()).expect("Operation failed");
665
666        // Apply the transforms individually
667        let intermediate = t1.apply(&point.view()).expect("Operation failed");
668        let transformed2 = t2.apply(&intermediate.view()).expect("Operation failed");
669
670        // The composed transform and individual transforms should produce the same result
671        assert_relative_eq!(transformed[0], transformed2[0], epsilon = 1e-10);
672        assert_relative_eq!(transformed[1], transformed2[1], epsilon = 1e-10);
673        assert_relative_eq!(transformed[2], transformed2[2], epsilon = 1e-10);
674    }
675
676    #[test]
677    fn test_rigid_transform_multiple_points() {
678        let rotation = rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed");
679        let translation = array![1.0, 2.0, 3.0];
680        let transform =
681            RigidTransform::from_rotation_and_translation(rotation, &translation.view())
682                .expect("Operation failed");
683
684        let points = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
685
686        let transformed = transform
687            .apply_multiple(&points.view())
688            .expect("Operation failed");
689
690        // Check that we get the correct transformed points
691        assert_eq!(transformed.shape(), points.shape());
692
693        // [1, 0, 0] -> [0, 1, 0] -> [1, 3, 3]
694        assert_relative_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
695        assert_relative_eq!(transformed[[0, 1]], 3.0, epsilon = 1e-10);
696        assert_relative_eq!(transformed[[0, 2]], 3.0, epsilon = 1e-10);
697
698        // [0, 1, 0] -> [-1, 0, 0] -> [0, 2, 3]
699        assert_relative_eq!(transformed[[1, 0]], 0.0, epsilon = 1e-10);
700        assert_relative_eq!(transformed[[1, 1]], 2.0, epsilon = 1e-10);
701        assert_relative_eq!(transformed[[1, 2]], 3.0, epsilon = 1e-10);
702
703        // [0, 0, 1] -> [0, 0, 1] -> [1, 2, 4]
704        assert_relative_eq!(transformed[[2, 0]], 1.0, epsilon = 1e-10);
705        assert_relative_eq!(transformed[[2, 1]], 2.0, epsilon = 1e-10);
706        assert_relative_eq!(transformed[[2, 2]], 4.0, epsilon = 1e-10);
707    }
708}