Skip to main content

sphereql_graphql/
types.rs

1use sphereql_core::{Band, Cap, Cone, Region, Shell, SphericalPoint, Wedge};
2use sphereql_embed::pipeline::PipelineInput;
3
4// --- Output types ---
5
6#[derive(async_graphql::SimpleObject, Debug, Clone)]
7pub struct SphericalPointOutput {
8    pub r: f64,
9    pub theta: f64,
10    pub phi: f64,
11    pub theta_degrees: f64,
12    pub phi_degrees: f64,
13}
14
15#[derive(async_graphql::SimpleObject, Debug, Clone)]
16pub struct SpatialQueryResultOutput {
17    pub items: Vec<SphericalPointOutput>,
18    pub total_scanned: i32,
19}
20
21#[derive(async_graphql::SimpleObject, Debug, Clone)]
22pub struct NearestResultOutput {
23    pub point: SphericalPointOutput,
24    pub distance: f64,
25}
26
27// --- Input types ---
28
29#[derive(async_graphql::InputObject, Debug, Clone)]
30pub struct SphericalPointInput {
31    pub r: f64,
32    pub theta: f64,
33    pub phi: f64,
34}
35
36#[derive(async_graphql::InputObject, Debug, Clone)]
37pub struct ConeInput {
38    pub apex: SphericalPointInput,
39    pub axis: SphericalPointInput,
40    pub half_angle: f64,
41}
42
43#[derive(async_graphql::InputObject, Debug, Clone)]
44pub struct CapInput {
45    pub center: SphericalPointInput,
46    pub half_angle: f64,
47}
48
49#[derive(async_graphql::InputObject, Debug, Clone)]
50pub struct ShellInput {
51    pub inner: f64,
52    pub outer: f64,
53}
54
55#[derive(async_graphql::InputObject, Debug, Clone)]
56pub struct BandInput {
57    pub phi_min: f64,
58    pub phi_max: f64,
59}
60
61#[derive(async_graphql::InputObject, Debug, Clone)]
62pub struct WedgeInput {
63    pub theta_min: f64,
64    pub theta_max: f64,
65}
66
67#[derive(async_graphql::InputObject, Debug, Clone)]
68pub struct RegionInput {
69    pub cone: Option<ConeInput>,
70    pub cap: Option<CapInput>,
71    pub shell: Option<ShellInput>,
72    pub band: Option<BandInput>,
73    pub wedge: Option<WedgeInput>,
74    pub intersection: Option<Vec<RegionInput>>,
75    pub union: Option<Vec<RegionInput>>,
76}
77
78// --- Categorized item (for the category-enrichment surface) ---
79
80/// Input type for items consumed by the category-enrichment pipeline.
81///
82/// `embedding` is the high-dimensional vector that the pipeline projects
83/// onto S²; `category` is the label the enrichment layer groups by; `id`
84/// is a stable string returned in query results.
85///
86/// Used as the input shape when (re)building the pipeline from GraphQL
87/// or from in-process Rust callers; the pipeline itself stores categories
88/// and embeddings in parallel `Vec`s rather than as items, so this type
89/// exists only at the boundary.
90#[derive(async_graphql::InputObject, Debug, Clone)]
91pub struct CategorizedItemInput {
92    pub id: String,
93    pub category: String,
94    pub embedding: Vec<f64>,
95}
96
97/// Output mirror of [`CategorizedItemInput`] — useful for echoing items
98/// back to clients, or for resolvers that surface the underlying vectors.
99#[derive(async_graphql::SimpleObject, Debug, Clone)]
100pub struct CategorizedItemOutput {
101    pub id: String,
102    pub category: String,
103    pub embedding: Vec<f64>,
104}
105
106impl From<&CategorizedItemInput> for CategorizedItemOutput {
107    fn from(i: &CategorizedItemInput) -> Self {
108        Self {
109            id: i.id.clone(),
110            category: i.category.clone(),
111            embedding: i.embedding.clone(),
112        }
113    }
114}
115
116/// Convert a slice of [`CategorizedItemInput`] into the pipeline's
117/// expected input shape (parallel `categories` / `embeddings` vecs).
118///
119/// # Id handling
120///
121/// The `id` field on the input is **dropped** — the pipeline assigns its
122/// own stable internal ids of the form `s-0000`, `s-0001`, … in input
123/// order. Query results surface those generated ids, not the ones the
124/// caller supplied. The field is kept on the input type so future
125/// sphereql-embed work can round-trip user ids without a breaking shape
126/// change here.
127pub fn items_to_pipeline_input(items: &[CategorizedItemInput]) -> PipelineInput {
128    PipelineInput {
129        categories: items.iter().map(|i| i.category.clone()).collect(),
130        embeddings: items.iter().map(|i| i.embedding.clone()).collect(),
131    }
132}
133
134// --- Enum ---
135
136#[derive(async_graphql::Enum, Copy, Clone, Eq, PartialEq, Debug)]
137#[non_exhaustive]
138pub enum DistanceMetric {
139    Angular,
140    GreatCircle,
141    Chord,
142    Euclidean,
143}
144
145// --- Conversions ---
146
147impl From<&SphericalPoint> for SphericalPointOutput {
148    fn from(p: &SphericalPoint) -> Self {
149        Self {
150            r: p.r,
151            theta: p.theta,
152            phi: p.phi,
153            theta_degrees: p.theta.to_degrees(),
154            phi_degrees: p.phi.to_degrees(),
155        }
156    }
157}
158
159impl SphericalPointInput {
160    pub fn to_core(&self) -> Result<SphericalPoint, async_graphql::Error> {
161        SphericalPoint::new(self.r, self.theta, self.phi)
162            .map_err(|e| async_graphql::Error::new(e.to_string()))
163    }
164}
165
166impl ConeInput {
167    pub fn to_core(&self) -> Result<Cone, async_graphql::Error> {
168        let apex = self.apex.to_core()?;
169        let axis = self.axis.to_core()?;
170        Cone::new(apex, axis, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
171    }
172}
173
174impl CapInput {
175    pub fn to_core(&self) -> Result<Cap, async_graphql::Error> {
176        let center = self.center.to_core()?;
177        Cap::new(center, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
178    }
179}
180
181impl ShellInput {
182    pub fn to_core(&self) -> Result<Shell, async_graphql::Error> {
183        Shell::new(self.inner, self.outer).map_err(|e| async_graphql::Error::new(e.to_string()))
184    }
185}
186
187impl BandInput {
188    pub fn to_core(&self) -> Result<Band, async_graphql::Error> {
189        Band::new(self.phi_min, self.phi_max).map_err(|e| async_graphql::Error::new(e.to_string()))
190    }
191}
192
193impl WedgeInput {
194    pub fn to_core(&self) -> Result<Wedge, async_graphql::Error> {
195        Wedge::new(self.theta_min, self.theta_max)
196            .map_err(|e| async_graphql::Error::new(e.to_string()))
197    }
198}
199
200impl RegionInput {
201    pub fn to_core(&self) -> Result<Region, async_graphql::Error> {
202        let set_count = [
203            self.cone.is_some(),
204            self.cap.is_some(),
205            self.shell.is_some(),
206            self.band.is_some(),
207            self.wedge.is_some(),
208            self.intersection.is_some(),
209            self.union.is_some(),
210        ]
211        .iter()
212        .filter(|&&v| v)
213        .count();
214
215        if set_count == 0 {
216            return Err(async_graphql::Error::new(
217                "RegionInput must have exactly one field set, but none were provided",
218            ));
219        }
220        if set_count > 1 {
221            return Err(async_graphql::Error::new(format!(
222                "RegionInput must have exactly one field set, but {set_count} were provided",
223            )));
224        }
225
226        if let Some(cone) = &self.cone {
227            return Ok(Region::Cone(cone.to_core()?));
228        }
229        if let Some(cap) = &self.cap {
230            return Ok(Region::Cap(cap.to_core()?));
231        }
232        if let Some(shell) = &self.shell {
233            return Ok(Region::Shell(shell.to_core()?));
234        }
235        if let Some(band) = &self.band {
236            return Ok(Region::Band(band.to_core()?));
237        }
238        if let Some(wedge) = &self.wedge {
239            return Ok(Region::Wedge(wedge.to_core()?));
240        }
241        if let Some(regions) = &self.intersection {
242            let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
243            return Ok(Region::Intersection(converted?));
244        }
245        if let Some(regions) = &self.union {
246            let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
247            return Ok(Region::Union(converted?));
248        }
249
250        unreachable!()
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
258
259    fn sp_input(r: f64, theta: f64, phi: f64) -> SphericalPointInput {
260        SphericalPointInput { r, theta, phi }
261    }
262
263    #[test]
264    fn spherical_point_input_to_core_roundtrip() {
265        let input = sp_input(2.0, 1.0, FRAC_PI_4);
266        let core = input.to_core().unwrap();
267        assert!((core.r - 2.0).abs() < 1e-12);
268        assert!((core.theta - 1.0).abs() < 1e-12);
269        assert!((core.phi - FRAC_PI_4).abs() < 1e-12);
270
271        let output = SphericalPointOutput::from(&core);
272        assert!((output.r - 2.0).abs() < 1e-12);
273        assert!((output.theta - 1.0).abs() < 1e-12);
274        assert!((output.phi - FRAC_PI_4).abs() < 1e-12);
275        assert!((output.theta_degrees - 1.0_f64.to_degrees()).abs() < 1e-9);
276        assert!((output.phi_degrees - FRAC_PI_4.to_degrees()).abs() < 1e-9);
277    }
278
279    #[test]
280    fn region_input_cone_converts() {
281        let region = RegionInput {
282            cone: Some(ConeInput {
283                apex: sp_input(0.0, 0.0, 0.0),
284                axis: sp_input(1.0, 0.5, FRAC_PI_2),
285                half_angle: FRAC_PI_4,
286            }),
287            cap: None,
288            shell: None,
289            band: None,
290            wedge: None,
291            intersection: None,
292            union: None,
293        };
294        let core = region.to_core().unwrap();
295        assert!(matches!(core, Region::Cone(_)));
296    }
297
298    #[test]
299    fn region_input_intersection_recursive() {
300        let shell_region = RegionInput {
301            cone: None,
302            cap: None,
303            shell: Some(ShellInput {
304                inner: 1.0,
305                outer: 5.0,
306            }),
307            band: None,
308            wedge: None,
309            intersection: None,
310            union: None,
311        };
312        let band_region = RegionInput {
313            cone: None,
314            cap: None,
315            shell: None,
316            band: Some(BandInput {
317                phi_min: FRAC_PI_4,
318                phi_max: FRAC_PI_2,
319            }),
320            wedge: None,
321            intersection: None,
322            union: None,
323        };
324        let compound = RegionInput {
325            cone: None,
326            cap: None,
327            shell: None,
328            band: None,
329            wedge: None,
330            intersection: Some(vec![shell_region, band_region]),
331            union: None,
332        };
333
334        let core = compound.to_core().unwrap();
335        match core {
336            Region::Intersection(regions) => {
337                assert_eq!(regions.len(), 2);
338                assert!(matches!(regions[0], Region::Shell(_)));
339                assert!(matches!(regions[1], Region::Band(_)));
340            }
341            other => panic!("expected Intersection, got {other:?}"),
342        }
343    }
344
345    #[test]
346    fn invalid_inputs_produce_errors() {
347        let bad_point = sp_input(-1.0, 0.0, 0.0);
348        assert!(bad_point.to_core().is_err());
349
350        let bad_shell = ShellInput {
351            inner: 5.0,
352            outer: 1.0,
353        };
354        assert!(bad_shell.to_core().is_err());
355
356        let bad_band = BandInput {
357            phi_min: PI,
358            phi_max: 0.1,
359        };
360        assert!(bad_band.to_core().is_err());
361
362        let empty_region = RegionInput {
363            cone: None,
364            cap: None,
365            shell: None,
366            band: None,
367            wedge: None,
368            intersection: None,
369            union: None,
370        };
371        assert!(empty_region.to_core().is_err());
372
373        let multi_region = RegionInput {
374            cone: Some(ConeInput {
375                apex: sp_input(0.0, 0.0, 0.0),
376                axis: sp_input(1.0, 0.0, FRAC_PI_2),
377                half_angle: FRAC_PI_4,
378            }),
379            cap: None,
380            shell: Some(ShellInput {
381                inner: 1.0,
382                outer: 5.0,
383            }),
384            band: None,
385            wedge: None,
386            intersection: None,
387            union: None,
388        };
389        assert!(multi_region.to_core().is_err());
390    }
391}