1use nalgebra::SVector;
28
29use crate::point::Point;
30use crate::shape::Shape;
31use crate::transform::Transform;
32
33pub struct CompoundShape<const D: usize> {
39 children: Vec<(Transform<D>, Box<dyn Shape<D>>)>,
41 cached_center: Point<D>,
43 cached_radius: f64,
44}
45
46impl<const D: usize> CompoundShape<D> {
47 pub fn new() -> Self {
49 Self {
50 children: Vec::new(),
51 cached_center: Point::origin(),
52 cached_radius: 0.0,
53 }
54 }
55
56 pub fn add_child(&mut self, transform: Transform<D>, shape: Box<dyn Shape<D>>) {
60 self.children.push((transform, shape));
61 self.recompute_bounding();
62 }
63
64 pub fn child_count(&self) -> usize {
66 self.children.len()
67 }
68
69 fn recompute_bounding(&mut self) {
75 if self.children.is_empty() {
76 self.cached_center = Point::origin();
77 self.cached_radius = 0.0;
78 return;
79 }
80
81 let mut center = SVector::<f64, D>::zeros();
83 for (tf, child) in &self.children {
84 let (local_c, _) = child.bounding_sphere();
85 let world_c = tf.transform_point(&local_c).0;
86 center += world_c;
87 }
88 center /= self.children.len() as f64;
89
90 let mut radius = 0.0f64;
92 for (tf, child) in &self.children {
93 let (local_c, child_r) = child.bounding_sphere();
94 let world_c = tf.transform_point(&local_c).0;
95 let d = (world_c - center).norm() + child_r;
96 if d > radius {
97 radius = d;
98 }
99 }
100
101 self.cached_center = Point(center);
102 self.cached_radius = radius;
103 }
104}
105
106impl<const D: usize> Default for CompoundShape<D> {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl<const D: usize> Shape<D> for CompoundShape<D> {
113 fn support(&self, direction: &SVector<f64, D>) -> SVector<f64, D> {
121 self.children
122 .iter()
123 .map(|(tf, child)| {
124 let local_dir = tf.rotation.reverse().rotate_vector(direction);
126 let local_pt = child.support(&local_dir);
128 tf.transform_point(&Point(local_pt)).0
130 })
131 .max_by(|a, b| {
132 let da = a.dot(direction);
133 let db = b.dot(direction);
134 da.total_cmp(&db)
135 })
136 .unwrap_or_else(SVector::zeros)
137 }
138
139 fn bounding_sphere(&self) -> (Point<D>, f64) {
140 (self.cached_center.clone(), self.cached_radius)
141 }
142
143 fn as_any(&self) -> &dyn std::any::Any {
144 self
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::bivector::Bivector;
152 use crate::rotor::Rotor;
153 use crate::sphere::Sphere;
154
155 fn vec3(x: f64, y: f64, z: f64) -> SVector<f64, 3> {
156 SVector::from([x, y, z])
157 }
158
159 fn dumbbell() -> CompoundShape<3> {
161 let mut c = CompoundShape::new();
162 c.add_child(
163 Transform::from_translation(Point::new([-2.0, 0.0, 0.0])),
164 Box::new(Sphere::unit()),
165 );
166 c.add_child(
167 Transform::from_translation(Point::new([2.0, 0.0, 0.0])),
168 Box::new(Sphere::unit()),
169 );
170 c
171 }
172
173 #[test]
174 fn support_along_positive_x_reaches_far_child() {
175 let db = dumbbell();
176 let s = db.support(&vec3(1.0, 0.0, 0.0));
177 assert!((s[0] - 3.0).abs() < 1e-10, "support x = {}", s[0]);
179 }
180
181 #[test]
182 fn support_along_negative_x_reaches_near_child() {
183 let db = dumbbell();
184 let s = db.support(&vec3(-1.0, 0.0, 0.0));
185 assert!((s[0] - (-3.0)).abs() < 1e-10, "support x = {}", s[0]);
187 }
188
189 #[test]
190 fn support_along_y_is_symmetric() {
191 let db = dumbbell();
192 let s_pos = db.support(&vec3(0.0, 1.0, 0.0));
193 let s_neg = db.support(&vec3(0.0, -1.0, 0.0));
194 assert!((s_pos[1] - 1.0).abs() < 1e-10);
196 assert!((s_neg[1] - (-1.0)).abs() < 1e-10);
197 }
198
199 #[test]
200 fn bounding_sphere_covers_all_children() {
201 let db = dumbbell();
202 let (center, radius) = db.bounding_sphere();
203 assert!(center.0.norm() < 1e-10, "centroid should be at origin");
205 assert!(radius >= 3.0 - 1e-10, "radius = {}", radius);
207 }
208
209 #[test]
210 fn single_child_matches_original_shape() {
211 let r2 = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0);
212 let mut c = CompoundShape::<3>::new();
213 c.add_child(Transform::identity(), Box::new(r2));
214
215 let dir = vec3(1.0, 0.0, 0.0);
216 let s_compound = c.support(&dir);
217 let s_direct = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0).support(&dir);
218 assert!((s_compound - s_direct).norm() < 1e-10);
219 }
220
221 #[test]
222 fn child_with_rotation_transforms_direction_correctly() {
223 let plane = Bivector::<3>::unit_plane(0, 1);
227 let rot = Rotor::from_plane_angle(&plane, std::f64::consts::FRAC_PI_2);
228 let tf = Transform {
229 translation: Point::new([1.0, 0.0, 0.0]),
230 rotation: rot,
231 };
232 let mut c = CompoundShape::<3>::new();
233 c.add_child(tf, Box::new(Sphere::unit()));
234
235 let s = c.support(&vec3(1.0, 0.0, 0.0));
237 assert!((s[0] - 2.0).abs() < 1e-10, "s[0] = {}", s[0]);
238 }
239
240 #[test]
241 fn empty_compound_returns_zero_support() {
242 let c = CompoundShape::<3>::new();
243 let s = c.support(&vec3(1.0, 0.0, 0.0));
244 assert!(s.norm() < 1e-10);
245 }
246
247 #[test]
248 fn bounding_sphere_empty_is_zero() {
249 let c = CompoundShape::<3>::new();
250 let (center, radius) = c.bounding_sphere();
251 assert!(center.0.norm() < 1e-10);
252 assert!(radius < 1e-10);
253 }
254
255 #[test]
256 fn works_in_2d() {
257 let mut c = CompoundShape::<2>::new();
258 c.add_child(
259 Transform::from_translation(Point::new([-1.0, 0.0])),
260 Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
261 );
262 c.add_child(
263 Transform::from_translation(Point::new([1.0, 0.0])),
264 Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
265 );
266 let s = c.support(&SVector::<f64, 2>::from([1.0, 0.0]));
267 assert!((s[0] - 1.5).abs() < 1e-10);
268 }
269
270 #[test]
271 fn works_in_4d() {
272 let mut c = CompoundShape::<4>::new();
273 c.add_child(
274 Transform::from_translation(Point::new([0.0, 0.0, 0.0, 3.0])),
275 Box::new(Sphere::unit()),
276 );
277 let s = c.support(&SVector::<f64, 4>::from([0.0, 0.0, 0.0, 1.0]));
278 assert!((s[3] - 4.0).abs() < 1e-10);
279 }
280}