1use std::sync::Arc;
2
3use sphereql_embed::pipeline::{PipelineError, SphereQLPipeline};
4use sphereql_embed::text_embedder::{NoEmbedder, TextEmbedder};
5use sphereql_index::SpatialIndexBuilder;
6
7use crate::query::{PointIndex, SphericalQueryRoot};
8use crate::subscription::{SpatialEventBus, SphericalSubscriptionRoot};
9use crate::types::{CategorizedItemInput, items_to_pipeline_input};
10
11pub type SphericalSchema = async_graphql::Schema<
12 SphericalQueryRoot,
13 async_graphql::EmptyMutation,
14 SphericalSubscriptionRoot,
15>;
16
17pub fn build_schema(index: PointIndex, event_bus: SpatialEventBus) -> SphericalSchema {
18 async_graphql::Schema::build(
19 SphericalQueryRoot,
20 async_graphql::EmptyMutation,
21 SphericalSubscriptionRoot,
22 )
23 .limit_depth(crate::DEFAULT_MAX_DEPTH)
24 .limit_complexity(crate::DEFAULT_MAX_COMPLEXITY)
25 .data(index)
26 .data(event_bus)
27 .finish()
28}
29
30pub fn create_default_index() -> PointIndex {
31 Arc::new(tokio::sync::RwLock::new(
32 SpatialIndexBuilder::new()
33 .uniform_shells(5, 10.0)
34 .theta_divisions(12)
35 .phi_divisions(6)
36 .build(),
37 ))
38}
39
40pub fn create_schema_with_defaults() -> SphericalSchema {
41 build_schema(create_default_index(), SpatialEventBus::new(256))
42}
43
44pub type CategoryPipelineHandle = Arc<tokio::sync::RwLock<SphereQLPipeline>>;
58
59pub type EmbedderHandle = Arc<dyn TextEmbedder>;
61
62pub fn into_pipeline_handle(pipeline: SphereQLPipeline) -> CategoryPipelineHandle {
64 Arc::new(tokio::sync::RwLock::new(pipeline))
65}
66
67pub fn build_pipeline_handle_from_items(
76 items: &[CategorizedItemInput],
77) -> Result<CategoryPipelineHandle, PipelineError> {
78 let input = items_to_pipeline_input(items);
79 let pipeline = SphereQLPipeline::new(input)?;
80 Ok(into_pipeline_handle(pipeline))
81}
82
83pub fn default_no_embedder_handle() -> EmbedderHandle {
88 Arc::new(NoEmbedder)
89}
90
91#[cfg(test)]
92mod category_context_tests {
93 use super::*;
94
95 fn synthetic_items() -> Vec<CategorizedItemInput> {
96 let pairs = [
97 ("a", "science", vec![1.0, 0.1, 0.0, 0.2]),
98 ("b", "cooking", vec![0.1, 1.0, 0.0, 0.2]),
99 ("c", "science", vec![0.9, 0.2, 0.1, 0.3]),
100 ("d", "cooking", vec![0.2, 0.9, 0.1, 0.3]),
101 ("e", "science", vec![0.8, 0.3, 0.2, 0.1]),
102 ("f", "cooking", vec![0.3, 0.8, 0.2, 0.1]),
103 ];
104 pairs
105 .into_iter()
106 .map(|(id, cat, emb)| CategorizedItemInput {
107 id: id.into(),
108 category: cat.into(),
109 embedding: emb,
110 })
111 .collect()
112 }
113
114 #[tokio::test]
115 async fn build_pipeline_handle_from_items_constructs_pipeline() {
116 let items = synthetic_items();
117 let handle = build_pipeline_handle_from_items(&items).expect("pipeline build failed");
118 let read = handle.read().await;
119 assert_eq!(read.num_items(), 6);
120 assert_eq!(read.projection_kind().name(), "pca");
122 }
123
124 #[tokio::test]
125 async fn no_embedder_handle_errors_on_embed() {
126 let h = default_no_embedder_handle();
127 let err = h.embed("hi").unwrap_err();
128 assert!(err.to_string().contains("no TextEmbedder configured"));
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use std::f64::consts::{FRAC_PI_2, FRAC_PI_4};
136
137 use sphereql_core::SphericalPoint;
138
139 use crate::query::PointItem;
140
141 fn point(r: f64, theta: f64, phi: f64) -> SphericalPoint {
142 SphericalPoint::new_unchecked(r, theta, phi)
143 }
144
145 fn item(id: &str, r: f64, theta: f64, phi: f64) -> PointItem {
146 PointItem {
147 id: id.into(),
148 position: point(r, theta, phi),
149 }
150 }
151
152 async fn schema_with_items(items: Vec<PointItem>) -> SphericalSchema {
153 let index = create_default_index();
154 {
155 let mut idx = index.write().await;
156 for it in items {
157 idx.insert(it);
158 }
159 }
160 build_schema(index, SpatialEventBus::new(16))
161 }
162
163 #[tokio::test]
164 async fn test_within_cone_query() {
165 let schema = schema_with_items(vec![
167 item("a", 1.0, 0.5, FRAC_PI_4),
168 item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
169 item("c", 1.0, 2.5, FRAC_PI_2 + 1.0),
170 ])
171 .await;
172
173 let res = schema
174 .execute(
175 r#"{ withinCone(
176 cone: {
177 apex: { r: 0.0, theta: 0.0, phi: 0.0 },
178 axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
179 halfAngle: 0.5
180 }
181 ) { items { r theta phi } totalScanned } }"#,
182 )
183 .await;
184
185 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
186 let data = res.data.into_json().unwrap();
187 let items = data["withinCone"]["items"].as_array().unwrap();
188 assert_eq!(
189 items.len(),
190 2,
191 "expected exactly 2 items in cone, got {}",
192 items.len()
193 );
194 }
195
196 #[tokio::test]
197 async fn test_within_shell_query() {
198 let schema = schema_with_items(vec![
200 item("near", 1.0, 0.5, FRAC_PI_4),
201 item("mid", 3.0, 0.5, FRAC_PI_4),
202 item("far", 7.0, 0.5, FRAC_PI_4),
203 ])
204 .await;
205
206 let res = schema
207 .execute(r#"{ withinShell(shell: { inner: 2.0, outer: 5.0 }) { items { r } totalScanned } }"#)
208 .await;
209
210 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
211 let data = res.data.into_json().unwrap();
212 let items = data["withinShell"]["items"].as_array().unwrap();
213 assert_eq!(items.len(), 1, "only the r=3 item should be in shell [2,5]");
214 let r = items[0]["r"].as_f64().unwrap();
215 assert!((r - 3.0).abs() < 1e-6);
216 }
217
218 #[tokio::test]
219 async fn test_nearest_to_query() {
220 let schema = schema_with_items(vec![
221 item("a", 1.0, 0.5, FRAC_PI_4),
222 item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
223 item("c", 1.0, 2.0, FRAC_PI_2),
224 ])
225 .await;
226
227 let res = schema
228 .execute(
229 r#"{ nearestTo(
230 point: { r: 1.0, theta: 0.5, phi: 0.7854 },
231 k: 2
232 ) { point { r theta phi } distance } }"#,
233 )
234 .await;
235
236 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
237 let data = res.data.into_json().unwrap();
238 let results = data["nearestTo"].as_array().unwrap();
239 assert_eq!(results.len(), 2, "expected 2 nearest results");
240
241 let d0 = results[0]["distance"].as_f64().unwrap();
242 let d1 = results[1]["distance"].as_f64().unwrap();
243 assert!(d0 <= d1, "results should be sorted by distance ascending");
244 }
245
246 #[tokio::test]
247 async fn test_distance_between_query() {
248 let schema = create_schema_with_defaults();
249
250 let res = schema
251 .execute(
252 r#"{ distanceBetween(
253 a: { r: 1.0, theta: 0.0, phi: 0.7854 },
254 b: { r: 1.0, theta: 1.5, phi: 1.2 }
255 ) }"#,
256 )
257 .await;
258
259 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
260 let data = res.data.into_json().unwrap();
261 let distance = data["distanceBetween"].as_f64().unwrap();
262 assert!(
263 distance > 0.0,
264 "distance between distinct points should be positive"
265 );
266 }
267
268 #[tokio::test]
269 async fn test_within_region_query() {
270 let schema = schema_with_items(vec![
271 item("in_shell", 3.0, 0.5, FRAC_PI_4),
272 item("out_shell", 8.0, 0.5, FRAC_PI_4),
273 ])
274 .await;
275
276 let res = schema
277 .execute(
278 r#"{ withinRegion(
279 region: { shell: { inner: 2.0, outer: 5.0 } }
280 ) { items { r } totalScanned } }"#,
281 )
282 .await;
283
284 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
285 let data = res.data.into_json().unwrap();
286 let items = data["withinRegion"]["items"].as_array().unwrap();
287 assert_eq!(items.len(), 1);
288 let r = items[0]["r"].as_f64().unwrap();
289 assert!((r - 3.0).abs() < 1e-6);
290 }
291
292 #[tokio::test]
293 async fn test_empty_index_queries() {
294 let schema = create_schema_with_defaults();
295
296 let cone_res = schema
297 .execute(
298 r#"{ withinCone(
299 cone: {
300 apex: { r: 0.0, theta: 0.0, phi: 0.0 },
301 axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
302 halfAngle: 1.0
303 }
304 ) { items { r } totalScanned } }"#,
305 )
306 .await;
307 assert!(
308 cone_res.errors.is_empty(),
309 "cone errors: {:?}",
310 cone_res.errors
311 );
312 let data = cone_res.data.into_json().unwrap();
313 let items = data["withinCone"]["items"].as_array().unwrap();
314 assert!(items.is_empty());
315
316 let nearest_res = schema
317 .execute(
318 r#"{ nearestTo(
319 point: { r: 1.0, theta: 0.5, phi: 0.7854 },
320 k: 5
321 ) { point { r } distance } }"#,
322 )
323 .await;
324 assert!(
325 nearest_res.errors.is_empty(),
326 "nearest errors: {:?}",
327 nearest_res.errors
328 );
329 let data = nearest_res.data.into_json().unwrap();
330 let results = data["nearestTo"].as_array().unwrap();
331 assert!(results.is_empty());
332 }
333
334 #[tokio::test]
335 async fn test_schema_sdl_contains_expected_types() {
336 let schema = create_schema_with_defaults();
337 let sdl = schema.sdl();
338
339 let expected = [
340 "SphericalPointOutput",
341 "SpatialQueryResultOutput",
342 "NearestResultOutput",
343 "withinCone",
344 "withinShell",
345 "nearestTo",
346 "distanceBetween",
347 "SphericalPointInput",
348 ];
349
350 for token in &expected {
351 assert!(
352 sdl.contains(token),
353 "SDL missing expected token '{}'. SDL:\n{}",
354 token,
355 &sdl[..sdl.len().min(2000)],
356 );
357 }
358 }
359}