Skip to main content

symtropy_math/
compound.rs

1// Copyright (C) 2024-2026 Tristan Stoltz / Luminous Dynamics
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3// Commercial licensing: see COMMERCIAL_LICENSE.md at repository root
4//! Compound shape: a rigid body composed of multiple convex sub-shapes.
5//!
6//! Each child is placed at a fixed offset (Transform) relative to the
7//! compound origin. The support function reduces to an argmax over all
8//! children's individual support points — O(n_children) per GJK query,
9//! which is acceptable since compound bodies typically have ≤16 parts.
10//!
11//! # Usage
12//! ```
13//! use symtropy_math::{CompoundShape, Point, Sphere, Transform};
14//!
15//! // A dumbbell: two spheres connected along the x axis.
16//! let mut compound = CompoundShape::<3>::new();
17//! compound.add_child(
18//!     Transform::from_translation(Point::new([-2.0, 0.0, 0.0])),
19//!     Box::new(Sphere::unit()),
20//! );
21//! compound.add_child(
22//!     Transform::from_translation(Point::new([2.0, 0.0, 0.0])),
23//!     Box::new(Sphere::unit()),
24//! );
25//! ```
26
27use nalgebra::SVector;
28
29use crate::point::Point;
30use crate::shape::Shape;
31use crate::transform::Transform;
32
33/// A convex rigid body composed of multiple child shapes.
34///
35/// The overall support function is the argmax of all children's support
36/// points transformed into the compound's local frame. This makes compound
37/// shapes first-class `Shape<D>` objects compatible with GJK.
38pub struct CompoundShape<const D: usize> {
39    /// Children: (local-frame transform, shape).
40    parts: Vec<(Transform<D>, Box<dyn Shape<D>>)>,
41    /// Cached bounding sphere: (center-in-local-frame, radius).
42    cached_center: Point<D>,
43    cached_radius: f64,
44}
45
46impl<const D: usize> Clone for CompoundShape<D> {
47    fn clone(&self) -> Self {
48        let parts = self
49            .parts
50            .iter()
51            .map(|(tf, child)| (tf.clone(), child.clone_box()))
52            .collect();
53        Self {
54            parts,
55            cached_center: self.cached_center,
56            cached_radius: self.cached_radius,
57        }
58    }
59}
60
61impl<const D: usize> CompoundShape<D> {
62    /// Create an empty compound shape.
63    pub fn new() -> Self {
64        Self {
65            parts: Vec::new(),
66            cached_center: Point::origin(),
67            cached_radius: 0.0,
68        }
69    }
70
71    /// Create a compound shape from a single shape.
72    pub fn from_shape(shape: Box<dyn Shape<D>>) -> Self {
73        let mut s = Self::new();
74        s.add_child(Transform::identity(), shape);
75        s
76    }
77
78    /// Add a child shape at the given local-frame transform.
79    ///
80    /// Recomputes the cached bounding sphere after each addition.
81    pub fn add_child(&mut self, transform: Transform<D>, shape: Box<dyn Shape<D>>) {
82        self.parts.push((transform, shape));
83        self.recompute_bounding();
84    }
85
86    /// Access the child shapes and their local transforms.
87    pub fn children(&self) -> &[(Transform<D>, Box<dyn Shape<D>>)] {
88        &self.parts
89    }
90
91    /// Number of child shapes.
92    pub fn child_count(&self) -> usize {
93        self.parts.len()
94    }
95
96    /// Recompute the enclosing bounding sphere over all children.
97    ///
98    /// Algorithm:
99    /// 1. Centroid = mean of all child bounding-sphere centers in local frame.
100    /// 2. Radius = max(dist(centroid, child_center_world) + child_radius).
101    fn recompute_bounding(&mut self) {
102        if self.parts.is_empty() {
103            self.cached_center = Point::origin();
104            self.cached_radius = 0.0;
105            return;
106        }
107
108        // Step 1: centroid of all child sphere centers (in compound local frame)
109        let mut center = SVector::<f64, D>::zeros();
110        for (tf, child) in &self.parts {
111            let (local_c, _) = child.bounding_sphere();
112            let world_c = tf.transform_point(&local_c).0;
113            center += world_c;
114        }
115        center /= self.parts.len() as f64;
116
117        // Step 2: tightest enclosing sphere around centroid
118        let mut radius = 0.0f64;
119        for (tf, child) in &self.parts {
120            let (local_c, child_r) = child.bounding_sphere();
121            let world_c = tf.transform_point(&local_c).0;
122            let d = (world_c - center).norm() + child_r;
123            if d > radius {
124                radius = d;
125            }
126        }
127
128        self.cached_center = Point(center);
129        self.cached_radius = radius;
130    }
131}
132
133impl<const D: usize> Default for CompoundShape<D> {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl<const D: usize> Shape<D> for CompoundShape<D> {
140    /// Support function: furthest point on any child in the given direction.
141    ///
142    /// For each child:
143    ///   1. Rotate `direction` into child's local frame (inverse rotation).
144    ///   2. Query child's support in that local direction.
145    ///   3. Transform the result back to compound-local frame.
146    ///   4. Return the child whose result has the maximum dot with `direction`.
147    fn support(&self, direction: &SVector<f64, D>) -> SVector<f64, D> {
148        self.parts
149            .iter()
150            .map(|(tf, child)| {
151                // Rotate direction into child's local frame
152                let local_dir = tf.rotation.reverse().rotate_vector(direction);
153                // Support in child's local frame
154                let local_pt = child.support(&local_dir);
155                // Transform back to compound's frame (rotate + translate)
156                tf.transform_point(&Point(local_pt)).0
157            })
158            .max_by(|a, b| {
159                let da = a.dot(direction);
160                let db = b.dot(direction);
161                da.total_cmp(&db)
162            })
163            .unwrap_or_else(SVector::zeros)
164    }
165
166    fn bounding_sphere(&self) -> (Point<D>, f64) {
167        (self.cached_center, self.cached_radius)
168    }
169
170    fn as_any(&self) -> &dyn std::any::Any {
171        self
172    }
173
174    fn clone_box(&self) -> Box<dyn Shape<D>> {
175        Box::new(self.clone())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::bivector::Bivector;
183    use crate::rotor::Rotor;
184    use crate::sphere::Sphere;
185
186    fn vec3(x: f64, y: f64, z: f64) -> SVector<f64, 3> {
187        SVector::from([x, y, z])
188    }
189
190    /// Dumbbell: two unit spheres placed ±2 along X.
191    fn dumbbell() -> CompoundShape<3> {
192        let mut c = CompoundShape::new();
193        c.add_child(
194            Transform::from_translation(Point::new([-2.0, 0.0, 0.0])),
195            Box::new(Sphere::unit()),
196        );
197        c.add_child(
198            Transform::from_translation(Point::new([2.0, 0.0, 0.0])),
199            Box::new(Sphere::unit()),
200        );
201        c
202    }
203
204    #[test]
205    fn support_along_positive_x_reaches_far_child() {
206        let db = dumbbell();
207        let s = db.support(&vec3(1.0, 0.0, 0.0));
208        // Right sphere center at x=2, radius=1 → support at x=3
209        assert!((s[0] - 3.0).abs() < 1e-10, "support x = {}", s[0]);
210    }
211
212    #[test]
213    fn support_along_negative_x_reaches_near_child() {
214        let db = dumbbell();
215        let s = db.support(&vec3(-1.0, 0.0, 0.0));
216        // Left sphere center at x=-2, radius=1 → support at x=-3
217        assert!((s[0] - (-3.0)).abs() < 1e-10, "support x = {}", s[0]);
218    }
219
220    #[test]
221    fn support_along_y_is_symmetric() {
222        let db = dumbbell();
223        let s_pos = db.support(&vec3(0.0, 1.0, 0.0));
224        let s_neg = db.support(&vec3(0.0, -1.0, 0.0));
225        // Both spheres have radius 1; y component should be ±1
226        assert!((s_pos[1] - 1.0).abs() < 1e-10);
227        assert!((s_neg[1] - (-1.0)).abs() < 1e-10);
228    }
229
230    #[test]
231    fn bounding_sphere_covers_all_children() {
232        let db = dumbbell();
233        let (center, radius) = db.bounding_sphere();
234        // Centroid is at origin (symmetric dumbbell)
235        assert!(center.0.norm() < 1e-10, "centroid should be at origin");
236        // Farthest point: sphere at x=2, radius=1 → dist 2 + 1 = 3
237        assert!(radius >= 3.0 - 1e-10, "radius = {}", radius);
238    }
239
240    #[test]
241    fn single_child_matches_original_shape() {
242        let r2 = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0);
243        let mut c = CompoundShape::<3>::new();
244        c.add_child(Transform::identity(), Box::new(r2));
245
246        let dir = vec3(1.0, 0.0, 0.0);
247        let s_compound = c.support(&dir);
248        let s_direct = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0).support(&dir);
249        assert!((s_compound - s_direct).norm() < 1e-10);
250    }
251
252    #[test]
253    fn child_with_rotation_transforms_direction_correctly() {
254        // Sphere at origin with 90° XY rotation applied.
255        // A sphere is isotropic, so rotation shouldn't matter — but transform_point
256        // must not produce garbage for a rotated child offset.
257        let plane = Bivector::<3>::unit_plane(0, 1);
258        let rot = Rotor::from_plane_angle(&plane, std::f64::consts::FRAC_PI_2);
259        let tf = Transform {
260            translation: Point::new([1.0, 0.0, 0.0]),
261            rotation: rot,
262        };
263        let mut c = CompoundShape::<3>::new();
264        c.add_child(tf, Box::new(Sphere::unit()));
265
266        // Support along +X: sphere center at x=1, radius=1 → x=2
267        let s = c.support(&vec3(1.0, 0.0, 0.0));
268        assert!((s[0] - 2.0).abs() < 1e-10, "s[0] = {}", s[0]);
269    }
270
271    #[test]
272    fn empty_compound_returns_zero_support() {
273        let c = CompoundShape::<3>::new();
274        let s = c.support(&vec3(1.0, 0.0, 0.0));
275        assert!(s.norm() < 1e-10);
276    }
277
278    #[test]
279    fn bounding_sphere_empty_is_zero() {
280        let c = CompoundShape::<3>::new();
281        let (center, radius) = c.bounding_sphere();
282        assert!(center.0.norm() < 1e-10);
283        assert!(radius < 1e-10);
284    }
285
286    #[test]
287    fn works_in_2d() {
288        let mut c = CompoundShape::<2>::new();
289        c.add_child(
290            Transform::from_translation(Point::new([-1.0, 0.0])),
291            Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
292        );
293        c.add_child(
294            Transform::from_translation(Point::new([1.0, 0.0])),
295            Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
296        );
297        let s = c.support(&SVector::<f64, 2>::from([1.0, 0.0]));
298        assert!((s[0] - 1.5).abs() < 1e-10);
299    }
300
301    #[test]
302    fn works_in_4d() {
303        let mut c = CompoundShape::<4>::new();
304        c.add_child(
305            Transform::from_translation(Point::new([0.0, 0.0, 0.0, 3.0])),
306            Box::new(Sphere::unit()),
307        );
308        let s = c.support(&SVector::<f64, 4>::from([0.0, 0.0, 0.0, 1.0]));
309        assert!((s[3] - 4.0).abs() < 1e-10);
310    }
311}