Skip to main content

sphereql_graphql/
category.rs

1//! GraphQL root for the category-enrichment surface.
2//!
3//! Seven resolvers cover the full category API exposed by the Rust
4//! `SphereQLPipeline`:
5//!
6//! - `conceptPath` — item-to-item path through the concept graph
7//! - `categoryConceptPath` — category-to-category path through the graph
8//! - `categoryNeighbors` — k-NN over categories
9//! - `drillDown` — within-category search using inner-sphere projection
10//!   when available
11//! - `hierarchicalNearest` — domain-group routing fallback for low-EVR
12//!   projections
13//! - `categoryStats` — summaries + inner-sphere reports
14//! - `domainGroups` — coarse category clusters from Voronoi geometry
15//!
16//! Resolvers that take a natural-language `queryText` embed it through
17//! the `EmbedderHandle` stored in context. Schemas that haven't been
18//! given a real embedder use [`crate::default_no_embedder_handle`],
19//! which returns a descriptive error when `embed()` is called.
20
21use async_graphql::{Context, Object, Result};
22
23use sphereql_embed::pipeline::{SphereQLOutput, SphereQLQuery};
24use sphereql_embed::text_embedder::TextEmbedder;
25use sphereql_embed::types::Embedding;
26
27use crate::category_types::{
28    CategoryNearestResultOutput, CategoryPathOutput, CategoryStatsOutput, CategorySummaryOutput,
29    ConceptPathOutput, DomainGroupOutput, DrillDownOutput, InnerSphereReportOutput,
30};
31use crate::context::{CategoryPipelineHandle, EmbedderHandle};
32
33#[derive(Default)]
34pub struct CategoryQueryRoot;
35
36#[Object]
37impl CategoryQueryRoot {
38    /// Shortest path between two indexed items through the concept graph,
39    /// anchored at the supplied `queryText` embedding.
40    async fn concept_path(
41        &self,
42        ctx: &Context<'_>,
43        source_id: String,
44        target_id: String,
45        graph_k: i32,
46        query_text: String,
47    ) -> Result<Option<ConceptPathOutput>> {
48        let embedding = embed_query_text(ctx, &query_text)?;
49        let pipeline = pipeline_handle(ctx)?;
50        let guard = pipeline.read().await;
51        let query = sphereql_embed::pipeline::PipelineQuery {
52            embedding: embedding.values,
53        };
54        let out = guard
55            .query(
56                SphereQLQuery::ConceptPath {
57                    source_id: &source_id,
58                    target_id: &target_id,
59                    graph_k: graph_k.max(0) as usize,
60                },
61                &query,
62            )
63            .map_err(gql_err)?;
64        match out {
65            SphereQLOutput::ConceptPath(path) => Ok(path.as_ref().map(ConceptPathOutput::from)),
66            _ => Err(unexpected("ConceptPath")),
67        }
68    }
69
70    /// Shortest path between two categories through the category-level graph.
71    async fn category_concept_path(
72        &self,
73        ctx: &Context<'_>,
74        source_category: String,
75        target_category: String,
76    ) -> Result<Option<CategoryPathOutput>> {
77        let pipeline = pipeline_handle(ctx)?;
78        let guard = pipeline.read().await;
79        // CategoryConceptPath doesn't consult the query embedding, but the
80        // pipeline's uniform `query` entry point requires one. Pass a
81        // dimensionality-matched zero vector.
82        let query = zero_query(&guard);
83        let out = guard
84            .query(
85                SphereQLQuery::CategoryConceptPath {
86                    source_category: &source_category,
87                    target_category: &target_category,
88                },
89                &query,
90            )
91            .map_err(gql_err)?;
92        match out {
93            SphereQLOutput::CategoryConceptPath(path) => {
94                Ok(path.as_ref().map(CategoryPathOutput::from))
95            }
96            _ => Err(unexpected("CategoryConceptPath")),
97        }
98    }
99
100    /// k nearest neighbor categories to the given category.
101    async fn category_neighbors(
102        &self,
103        ctx: &Context<'_>,
104        category: String,
105        k: i32,
106    ) -> Result<Vec<CategorySummaryOutput>> {
107        let pipeline = pipeline_handle(ctx)?;
108        let guard = pipeline.read().await;
109        let query = zero_query(&guard);
110        let out = guard
111            .query(
112                SphereQLQuery::CategoryNeighbors {
113                    category: &category,
114                    k: k.max(0) as usize,
115                },
116                &query,
117            )
118            .map_err(gql_err)?;
119        match out {
120            SphereQLOutput::CategoryNeighbors(items) => {
121                Ok(items.iter().map(CategorySummaryOutput::from).collect())
122            }
123            _ => Err(unexpected("CategoryNeighbors")),
124        }
125    }
126
127    /// k-NN within a single category, using the category's inner-sphere
128    /// projection when one exists.
129    async fn drill_down(
130        &self,
131        ctx: &Context<'_>,
132        category: String,
133        query_text: String,
134        k: i32,
135    ) -> Result<Vec<DrillDownOutput>> {
136        let embedding = embed_query_text(ctx, &query_text)?;
137        let pipeline = pipeline_handle(ctx)?;
138        let guard = pipeline.read().await;
139        let query = sphereql_embed::pipeline::PipelineQuery {
140            embedding: embedding.values,
141        };
142        let out = guard
143            .query(
144                SphereQLQuery::DrillDown {
145                    category: &category,
146                    k: k.max(0) as usize,
147                },
148                &query,
149            )
150            .map_err(gql_err)?;
151        match out {
152            SphereQLOutput::DrillDown(items) => {
153                Ok(items.iter().map(DrillDownOutput::from).collect())
154            }
155            _ => Err(unexpected("DrillDown")),
156        }
157    }
158
159    /// Domain-group-routed nearest-neighbor search. Falls back to a plain
160    /// outer-sphere k-NN when the projection EVR is above the configured
161    /// routing threshold.
162    async fn hierarchical_nearest(
163        &self,
164        ctx: &Context<'_>,
165        query_text: String,
166        k: i32,
167    ) -> Result<Vec<CategoryNearestResultOutput>> {
168        let embedding = embed_query_text(ctx, &query_text)?;
169        let pipeline = pipeline_handle(ctx)?;
170        let guard = pipeline.read().await;
171        let items = guard.hierarchical_nearest(&embedding, k.max(0) as usize);
172        Ok(items
173            .iter()
174            .map(CategoryNearestResultOutput::from)
175            .collect())
176    }
177
178    /// Per-category summaries + inner-sphere reports.
179    async fn category_stats(&self, ctx: &Context<'_>) -> Result<CategoryStatsOutput> {
180        let pipeline = pipeline_handle(ctx)?;
181        let guard = pipeline.read().await;
182        let query = zero_query(&guard);
183        let out = guard
184            .query(SphereQLQuery::CategoryStats, &query)
185            .map_err(gql_err)?;
186        match out {
187            SphereQLOutput::CategoryStats {
188                summaries,
189                inner_sphere_reports,
190            } => Ok(CategoryStatsOutput {
191                summaries: summaries.iter().map(CategorySummaryOutput::from).collect(),
192                inner_sphere_reports: inner_sphere_reports
193                    .iter()
194                    .map(InnerSphereReportOutput::from)
195                    .collect(),
196            }),
197            _ => Err(unexpected("CategoryStats")),
198        }
199    }
200
201    /// Coarse domain groups detected from category geometry.
202    async fn domain_groups(&self, ctx: &Context<'_>) -> Result<Vec<DomainGroupOutput>> {
203        let pipeline = pipeline_handle(ctx)?;
204        let guard = pipeline.read().await;
205        Ok(guard
206            .domain_groups()
207            .iter()
208            .map(DomainGroupOutput::from)
209            .collect())
210    }
211}
212
213// ── Helpers ────────────────────────────────────────────────────────────
214
215fn pipeline_handle<'c>(ctx: &Context<'c>) -> Result<&'c CategoryPipelineHandle> {
216    ctx.data::<CategoryPipelineHandle>()
217        .map_err(|_| async_graphql::Error::new("SphereQLPipeline not found in context"))
218}
219
220fn embed_query_text(ctx: &Context<'_>, text: &str) -> Result<Embedding> {
221    let embedder = ctx
222        .data::<EmbedderHandle>()
223        .map_err(|_| async_graphql::Error::new("TextEmbedder not found in context"))?;
224    embedder
225        .embed(text)
226        .map_err(|e| async_graphql::Error::new(e.to_string()))
227}
228
229/// Build a dimensionality-matched zero query for resolvers that dispatch
230/// category-only `SphereQLQuery` variants (those variants never touch the
231/// embedding, but the pipeline's uniform entry point still requires one
232/// of matching dimension).
233fn zero_query(
234    pipeline: &sphereql_embed::pipeline::SphereQLPipeline,
235) -> sphereql_embed::pipeline::PipelineQuery {
236    use sphereql_embed::projection::Projection;
237    let dim = pipeline.projection().dimensionality();
238    sphereql_embed::pipeline::PipelineQuery {
239        embedding: vec![0.0; dim],
240    }
241}
242
243fn gql_err<E: std::fmt::Display>(e: E) -> async_graphql::Error {
244    async_graphql::Error::new(e.to_string())
245}
246
247fn unexpected(expected: &str) -> async_graphql::Error {
248    async_graphql::Error::new(format!(
249        "unexpected SphereQLOutput variant (expected {expected})"
250    ))
251}
252
253// ── Tests ──────────────────────────────────────────────────────────────
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use std::sync::Arc;
259
260    use async_graphql::{EmptyMutation, EmptySubscription, Schema};
261    use sphereql_embed::text_embedder::{EmbedderError, FnEmbedder};
262    use sphereql_embed::types::Embedding;
263
264    use crate::context::{build_pipeline_handle_from_items, default_no_embedder_handle};
265    use crate::types::CategorizedItemInput;
266
267    fn items() -> Vec<CategorizedItemInput> {
268        // 12 items, 4-dim embeddings, 3 categories with 4 each — enough
269        // for category neighbors, stats, and domain groups to have
270        // content without triggering inner-sphere thresholds.
271        let rows: &[(&str, &str, [f64; 4])] = &[
272            ("a1", "alpha", [1.0, 0.1, 0.0, 0.2]),
273            ("a2", "alpha", [0.9, 0.2, 0.0, 0.3]),
274            ("a3", "alpha", [1.0, 0.15, 0.05, 0.25]),
275            ("a4", "alpha", [0.95, 0.1, 0.05, 0.2]),
276            ("b1", "beta", [0.1, 1.0, 0.0, 0.2]),
277            ("b2", "beta", [0.2, 0.9, 0.1, 0.3]),
278            ("b3", "beta", [0.15, 0.95, 0.05, 0.25]),
279            ("b4", "beta", [0.05, 0.85, 0.05, 0.2]),
280            ("g1", "gamma", [0.3, 0.3, 0.9, 0.1]),
281            ("g2", "gamma", [0.25, 0.35, 0.85, 0.15]),
282            ("g3", "gamma", [0.35, 0.25, 0.9, 0.1]),
283            ("g4", "gamma", [0.3, 0.3, 0.95, 0.05]),
284        ];
285        rows.iter()
286            .map(|(id, cat, emb)| CategorizedItemInput {
287                id: (*id).into(),
288                category: (*cat).into(),
289                embedding: emb.to_vec(),
290            })
291            .collect()
292    }
293
294    fn test_embedder() -> EmbedderHandle {
295        // Deterministic closure: pad-or-truncate the text's bytes to a 4-d
296        // vector. Enough to exercise the resolver plumbing; the embedding
297        // correlates weakly with corpus items so queries produce real
298        // results without being hand-tuned.
299        Arc::new(FnEmbedder::new(|text: &str| {
300            let b = text.as_bytes();
301            let mut v = [0.0f64; 4];
302            for (i, slot) in v.iter_mut().enumerate() {
303                *slot = *b.get(i).unwrap_or(&1) as f64 / 128.0;
304            }
305            Ok::<_, EmbedderError>(Embedding::new(v.to_vec()))
306        }))
307    }
308
309    fn build_schema_for_tests(
310        embedder: EmbedderHandle,
311    ) -> Schema<CategoryQueryRoot, EmptyMutation, EmptySubscription> {
312        let pipeline = build_pipeline_handle_from_items(&items()).expect("pipeline build failed");
313        Schema::build(CategoryQueryRoot, EmptyMutation, EmptySubscription)
314            .data(pipeline)
315            .data(embedder)
316            .finish()
317    }
318
319    #[tokio::test]
320    async fn category_neighbors_returns_results() {
321        let schema = build_schema_for_tests(test_embedder());
322        let res = schema
323            .execute(r#"{ categoryNeighbors(category: "alpha", k: 2) { name memberCount } }"#)
324            .await;
325        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
326        let data = res.data.into_json().unwrap();
327        let arr = data["categoryNeighbors"].as_array().unwrap();
328        assert!(!arr.is_empty());
329    }
330
331    #[tokio::test]
332    async fn category_stats_returns_summaries() {
333        let schema = build_schema_for_tests(test_embedder());
334        let res = schema
335            .execute(r#"{ categoryStats { summaries { name memberCount } innerSphereReports { categoryName } } }"#)
336            .await;
337        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
338        let data = res.data.into_json().unwrap();
339        let summaries = data["categoryStats"]["summaries"].as_array().unwrap();
340        assert_eq!(summaries.len(), 3);
341    }
342
343    #[tokio::test]
344    async fn category_concept_path_finds_path() {
345        let schema = build_schema_for_tests(test_embedder());
346        let res = schema
347            .execute(
348                r#"{ categoryConceptPath(sourceCategory: "alpha", targetCategory: "gamma") { totalDistance steps { categoryName } } }"#,
349            )
350            .await;
351        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
352        let data = res.data.into_json().unwrap();
353        assert!(data["categoryConceptPath"].is_object());
354    }
355
356    #[tokio::test]
357    async fn drill_down_uses_embedder() {
358        let schema = build_schema_for_tests(test_embedder());
359        let res = schema
360            .execute(
361                r#"{ drillDown(category: "alpha", queryText: "alpha query", k: 3) { itemIndex distance usedInnerSphere } }"#,
362            )
363            .await;
364        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
365        let data = res.data.into_json().unwrap();
366        let arr = data["drillDown"].as_array().unwrap();
367        assert!(!arr.is_empty());
368    }
369
370    #[tokio::test]
371    async fn hierarchical_nearest_uses_embedder() {
372        let schema = build_schema_for_tests(test_embedder());
373        let res = schema
374            .execute(
375                r#"{ hierarchicalNearest(queryText: "something", k: 3) { id category distance } }"#,
376            )
377            .await;
378        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
379        let data = res.data.into_json().unwrap();
380        let arr = data["hierarchicalNearest"].as_array().unwrap();
381        assert!(!arr.is_empty());
382    }
383
384    #[tokio::test]
385    async fn domain_groups_returns_groups() {
386        let schema = build_schema_for_tests(test_embedder());
387        let res = schema
388            .execute(r#"{ domainGroups { categoryNames totalItems } }"#)
389            .await;
390        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
391        let data = res.data.into_json().unwrap();
392        assert!(data["domainGroups"].is_array());
393    }
394
395    #[tokio::test]
396    async fn concept_path_uses_embedder() {
397        // Pipeline assigns its own ids of the form `s-NNNN` (see
398        // `items_to_pipeline_input` docs) — use those, not the values
399        // passed on `CategorizedItemInput.id`.
400        let schema = build_schema_for_tests(test_embedder());
401        let res = schema
402            .execute(
403                r#"{ conceptPath(sourceId: "s-0000", targetId: "s-0008", graphK: 4, queryText: "bridge") { totalDistance steps { id category } } }"#,
404            )
405            .await;
406        // ConceptPath may return null if no path is found — both shapes
407        // are schema-valid. Just assert no error.
408        assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
409    }
410
411    #[tokio::test]
412    async fn text_query_without_embedder_errors_descriptively() {
413        let schema = build_schema_for_tests(default_no_embedder_handle());
414        let res = schema
415            .execute(r#"{ hierarchicalNearest(queryText: "whatever", k: 3) { id } }"#)
416            .await;
417        assert!(!res.errors.is_empty(), "expected an error");
418        let msg = res.errors[0].message.clone();
419        assert!(msg.contains("no TextEmbedder configured"), "got: {msg}");
420    }
421}