Skip to main content

sphereql_graphql/
types.rs

1use sphereql_core::{Band, Cap, Cone, Region, Shell, SphericalPoint, Wedge};
2use sphereql_embed::pipeline::PipelineInput;
3
4// --- Output types ---
5
6#[derive(async_graphql::SimpleObject, Debug, Clone)]
7pub struct SphericalPointOutput {
8    pub r: f64,
9    pub theta: f64,
10    pub phi: f64,
11    pub theta_degrees: f64,
12    pub phi_degrees: f64,
13}
14
15#[derive(async_graphql::SimpleObject, Debug, Clone)]
16pub struct CartesianPointOutput {
17    pub x: f64,
18    pub y: f64,
19    pub z: f64,
20}
21
22#[derive(async_graphql::SimpleObject, Debug, Clone)]
23pub struct GeoPointOutput {
24    pub lat: f64,
25    pub lon: f64,
26    pub alt: f64,
27}
28
29#[derive(async_graphql::SimpleObject, Debug, Clone)]
30pub struct SpatialQueryResultOutput {
31    pub items: Vec<SphericalPointOutput>,
32    pub total_scanned: i32,
33}
34
35#[derive(async_graphql::SimpleObject, Debug, Clone)]
36pub struct NearestResultOutput {
37    pub point: SphericalPointOutput,
38    pub distance: f64,
39}
40
41#[derive(async_graphql::SimpleObject, Debug, Clone)]
42pub struct DistanceResultOutput {
43    pub angular: f64,
44    pub great_circle: Option<f64>,
45    pub chord: f64,
46}
47
48// --- Input types ---
49
50#[derive(async_graphql::InputObject, Debug, Clone)]
51pub struct SphericalPointInput {
52    pub r: f64,
53    pub theta: f64,
54    pub phi: f64,
55}
56
57#[derive(async_graphql::InputObject, Debug, Clone)]
58pub struct ConeInput {
59    pub apex: SphericalPointInput,
60    pub axis: SphericalPointInput,
61    pub half_angle: f64,
62}
63
64#[derive(async_graphql::InputObject, Debug, Clone)]
65pub struct CapInput {
66    pub center: SphericalPointInput,
67    pub half_angle: f64,
68}
69
70#[derive(async_graphql::InputObject, Debug, Clone)]
71pub struct ShellInput {
72    pub inner: f64,
73    pub outer: f64,
74}
75
76#[derive(async_graphql::InputObject, Debug, Clone)]
77pub struct BandInput {
78    pub phi_min: f64,
79    pub phi_max: f64,
80}
81
82#[derive(async_graphql::InputObject, Debug, Clone)]
83pub struct WedgeInput {
84    pub theta_min: f64,
85    pub theta_max: f64,
86}
87
88#[derive(async_graphql::InputObject, Debug, Clone)]
89pub struct RegionInput {
90    pub cone: Option<ConeInput>,
91    pub cap: Option<CapInput>,
92    pub shell: Option<ShellInput>,
93    pub band: Option<BandInput>,
94    pub wedge: Option<WedgeInput>,
95    pub intersection: Option<Vec<RegionInput>>,
96    pub union: Option<Vec<RegionInput>>,
97}
98
99// --- Categorized item (for the category-enrichment surface) ---
100
101/// Input type for items consumed by the category-enrichment pipeline.
102///
103/// `embedding` is the high-dimensional vector that the pipeline projects
104/// onto S²; `category` is the label the enrichment layer groups by; `id`
105/// is a stable string returned in query results.
106///
107/// Used as the input shape when (re)building the pipeline from GraphQL
108/// or from in-process Rust callers; the pipeline itself stores categories
109/// and embeddings in parallel `Vec`s rather than as items, so this type
110/// exists only at the boundary.
111#[derive(async_graphql::InputObject, Debug, Clone)]
112pub struct CategorizedItemInput {
113    pub id: String,
114    pub category: String,
115    pub embedding: Vec<f64>,
116}
117
118/// Output mirror of [`CategorizedItemInput`] — useful for echoing items
119/// back to clients, or for resolvers that surface the underlying vectors.
120#[derive(async_graphql::SimpleObject, Debug, Clone)]
121pub struct CategorizedItemOutput {
122    pub id: String,
123    pub category: String,
124    pub embedding: Vec<f64>,
125}
126
127impl From<&CategorizedItemInput> for CategorizedItemOutput {
128    fn from(i: &CategorizedItemInput) -> Self {
129        Self {
130            id: i.id.clone(),
131            category: i.category.clone(),
132            embedding: i.embedding.clone(),
133        }
134    }
135}
136
137/// Convert a slice of [`CategorizedItemInput`] into the pipeline's
138/// expected input shape (parallel `categories` / `embeddings` vecs).
139///
140/// # Id handling
141///
142/// The `id` field on the input is **dropped** — the pipeline assigns its
143/// own stable internal ids of the form `s-0000`, `s-0001`, … in input
144/// order. Query results surface those generated ids, not the ones the
145/// caller supplied. The field is kept on the input type so future
146/// sphereql-embed work can round-trip user ids without a breaking shape
147/// change here.
148pub fn items_to_pipeline_input(items: &[CategorizedItemInput]) -> PipelineInput {
149    PipelineInput {
150        categories: items.iter().map(|i| i.category.clone()).collect(),
151        embeddings: items.iter().map(|i| i.embedding.clone()).collect(),
152    }
153}
154
155// --- Enum ---
156
157#[derive(async_graphql::Enum, Copy, Clone, Eq, PartialEq, Debug)]
158pub enum DistanceMetric {
159    Angular,
160    GreatCircle,
161    Chord,
162    Euclidean,
163}
164
165// --- Conversions ---
166
167impl From<&SphericalPoint> for SphericalPointOutput {
168    fn from(p: &SphericalPoint) -> Self {
169        Self {
170            r: p.r,
171            theta: p.theta,
172            phi: p.phi,
173            theta_degrees: p.theta.to_degrees(),
174            phi_degrees: p.phi.to_degrees(),
175        }
176    }
177}
178
179impl SphericalPointInput {
180    pub fn to_core(&self) -> Result<SphericalPoint, async_graphql::Error> {
181        SphericalPoint::new(self.r, self.theta, self.phi)
182            .map_err(|e| async_graphql::Error::new(e.to_string()))
183    }
184}
185
186impl ConeInput {
187    pub fn to_core(&self) -> Result<Cone, async_graphql::Error> {
188        let apex = self.apex.to_core()?;
189        let axis = self.axis.to_core()?;
190        Cone::new(apex, axis, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
191    }
192}
193
194impl CapInput {
195    pub fn to_core(&self) -> Result<Cap, async_graphql::Error> {
196        let center = self.center.to_core()?;
197        Cap::new(center, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
198    }
199}
200
201impl ShellInput {
202    pub fn to_core(&self) -> Result<Shell, async_graphql::Error> {
203        Shell::new(self.inner, self.outer).map_err(|e| async_graphql::Error::new(e.to_string()))
204    }
205}
206
207impl BandInput {
208    pub fn to_core(&self) -> Result<Band, async_graphql::Error> {
209        Band::new(self.phi_min, self.phi_max).map_err(|e| async_graphql::Error::new(e.to_string()))
210    }
211}
212
213impl WedgeInput {
214    pub fn to_core(&self) -> Result<Wedge, async_graphql::Error> {
215        Wedge::new(self.theta_min, self.theta_max)
216            .map_err(|e| async_graphql::Error::new(e.to_string()))
217    }
218}
219
220impl RegionInput {
221    pub fn to_core(&self) -> Result<Region, async_graphql::Error> {
222        let set_count = [
223            self.cone.is_some(),
224            self.cap.is_some(),
225            self.shell.is_some(),
226            self.band.is_some(),
227            self.wedge.is_some(),
228            self.intersection.is_some(),
229            self.union.is_some(),
230        ]
231        .iter()
232        .filter(|&&v| v)
233        .count();
234
235        if set_count == 0 {
236            return Err(async_graphql::Error::new(
237                "RegionInput must have exactly one field set, but none were provided",
238            ));
239        }
240        if set_count > 1 {
241            return Err(async_graphql::Error::new(format!(
242                "RegionInput must have exactly one field set, but {set_count} were provided",
243            )));
244        }
245
246        if let Some(cone) = &self.cone {
247            return Ok(Region::Cone(cone.to_core()?));
248        }
249        if let Some(cap) = &self.cap {
250            return Ok(Region::Cap(cap.to_core()?));
251        }
252        if let Some(shell) = &self.shell {
253            return Ok(Region::Shell(shell.to_core()?));
254        }
255        if let Some(band) = &self.band {
256            return Ok(Region::Band(band.to_core()?));
257        }
258        if let Some(wedge) = &self.wedge {
259            return Ok(Region::Wedge(wedge.to_core()?));
260        }
261        if let Some(regions) = &self.intersection {
262            let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
263            return Ok(Region::Intersection(converted?));
264        }
265        if let Some(regions) = &self.union {
266            let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
267            return Ok(Region::Union(converted?));
268        }
269
270        unreachable!()
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
278
279    fn sp_input(r: f64, theta: f64, phi: f64) -> SphericalPointInput {
280        SphericalPointInput { r, theta, phi }
281    }
282
283    #[test]
284    fn spherical_point_input_to_core_roundtrip() {
285        let input = sp_input(2.0, 1.0, FRAC_PI_4);
286        let core = input.to_core().unwrap();
287        assert!((core.r - 2.0).abs() < 1e-12);
288        assert!((core.theta - 1.0).abs() < 1e-12);
289        assert!((core.phi - FRAC_PI_4).abs() < 1e-12);
290
291        let output = SphericalPointOutput::from(&core);
292        assert!((output.r - 2.0).abs() < 1e-12);
293        assert!((output.theta - 1.0).abs() < 1e-12);
294        assert!((output.phi - FRAC_PI_4).abs() < 1e-12);
295        assert!((output.theta_degrees - 1.0_f64.to_degrees()).abs() < 1e-9);
296        assert!((output.phi_degrees - FRAC_PI_4.to_degrees()).abs() < 1e-9);
297    }
298
299    #[test]
300    fn region_input_cone_converts() {
301        let region = RegionInput {
302            cone: Some(ConeInput {
303                apex: sp_input(0.0, 0.0, 0.0),
304                axis: sp_input(1.0, 0.5, FRAC_PI_2),
305                half_angle: FRAC_PI_4,
306            }),
307            cap: None,
308            shell: None,
309            band: None,
310            wedge: None,
311            intersection: None,
312            union: None,
313        };
314        let core = region.to_core().unwrap();
315        assert!(matches!(core, Region::Cone(_)));
316    }
317
318    #[test]
319    fn region_input_intersection_recursive() {
320        let shell_region = RegionInput {
321            cone: None,
322            cap: None,
323            shell: Some(ShellInput {
324                inner: 1.0,
325                outer: 5.0,
326            }),
327            band: None,
328            wedge: None,
329            intersection: None,
330            union: None,
331        };
332        let band_region = RegionInput {
333            cone: None,
334            cap: None,
335            shell: None,
336            band: Some(BandInput {
337                phi_min: FRAC_PI_4,
338                phi_max: FRAC_PI_2,
339            }),
340            wedge: None,
341            intersection: None,
342            union: None,
343        };
344        let compound = RegionInput {
345            cone: None,
346            cap: None,
347            shell: None,
348            band: None,
349            wedge: None,
350            intersection: Some(vec![shell_region, band_region]),
351            union: None,
352        };
353
354        let core = compound.to_core().unwrap();
355        match core {
356            Region::Intersection(regions) => {
357                assert_eq!(regions.len(), 2);
358                assert!(matches!(regions[0], Region::Shell(_)));
359                assert!(matches!(regions[1], Region::Band(_)));
360            }
361            other => panic!("expected Intersection, got {other:?}"),
362        }
363    }
364
365    #[test]
366    fn invalid_inputs_produce_errors() {
367        let bad_point = sp_input(-1.0, 0.0, 0.0);
368        assert!(bad_point.to_core().is_err());
369
370        let bad_shell = ShellInput {
371            inner: 5.0,
372            outer: 1.0,
373        };
374        assert!(bad_shell.to_core().is_err());
375
376        let bad_band = BandInput {
377            phi_min: PI,
378            phi_max: 0.1,
379        };
380        assert!(bad_band.to_core().is_err());
381
382        let empty_region = RegionInput {
383            cone: None,
384            cap: None,
385            shell: None,
386            band: None,
387            wedge: None,
388            intersection: None,
389            union: None,
390        };
391        assert!(empty_region.to_core().is_err());
392
393        let multi_region = RegionInput {
394            cone: Some(ConeInput {
395                apex: sp_input(0.0, 0.0, 0.0),
396                axis: sp_input(1.0, 0.0, FRAC_PI_2),
397                half_angle: FRAC_PI_4,
398            }),
399            cap: None,
400            shell: Some(ShellInput {
401                inner: 1.0,
402                outer: 5.0,
403            }),
404            band: None,
405            wedge: None,
406            intersection: None,
407            union: None,
408        };
409        assert!(multi_region.to_core().is_err());
410    }
411}