Skip to main content

sphereql_graphql/
context.rs

1use std::sync::Arc;
2
3use sphereql_embed::pipeline::{PipelineError, SphereQLPipeline};
4use sphereql_embed::text_embedder::{NoEmbedder, TextEmbedder};
5use sphereql_index::SpatialIndexBuilder;
6
7use crate::query::{PointIndex, SphericalQueryRoot};
8use crate::subscription::{SpatialEventBus, SphericalSubscriptionRoot};
9use crate::types::{CategorizedItemInput, items_to_pipeline_input};
10
11pub type SphericalSchema = async_graphql::Schema<
12    SphericalQueryRoot,
13    async_graphql::EmptyMutation,
14    SphericalSubscriptionRoot,
15>;
16
17pub fn build_schema(index: PointIndex, event_bus: SpatialEventBus) -> SphericalSchema {
18    async_graphql::Schema::build(
19        SphericalQueryRoot,
20        async_graphql::EmptyMutation,
21        SphericalSubscriptionRoot,
22    )
23    .limit_depth(crate::DEFAULT_MAX_DEPTH)
24    .limit_complexity(crate::DEFAULT_MAX_COMPLEXITY)
25    .data(index)
26    .data(event_bus)
27    .finish()
28}
29
30pub fn create_default_index() -> PointIndex {
31    Arc::new(tokio::sync::RwLock::new(
32        SpatialIndexBuilder::new()
33            .uniform_shells(5, 10.0)
34            .theta_divisions(12)
35            .phi_divisions(6)
36            .build(),
37    ))
38}
39
40pub fn create_schema_with_defaults() -> SphericalSchema {
41    build_schema(create_default_index(), SpatialEventBus::new(256))
42}
43
44// ── Category-enrichment context ────────────────────────────────────────
45//
46// The category resolvers (Phase 4) consume two extra context resources:
47//
48// - [`CategoryPipelineHandle`] — the fitted [`SphereQLPipeline`] wrapped in
49//   `Arc<RwLock<…>>` so resolvers can read concurrently and the (eventual)
50//   mutation surface can swap it under exclusive lock.
51// - [`EmbedderHandle`] — a type-erased [`TextEmbedder`] used by resolvers
52//   that take a `queryText: String` argument (drillDown,
53//   hierarchicalNearest, etc.). Defaults to [`NoEmbedder`], which returns
54//   a clear error rather than silently degrading.
55
56/// Shared, lockable handle to the fitted category-enrichment pipeline.
57pub type CategoryPipelineHandle = Arc<tokio::sync::RwLock<SphereQLPipeline>>;
58
59/// Type-erased text embedder shared across resolvers.
60pub type EmbedderHandle = Arc<dyn TextEmbedder>;
61
62/// Wrap an existing [`SphereQLPipeline`] for use as GraphQL context data.
63pub fn into_pipeline_handle(pipeline: SphereQLPipeline) -> CategoryPipelineHandle {
64    Arc::new(tokio::sync::RwLock::new(pipeline))
65}
66
67/// Build a fresh [`SphereQLPipeline`] from a slice of
68/// [`CategorizedItemInput`]s using [`SphereQLPipeline::new`] (default
69/// `PipelineConfig`). Returns the wrapped handle ready to feed into a
70/// schema.
71///
72/// For finer control (custom projection kind, Laplacian params, etc.)
73/// build the pipeline directly via [`SphereQLPipeline::new_with_config`]
74/// and pass the result through [`into_pipeline_handle`].
75pub fn build_pipeline_handle_from_items(
76    items: &[CategorizedItemInput],
77) -> Result<CategoryPipelineHandle, PipelineError> {
78    let input = items_to_pipeline_input(items);
79    let pipeline = SphereQLPipeline::new(input)?;
80    Ok(into_pipeline_handle(pipeline))
81}
82
83/// Default embedder handle that always errors. Use this when wiring a
84/// schema for a deployment that only exposes resolvers which don't
85/// require text embedding (or as a placeholder until a real embedder
86/// is plugged in).
87pub fn default_no_embedder_handle() -> EmbedderHandle {
88    Arc::new(NoEmbedder)
89}
90
91#[cfg(test)]
92mod category_context_tests {
93    use super::*;
94
95    fn synthetic_items() -> Vec<CategorizedItemInput> {
96        let pairs = [
97            ("a", "science", vec![1.0, 0.1, 0.0, 0.2]),
98            ("b", "cooking", vec![0.1, 1.0, 0.0, 0.2]),
99            ("c", "science", vec![0.9, 0.2, 0.1, 0.3]),
100            ("d", "cooking", vec![0.2, 0.9, 0.1, 0.3]),
101            ("e", "science", vec![0.8, 0.3, 0.2, 0.1]),
102            ("f", "cooking", vec![0.3, 0.8, 0.2, 0.1]),
103        ];
104        pairs
105            .into_iter()
106            .map(|(id, cat, emb)| CategorizedItemInput {
107                id: id.into(),
108                category: cat.into(),
109                embedding: emb,
110            })
111            .collect()
112    }
113
114    #[tokio::test]
115    async fn build_pipeline_handle_from_items_constructs_pipeline() {
116        let items = synthetic_items();
117        let handle = build_pipeline_handle_from_items(&items).expect("pipeline build failed");
118        let read = handle.read().await;
119        assert_eq!(read.num_items(), 6);
120        // Default projection kind is PCA.
121        assert_eq!(read.projection_kind().name(), "pca");
122    }
123
124    #[tokio::test]
125    async fn no_embedder_handle_errors_on_embed() {
126        let h = default_no_embedder_handle();
127        let err = h.embed("hi").unwrap_err();
128        assert!(err.to_string().contains("no TextEmbedder configured"));
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use std::f64::consts::{FRAC_PI_2, FRAC_PI_4};
136
137    use sphereql_core::SphericalPoint;
138
139    use crate::query::PointItem;
140
141    fn point(r: f64, theta: f64, phi: f64) -> SphericalPoint {
142        SphericalPoint::new_unchecked(r, theta, phi)
143    }
144
145    fn item(id: &str, r: f64, theta: f64, phi: f64) -> PointItem {
146        PointItem {
147            id: id.into(),
148            position: point(r, theta, phi),
149        }
150    }
151
152    async fn schema_with_items(items: Vec<PointItem>) -> SphericalSchema {
153        let index = create_default_index();
154        {
155            let mut idx = index.write().await;
156            for it in items {
157                idx.insert(it);
158            }
159        }
160        build_schema(index, SpatialEventBus::new(16))
161    }
162
163    #[tokio::test]
164    async fn test_within_cone_query() {
165        // Three items: two near (theta=0.5, phi=PI/4) and one far away
166        let schema = schema_with_items(vec![
167            item("a", 1.0, 0.5, FRAC_PI_4),
168            item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
169            item("c", 1.0, 2.5, FRAC_PI_2 + 1.0),
170        ])
171        .await;
172
173        let res = schema
174            .execute(
175                r#"{ withinCone(
176                    cone: {
177                        apex: { r: 0.0, theta: 0.0, phi: 0.0 },
178                        axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
179                        halfAngle: 0.5
180                    }
181                ) { items { r theta phi } totalScanned } }"#,
182            )
183            .await;
184
185        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
186        let data = res.data.into_json().unwrap();
187        let items = data["withinCone"]["items"].as_array().unwrap();
188        assert_eq!(
189            items.len(),
190            2,
191            "expected exactly 2 items in cone, got {}",
192            items.len()
193        );
194    }
195
196    #[tokio::test]
197    async fn test_within_shell_query() {
198        // Items at different radii: r=1, r=3, r=7
199        let schema = schema_with_items(vec![
200            item("near", 1.0, 0.5, FRAC_PI_4),
201            item("mid", 3.0, 0.5, FRAC_PI_4),
202            item("far", 7.0, 0.5, FRAC_PI_4),
203        ])
204        .await;
205
206        let res = schema
207            .execute(r#"{ withinShell(shell: { inner: 2.0, outer: 5.0 }) { items { r } totalScanned } }"#)
208            .await;
209
210        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
211        let data = res.data.into_json().unwrap();
212        let items = data["withinShell"]["items"].as_array().unwrap();
213        assert_eq!(items.len(), 1, "only the r=3 item should be in shell [2,5]");
214        let r = items[0]["r"].as_f64().unwrap();
215        assert!((r - 3.0).abs() < 1e-6);
216    }
217
218    #[tokio::test]
219    async fn test_nearest_to_query() {
220        let schema = schema_with_items(vec![
221            item("a", 1.0, 0.5, FRAC_PI_4),
222            item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
223            item("c", 1.0, 2.0, FRAC_PI_2),
224        ])
225        .await;
226
227        let res = schema
228            .execute(
229                r#"{ nearestTo(
230                    point: { r: 1.0, theta: 0.5, phi: 0.7854 },
231                    k: 2
232                ) { point { r theta phi } distance } }"#,
233            )
234            .await;
235
236        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
237        let data = res.data.into_json().unwrap();
238        let results = data["nearestTo"].as_array().unwrap();
239        assert_eq!(results.len(), 2, "expected 2 nearest results");
240
241        let d0 = results[0]["distance"].as_f64().unwrap();
242        let d1 = results[1]["distance"].as_f64().unwrap();
243        assert!(d0 <= d1, "results should be sorted by distance ascending");
244    }
245
246    #[tokio::test]
247    async fn test_distance_between_query() {
248        let schema = create_schema_with_defaults();
249
250        let res = schema
251            .execute(
252                r#"{ distanceBetween(
253                    a: { r: 1.0, theta: 0.0, phi: 0.7854 },
254                    b: { r: 1.0, theta: 1.5, phi: 1.2 }
255                ) }"#,
256            )
257            .await;
258
259        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
260        let data = res.data.into_json().unwrap();
261        let distance = data["distanceBetween"].as_f64().unwrap();
262        assert!(
263            distance > 0.0,
264            "distance between distinct points should be positive"
265        );
266    }
267
268    #[tokio::test]
269    async fn test_within_region_query() {
270        let schema = schema_with_items(vec![
271            item("in_shell", 3.0, 0.5, FRAC_PI_4),
272            item("out_shell", 8.0, 0.5, FRAC_PI_4),
273        ])
274        .await;
275
276        let res = schema
277            .execute(
278                r#"{ withinRegion(
279                    region: { shell: { inner: 2.0, outer: 5.0 } }
280                ) { items { r } totalScanned } }"#,
281            )
282            .await;
283
284        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
285        let data = res.data.into_json().unwrap();
286        let items = data["withinRegion"]["items"].as_array().unwrap();
287        assert_eq!(items.len(), 1);
288        let r = items[0]["r"].as_f64().unwrap();
289        assert!((r - 3.0).abs() < 1e-6);
290    }
291
292    #[tokio::test]
293    async fn test_empty_index_queries() {
294        let schema = create_schema_with_defaults();
295
296        let cone_res = schema
297            .execute(
298                r#"{ withinCone(
299                    cone: {
300                        apex: { r: 0.0, theta: 0.0, phi: 0.0 },
301                        axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
302                        halfAngle: 1.0
303                    }
304                ) { items { r } totalScanned } }"#,
305            )
306            .await;
307        assert!(
308            cone_res.errors.is_empty(),
309            "cone errors: {:?}",
310            cone_res.errors
311        );
312        let data = cone_res.data.into_json().unwrap();
313        let items = data["withinCone"]["items"].as_array().unwrap();
314        assert!(items.is_empty());
315
316        let nearest_res = schema
317            .execute(
318                r#"{ nearestTo(
319                    point: { r: 1.0, theta: 0.5, phi: 0.7854 },
320                    k: 5
321                ) { point { r } distance } }"#,
322            )
323            .await;
324        assert!(
325            nearest_res.errors.is_empty(),
326            "nearest errors: {:?}",
327            nearest_res.errors
328        );
329        let data = nearest_res.data.into_json().unwrap();
330        let results = data["nearestTo"].as_array().unwrap();
331        assert!(results.is_empty());
332    }
333
334    #[tokio::test]
335    async fn test_schema_sdl_contains_expected_types() {
336        let schema = create_schema_with_defaults();
337        let sdl = schema.sdl();
338
339        let expected = [
340            "SphericalPointOutput",
341            "SpatialQueryResultOutput",
342            "NearestResultOutput",
343            "withinCone",
344            "withinShell",
345            "nearestTo",
346            "distanceBetween",
347            "SphericalPointInput",
348        ];
349
350        for token in &expected {
351            assert!(
352                sdl.contains(token),
353                "SDL missing expected token '{}'. SDL:\n{}",
354                token,
355                &sdl[..sdl.len().min(2000)],
356            );
357        }
358    }
359}