1use nalgebra::SVector;
28
29use crate::point::Point;
30use crate::shape::Shape;
31use crate::transform::Transform;
32
33pub struct CompoundShape<const D: usize> {
39 parts: Vec<(Transform<D>, Box<dyn Shape<D>>)>,
41 cached_center: Point<D>,
43 cached_radius: f64,
44}
45
46impl<const D: usize> Clone for CompoundShape<D> {
47 fn clone(&self) -> Self {
48 let parts = self
49 .parts
50 .iter()
51 .map(|(tf, child)| (tf.clone(), child.clone_box()))
52 .collect();
53 Self {
54 parts,
55 cached_center: self.cached_center,
56 cached_radius: self.cached_radius,
57 }
58 }
59}
60
61impl<const D: usize> CompoundShape<D> {
62 pub fn new() -> Self {
64 Self {
65 parts: Vec::new(),
66 cached_center: Point::origin(),
67 cached_radius: 0.0,
68 }
69 }
70
71 pub fn from_shape(shape: Box<dyn Shape<D>>) -> Self {
73 let mut s = Self::new();
74 s.add_child(Transform::identity(), shape);
75 s
76 }
77
78 pub fn add_child(&mut self, transform: Transform<D>, shape: Box<dyn Shape<D>>) {
82 self.parts.push((transform, shape));
83 self.recompute_bounding();
84 }
85
86 pub fn children(&self) -> &[(Transform<D>, Box<dyn Shape<D>>)] {
88 &self.parts
89 }
90
91 pub fn child_count(&self) -> usize {
93 self.parts.len()
94 }
95
96 fn recompute_bounding(&mut self) {
102 if self.parts.is_empty() {
103 self.cached_center = Point::origin();
104 self.cached_radius = 0.0;
105 return;
106 }
107
108 let mut center = SVector::<f64, D>::zeros();
110 for (tf, child) in &self.parts {
111 let (local_c, _) = child.bounding_sphere();
112 let world_c = tf.transform_point(&local_c).0;
113 center += world_c;
114 }
115 center /= self.parts.len() as f64;
116
117 let mut radius = 0.0f64;
119 for (tf, child) in &self.parts {
120 let (local_c, child_r) = child.bounding_sphere();
121 let world_c = tf.transform_point(&local_c).0;
122 let d = (world_c - center).norm() + child_r;
123 if d > radius {
124 radius = d;
125 }
126 }
127
128 self.cached_center = Point(center);
129 self.cached_radius = radius;
130 }
131}
132
133impl<const D: usize> Default for CompoundShape<D> {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl<const D: usize> Shape<D> for CompoundShape<D> {
140 fn support(&self, direction: &SVector<f64, D>) -> SVector<f64, D> {
148 self.parts
149 .iter()
150 .map(|(tf, child)| {
151 let local_dir = tf.rotation.reverse().rotate_vector(direction);
153 let local_pt = child.support(&local_dir);
155 tf.transform_point(&Point(local_pt)).0
157 })
158 .max_by(|a, b| {
159 let da = a.dot(direction);
160 let db = b.dot(direction);
161 da.total_cmp(&db)
162 })
163 .unwrap_or_else(SVector::zeros)
164 }
165
166 fn bounding_sphere(&self) -> (Point<D>, f64) {
167 (self.cached_center, self.cached_radius)
168 }
169
170 fn as_any(&self) -> &dyn std::any::Any {
171 self
172 }
173
174 fn clone_box(&self) -> Box<dyn Shape<D>> {
175 Box::new(self.clone())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::bivector::Bivector;
183 use crate::rotor::Rotor;
184 use crate::sphere::Sphere;
185
186 fn vec3(x: f64, y: f64, z: f64) -> SVector<f64, 3> {
187 SVector::from([x, y, z])
188 }
189
190 fn dumbbell() -> CompoundShape<3> {
192 let mut c = CompoundShape::new();
193 c.add_child(
194 Transform::from_translation(Point::new([-2.0, 0.0, 0.0])),
195 Box::new(Sphere::unit()),
196 );
197 c.add_child(
198 Transform::from_translation(Point::new([2.0, 0.0, 0.0])),
199 Box::new(Sphere::unit()),
200 );
201 c
202 }
203
204 #[test]
205 fn support_along_positive_x_reaches_far_child() {
206 let db = dumbbell();
207 let s = db.support(&vec3(1.0, 0.0, 0.0));
208 assert!((s[0] - 3.0).abs() < 1e-10, "support x = {}", s[0]);
210 }
211
212 #[test]
213 fn support_along_negative_x_reaches_near_child() {
214 let db = dumbbell();
215 let s = db.support(&vec3(-1.0, 0.0, 0.0));
216 assert!((s[0] - (-3.0)).abs() < 1e-10, "support x = {}", s[0]);
218 }
219
220 #[test]
221 fn support_along_y_is_symmetric() {
222 let db = dumbbell();
223 let s_pos = db.support(&vec3(0.0, 1.0, 0.0));
224 let s_neg = db.support(&vec3(0.0, -1.0, 0.0));
225 assert!((s_pos[1] - 1.0).abs() < 1e-10);
227 assert!((s_neg[1] - (-1.0)).abs() < 1e-10);
228 }
229
230 #[test]
231 fn bounding_sphere_covers_all_children() {
232 let db = dumbbell();
233 let (center, radius) = db.bounding_sphere();
234 assert!(center.0.norm() < 1e-10, "centroid should be at origin");
236 assert!(radius >= 3.0 - 1e-10, "radius = {}", radius);
238 }
239
240 #[test]
241 fn single_child_matches_original_shape() {
242 let r2 = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0);
243 let mut c = CompoundShape::<3>::new();
244 c.add_child(Transform::identity(), Box::new(r2));
245
246 let dir = vec3(1.0, 0.0, 0.0);
247 let s_compound = c.support(&dir);
248 let s_direct = Sphere::new(Point::new([0.0, 0.0, 0.0]), 2.0).support(&dir);
249 assert!((s_compound - s_direct).norm() < 1e-10);
250 }
251
252 #[test]
253 fn child_with_rotation_transforms_direction_correctly() {
254 let plane = Bivector::<3>::unit_plane(0, 1);
258 let rot = Rotor::from_plane_angle(&plane, std::f64::consts::FRAC_PI_2);
259 let tf = Transform {
260 translation: Point::new([1.0, 0.0, 0.0]),
261 rotation: rot,
262 };
263 let mut c = CompoundShape::<3>::new();
264 c.add_child(tf, Box::new(Sphere::unit()));
265
266 let s = c.support(&vec3(1.0, 0.0, 0.0));
268 assert!((s[0] - 2.0).abs() < 1e-10, "s[0] = {}", s[0]);
269 }
270
271 #[test]
272 fn empty_compound_returns_zero_support() {
273 let c = CompoundShape::<3>::new();
274 let s = c.support(&vec3(1.0, 0.0, 0.0));
275 assert!(s.norm() < 1e-10);
276 }
277
278 #[test]
279 fn bounding_sphere_empty_is_zero() {
280 let c = CompoundShape::<3>::new();
281 let (center, radius) = c.bounding_sphere();
282 assert!(center.0.norm() < 1e-10);
283 assert!(radius < 1e-10);
284 }
285
286 #[test]
287 fn works_in_2d() {
288 let mut c = CompoundShape::<2>::new();
289 c.add_child(
290 Transform::from_translation(Point::new([-1.0, 0.0])),
291 Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
292 );
293 c.add_child(
294 Transform::from_translation(Point::new([1.0, 0.0])),
295 Box::new(Sphere::new(Point::new([0.0, 0.0]), 0.5)),
296 );
297 let s = c.support(&SVector::<f64, 2>::from([1.0, 0.0]));
298 assert!((s[0] - 1.5).abs() < 1e-10);
299 }
300
301 #[test]
302 fn works_in_4d() {
303 let mut c = CompoundShape::<4>::new();
304 c.add_child(
305 Transform::from_translation(Point::new([0.0, 0.0, 0.0, 3.0])),
306 Box::new(Sphere::unit()),
307 );
308 let s = c.support(&SVector::<f64, 4>::from([0.0, 0.0, 0.0, 1.0]));
309 assert!((s[3] - 4.0).abs() < 1e-10);
310 }
311}