Skip to main content

symtropy_math/
compound.rs

1// Copyright (C) 2024-2026 Tristan Stoltz / Luminous Dynamics
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing: see COMMERCIAL_LICENSE.md at repository root
4//! Compound shape: a rigid body composed of multiple convex sub-shapes.
5//!
6//! Each child is placed at a fixed offset (Transform) relative to the
7//! compound origin. The support function reduces to an argmax over all
8//! children's individual support points — O(n_children) per GJK query,
9//! which is acceptable since compound bodies typically have ≤16 parts.
10//!
11//! # Usage
12//! ```
13//! use symtropy_math::{CompoundShape, Point, Sphere, Transform};
14//!
15//! // A dumbbell: two spheres connected along the x axis.
16//! let mut compound = CompoundShape::<3>::new();
17//! compound.add_child(
18//!     Transform::from_translation(Point::new([-2.0, 0.0, 0.0])),
19//!     Box::new(Sphere::unit()),
20//! );
21//! compound.add_child(
22//!     Transform::from_translation(Point::new([2.0, 0.0, 0.0])),
23//!     Box::new(Sphere::unit()),
24//! );
25//! ```
26
27use nalgebra::SVector;
28
29use crate::point::Point;
30use crate::shape::Shape;
31use crate::transform::Transform;
32
33/// A convex rigid body composed of multiple child shapes.
34///
35/// The overall support function is the argmax of all children's support
36/// points transformed into the compound's local frame. This makes compound
37/// shapes first-class `Shape<D>` objects compatible with GJK.
38pub struct CompoundShape<const D: usize> {
39    /// Children: (local-frame transform, shape).
40    children: Vec<(Transform<D>, Box<dyn Shape<D>>)>,
41    /// Cached bounding sphere: (center-in-local-frame, radius).
42    cached_center: Point<D>,
43    cached_radius: f64,
44}
45
46impl<const D: usize> CompoundShape<D> {
47    /// Create an empty compound shape.
48    pub fn new() -> Self {
49        Self {
50            children: Vec::new(),
51            cached_center: Point::origin(),
52            cached_radius: 0.0,
53        }
54    }
55
56    /// Add a child shape at the given local-frame transform.
57    ///
58    /// Recomputes the cached bounding sphere after each addition.
59    pub fn add_child(&mut self, transform: Transform<D>, shape: Box<dyn Shape<D>>) {
60        self.children.push((transform, shape));
61        self.recompute_bounding();
62    }
63
64    /// Number of child shapes.
65    pub fn child_count(&self) -> usize {
66        self.children.len()
67    }
68
69    /// Recompute the enclosing bounding sphere over all children.
70    ///
71    /// Algorithm:
72    /// 1. Centroid = mean of all child bounding-sphere centers in local frame.
73    /// 2. Radius = max(dist(centroid, child_center_world) + child_radius).
74    fn recompute_bounding(&mut self) {
75        if self.children.is_empty() {
76            self.cached_center = Point::origin();
77            self.cached_radius = 0.0;
78            return;
79        }
80
81        // Step 1: centroid of all child sphere centers (in compound local frame)
82        let mut center = SVector::<f64, D>::zeros();
83        for (tf, child) in &self.children {
84            let (local_c, _) = child.bounding_sphere();
85            let world_c = tf.transform_point(&local_c).0;
86            center += world_c;
87        }
88        center /= self.children.len() as f64;
89
90        // Step 2: tightest enclosing sphere around centroid
91        let mut radius = 0.0f64;
92        for (tf, child) in &self.children {
93            let (local_c, child_r) = child.bounding_sphere();
94            let world_c = tf.transform_point(&local_c).0;
95            let d = (world_c - center).norm() + child_r;
96            if d > radius {
97                radius = d;
98            }
99        }
100
101        self.cached_center = Point(center);
102        self.cached_radius = radius;
103    }
104}
105
106impl<const D: usize> Default for CompoundShape<D> {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl<const D: usize> Shape<D> for CompoundShape<D> {
113    /// Support function: furthest point on any child in the given direction.
114    ///
115    /// For each child:
116    ///   1. Rotate `direction` into child's local frame (inverse rotation).
117    ///   2. Query child's support in that local direction.
118    ///   3. Transform the result back to compound-local frame.
119    ///   4. Return the child whose result has the maximum dot with `direction`.
120    fn support(&self, direction: &SVector<f64, D>) -> SVector<f64, D> {
121        self.children
122            .iter()
123            .map(|(tf, child)| {
124                // Rotate direction into child's local frame
125                let local_dir = tf.rotation.reverse().rotate_vector(direction);
126                // Support in child's local frame
127                let local_pt = child.support(&local_dir);
128                // Transform back to compound's frame (rotate + translate)
129                tf.transform_point(&Point(local_pt)).0
130            })
131            .max_by(|a, b| {
132                let da = a.dot(direction);
133                let db = b.dot(direction);
134                da.total_cmp(&db)
135            })
136            .unwrap_or_else(SVector::zeros)
137    }
138
139    fn bounding_sphere(&self) -> (Point<D>, f64) {
140        (self.cached_center.clone(), self.cached_radius)
141    }
142
143    fn as_any(&self) -> &dyn std::any::Any {
144        self
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::bivector::Bivector;
152    use crate::rotor::Rotor;
153    use crate::sphere::Sphere;
154
155    fn vec3(x: f64, y: f64, z: f64) -> SVector<f64, 3> {
156        SVector::from([x, y, z])
157    }
158
159    /// Dumbbell: two unit spheres placed ±2 along X.
160    fn dumbbell() -> CompoundShape<3> {
161        let mut c = CompoundShape::new();
162        c.add_child(
163            Transform::from_translation(Point::new([-2.0, 0.0, 0.0])),
164            Box::new(Sphere::unit()),
165        );
166        c.add_child(
167            Transform::from_translation(Point::new([2.0, 0.0, 0.0])),
168            Box::new(Sphere::unit()),
169        );
170        c
171    }
172
173    #[test]
174    fn support_along_positive_x_reaches_far_child() {
175        let db = dumbbell();
176        let s = db.support(&vec3(1.0, 0.0, 0.0));
177        // Right sphere center at x=2, radius=1 → support at x=3
178        assert!((s[0] - 3.0).abs() < 1e-10, "support x = {}", s[0]);
179    }
180
181    #[test]
182    fn support_along_negative_x_reaches_near_child() {
183        let db = dumbbell();
184        let s = db.support(&vec3(-1.0, 0.0, 0.0));
185        // Left sphere center at x=-2, radius=1 → support at x=-3
186        assert!((s[0] - (-3.0)).abs() < 1e-10, "support x = {}", s[0]);
187    }
188
189    #[test]
190    fn support_along_y_is_symmetric() {
191        let db = dumbbell();
192        let s_pos = db.support(&vec3(0.0, 1.0, 0.0));
193        let s_neg = db.support(&vec3(0.0, -1.0, 0.0));
194        // Both spheres have radius 1; y component should be ±1
195        assert!((s_pos[1] - 1.0).abs() < 1e-10);
196        assert!((s_neg[1] - (-1.0)).abs() < 1e-10);
197    }
198
199    #[test]
200    fn bounding_sphere_covers_all_children() {
201        let db = dumbbell();
202        let (center, radius) = db.bounding_sphere();
203        // Centroid is at origin (symmetric dumbbell)
204        assert!(center.0.norm() < 1e-10, "centroid should be at origin");
205        // Farthest point: sphere at x=2, radius=1 → dist 2 + 1 = 3
206        assert!(radius >= 3.0 - 1e-10, "radius = {}", radius);
207    }
208
209    #[test]
210    fn single_child_matches_original_shape() {
211        let r2 = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0);
212        let mut c = CompoundShape::<3>::new();
213        c.add_child(Transform::identity(), Box::new(r2));
214
215        let dir = vec3(1.0, 0.0, 0.0);
216        let s_compound = c.support(&dir);
217        let s_direct = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0).support(&dir);
218        assert!((s_compound - s_direct).norm() < 1e-10);
219    }
220
221    #[test]
222    fn child_with_rotation_transforms_direction_correctly() {
223        // Sphere at origin with 90° XY rotation applied.
224        // A sphere is isotropic, so rotation shouldn't matter — but transform_point
225        // must not produce garbage for a rotated child offset.
226        let plane = Bivector::<3>::unit_plane(0, 1);
227        let rot = Rotor::from_plane_angle(&plane, std::f64::consts::FRAC_PI_2);
228        let tf = Transform {
229            translation: Point::new([1.0, 0.0, 0.0]),
230            rotation: rot,
231        };
232        let mut c = CompoundShape::<3>::new();
233        c.add_child(tf, Box::new(Sphere::unit()));
234
235        // Support along +X: sphere center at x=1, radius=1 → x=2
236        let s = c.support(&vec3(1.0, 0.0, 0.0));
237        assert!((s[0] - 2.0).abs() < 1e-10, "s[0] = {}", s[0]);
238    }
239
240    #[test]
241    fn empty_compound_returns_zero_support() {
242        let c = CompoundShape::<3>::new();
243        let s = c.support(&vec3(1.0, 0.0, 0.0));
244        assert!(s.norm() < 1e-10);
245    }
246
247    #[test]
248    fn bounding_sphere_empty_is_zero() {
249        let c = CompoundShape::<3>::new();
250        let (center, radius) = c.bounding_sphere();
251        assert!(center.0.norm() < 1e-10);
252        assert!(radius < 1e-10);
253    }
254
255    #[test]
256    fn works_in_2d() {
257        let mut c = CompoundShape::<2>::new();
258        c.add_child(
259            Transform::from_translation(Point::new([-1.0, 0.0])),
260            Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
261        );
262        c.add_child(
263            Transform::from_translation(Point::new([1.0, 0.0])),
264            Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
265        );
266        let s = c.support(&SVector::<f64, 2>::from([1.0, 0.0]));
267        assert!((s[0] - 1.5).abs() < 1e-10);
268    }
269
270    #[test]
271    fn works_in_4d() {
272        let mut c = CompoundShape::<4>::new();
273        c.add_child(
274            Transform::from_translation(Point::new([0.0, 0.0, 0.0, 3.0])),
275            Box::new(Sphere::unit()),
276        );
277        let s = c.support(&SVector::<f64, 4>::from([0.0, 0.0, 0.0, 1.0]));
278        assert!((s[3] - 4.0).abs() < 1e-10);
279    }
280}