Skip to main content

sphereql_core/
regions.rs

1use std::f64::consts::{PI, TAU};
2
3use crate::distance::angular_distance;
4use crate::error::SphereQlError;
5use crate::types::SphericalPoint;
6
7pub trait Contains {
8    fn contains(&self, point: &SphericalPoint) -> bool;
9}
10
11#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
12pub struct Cone {
13    pub apex: SphericalPoint,
14    pub axis: SphericalPoint,
15    pub half_angle: f64,
16}
17
18impl Cone {
19    pub fn new(
20        apex: SphericalPoint,
21        axis: SphericalPoint,
22        half_angle: f64,
23    ) -> Result<Self, SphereQlError> {
24        if half_angle <= 0.0 || half_angle > PI {
25            return Err(SphereQlError::InvalidConeAngle(half_angle));
26        }
27        Ok(Self {
28            apex,
29            axis,
30            half_angle,
31        })
32    }
33}
34
35impl Contains for Cone {
36    fn contains(&self, point: &SphericalPoint) -> bool {
37        let point_unit = SphericalPoint::new_unchecked(1.0, point.theta, point.phi);
38        let axis_unit = SphericalPoint::new_unchecked(1.0, self.axis.theta, self.axis.phi);
39        angular_distance(&point_unit, &axis_unit) <= self.half_angle
40    }
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
44pub struct Cap {
45    pub center: SphericalPoint,
46    pub half_angle: f64,
47}
48
49impl Cap {
50    pub fn new(center: SphericalPoint, half_angle: f64) -> Result<Self, SphereQlError> {
51        if half_angle <= 0.0 || half_angle > PI {
52            return Err(SphereQlError::InvalidCapAngle(half_angle));
53        }
54        Ok(Self { center, half_angle })
55    }
56}
57
58impl Contains for Cap {
59    fn contains(&self, point: &SphericalPoint) -> bool {
60        let point_unit = SphericalPoint::new_unchecked(1.0, point.theta, point.phi);
61        let center_unit = SphericalPoint::new_unchecked(1.0, self.center.theta, self.center.phi);
62        angular_distance(&point_unit, &center_unit) <= self.half_angle
63    }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
67pub struct Shell {
68    pub inner: f64,
69    pub outer: f64,
70}
71
72impl Shell {
73    pub fn new(inner: f64, outer: f64) -> Result<Self, SphereQlError> {
74        if inner < 0.0 || inner >= outer {
75            return Err(SphereQlError::InvalidShellBounds { inner, outer });
76        }
77        Ok(Self { inner, outer })
78    }
79}
80
81impl Contains for Shell {
82    fn contains(&self, point: &SphericalPoint) -> bool {
83        point.r >= self.inner && point.r <= self.outer
84    }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
88pub struct Band {
89    pub phi_min: f64,
90    pub phi_max: f64,
91}
92
93impl Band {
94    pub fn new(phi_min: f64, phi_max: f64) -> Result<Self, SphereQlError> {
95        if phi_min < 0.0 || phi_min >= phi_max || phi_max > PI {
96            return Err(SphereQlError::InvalidBandBounds { phi_min, phi_max });
97        }
98        Ok(Self { phi_min, phi_max })
99    }
100}
101
102impl Contains for Band {
103    fn contains(&self, point: &SphericalPoint) -> bool {
104        point.phi >= self.phi_min && point.phi <= self.phi_max
105    }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
109pub struct Wedge {
110    pub theta_min: f64,
111    pub theta_max: f64,
112}
113
114impl Wedge {
115    pub fn new(theta_min: f64, theta_max: f64) -> Result<Self, SphereQlError> {
116        if !(0.0..TAU).contains(&theta_min)
117            || !(0.0..TAU).contains(&theta_max)
118            || theta_min == theta_max
119        {
120            return Err(SphereQlError::InvalidWedgeBounds {
121                theta_min,
122                theta_max,
123            });
124        }
125        Ok(Self {
126            theta_min,
127            theta_max,
128        })
129    }
130
131    fn wraps(&self) -> bool {
132        self.theta_min > self.theta_max
133    }
134}
135
136impl Contains for Wedge {
137    fn contains(&self, point: &SphericalPoint) -> bool {
138        if self.wraps() {
139            point.theta >= self.theta_min || point.theta <= self.theta_max
140        } else {
141            point.theta >= self.theta_min && point.theta <= self.theta_max
142        }
143    }
144}
145
146#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
147pub enum Region {
148    Cone(Cone),
149    Cap(Cap),
150    Shell(Shell),
151    Band(Band),
152    Wedge(Wedge),
153    Intersection(Vec<Region>),
154    Union(Vec<Region>),
155}
156
157impl Region {
158    pub fn intersection(regions: Vec<Region>) -> Self {
159        Region::Intersection(regions)
160    }
161
162    pub fn union(regions: Vec<Region>) -> Self {
163        Region::Union(regions)
164    }
165}
166
167impl Contains for Region {
168    fn contains(&self, point: &SphericalPoint) -> bool {
169        match self {
170            Region::Cone(c) => c.contains(point),
171            Region::Cap(c) => c.contains(point),
172            Region::Shell(s) => s.contains(point),
173            Region::Band(b) => b.contains(point),
174            Region::Wedge(w) => w.contains(point),
175            Region::Intersection(regions) => regions.iter().all(|r| r.contains(point)),
176            Region::Union(regions) => regions.iter().any(|r| r.contains(point)),
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
185
186    fn point(r: f64, theta: f64, phi: f64) -> SphericalPoint {
187        SphericalPoint::new_unchecked(r, theta, phi)
188    }
189
190    // --- Cone tests ---
191
192    #[test]
193    fn cone_contains_point_inside() {
194        let cone = Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, FRAC_PI_4), FRAC_PI_2).unwrap();
195        let p = point(2.0, 0.0, FRAC_PI_4);
196        assert!(cone.contains(&p));
197    }
198
199    #[test]
200    fn cone_excludes_point_outside() {
201        let cone = Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, 0.1), 0.05).unwrap();
202        let p = point(1.0, PI, FRAC_PI_2);
203        assert!(!cone.contains(&p));
204    }
205
206    #[test]
207    fn cone_contains_point_on_boundary() {
208        let half = FRAC_PI_4;
209        let cone = Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, 0.0), half).unwrap();
210        // axis is at phi=0 (north pole), point at phi=half_angle should be on boundary
211        let p = point(1.0, 0.0, half);
212        assert!(cone.contains(&p));
213    }
214
215    #[test]
216    fn cone_various_half_angles() {
217        // Narrow cone
218        let narrow = Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, FRAC_PI_2), 0.01).unwrap();
219        let near = point(1.0, 0.0, FRAC_PI_2);
220        let far = point(1.0, 0.0, FRAC_PI_2 + 0.1);
221        assert!(narrow.contains(&near));
222        assert!(!narrow.contains(&far));
223
224        // Full hemisphere
225        let wide = Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, FRAC_PI_2), FRAC_PI_2).unwrap();
226        assert!(wide.contains(&point(1.0, 0.5, FRAC_PI_2 + 0.3)));
227    }
228
229    #[test]
230    fn cone_invalid_half_angle() {
231        assert!(Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, 0.0), 0.0).is_err());
232        assert!(Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, 0.0), -0.1).is_err());
233        assert!(Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, 0.0), PI + 0.1).is_err());
234        // PI is valid
235        assert!(Cone::new(point(0.0, 0.0, 0.0), point(1.0, 0.0, 0.0), PI).is_ok());
236    }
237
238    // --- Cap tests ---
239
240    #[test]
241    fn cap_contains_point_inside() {
242        let cap = Cap::new(point(1.0, 0.0, FRAC_PI_2), FRAC_PI_4).unwrap();
243        let p = point(5.0, 0.0, FRAC_PI_2);
244        assert!(cap.contains(&p));
245    }
246
247    #[test]
248    fn cap_excludes_point_outside() {
249        let cap = Cap::new(point(1.0, 0.0, 0.1), 0.05).unwrap();
250        let p = point(1.0, PI, PI - 0.1);
251        assert!(!cap.contains(&p));
252    }
253
254    #[test]
255    fn cap_ignores_radius() {
256        let cap = Cap::new(point(1.0, 0.0, FRAC_PI_2), FRAC_PI_4).unwrap();
257        let near = point(0.1, 0.0, FRAC_PI_2);
258        let far = point(1000.0, 0.0, FRAC_PI_2);
259        assert!(cap.contains(&near));
260        assert!(cap.contains(&far));
261    }
262
263    #[test]
264    fn cap_invalid_half_angle() {
265        assert!(Cap::new(point(1.0, 0.0, 0.0), 0.0).is_err());
266        assert!(Cap::new(point(1.0, 0.0, 0.0), -1.0).is_err());
267    }
268
269    #[test]
270    fn cap_error_is_cap_specific() {
271        let err = Cap::new(point(1.0, 0.0, 0.0), 0.0).unwrap_err();
272        assert!(
273            matches!(err, SphereQlError::InvalidCapAngle(_)),
274            "expected InvalidCapAngle, got {err:?}"
275        );
276    }
277
278    // --- Shell tests ---
279
280    #[test]
281    fn shell_contains_point_inside() {
282        let shell = Shell::new(1.0, 5.0).unwrap();
283        assert!(shell.contains(&point(3.0, 0.0, 0.0)));
284    }
285
286    #[test]
287    fn shell_excludes_point_outside() {
288        let shell = Shell::new(1.0, 5.0).unwrap();
289        assert!(!shell.contains(&point(0.5, 0.0, 0.0)));
290        assert!(!shell.contains(&point(6.0, 0.0, 0.0)));
291    }
292
293    #[test]
294    fn shell_boundary_inclusive() {
295        let shell = Shell::new(1.0, 5.0).unwrap();
296        assert!(shell.contains(&point(1.0, 0.0, 0.0)));
297        assert!(shell.contains(&point(5.0, 0.0, 0.0)));
298    }
299
300    #[test]
301    fn shell_invalid_bounds() {
302        assert!(Shell::new(5.0, 1.0).is_err());
303        assert!(Shell::new(3.0, 3.0).is_err());
304        assert!(Shell::new(-1.0, 5.0).is_err());
305    }
306
307    // --- Band tests ---
308
309    #[test]
310    fn band_contains_point_inside() {
311        let band = Band::new(FRAC_PI_4, 3.0 * FRAC_PI_4).unwrap();
312        assert!(band.contains(&point(1.0, 0.0, FRAC_PI_2)));
313    }
314
315    #[test]
316    fn band_excludes_point_outside() {
317        let band = Band::new(FRAC_PI_4, FRAC_PI_2).unwrap();
318        assert!(!band.contains(&point(1.0, 0.0, 0.1)));
319        assert!(!band.contains(&point(1.0, 0.0, PI - 0.1)));
320    }
321
322    #[test]
323    fn band_boundary_inclusive() {
324        let band = Band::new(FRAC_PI_4, FRAC_PI_2).unwrap();
325        assert!(band.contains(&point(1.0, 0.0, FRAC_PI_4)));
326        assert!(band.contains(&point(1.0, 0.0, FRAC_PI_2)));
327    }
328
329    #[test]
330    fn band_poles() {
331        // Band covering north pole area
332        let north = Band::new(0.0 + 0.001, FRAC_PI_4).unwrap();
333        assert!(north.contains(&point(1.0, 0.0, 0.01)));
334        assert!(!north.contains(&point(1.0, 0.0, FRAC_PI_2)));
335
336        // Band covering south pole area
337        let south = Band::new(3.0 * FRAC_PI_4, PI).unwrap();
338        assert!(south.contains(&point(1.0, 0.0, PI - 0.1)));
339        assert!(!south.contains(&point(1.0, 0.0, FRAC_PI_4)));
340    }
341
342    #[test]
343    fn band_invalid_bounds() {
344        assert!(Band::new(FRAC_PI_2, FRAC_PI_4).is_err());
345        assert!(Band::new(FRAC_PI_4, FRAC_PI_4).is_err());
346        assert!(Band::new(-0.1, FRAC_PI_2).is_err());
347        assert!(Band::new(0.0, PI + 0.1).is_err());
348    }
349
350    // --- Wedge tests ---
351
352    #[test]
353    fn wedge_contains_normal_range() {
354        let wedge = Wedge::new(0.5, 2.0).unwrap();
355        assert!(wedge.contains(&point(1.0, 1.0, FRAC_PI_2)));
356        assert!(!wedge.contains(&point(1.0, 3.0, FRAC_PI_2)));
357    }
358
359    #[test]
360    fn wedge_wraparound() {
361        // 350° to 10° in radians: ~6.1087 to ~0.1745
362        let theta_min = 350.0_f64.to_radians();
363        let theta_max = 10.0_f64.to_radians();
364        let wedge = Wedge::new(theta_min, theta_max).unwrap();
365
366        let inside_high = point(1.0, 355.0_f64.to_radians(), FRAC_PI_2);
367        let inside_low = point(1.0, 5.0_f64.to_radians(), FRAC_PI_2);
368        let outside = point(1.0, 180.0_f64.to_radians(), FRAC_PI_2);
369
370        assert!(wedge.contains(&inside_high));
371        assert!(wedge.contains(&inside_low));
372        assert!(!wedge.contains(&outside));
373    }
374
375    #[test]
376    fn wedge_boundary_inclusive() {
377        let wedge = Wedge::new(1.0, 2.0).unwrap();
378        assert!(wedge.contains(&point(1.0, 1.0, FRAC_PI_2)));
379        assert!(wedge.contains(&point(1.0, 2.0, FRAC_PI_2)));
380    }
381
382    #[test]
383    fn wedge_rejects_invalid_theta() {
384        assert!(Wedge::new(-0.1, 1.0).is_err());
385        assert!(Wedge::new(0.0, 7.0).is_err());
386    }
387
388    // --- Compound region tests ---
389
390    #[test]
391    fn intersection_shell_and_band() {
392        let shell = Region::Shell(Shell::new(1.0, 5.0).unwrap());
393        let band = Region::Band(Band::new(FRAC_PI_4, 3.0 * FRAC_PI_4).unwrap());
394        let region = Region::intersection(vec![shell, band]);
395
396        // Inside both
397        assert!(region.contains(&point(3.0, 0.0, FRAC_PI_2)));
398        // Inside shell but outside band
399        assert!(!region.contains(&point(3.0, 0.0, 0.1)));
400        // Inside band but outside shell
401        assert!(!region.contains(&point(10.0, 0.0, FRAC_PI_2)));
402    }
403
404    #[test]
405    fn union_two_caps() {
406        let cap_a = Region::Cap(Cap::new(point(1.0, 0.0, 0.1), 0.2).unwrap());
407        let cap_b = Region::Cap(Cap::new(point(1.0, 0.0, PI - 0.1), 0.2).unwrap());
408        let region = Region::union(vec![cap_a, cap_b]);
409
410        // Near north pole (cap_a)
411        assert!(region.contains(&point(1.0, 0.0, 0.05)));
412        // Near south pole (cap_b)
413        assert!(region.contains(&point(1.0, 0.0, PI - 0.05)));
414        // Equator (neither)
415        assert!(!region.contains(&point(1.0, 0.0, FRAC_PI_2)));
416    }
417
418    #[test]
419    fn empty_intersection_contains_everything() {
420        let region = Region::intersection(vec![]);
421        assert!(region.contains(&point(1.0, 0.0, FRAC_PI_2)));
422    }
423
424    #[test]
425    fn empty_union_contains_nothing() {
426        let region = Region::union(vec![]);
427        assert!(!region.contains(&point(1.0, 0.0, FRAC_PI_2)));
428    }
429
430    // --- Region enum dispatch ---
431
432    #[test]
433    fn region_dispatches_to_inner_types() {
434        let shell_region = Region::Shell(Shell::new(1.0, 5.0).unwrap());
435        assert!(shell_region.contains(&point(3.0, 0.0, 0.0)));
436        assert!(!shell_region.contains(&point(10.0, 0.0, 0.0)));
437
438        let wedge_region = Region::Wedge(Wedge::new(0.5, 2.0).unwrap());
439        assert!(wedge_region.contains(&point(1.0, 1.0, FRAC_PI_2)));
440    }
441}