Skip to main content

sphereql_graphql/
context.rs

1use std::sync::Arc;
2
3use sphereql_embed::pipeline::{PipelineError, SphereQLPipeline};
4use sphereql_embed::text_embedder::{NoEmbedder, TextEmbedder};
5use sphereql_index::SpatialIndexBuilder;
6
7use crate::query::{PointIndex, SphericalQueryRoot};
8use crate::subscription::{SpatialEventBus, SphericalSubscriptionRoot};
9use crate::types::{CategorizedItemInput, items_to_pipeline_input};
10
11pub type SphericalSchema = async_graphql::Schema<
12    SphericalQueryRoot,
13    async_graphql::EmptyMutation,
14    SphericalSubscriptionRoot,
15>;
16
17pub fn build_schema(index: PointIndex, event_bus: SpatialEventBus) -> SphericalSchema {
18    async_graphql::Schema::build(
19        SphericalQueryRoot,
20        async_graphql::EmptyMutation,
21        SphericalSubscriptionRoot,
22    )
23    .data(index)
24    .data(event_bus)
25    .finish()
26}
27
28pub fn create_default_index() -> PointIndex {
29    Arc::new(tokio::sync::RwLock::new(
30        SpatialIndexBuilder::new()
31            .uniform_shells(5, 10.0)
32            .theta_divisions(12)
33            .phi_divisions(6)
34            .build(),
35    ))
36}
37
38pub fn create_schema_with_defaults() -> SphericalSchema {
39    build_schema(create_default_index(), SpatialEventBus::new(256))
40}
41
42// ── Category-enrichment context ────────────────────────────────────────
43//
44// The category resolvers (Phase 4) consume two extra context resources:
45//
46// - [`CategoryPipelineHandle`] — the fitted [`SphereQLPipeline`] wrapped in
47//   `Arc<RwLock<…>>` so resolvers can read concurrently and the (eventual)
48//   mutation surface can swap it under exclusive lock.
49// - [`EmbedderHandle`] — a type-erased [`TextEmbedder`] used by resolvers
50//   that take a `queryText: String` argument (drillDown,
51//   hierarchicalNearest, etc.). Defaults to [`NoEmbedder`], which returns
52//   a clear error rather than silently degrading.
53
54/// Shared, lockable handle to the fitted category-enrichment pipeline.
55pub type CategoryPipelineHandle = Arc<tokio::sync::RwLock<SphereQLPipeline>>;
56
57/// Type-erased text embedder shared across resolvers.
58pub type EmbedderHandle = Arc<dyn TextEmbedder>;
59
60/// Wrap an existing [`SphereQLPipeline`] for use as GraphQL context data.
61pub fn into_pipeline_handle(pipeline: SphereQLPipeline) -> CategoryPipelineHandle {
62    Arc::new(tokio::sync::RwLock::new(pipeline))
63}
64
65/// Build a fresh [`SphereQLPipeline`] from a slice of
66/// [`CategorizedItemInput`]s using [`SphereQLPipeline::new`] (default
67/// `PipelineConfig`). Returns the wrapped handle ready to feed into a
68/// schema.
69///
70/// For finer control (custom projection kind, Laplacian params, etc.)
71/// build the pipeline directly via [`SphereQLPipeline::new_with_config`]
72/// and pass the result through [`into_pipeline_handle`].
73pub fn build_pipeline_handle_from_items(
74    items: &[CategorizedItemInput],
75) -> Result<CategoryPipelineHandle, PipelineError> {
76    let input = items_to_pipeline_input(items);
77    let pipeline = SphereQLPipeline::new(input)?;
78    Ok(into_pipeline_handle(pipeline))
79}
80
81/// Default embedder handle that always errors. Use this when wiring a
82/// schema for a deployment that only exposes resolvers which don't
83/// require text embedding (or as a placeholder until a real embedder
84/// is plugged in).
85pub fn default_no_embedder_handle() -> EmbedderHandle {
86    Arc::new(NoEmbedder)
87}
88
89#[cfg(test)]
90mod category_context_tests {
91    use super::*;
92
93    fn synthetic_items() -> Vec<CategorizedItemInput> {
94        let pairs = [
95            ("a", "science", vec![1.0, 0.1, 0.0, 0.2]),
96            ("b", "cooking", vec![0.1, 1.0, 0.0, 0.2]),
97            ("c", "science", vec![0.9, 0.2, 0.1, 0.3]),
98            ("d", "cooking", vec![0.2, 0.9, 0.1, 0.3]),
99            ("e", "science", vec![0.8, 0.3, 0.2, 0.1]),
100            ("f", "cooking", vec![0.3, 0.8, 0.2, 0.1]),
101        ];
102        pairs
103            .into_iter()
104            .map(|(id, cat, emb)| CategorizedItemInput {
105                id: id.into(),
106                category: cat.into(),
107                embedding: emb,
108            })
109            .collect()
110    }
111
112    #[tokio::test]
113    async fn build_pipeline_handle_from_items_constructs_pipeline() {
114        let items = synthetic_items();
115        let handle = build_pipeline_handle_from_items(&items).expect("pipeline build failed");
116        let read = handle.read().await;
117        assert_eq!(read.num_items(), 6);
118        // Default projection kind is PCA.
119        assert_eq!(read.projection_kind().name(), "pca");
120    }
121
122    #[tokio::test]
123    async fn no_embedder_handle_errors_on_embed() {
124        let h = default_no_embedder_handle();
125        let err = h.embed("hi").unwrap_err();
126        assert!(err.to_string().contains("no TextEmbedder configured"));
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use std::f64::consts::{FRAC_PI_2, FRAC_PI_4};
134
135    use sphereql_core::SphericalPoint;
136
137    use crate::query::PointItem;
138
139    fn point(r: f64, theta: f64, phi: f64) -> SphericalPoint {
140        SphericalPoint::new_unchecked(r, theta, phi)
141    }
142
143    fn item(id: &str, r: f64, theta: f64, phi: f64) -> PointItem {
144        PointItem {
145            id: id.into(),
146            position: point(r, theta, phi),
147        }
148    }
149
150    async fn schema_with_items(items: Vec<PointItem>) -> SphericalSchema {
151        let index = create_default_index();
152        {
153            let mut idx = index.write().await;
154            for it in items {
155                idx.insert(it);
156            }
157        }
158        build_schema(index, SpatialEventBus::new(16))
159    }
160
161    #[tokio::test]
162    async fn test_within_cone_query() {
163        // Three items: two near (theta=0.5, phi=PI/4) and one far away
164        let schema = schema_with_items(vec![
165            item("a", 1.0, 0.5, FRAC_PI_4),
166            item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
167            item("c", 1.0, 2.5, FRAC_PI_2 + 1.0),
168        ])
169        .await;
170
171        let res = schema
172            .execute(
173                r#"{ withinCone(
174                    cone: {
175                        apex: { r: 0.0, theta: 0.0, phi: 0.0 },
176                        axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
177                        halfAngle: 0.5
178                    }
179                ) { items { r theta phi } totalScanned } }"#,
180            )
181            .await;
182
183        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
184        let data = res.data.into_json().unwrap();
185        let items = data["withinCone"]["items"].as_array().unwrap();
186        assert!(
187            items.len() >= 2,
188            "expected at least 2 items in cone, got {}",
189            items.len()
190        );
191    }
192
193    #[tokio::test]
194    async fn test_within_shell_query() {
195        // Items at different radii: r=1, r=3, r=7
196        let schema = schema_with_items(vec![
197            item("near", 1.0, 0.5, FRAC_PI_4),
198            item("mid", 3.0, 0.5, FRAC_PI_4),
199            item("far", 7.0, 0.5, FRAC_PI_4),
200        ])
201        .await;
202
203        let res = schema
204            .execute(r#"{ withinShell(shell: { inner: 2.0, outer: 5.0 }) { items { r } totalScanned } }"#)
205            .await;
206
207        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
208        let data = res.data.into_json().unwrap();
209        let items = data["withinShell"]["items"].as_array().unwrap();
210        assert_eq!(items.len(), 1, "only the r=3 item should be in shell [2,5]");
211        let r = items[0]["r"].as_f64().unwrap();
212        assert!((r - 3.0).abs() < 1e-6);
213    }
214
215    #[tokio::test]
216    async fn test_nearest_to_query() {
217        let schema = schema_with_items(vec![
218            item("a", 1.0, 0.5, FRAC_PI_4),
219            item("b", 1.0, 0.6, FRAC_PI_4 + 0.1),
220            item("c", 1.0, 2.0, FRAC_PI_2),
221        ])
222        .await;
223
224        let res = schema
225            .execute(
226                r#"{ nearestTo(
227                    point: { r: 1.0, theta: 0.5, phi: 0.7854 },
228                    k: 2
229                ) { point { r theta phi } distance } }"#,
230            )
231            .await;
232
233        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
234        let data = res.data.into_json().unwrap();
235        let results = data["nearestTo"].as_array().unwrap();
236        assert_eq!(results.len(), 2, "expected 2 nearest results");
237
238        let d0 = results[0]["distance"].as_f64().unwrap();
239        let d1 = results[1]["distance"].as_f64().unwrap();
240        assert!(d0 <= d1, "results should be sorted by distance ascending");
241    }
242
243    #[tokio::test]
244    async fn test_distance_between_query() {
245        let schema = create_schema_with_defaults();
246
247        let res = schema
248            .execute(
249                r#"{ distanceBetween(
250                    a: { r: 1.0, theta: 0.0, phi: 0.7854 },
251                    b: { r: 1.0, theta: 1.5, phi: 1.2 }
252                ) }"#,
253            )
254            .await;
255
256        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
257        let data = res.data.into_json().unwrap();
258        let distance = data["distanceBetween"].as_f64().unwrap();
259        assert!(
260            distance > 0.0,
261            "distance between distinct points should be positive"
262        );
263    }
264
265    #[tokio::test]
266    async fn test_within_region_query() {
267        let schema = schema_with_items(vec![
268            item("in_shell", 3.0, 0.5, FRAC_PI_4),
269            item("out_shell", 8.0, 0.5, FRAC_PI_4),
270        ])
271        .await;
272
273        let res = schema
274            .execute(
275                r#"{ withinRegion(
276                    region: { shell: { inner: 2.0, outer: 5.0 } }
277                ) { items { r } totalScanned } }"#,
278            )
279            .await;
280
281        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
282        let data = res.data.into_json().unwrap();
283        let items = data["withinRegion"]["items"].as_array().unwrap();
284        assert_eq!(items.len(), 1);
285        let r = items[0]["r"].as_f64().unwrap();
286        assert!((r - 3.0).abs() < 1e-6);
287    }
288
289    #[tokio::test]
290    async fn test_empty_index_queries() {
291        let schema = create_schema_with_defaults();
292
293        let cone_res = schema
294            .execute(
295                r#"{ withinCone(
296                    cone: {
297                        apex: { r: 0.0, theta: 0.0, phi: 0.0 },
298                        axis: { r: 1.0, theta: 0.5, phi: 0.7854 },
299                        halfAngle: 1.0
300                    }
301                ) { items { r } totalScanned } }"#,
302            )
303            .await;
304        assert!(
305            cone_res.errors.is_empty(),
306            "cone errors: {:?}",
307            cone_res.errors
308        );
309        let data = cone_res.data.into_json().unwrap();
310        let items = data["withinCone"]["items"].as_array().unwrap();
311        assert!(items.is_empty());
312
313        let nearest_res = schema
314            .execute(
315                r#"{ nearestTo(
316                    point: { r: 1.0, theta: 0.5, phi: 0.7854 },
317                    k: 5
318                ) { point { r } distance } }"#,
319            )
320            .await;
321        assert!(
322            nearest_res.errors.is_empty(),
323            "nearest errors: {:?}",
324            nearest_res.errors
325        );
326        let data = nearest_res.data.into_json().unwrap();
327        let results = data["nearestTo"].as_array().unwrap();
328        assert!(results.is_empty());
329    }
330
331    #[tokio::test]
332    async fn test_schema_sdl_contains_expected_types() {
333        let schema = create_schema_with_defaults();
334        let sdl = schema.sdl();
335
336        let expected = [
337            "SphericalPointOutput",
338            "SpatialQueryResultOutput",
339            "NearestResultOutput",
340            "withinCone",
341            "withinShell",
342            "nearestTo",
343            "distanceBetween",
344            "SphericalPointInput",
345        ];
346
347        for token in &expected {
348            assert!(
349                sdl.contains(token),
350                "SDL missing expected token '{}'. SDL:\n{}",
351                token,
352                &sdl[..sdl.len().min(2000)],
353            );
354        }
355    }
356}