1use std::sync::Arc;
2
3use sphereql_embed::pipeline::{PipelineError, SphereQLPipeline};
4use sphereql_embed::text_embedder::{NoEmbedder, TextEmbedder};
5use sphereql_index::SpatialIndexBuilder;
6
7use crate::query::{PointIndex, SphericalQueryRoot};
8use crate::subscription::{SpatialEventBus, SphericalSubscriptionRoot};
9use crate::types::{CategorizedItemInput, items_to_pipeline_input};
10
11pub type SphericalSchema = async_graphql::Schema<
12 SphericalQueryRoot,
13 async_graphql::EmptyMutation,
14 SphericalSubscriptionRoot,
15>;
16
17pub fn build_schema(index: PointIndex, event_bus: SpatialEventBus) -> SphericalSchema {
18 async_graphql::Schema::build(
19 SphericalQueryRoot,
20 async_graphql::EmptyMutation,
21 SphericalSubscriptionRoot,
22 )
23 .data(index)
24 .data(event_bus)
25 .finish()
26}
27
28pub fn create_default_index() -> PointIndex {
29 Arc::new(tokio::sync::RwLock::new(
30 SpatialIndexBuilder::new()
31 .uniform_shells(5, 10.0)
32 .theta_divisions(12)
33 .phi_divisions(6)
34 .build(),
35 ))
36}
37
38pub fn create_schema_with_defaults() -> SphericalSchema {
39 build_schema(create_default_index(), SpatialEventBus::new(256))
40}
41
42pub type CategoryPipelineHandle = Arc<tokio::sync::RwLock<SphereQLPipeline>>;
56
57pub type EmbedderHandle = Arc<dyn TextEmbedder>;
59
60pub fn into_pipeline_handle(pipeline: SphereQLPipeline) -> CategoryPipelineHandle {
62 Arc::new(tokio::sync::RwLock::new(pipeline))
63}
64
65pub fn build_pipeline_handle_from_items(
74 items: &[CategorizedItemInput],
75) -> Result<CategoryPipelineHandle, PipelineError> {
76 let input = items_to_pipeline_input(items);
77 let pipeline = SphereQLPipeline::new(input)?;
78 Ok(into_pipeline_handle(pipeline))
79}
80
81pub fn default_no_embedder_handle() -> EmbedderHandle {
86 Arc::new(NoEmbedder)
87}
88
89#[cfg(test)]
90mod category_context_tests {
91 use super::*;
92
93 fn synthetic_items() -> Vec<CategorizedItemInput> {
94 let pairs = [
95 ("a", "science", vec![1.0, 0.1, 0.0, 0.2]),
96 ("b", "cooking", vec![0.1, 1.0, 0.0, 0.2]),
97 ("c", "science", vec![0.9, 0.2, 0.1, 0.3]),
98 ("d", "cooking", vec![0.2, 0.9, 0.1, 0.3]),
99 ("e", "science", vec![0.8, 0.3, 0.2, 0.1]),
100 ("f", "cooking", vec![0.3, 0.8, 0.2, 0.1]),
101 ];
102 pairs
103 .into_iter()
104 .map(|(id, cat, emb)| CategorizedItemInput {
105 id: id.into(),
106 category: cat.into(),
107 embedding: emb,
108 })
109 .collect()
110 }
111
112 #[tokio::test]
113 async fn build_pipeline_handle_from_items_constructs_pipeline() {
114 let items = synthetic_items();
115 let handle = build_pipeline_handle_from_items(&items).expect("pipeline build failed");
116 let read = handle.read().await;
117 assert_eq!(read.num_items(), 6);
118 assert_eq!(read.projection_kind().name(), "pca");
120 }
121
122 #[tokio::test]
123 async fn no_embedder_handle_errors_on_embed() {
124 let h = default_no_embedder_handle();
125 let err = h.embed("hi").unwrap_err();
126 assert!(err.to_string().contains("no TextEmbedder configured"));
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::f64::consts::{FRAC_PI_2, FRAC_PI_4};
134
135 use sphereql_core::SphericalPoint;
136
137 use crate::query::PointItem;
138
139 fn point(r: f64, theta: f64, phi: f64) -> SphericalPoint {
140 SphericalPoint::new_unchecked(r, theta, phi)
141 }
142
143 fn item(id: &str, r: f64, theta: f64, phi: f64) -> PointItem {
144 PointItem {
145 id: id.into(),
146 position: point(r, theta, phi),
147 }
148 }
149
150 async fn schema_with_items(items: Vec<PointItem>) -> SphericalSchema {
151 let index = create_default_index();
152 {
153 let mut idx = index.write().await;
154 for it in items {
155 idx.insert(it);
156 }
157 }
158 build_schema(index, SpatialEventBus::new(16))
159 }
160
161 #[tokio::test]
162 async fn test_within_cone_query() {
163 let schema = schema_with_items(vec![
165 item("a", 1.0, 0.5, FRAC_PI_4),
166 item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
167 item("c", 1.0, 2.5, FRAC_PI_2 + 1.0),
168 ])
169 .await;
170
171 let res = schema
172 .execute(
173 r#"{ withinCone(
174 cone: {
175 apex: { r: 0.0, theta: 0.0, phi: 0.0 },
176 axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
177 halfAngle: 0.5
178 }
179 ) { items { r theta phi } totalScanned } }"#,
180 )
181 .await;
182
183 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
184 let data = res.data.into_json().unwrap();
185 let items = data["withinCone"]["items"].as_array().unwrap();
186 assert!(
187 items.len() >= 2,
188 "expected at least 2 items in cone, got {}",
189 items.len()
190 );
191 }
192
193 #[tokio::test]
194 async fn test_within_shell_query() {
195 let schema = schema_with_items(vec![
197 item("near", 1.0, 0.5, FRAC_PI_4),
198 item("mid", 3.0, 0.5, FRAC_PI_4),
199 item("far", 7.0, 0.5, FRAC_PI_4),
200 ])
201 .await;
202
203 let res = schema
204 .execute(r#"{ withinShell(shell: { inner: 2.0, outer: 5.0 }) { items { r } totalScanned } }"#)
205 .await;
206
207 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
208 let data = res.data.into_json().unwrap();
209 let items = data["withinShell"]["items"].as_array().unwrap();
210 assert_eq!(items.len(), 1, "only the r=3 item should be in shell [2,5]");
211 let r = items[0]["r"].as_f64().unwrap();
212 assert!((r - 3.0).abs() < 1e-6);
213 }
214
215 #[tokio::test]
216 async fn test_nearest_to_query() {
217 let schema = schema_with_items(vec![
218 item("a", 1.0, 0.5, FRAC_PI_4),
219 item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
220 item("c", 1.0, 2.0, FRAC_PI_2),
221 ])
222 .await;
223
224 let res = schema
225 .execute(
226 r#"{ nearestTo(
227 point: { r: 1.0, theta: 0.5, phi: 0.7854 },
228 k: 2
229 ) { point { r theta phi } distance } }"#,
230 )
231 .await;
232
233 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
234 let data = res.data.into_json().unwrap();
235 let results = data["nearestTo"].as_array().unwrap();
236 assert_eq!(results.len(), 2, "expected 2 nearest results");
237
238 let d0 = results[0]["distance"].as_f64().unwrap();
239 let d1 = results[1]["distance"].as_f64().unwrap();
240 assert!(d0 <= d1, "results should be sorted by distance ascending");
241 }
242
243 #[tokio::test]
244 async fn test_distance_between_query() {
245 let schema = create_schema_with_defaults();
246
247 let res = schema
248 .execute(
249 r#"{ distanceBetween(
250 a: { r: 1.0, theta: 0.0, phi: 0.7854 },
251 b: { r: 1.0, theta: 1.5, phi: 1.2 }
252 ) }"#,
253 )
254 .await;
255
256 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
257 let data = res.data.into_json().unwrap();
258 let distance = data["distanceBetween"].as_f64().unwrap();
259 assert!(
260 distance > 0.0,
261 "distance between distinct points should be positive"
262 );
263 }
264
265 #[tokio::test]
266 async fn test_within_region_query() {
267 let schema = schema_with_items(vec![
268 item("in_shell", 3.0, 0.5, FRAC_PI_4),
269 item("out_shell", 8.0, 0.5, FRAC_PI_4),
270 ])
271 .await;
272
273 let res = schema
274 .execute(
275 r#"{ withinRegion(
276 region: { shell: { inner: 2.0, outer: 5.0 } }
277 ) { items { r } totalScanned } }"#,
278 )
279 .await;
280
281 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
282 let data = res.data.into_json().unwrap();
283 let items = data["withinRegion"]["items"].as_array().unwrap();
284 assert_eq!(items.len(), 1);
285 let r = items[0]["r"].as_f64().unwrap();
286 assert!((r - 3.0).abs() < 1e-6);
287 }
288
289 #[tokio::test]
290 async fn test_empty_index_queries() {
291 let schema = create_schema_with_defaults();
292
293 let cone_res = schema
294 .execute(
295 r#"{ withinCone(
296 cone: {
297 apex: { r: 0.0, theta: 0.0, phi: 0.0 },
298 axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
299 halfAngle: 1.0
300 }
301 ) { items { r } totalScanned } }"#,
302 )
303 .await;
304 assert!(
305 cone_res.errors.is_empty(),
306 "cone errors: {:?}",
307 cone_res.errors
308 );
309 let data = cone_res.data.into_json().unwrap();
310 let items = data["withinCone"]["items"].as_array().unwrap();
311 assert!(items.is_empty());
312
313 let nearest_res = schema
314 .execute(
315 r#"{ nearestTo(
316 point: { r: 1.0, theta: 0.5, phi: 0.7854 },
317 k: 5
318 ) { point { r } distance } }"#,
319 )
320 .await;
321 assert!(
322 nearest_res.errors.is_empty(),
323 "nearest errors: {:?}",
324 nearest_res.errors
325 );
326 let data = nearest_res.data.into_json().unwrap();
327 let results = data["nearestTo"].as_array().unwrap();
328 assert!(results.is_empty());
329 }
330
331 #[tokio::test]
332 async fn test_schema_sdl_contains_expected_types() {
333 let schema = create_schema_with_defaults();
334 let sdl = schema.sdl();
335
336 let expected = [
337 "SphericalPointOutput",
338 "SpatialQueryResultOutput",
339 "NearestResultOutput",
340 "withinCone",
341 "withinShell",
342 "nearestTo",
343 "distanceBetween",
344 "SphericalPointInput",
345 ];
346
347 for token in &expected {
348 assert!(
349 sdl.contains(token),
350 "SDL missing expected token '{}'. SDL:\n{}",
351 token,
352 &sdl[..sdl.len().min(2000)],
353 );
354 }
355 }
356}