Skip to main content

sphereql_graphql/
category.rs

1//! GraphQL root for the category-enrichment surface.
2//!
3//! Seven resolvers cover the full category API exposed by the Rust
4//! `SphereQLPipeline`:
5//!
6//! - `conceptPath` — item-to-item path through the concept graph
7//! - `categoryConceptPath` — category-to-category path through the graph
8//! - `categoryNeighbors` — k-NN over categories
9//! - `drillDown` — within-category search using inner-sphere projection
10//!   when available
11//! - `hierarchicalNearest` — domain-group routing fallback for low-EVR
12//!   projections
13//! - `categoryStats` — summaries + inner-sphere reports
14//! - `domainGroups` — coarse category clusters from Voronoi geometry
15//!
16//! Resolvers that take a natural-language `queryText` embed it through
17//! the `EmbedderHandle` stored in context. Schemas that haven't been
18//! given a real embedder use [`crate::default_no_embedder_handle`],
19//! which returns a descriptive error when `embed()` is called.
20
21use async_graphql::{Context, Object, Result};
22
23use sphereql_embed::pipeline::{SphereQLOutput, SphereQLQuery};
24use sphereql_embed::text_embedder::TextEmbedder;
25use sphereql_embed::types::Embedding;
26
27use crate::category_types::{
28    CategoryNearestResultOutput, CategoryPathOutput, CategoryStatsOutput, CategorySummaryOutput,
29    ConceptPathOutput, DomainGroupOutput, DrillDownOutput,
30};
31use crate::context::{CategoryPipelineHandle, EmbedderHandle};
32
33pub struct CategoryQueryRoot;
34
35#[Object]
36impl CategoryQueryRoot {
37    /// Shortest path between two indexed items through the concept graph,
38    /// anchored at the supplied `queryText` embedding.
39    async fn concept_path(
40        &self,
41        ctx: &Context<'_>,
42        source_id: String,
43        target_id: String,
44        graph_k: i32,
45        query_text: String,
46    ) -> Result<Option<ConceptPathOutput>> {
47        let embedding = embed_query_text(ctx, &query_text)?;
48        let pipeline = pipeline_handle(ctx)?;
49        let guard = pipeline.read().await;
50        let query = sphereql_embed::pipeline::PipelineQuery {
51            embedding: embedding.values,
52        };
53        let out = guard
54            .query(
55                SphereQLQuery::ConceptPath {
56                    source_id: &source_id,
57                    target_id: &target_id,
58                    graph_k: graph_k.max(0) as usize,
59                },
60                &query,
61            )
62            .map_err(gql_err)?;
63        match out {
64            SphereQLOutput::ConceptPath(path) => Ok(path.as_ref().map(ConceptPathOutput::from)),
65            _ => Err(unexpected("ConceptPath")),
66        }
67    }
68
69    /// Shortest path between two categories through the category-level graph.
70    async fn category_concept_path(
71        &self,
72        ctx: &Context<'_>,
73        source_category: String,
74        target_category: String,
75    ) -> Result<Option<CategoryPathOutput>> {
76        let pipeline = pipeline_handle(ctx)?;
77        let guard = pipeline.read().await;
78        // CategoryConceptPath doesn't consult the query embedding, but the
79        // pipeline's uniform `query` entry point requires one. Pass a
80        // dimensionality-matched zero vector.
81        let query = zero_query(&guard);
82        let out = guard
83            .query(
84                SphereQLQuery::CategoryConceptPath {
85                    source_category: &source_category,
86                    target_category: &target_category,
87                },
88                &query,
89            )
90            .map_err(gql_err)?;
91        match out {
92            SphereQLOutput::CategoryConceptPath(path) => {
93                Ok(path.as_ref().map(CategoryPathOutput::from))
94            }
95            _ => Err(unexpected("CategoryConceptPath")),
96        }
97    }
98
99    /// k nearest neighbor categories to the given category.
100    async fn category_neighbors(
101        &self,
102        ctx: &Context<'_>,
103        category: String,
104        k: i32,
105    ) -> Result<Vec<CategorySummaryOutput>> {
106        let pipeline = pipeline_handle(ctx)?;
107        let guard = pipeline.read().await;
108        let query = zero_query(&guard);
109        let out = guard
110            .query(
111                SphereQLQuery::CategoryNeighbors {
112                    category: &category,
113                    k: k.max(0) as usize,
114                },
115                &query,
116            )
117            .map_err(gql_err)?;
118        match out {
119            SphereQLOutput::CategoryNeighbors(items) => {
120                Ok(items.iter().map(CategorySummaryOutput::from).collect())
121            }
122            _ => Err(unexpected("CategoryNeighbors")),
123        }
124    }
125
126    /// k-NN within a single category, using the category's inner-sphere
127    /// projection when one exists.
128    async fn drill_down(
129        &self,
130        ctx: &Context<'_>,
131        category: String,
132        query_text: String,
133        k: i32,
134    ) -> Result<Vec<DrillDownOutput>> {
135        let embedding = embed_query_text(ctx, &query_text)?;
136        let pipeline = pipeline_handle(ctx)?;
137        let guard = pipeline.read().await;
138        let query = sphereql_embed::pipeline::PipelineQuery {
139            embedding: embedding.values,
140        };
141        let out = guard
142            .query(
143                SphereQLQuery::DrillDown {
144                    category: &category,
145                    k: k.max(0) as usize,
146                },
147                &query,
148            )
149            .map_err(gql_err)?;
150        match out {
151            SphereQLOutput::DrillDown(items) => {
152                Ok(items.iter().map(DrillDownOutput::from).collect())
153            }
154            _ => Err(unexpected("DrillDown")),
155        }
156    }
157
158    /// Domain-group-routed nearest-neighbor search. Falls back to a plain
159    /// outer-sphere k-NN when the projection EVR is above the configured
160    /// routing threshold.
161    async fn hierarchical_nearest(
162        &self,
163        ctx: &Context<'_>,
164        query_text: String,
165        k: i32,
166    ) -> Result<Vec<CategoryNearestResultOutput>> {
167        let embedding = embed_query_text(ctx, &query_text)?;
168        let pipeline = pipeline_handle(ctx)?;
169        let guard = pipeline.read().await;
170        let items = guard.hierarchical_nearest(&embedding, k.max(0) as usize);
171        Ok(items
172            .iter()
173            .map(CategoryNearestResultOutput::from)
174            .collect())
175    }
176
177    /// Per-category summaries + inner-sphere reports.
178    async fn category_stats(&self, ctx: &Context<'_>) -> Result<CategoryStatsOutput> {
179        let pipeline = pipeline_handle(ctx)?;
180        let guard = pipeline.read().await;
181        let query = zero_query(&guard);
182        let out = guard
183            .query(SphereQLQuery::CategoryStats, &query)
184            .map_err(gql_err)?;
185        match out {
186            SphereQLOutput::CategoryStats {
187                summaries,
188                inner_sphere_reports,
189            } => Ok(CategoryStatsOutput {
190                summaries: summaries.iter().map(CategorySummaryOutput::from).collect(),
191                inner_sphere_reports: inner_sphere_reports
192                    .iter()
193                    .map(crate::category_types::InnerSphereReportOutput::from)
194                    .collect(),
195            }),
196            _ => Err(unexpected("CategoryStats")),
197        }
198    }
199
200    /// Coarse domain groups detected from category geometry.
201    async fn domain_groups(&self, ctx: &Context<'_>) -> Result<Vec<DomainGroupOutput>> {
202        let pipeline = pipeline_handle(ctx)?;
203        let guard = pipeline.read().await;
204        Ok(guard
205            .domain_groups()
206            .iter()
207            .map(DomainGroupOutput::from)
208            .collect())
209    }
210}
211
212// ── Helpers ────────────────────────────────────────────────────────────
213
214fn pipeline_handle<'c>(ctx: &Context<'c>) -> Result<&'c CategoryPipelineHandle> {
215    ctx.data::<CategoryPipelineHandle>()
216        .map_err(|_| async_graphql::Error::new("SphereQLPipeline not found in context"))
217}
218
219fn embed_query_text(ctx: &Context<'_>, text: &str) -> Result<Embedding> {
220    let embedder = ctx
221        .data::<EmbedderHandle>()
222        .map_err(|_| async_graphql::Error::new("TextEmbedder not found in context"))?;
223    embedder
224        .embed(text)
225        .map_err(|e| async_graphql::Error::new(e.to_string()))
226}
227
228/// Build a dimensionality-matched zero query for resolvers that dispatch
229/// category-only `SphereQLQuery` variants (those variants never touch the
230/// embedding, but the pipeline's uniform entry point still requires one
231/// of matching dimension).
232fn zero_query(
233    pipeline: &sphereql_embed::pipeline::SphereQLPipeline,
234) -> sphereql_embed::pipeline::PipelineQuery {
235    use sphereql_embed::projection::Projection;
236    let dim = pipeline.projection().dimensionality();
237    sphereql_embed::pipeline::PipelineQuery {
238        embedding: vec![0.0; dim],
239    }
240}
241
242fn gql_err<E: std::fmt::Display>(e: E) -> async_graphql::Error {
243    async_graphql::Error::new(e.to_string())
244}
245
246fn unexpected(expected: &str) -> async_graphql::Error {
247    async_graphql::Error::new(format!(
248        "unexpected SphereQLOutput variant (expected {expected})"
249    ))
250}
251
252// ── Tests ──────────────────────────────────────────────────────────────
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use std::sync::Arc;
258
259    use async_graphql::{EmptyMutation, EmptySubscription, Schema};
260    use sphereql_embed::text_embedder::{EmbedderError, FnEmbedder};
261    use sphereql_embed::types::Embedding;
262
263    use crate::context::{build_pipeline_handle_from_items, default_no_embedder_handle};
264    use crate::types::CategorizedItemInput;
265
266    fn items() -> Vec<CategorizedItemInput> {
267        // 12 items, 4-dim embeddings, 3 categories with 4 each — enough
268        // for category neighbors, stats, and domain groups to have
269        // content without triggering inner-sphere thresholds.
270        let rows: &[(&str, &str, [f64; 4])] = &[
271            ("a1", "alpha", [1.0, 0.1, 0.0, 0.2]),
272            ("a2", "alpha", [0.9, 0.2, 0.0, 0.3]),
273            ("a3", "alpha", [1.0, 0.15, 0.05, 0.25]),
274            ("a4", "alpha", [0.95, 0.1, 0.05, 0.2]),
275            ("b1", "beta", [0.1, 1.0, 0.0, 0.2]),
276            ("b2", "beta", [0.2, 0.9, 0.1, 0.3]),
277            ("b3", "beta", [0.15, 0.95, 0.05, 0.25]),
278            ("b4", "beta", [0.05, 0.85, 0.05, 0.2]),
279            ("g1", "gamma", [0.3, 0.3, 0.9, 0.1]),
280            ("g2", "gamma", [0.25, 0.35, 0.85, 0.15]),
281            ("g3", "gamma", [0.35, 0.25, 0.9, 0.1]),
282            ("g4", "gamma", [0.3, 0.3, 0.95, 0.05]),
283        ];
284        rows.iter()
285            .map(|(id, cat, emb)| CategorizedItemInput {
286                id: (*id).into(),
287                category: (*cat).into(),
288                embedding: emb.to_vec(),
289            })
290            .collect()
291    }
292
293    fn test_embedder() -> EmbedderHandle {
294        // Deterministic closure: pad-or-truncate the text's bytes to a 4-d
295        // vector. Enough to exercise the resolver plumbing; the embedding
296        // correlates weakly with corpus items so queries produce real
297        // results without being hand-tuned.
298        Arc::new(FnEmbedder::new(|text: &str| {
299            let b = text.as_bytes();
300            let mut v = [0.0f64; 4];
301            for (i, slot) in v.iter_mut().enumerate() {
302                *slot = *b.get(i).unwrap_or(&1) as f64 / 128.0;
303            }
304            Ok::<_, EmbedderError>(Embedding::new(v.to_vec()))
305        }))
306    }
307
308    fn build_schema_for_tests(
309        embedder: EmbedderHandle,
310    ) -> Schema<CategoryQueryRoot, EmptyMutation, EmptySubscription> {
311        let pipeline = build_pipeline_handle_from_items(&items()).expect("pipeline build failed");
312        Schema::build(CategoryQueryRoot, EmptyMutation, EmptySubscription)
313            .data(pipeline)
314            .data(embedder)
315            .finish()
316    }
317
318    #[tokio::test]
319    async fn category_neighbors_returns_results() {
320        let schema = build_schema_for_tests(test_embedder());
321        let res = schema
322            .execute(r#"{ categoryNeighbors(category: "alpha", k: 2) { name memberCount } }"#)
323            .await;
324        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
325        let data = res.data.into_json().unwrap();
326        let arr = data["categoryNeighbors"].as_array().unwrap();
327        assert!(!arr.is_empty());
328    }
329
330    #[tokio::test]
331    async fn category_stats_returns_summaries() {
332        let schema = build_schema_for_tests(test_embedder());
333        let res = schema
334            .execute(r#"{ categoryStats { summaries { name memberCount } innerSphereReports { categoryName } } }"#)
335            .await;
336        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
337        let data = res.data.into_json().unwrap();
338        let summaries = data["categoryStats"]["summaries"].as_array().unwrap();
339        assert_eq!(summaries.len(), 3);
340    }
341
342    #[tokio::test]
343    async fn category_concept_path_finds_path() {
344        let schema = build_schema_for_tests(test_embedder());
345        let res = schema
346            .execute(
347                r#"{ categoryConceptPath(sourceCategory: "alpha", targetCategory: "gamma") { totalDistance steps { categoryName } } }"#,
348            )
349            .await;
350        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
351        let data = res.data.into_json().unwrap();
352        assert!(data["categoryConceptPath"].is_object());
353    }
354
355    #[tokio::test]
356    async fn drill_down_uses_embedder() {
357        let schema = build_schema_for_tests(test_embedder());
358        let res = schema
359            .execute(
360                r#"{ drillDown(category: "alpha", queryText: "alpha query", k: 3) { itemIndex distance usedInnerSphere } }"#,
361            )
362            .await;
363        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
364        let data = res.data.into_json().unwrap();
365        let arr = data["drillDown"].as_array().unwrap();
366        assert!(!arr.is_empty());
367    }
368
369    #[tokio::test]
370    async fn hierarchical_nearest_uses_embedder() {
371        let schema = build_schema_for_tests(test_embedder());
372        let res = schema
373            .execute(
374                r#"{ hierarchicalNearest(queryText: "something", k: 3) { id category distance } }"#,
375            )
376            .await;
377        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
378        let data = res.data.into_json().unwrap();
379        let arr = data["hierarchicalNearest"].as_array().unwrap();
380        assert!(!arr.is_empty());
381    }
382
383    #[tokio::test]
384    async fn domain_groups_returns_groups() {
385        let schema = build_schema_for_tests(test_embedder());
386        let res = schema
387            .execute(r#"{ domainGroups { categoryNames totalItems } }"#)
388            .await;
389        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
390        let data = res.data.into_json().unwrap();
391        assert!(data["domainGroups"].is_array());
392    }
393
394    #[tokio::test]
395    async fn concept_path_uses_embedder() {
396        // Pipeline assigns its own ids of the form `s-NNNN` (see
397        // `items_to_pipeline_input` docs) — use those, not the values
398        // passed on `CategorizedItemInput.id`.
399        let schema = build_schema_for_tests(test_embedder());
400        let res = schema
401            .execute(
402                r#"{ conceptPath(sourceId: "s-0000", targetId: "s-0008", graphK: 4, queryText: "bridge") { totalDistance steps { id category } } }"#,
403            )
404            .await;
405        // ConceptPath may return null if no path is found — both shapes
406        // are schema-valid. Just assert no error.
407        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
408    }
409
410    #[tokio::test]
411    async fn text_query_without_embedder_errors_descriptively() {
412        let schema = build_schema_for_tests(default_no_embedder_handle());
413        let res = schema
414            .execute(r#"{ hierarchicalNearest(queryText: "whatever", k: 3) { id } }"#)
415            .await;
416        assert!(!res.errors.is_empty(), "expected an error");
417        let msg = res.errors[0].message.clone();
418        assert!(msg.contains("no TextEmbedder configured"), "got: {msg}");
419    }
420}