1use sphereql_core::{Band, Cap, Cone, Region, Shell, SphericalPoint, Wedge};
2use sphereql_embed::pipeline::PipelineInput;
3
4#[derive(async_graphql::SimpleObject, Debug, Clone)]
7pub struct SphericalPointOutput {
8 pub r: f64,
9 pub theta: f64,
10 pub phi: f64,
11 pub theta_degrees: f64,
12 pub phi_degrees: f64,
13}
14
15#[derive(async_graphql::SimpleObject, Debug, Clone)]
16pub struct CartesianPointOutput {
17 pub x: f64,
18 pub y: f64,
19 pub z: f64,
20}
21
22#[derive(async_graphql::SimpleObject, Debug, Clone)]
23pub struct GeoPointOutput {
24 pub lat: f64,
25 pub lon: f64,
26 pub alt: f64,
27}
28
29#[derive(async_graphql::SimpleObject, Debug, Clone)]
30pub struct SpatialQueryResultOutput {
31 pub items: Vec<SphericalPointOutput>,
32 pub total_scanned: i32,
33}
34
35#[derive(async_graphql::SimpleObject, Debug, Clone)]
36pub struct NearestResultOutput {
37 pub point: SphericalPointOutput,
38 pub distance: f64,
39}
40
41#[derive(async_graphql::SimpleObject, Debug, Clone)]
42pub struct DistanceResultOutput {
43 pub angular: f64,
44 pub great_circle: Option<f64>,
45 pub chord: f64,
46}
47
48#[derive(async_graphql::InputObject, Debug, Clone)]
51pub struct SphericalPointInput {
52 pub r: f64,
53 pub theta: f64,
54 pub phi: f64,
55}
56
57#[derive(async_graphql::InputObject, Debug, Clone)]
58pub struct ConeInput {
59 pub apex: SphericalPointInput,
60 pub axis: SphericalPointInput,
61 pub half_angle: f64,
62}
63
64#[derive(async_graphql::InputObject, Debug, Clone)]
65pub struct CapInput {
66 pub center: SphericalPointInput,
67 pub half_angle: f64,
68}
69
70#[derive(async_graphql::InputObject, Debug, Clone)]
71pub struct ShellInput {
72 pub inner: f64,
73 pub outer: f64,
74}
75
76#[derive(async_graphql::InputObject, Debug, Clone)]
77pub struct BandInput {
78 pub phi_min: f64,
79 pub phi_max: f64,
80}
81
82#[derive(async_graphql::InputObject, Debug, Clone)]
83pub struct WedgeInput {
84 pub theta_min: f64,
85 pub theta_max: f64,
86}
87
88#[derive(async_graphql::InputObject, Debug, Clone)]
89pub struct RegionInput {
90 pub cone: Option<ConeInput>,
91 pub cap: Option<CapInput>,
92 pub shell: Option<ShellInput>,
93 pub band: Option<BandInput>,
94 pub wedge: Option<WedgeInput>,
95 pub intersection: Option<Vec<RegionInput>>,
96 pub union: Option<Vec<RegionInput>>,
97}
98
99#[derive(async_graphql::InputObject, Debug, Clone)]
112pub struct CategorizedItemInput {
113 pub id: String,
114 pub category: String,
115 pub embedding: Vec<f64>,
116}
117
118#[derive(async_graphql::SimpleObject, Debug, Clone)]
121pub struct CategorizedItemOutput {
122 pub id: String,
123 pub category: String,
124 pub embedding: Vec<f64>,
125}
126
127impl From<&CategorizedItemInput> for CategorizedItemOutput {
128 fn from(i: &CategorizedItemInput) -> Self {
129 Self {
130 id: i.id.clone(),
131 category: i.category.clone(),
132 embedding: i.embedding.clone(),
133 }
134 }
135}
136
137pub fn items_to_pipeline_input(items: &[CategorizedItemInput]) -> PipelineInput {
149 PipelineInput {
150 categories: items.iter().map(|i| i.category.clone()).collect(),
151 embeddings: items.iter().map(|i| i.embedding.clone()).collect(),
152 }
153}
154
155#[derive(async_graphql::Enum, Copy, Clone, Eq, PartialEq, Debug)]
158pub enum DistanceMetric {
159 Angular,
160 GreatCircle,
161 Chord,
162 Euclidean,
163}
164
165impl From<&SphericalPoint> for SphericalPointOutput {
168 fn from(p: &SphericalPoint) -> Self {
169 Self {
170 r: p.r,
171 theta: p.theta,
172 phi: p.phi,
173 theta_degrees: p.theta.to_degrees(),
174 phi_degrees: p.phi.to_degrees(),
175 }
176 }
177}
178
179impl SphericalPointInput {
180 pub fn to_core(&self) -> Result<SphericalPoint, async_graphql::Error> {
181 SphericalPoint::new(self.r, self.theta, self.phi)
182 .map_err(|e| async_graphql::Error::new(e.to_string()))
183 }
184}
185
186impl ConeInput {
187 pub fn to_core(&self) -> Result<Cone, async_graphql::Error> {
188 let apex = self.apex.to_core()?;
189 let axis = self.axis.to_core()?;
190 Cone::new(apex, axis, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
191 }
192}
193
194impl CapInput {
195 pub fn to_core(&self) -> Result<Cap, async_graphql::Error> {
196 let center = self.center.to_core()?;
197 Cap::new(center, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
198 }
199}
200
201impl ShellInput {
202 pub fn to_core(&self) -> Result<Shell, async_graphql::Error> {
203 Shell::new(self.inner, self.outer).map_err(|e| async_graphql::Error::new(e.to_string()))
204 }
205}
206
207impl BandInput {
208 pub fn to_core(&self) -> Result<Band, async_graphql::Error> {
209 Band::new(self.phi_min, self.phi_max).map_err(|e| async_graphql::Error::new(e.to_string()))
210 }
211}
212
213impl WedgeInput {
214 pub fn to_core(&self) -> Result<Wedge, async_graphql::Error> {
215 Wedge::new(self.theta_min, self.theta_max)
216 .map_err(|e| async_graphql::Error::new(e.to_string()))
217 }
218}
219
220impl RegionInput {
221 pub fn to_core(&self) -> Result<Region, async_graphql::Error> {
222 let set_count = [
223 self.cone.is_some(),
224 self.cap.is_some(),
225 self.shell.is_some(),
226 self.band.is_some(),
227 self.wedge.is_some(),
228 self.intersection.is_some(),
229 self.union.is_some(),
230 ]
231 .iter()
232 .filter(|&&v| v)
233 .count();
234
235 if set_count == 0 {
236 return Err(async_graphql::Error::new(
237 "RegionInput must have exactly one field set, but none were provided",
238 ));
239 }
240 if set_count > 1 {
241 return Err(async_graphql::Error::new(format!(
242 "RegionInput must have exactly one field set, but {set_count} were provided",
243 )));
244 }
245
246 if let Some(cone) = &self.cone {
247 return Ok(Region::Cone(cone.to_core()?));
248 }
249 if let Some(cap) = &self.cap {
250 return Ok(Region::Cap(cap.to_core()?));
251 }
252 if let Some(shell) = &self.shell {
253 return Ok(Region::Shell(shell.to_core()?));
254 }
255 if let Some(band) = &self.band {
256 return Ok(Region::Band(band.to_core()?));
257 }
258 if let Some(wedge) = &self.wedge {
259 return Ok(Region::Wedge(wedge.to_core()?));
260 }
261 if let Some(regions) = &self.intersection {
262 let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
263 return Ok(Region::Intersection(converted?));
264 }
265 if let Some(regions) = &self.union {
266 let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
267 return Ok(Region::Union(converted?));
268 }
269
270 unreachable!()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
278
279 fn sp_input(r: f64, theta: f64, phi: f64) -> SphericalPointInput {
280 SphericalPointInput { r, theta, phi }
281 }
282
283 #[test]
284 fn spherical_point_input_to_core_roundtrip() {
285 let input = sp_input(2.0, 1.0, FRAC_PI_4);
286 let core = input.to_core().unwrap();
287 assert!((core.r - 2.0).abs() < 1e-12);
288 assert!((core.theta - 1.0).abs() < 1e-12);
289 assert!((core.phi - FRAC_PI_4).abs() < 1e-12);
290
291 let output = SphericalPointOutput::from(&core);
292 assert!((output.r - 2.0).abs() < 1e-12);
293 assert!((output.theta - 1.0).abs() < 1e-12);
294 assert!((output.phi - FRAC_PI_4).abs() < 1e-12);
295 assert!((output.theta_degrees - 1.0_f64.to_degrees()).abs() < 1e-9);
296 assert!((output.phi_degrees - FRAC_PI_4.to_degrees()).abs() < 1e-9);
297 }
298
299 #[test]
300 fn region_input_cone_converts() {
301 let region = RegionInput {
302 cone: Some(ConeInput {
303 apex: sp_input(0.0, 0.0, 0.0),
304 axis: sp_input(1.0, 0.5, FRAC_PI_2),
305 half_angle: FRAC_PI_4,
306 }),
307 cap: None,
308 shell: None,
309 band: None,
310 wedge: None,
311 intersection: None,
312 union: None,
313 };
314 let core = region.to_core().unwrap();
315 assert!(matches!(core, Region::Cone(_)));
316 }
317
318 #[test]
319 fn region_input_intersection_recursive() {
320 let shell_region = RegionInput {
321 cone: None,
322 cap: None,
323 shell: Some(ShellInput {
324 inner: 1.0,
325 outer: 5.0,
326 }),
327 band: None,
328 wedge: None,
329 intersection: None,
330 union: None,
331 };
332 let band_region = RegionInput {
333 cone: None,
334 cap: None,
335 shell: None,
336 band: Some(BandInput {
337 phi_min: FRAC_PI_4,
338 phi_max: FRAC_PI_2,
339 }),
340 wedge: None,
341 intersection: None,
342 union: None,
343 };
344 let compound = RegionInput {
345 cone: None,
346 cap: None,
347 shell: None,
348 band: None,
349 wedge: None,
350 intersection: Some(vec![shell_region, band_region]),
351 union: None,
352 };
353
354 let core = compound.to_core().unwrap();
355 match core {
356 Region::Intersection(regions) => {
357 assert_eq!(regions.len(), 2);
358 assert!(matches!(regions[0], Region::Shell(_)));
359 assert!(matches!(regions[1], Region::Band(_)));
360 }
361 other => panic!("expected Intersection, got {other:?}"),
362 }
363 }
364
365 #[test]
366 fn invalid_inputs_produce_errors() {
367 let bad_point = sp_input(-1.0, 0.0, 0.0);
368 assert!(bad_point.to_core().is_err());
369
370 let bad_shell = ShellInput {
371 inner: 5.0,
372 outer: 1.0,
373 };
374 assert!(bad_shell.to_core().is_err());
375
376 let bad_band = BandInput {
377 phi_min: PI,
378 phi_max: 0.1,
379 };
380 assert!(bad_band.to_core().is_err());
381
382 let empty_region = RegionInput {
383 cone: None,
384 cap: None,
385 shell: None,
386 band: None,
387 wedge: None,
388 intersection: None,
389 union: None,
390 };
391 assert!(empty_region.to_core().is_err());
392
393 let multi_region = RegionInput {
394 cone: Some(ConeInput {
395 apex: sp_input(0.0, 0.0, 0.0),
396 axis: sp_input(1.0, 0.0, FRAC_PI_2),
397 half_angle: FRAC_PI_4,
398 }),
399 cap: None,
400 shell: Some(ShellInput {
401 inner: 1.0,
402 outer: 5.0,
403 }),
404 band: None,
405 wedge: None,
406 intersection: None,
407 union: None,
408 };
409 assert!(multi_region.to_core().is_err());
410 }
411}