1use sphereql_core::{Band, Cap, Cone, Region, Shell, SphericalPoint, Wedge};
2use sphereql_embed::pipeline::PipelineInput;
3
4#[derive(async_graphql::SimpleObject, Debug, Clone)]
7pub struct SphericalPointOutput {
8 pub r: f64,
9 pub theta: f64,
10 pub phi: f64,
11 pub theta_degrees: f64,
12 pub phi_degrees: f64,
13}
14
15#[derive(async_graphql::SimpleObject, Debug, Clone)]
16pub struct SpatialQueryResultOutput {
17 pub items: Vec<SphericalPointOutput>,
18 pub total_scanned: i32,
19}
20
21#[derive(async_graphql::SimpleObject, Debug, Clone)]
22pub struct NearestResultOutput {
23 pub point: SphericalPointOutput,
24 pub distance: f64,
25}
26
27#[derive(async_graphql::InputObject, Debug, Clone)]
30pub struct SphericalPointInput {
31 pub r: f64,
32 pub theta: f64,
33 pub phi: f64,
34}
35
36#[derive(async_graphql::InputObject, Debug, Clone)]
37pub struct ConeInput {
38 pub apex: SphericalPointInput,
39 pub axis: SphericalPointInput,
40 pub half_angle: f64,
41}
42
43#[derive(async_graphql::InputObject, Debug, Clone)]
44pub struct CapInput {
45 pub center: SphericalPointInput,
46 pub half_angle: f64,
47}
48
49#[derive(async_graphql::InputObject, Debug, Clone)]
50pub struct ShellInput {
51 pub inner: f64,
52 pub outer: f64,
53}
54
55#[derive(async_graphql::InputObject, Debug, Clone)]
56pub struct BandInput {
57 pub phi_min: f64,
58 pub phi_max: f64,
59}
60
61#[derive(async_graphql::InputObject, Debug, Clone)]
62pub struct WedgeInput {
63 pub theta_min: f64,
64 pub theta_max: f64,
65}
66
67#[derive(async_graphql::InputObject, Debug, Clone)]
68pub struct RegionInput {
69 pub cone: Option<ConeInput>,
70 pub cap: Option<CapInput>,
71 pub shell: Option<ShellInput>,
72 pub band: Option<BandInput>,
73 pub wedge: Option<WedgeInput>,
74 pub intersection: Option<Vec<RegionInput>>,
75 pub union: Option<Vec<RegionInput>>,
76}
77
78#[derive(async_graphql::InputObject, Debug, Clone)]
91pub struct CategorizedItemInput {
92 pub id: String,
93 pub category: String,
94 pub embedding: Vec<f64>,
95}
96
97#[derive(async_graphql::SimpleObject, Debug, Clone)]
100pub struct CategorizedItemOutput {
101 pub id: String,
102 pub category: String,
103 pub embedding: Vec<f64>,
104}
105
106impl From<&CategorizedItemInput> for CategorizedItemOutput {
107 fn from(i: &CategorizedItemInput) -> Self {
108 Self {
109 id: i.id.clone(),
110 category: i.category.clone(),
111 embedding: i.embedding.clone(),
112 }
113 }
114}
115
116pub fn items_to_pipeline_input(items: &[CategorizedItemInput]) -> PipelineInput {
128 PipelineInput {
129 categories: items.iter().map(|i| i.category.clone()).collect(),
130 embeddings: items.iter().map(|i| i.embedding.clone()).collect(),
131 }
132}
133
134#[derive(async_graphql::Enum, Copy, Clone, Eq, PartialEq, Debug)]
137#[non_exhaustive]
138pub enum DistanceMetric {
139 Angular,
140 GreatCircle,
141 Chord,
142 Euclidean,
143}
144
145impl From<&SphericalPoint> for SphericalPointOutput {
148 fn from(p: &SphericalPoint) -> Self {
149 Self {
150 r: p.r,
151 theta: p.theta,
152 phi: p.phi,
153 theta_degrees: p.theta.to_degrees(),
154 phi_degrees: p.phi.to_degrees(),
155 }
156 }
157}
158
159impl SphericalPointInput {
160 pub fn to_core(&self) -> Result<SphericalPoint, async_graphql::Error> {
161 SphericalPoint::new(self.r, self.theta, self.phi)
162 .map_err(|e| async_graphql::Error::new(e.to_string()))
163 }
164}
165
166impl ConeInput {
167 pub fn to_core(&self) -> Result<Cone, async_graphql::Error> {
168 let apex = self.apex.to_core()?;
169 let axis = self.axis.to_core()?;
170 Cone::new(apex, axis, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
171 }
172}
173
174impl CapInput {
175 pub fn to_core(&self) -> Result<Cap, async_graphql::Error> {
176 let center = self.center.to_core()?;
177 Cap::new(center, self.half_angle).map_err(|e| async_graphql::Error::new(e.to_string()))
178 }
179}
180
181impl ShellInput {
182 pub fn to_core(&self) -> Result<Shell, async_graphql::Error> {
183 Shell::new(self.inner, self.outer).map_err(|e| async_graphql::Error::new(e.to_string()))
184 }
185}
186
187impl BandInput {
188 pub fn to_core(&self) -> Result<Band, async_graphql::Error> {
189 Band::new(self.phi_min, self.phi_max).map_err(|e| async_graphql::Error::new(e.to_string()))
190 }
191}
192
193impl WedgeInput {
194 pub fn to_core(&self) -> Result<Wedge, async_graphql::Error> {
195 Wedge::new(self.theta_min, self.theta_max)
196 .map_err(|e| async_graphql::Error::new(e.to_string()))
197 }
198}
199
200impl RegionInput {
201 pub fn to_core(&self) -> Result<Region, async_graphql::Error> {
202 let set_count = [
203 self.cone.is_some(),
204 self.cap.is_some(),
205 self.shell.is_some(),
206 self.band.is_some(),
207 self.wedge.is_some(),
208 self.intersection.is_some(),
209 self.union.is_some(),
210 ]
211 .iter()
212 .filter(|&&v| v)
213 .count();
214
215 if set_count == 0 {
216 return Err(async_graphql::Error::new(
217 "RegionInput must have exactly one field set, but none were provided",
218 ));
219 }
220 if set_count > 1 {
221 return Err(async_graphql::Error::new(format!(
222 "RegionInput must have exactly one field set, but {set_count} were provided",
223 )));
224 }
225
226 if let Some(cone) = &self.cone {
227 return Ok(Region::Cone(cone.to_core()?));
228 }
229 if let Some(cap) = &self.cap {
230 return Ok(Region::Cap(cap.to_core()?));
231 }
232 if let Some(shell) = &self.shell {
233 return Ok(Region::Shell(shell.to_core()?));
234 }
235 if let Some(band) = &self.band {
236 return Ok(Region::Band(band.to_core()?));
237 }
238 if let Some(wedge) = &self.wedge {
239 return Ok(Region::Wedge(wedge.to_core()?));
240 }
241 if let Some(regions) = &self.intersection {
242 let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
243 return Ok(Region::Intersection(converted?));
244 }
245 if let Some(regions) = &self.union {
246 let converted: Result<Vec<Region>, _> = regions.iter().map(|r| r.to_core()).collect();
247 return Ok(Region::Union(converted?));
248 }
249
250 unreachable!()
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
258
259 fn sp_input(r: f64, theta: f64, phi: f64) -> SphericalPointInput {
260 SphericalPointInput { r, theta, phi }
261 }
262
263 #[test]
264 fn spherical_point_input_to_core_roundtrip() {
265 let input = sp_input(2.0, 1.0, FRAC_PI_4);
266 let core = input.to_core().unwrap();
267 assert!((core.r - 2.0).abs() < 1e-12);
268 assert!((core.theta - 1.0).abs() < 1e-12);
269 assert!((core.phi - FRAC_PI_4).abs() < 1e-12);
270
271 let output = SphericalPointOutput::from(&core);
272 assert!((output.r - 2.0).abs() < 1e-12);
273 assert!((output.theta - 1.0).abs() < 1e-12);
274 assert!((output.phi - FRAC_PI_4).abs() < 1e-12);
275 assert!((output.theta_degrees - 1.0_f64.to_degrees()).abs() < 1e-9);
276 assert!((output.phi_degrees - FRAC_PI_4.to_degrees()).abs() < 1e-9);
277 }
278
279 #[test]
280 fn region_input_cone_converts() {
281 let region = RegionInput {
282 cone: Some(ConeInput {
283 apex: sp_input(0.0, 0.0, 0.0),
284 axis: sp_input(1.0, 0.5, FRAC_PI_2),
285 half_angle: FRAC_PI_4,
286 }),
287 cap: None,
288 shell: None,
289 band: None,
290 wedge: None,
291 intersection: None,
292 union: None,
293 };
294 let core = region.to_core().unwrap();
295 assert!(matches!(core, Region::Cone(_)));
296 }
297
298 #[test]
299 fn region_input_intersection_recursive() {
300 let shell_region = RegionInput {
301 cone: None,
302 cap: None,
303 shell: Some(ShellInput {
304 inner: 1.0,
305 outer: 5.0,
306 }),
307 band: None,
308 wedge: None,
309 intersection: None,
310 union: None,
311 };
312 let band_region = RegionInput {
313 cone: None,
314 cap: None,
315 shell: None,
316 band: Some(BandInput {
317 phi_min: FRAC_PI_4,
318 phi_max: FRAC_PI_2,
319 }),
320 wedge: None,
321 intersection: None,
322 union: None,
323 };
324 let compound = RegionInput {
325 cone: None,
326 cap: None,
327 shell: None,
328 band: None,
329 wedge: None,
330 intersection: Some(vec![shell_region, band_region]),
331 union: None,
332 };
333
334 let core = compound.to_core().unwrap();
335 match core {
336 Region::Intersection(regions) => {
337 assert_eq!(regions.len(), 2);
338 assert!(matches!(regions[0], Region::Shell(_)));
339 assert!(matches!(regions[1], Region::Band(_)));
340 }
341 other => panic!("expected Intersection, got {other:?}"),
342 }
343 }
344
345 #[test]
346 fn invalid_inputs_produce_errors() {
347 let bad_point = sp_input(-1.0, 0.0, 0.0);
348 assert!(bad_point.to_core().is_err());
349
350 let bad_shell = ShellInput {
351 inner: 5.0,
352 outer: 1.0,
353 };
354 assert!(bad_shell.to_core().is_err());
355
356 let bad_band = BandInput {
357 phi_min: PI,
358 phi_max: 0.1,
359 };
360 assert!(bad_band.to_core().is_err());
361
362 let empty_region = RegionInput {
363 cone: None,
364 cap: None,
365 shell: None,
366 band: None,
367 wedge: None,
368 intersection: None,
369 union: None,
370 };
371 assert!(empty_region.to_core().is_err());
372
373 let multi_region = RegionInput {
374 cone: Some(ConeInput {
375 apex: sp_input(0.0, 0.0, 0.0),
376 axis: sp_input(1.0, 0.0, FRAC_PI_2),
377 half_angle: FRAC_PI_4,
378 }),
379 cap: None,
380 shell: Some(ShellInput {
381 inner: 1.0,
382 outer: 5.0,
383 }),
384 band: None,
385 wedge: None,
386 intersection: None,
387 union: None,
388 };
389 assert!(multi_region.to_core().is_err());
390 }
391}