1use async_graphql::{Context, Object, Result};
22
23use sphereql_embed::pipeline::{SphereQLOutput, SphereQLQuery};
24use sphereql_embed::text_embedder::TextEmbedder;
25use sphereql_embed::types::Embedding;
26
27use crate::category_types::{
28 CategoryNearestResultOutput, CategoryPathOutput, CategoryStatsOutput, CategorySummaryOutput,
29 ConceptPathOutput, DomainGroupOutput, DrillDownOutput, InnerSphereReportOutput,
30};
31use crate::context::{CategoryPipelineHandle, EmbedderHandle};
32
33#[derive(Default)]
34pub struct CategoryQueryRoot;
35
36#[Object]
37impl CategoryQueryRoot {
38 async fn concept_path(
41 &self,
42 ctx: &Context<'_>,
43 source_id: String,
44 target_id: String,
45 graph_k: i32,
46 query_text: String,
47 ) -> Result<Option<ConceptPathOutput>> {
48 let embedding = embed_query_text(ctx, &query_text)?;
49 let pipeline = pipeline_handle(ctx)?;
50 let guard = pipeline.read().await;
51 let query = sphereql_embed::pipeline::PipelineQuery {
52 embedding: embedding.values,
53 };
54 let out = guard
55 .query(
56 SphereQLQuery::ConceptPath {
57 source_id: &source_id,
58 target_id: &target_id,
59 graph_k: graph_k.max(0) as usize,
60 },
61 &query,
62 )
63 .map_err(gql_err)?;
64 match out {
65 SphereQLOutput::ConceptPath(path) => Ok(path.as_ref().map(ConceptPathOutput::from)),
66 _ => Err(unexpected("ConceptPath")),
67 }
68 }
69
70 async fn category_concept_path(
72 &self,
73 ctx: &Context<'_>,
74 source_category: String,
75 target_category: String,
76 ) -> Result<Option<CategoryPathOutput>> {
77 let pipeline = pipeline_handle(ctx)?;
78 let guard = pipeline.read().await;
79 let query = zero_query(&guard);
83 let out = guard
84 .query(
85 SphereQLQuery::CategoryConceptPath {
86 source_category: &source_category,
87 target_category: &target_category,
88 },
89 &query,
90 )
91 .map_err(gql_err)?;
92 match out {
93 SphereQLOutput::CategoryConceptPath(path) => {
94 Ok(path.as_ref().map(CategoryPathOutput::from))
95 }
96 _ => Err(unexpected("CategoryConceptPath")),
97 }
98 }
99
100 async fn category_neighbors(
102 &self,
103 ctx: &Context<'_>,
104 category: String,
105 k: i32,
106 ) -> Result<Vec<CategorySummaryOutput>> {
107 let pipeline = pipeline_handle(ctx)?;
108 let guard = pipeline.read().await;
109 let query = zero_query(&guard);
110 let out = guard
111 .query(
112 SphereQLQuery::CategoryNeighbors {
113 category: &category,
114 k: k.max(0) as usize,
115 },
116 &query,
117 )
118 .map_err(gql_err)?;
119 match out {
120 SphereQLOutput::CategoryNeighbors(items) => {
121 Ok(items.iter().map(CategorySummaryOutput::from).collect())
122 }
123 _ => Err(unexpected("CategoryNeighbors")),
124 }
125 }
126
127 async fn drill_down(
130 &self,
131 ctx: &Context<'_>,
132 category: String,
133 query_text: String,
134 k: i32,
135 ) -> Result<Vec<DrillDownOutput>> {
136 let embedding = embed_query_text(ctx, &query_text)?;
137 let pipeline = pipeline_handle(ctx)?;
138 let guard = pipeline.read().await;
139 let query = sphereql_embed::pipeline::PipelineQuery {
140 embedding: embedding.values,
141 };
142 let out = guard
143 .query(
144 SphereQLQuery::DrillDown {
145 category: &category,
146 k: k.max(0) as usize,
147 },
148 &query,
149 )
150 .map_err(gql_err)?;
151 match out {
152 SphereQLOutput::DrillDown(items) => {
153 Ok(items.iter().map(DrillDownOutput::from).collect())
154 }
155 _ => Err(unexpected("DrillDown")),
156 }
157 }
158
159 async fn hierarchical_nearest(
163 &self,
164 ctx: &Context<'_>,
165 query_text: String,
166 k: i32,
167 ) -> Result<Vec<CategoryNearestResultOutput>> {
168 let embedding = embed_query_text(ctx, &query_text)?;
169 let pipeline = pipeline_handle(ctx)?;
170 let guard = pipeline.read().await;
171 let items = guard.hierarchical_nearest(&embedding, k.max(0) as usize);
172 Ok(items
173 .iter()
174 .map(CategoryNearestResultOutput::from)
175 .collect())
176 }
177
178 async fn category_stats(&self, ctx: &Context<'_>) -> Result<CategoryStatsOutput> {
180 let pipeline = pipeline_handle(ctx)?;
181 let guard = pipeline.read().await;
182 let query = zero_query(&guard);
183 let out = guard
184 .query(SphereQLQuery::CategoryStats, &query)
185 .map_err(gql_err)?;
186 match out {
187 SphereQLOutput::CategoryStats {
188 summaries,
189 inner_sphere_reports,
190 } => Ok(CategoryStatsOutput {
191 summaries: summaries.iter().map(CategorySummaryOutput::from).collect(),
192 inner_sphere_reports: inner_sphere_reports
193 .iter()
194 .map(InnerSphereReportOutput::from)
195 .collect(),
196 }),
197 _ => Err(unexpected("CategoryStats")),
198 }
199 }
200
201 async fn domain_groups(&self, ctx: &Context<'_>) -> Result<Vec<DomainGroupOutput>> {
203 let pipeline = pipeline_handle(ctx)?;
204 let guard = pipeline.read().await;
205 Ok(guard
206 .domain_groups()
207 .iter()
208 .map(DomainGroupOutput::from)
209 .collect())
210 }
211}
212
213fn pipeline_handle<'c>(ctx: &Context<'c>) -> Result<&'c CategoryPipelineHandle> {
216 ctx.data::<CategoryPipelineHandle>()
217 .map_err(|_| async_graphql::Error::new("SphereQLPipeline not found in context"))
218}
219
220fn embed_query_text(ctx: &Context<'_>, text: &str) -> Result<Embedding> {
221 let embedder = ctx
222 .data::<EmbedderHandle>()
223 .map_err(|_| async_graphql::Error::new("TextEmbedder not found in context"))?;
224 embedder
225 .embed(text)
226 .map_err(|e| async_graphql::Error::new(e.to_string()))
227}
228
229fn zero_query(
234 pipeline: &sphereql_embed::pipeline::SphereQLPipeline,
235) -> sphereql_embed::pipeline::PipelineQuery {
236 use sphereql_embed::projection::Projection;
237 let dim = pipeline.projection().dimensionality();
238 sphereql_embed::pipeline::PipelineQuery {
239 embedding: vec![0.0; dim],
240 }
241}
242
243fn gql_err<E: std::fmt::Display>(e: E) -> async_graphql::Error {
244 async_graphql::Error::new(e.to_string())
245}
246
247fn unexpected(expected: &str) -> async_graphql::Error {
248 async_graphql::Error::new(format!(
249 "unexpected SphereQLOutput variant (expected {expected})"
250 ))
251}
252
253#[cfg(test)]
256mod tests {
257 use super::*;
258 use std::sync::Arc;
259
260 use async_graphql::{EmptyMutation, EmptySubscription, Schema};
261 use sphereql_embed::text_embedder::{EmbedderError, FnEmbedder};
262 use sphereql_embed::types::Embedding;
263
264 use crate::context::{build_pipeline_handle_from_items, default_no_embedder_handle};
265 use crate::types::CategorizedItemInput;
266
267 fn items() -> Vec<CategorizedItemInput> {
268 let rows: &[(&str, &str, [f64; 4])] = &[
272 ("a1", "alpha", [1.0, 0.1, 0.0, 0.2]),
273 ("a2", "alpha", [0.9, 0.2, 0.0, 0.3]),
274 ("a3", "alpha", [1.0, 0.15, 0.05, 0.25]),
275 ("a4", "alpha", [0.95, 0.1, 0.05, 0.2]),
276 ("b1", "beta", [0.1, 1.0, 0.0, 0.2]),
277 ("b2", "beta", [0.2, 0.9, 0.1, 0.3]),
278 ("b3", "beta", [0.15, 0.95, 0.05, 0.25]),
279 ("b4", "beta", [0.05, 0.85, 0.05, 0.2]),
280 ("g1", "gamma", [0.3, 0.3, 0.9, 0.1]),
281 ("g2", "gamma", [0.25, 0.35, 0.85, 0.15]),
282 ("g3", "gamma", [0.35, 0.25, 0.9, 0.1]),
283 ("g4", "gamma", [0.3, 0.3, 0.95, 0.05]),
284 ];
285 rows.iter()
286 .map(|(id, cat, emb)| CategorizedItemInput {
287 id: (*id).into(),
288 category: (*cat).into(),
289 embedding: emb.to_vec(),
290 })
291 .collect()
292 }
293
294 fn test_embedder() -> EmbedderHandle {
295 Arc::new(FnEmbedder::new(|text: &str| {
300 let b = text.as_bytes();
301 let mut v = [0.0f64; 4];
302 for (i, slot) in v.iter_mut().enumerate() {
303 *slot = *b.get(i).unwrap_or(&1) as f64 / 128.0;
304 }
305 Ok::<_, EmbedderError>(Embedding::new(v.to_vec()))
306 }))
307 }
308
309 fn build_schema_for_tests(
310 embedder: EmbedderHandle,
311 ) -> Schema<CategoryQueryRoot, EmptyMutation, EmptySubscription> {
312 let pipeline = build_pipeline_handle_from_items(&items()).expect("pipeline build failed");
313 Schema::build(CategoryQueryRoot, EmptyMutation, EmptySubscription)
314 .data(pipeline)
315 .data(embedder)
316 .finish()
317 }
318
319 #[tokio::test]
320 async fn category_neighbors_returns_results() {
321 let schema = build_schema_for_tests(test_embedder());
322 let res = schema
323 .execute(r#"{ categoryNeighbors(category: "alpha", k: 2) { name memberCount } }"#)
324 .await;
325 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
326 let data = res.data.into_json().unwrap();
327 let arr = data["categoryNeighbors"].as_array().unwrap();
328 assert!(!arr.is_empty());
329 }
330
331 #[tokio::test]
332 async fn category_stats_returns_summaries() {
333 let schema = build_schema_for_tests(test_embedder());
334 let res = schema
335 .execute(r#"{ categoryStats { summaries { name memberCount } innerSphereReports { categoryName } } }"#)
336 .await;
337 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
338 let data = res.data.into_json().unwrap();
339 let summaries = data["categoryStats"]["summaries"].as_array().unwrap();
340 assert_eq!(summaries.len(), 3);
341 }
342
343 #[tokio::test]
344 async fn category_concept_path_finds_path() {
345 let schema = build_schema_for_tests(test_embedder());
346 let res = schema
347 .execute(
348 r#"{ categoryConceptPath(sourceCategory: "alpha", targetCategory: "gamma") { totalDistance steps { categoryName } } }"#,
349 )
350 .await;
351 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
352 let data = res.data.into_json().unwrap();
353 assert!(data["categoryConceptPath"].is_object());
354 }
355
356 #[tokio::test]
357 async fn drill_down_uses_embedder() {
358 let schema = build_schema_for_tests(test_embedder());
359 let res = schema
360 .execute(
361 r#"{ drillDown(category: "alpha", queryText: "alpha query", k: 3) { itemIndex distance usedInnerSphere } }"#,
362 )
363 .await;
364 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
365 let data = res.data.into_json().unwrap();
366 let arr = data["drillDown"].as_array().unwrap();
367 assert!(!arr.is_empty());
368 }
369
370 #[tokio::test]
371 async fn hierarchical_nearest_uses_embedder() {
372 let schema = build_schema_for_tests(test_embedder());
373 let res = schema
374 .execute(
375 r#"{ hierarchicalNearest(queryText: "something", k: 3) { id category distance } }"#,
376 )
377 .await;
378 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
379 let data = res.data.into_json().unwrap();
380 let arr = data["hierarchicalNearest"].as_array().unwrap();
381 assert!(!arr.is_empty());
382 }
383
384 #[tokio::test]
385 async fn domain_groups_returns_groups() {
386 let schema = build_schema_for_tests(test_embedder());
387 let res = schema
388 .execute(r#"{ domainGroups { categoryNames totalItems } }"#)
389 .await;
390 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
391 let data = res.data.into_json().unwrap();
392 assert!(data["domainGroups"].is_array());
393 }
394
395 #[tokio::test]
396 async fn concept_path_uses_embedder() {
397 let schema = build_schema_for_tests(test_embedder());
401 let res = schema
402 .execute(
403 r#"{ conceptPath(sourceId: "s-0000", targetId: "s-0008", graphK: 4, queryText: "bridge") { totalDistance steps { id category } } }"#,
404 )
405 .await;
406 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
409 }
410
411 #[tokio::test]
412 async fn text_query_without_embedder_errors_descriptively() {
413 let schema = build_schema_for_tests(default_no_embedder_handle());
414 let res = schema
415 .execute(r#"{ hierarchicalNearest(queryText: "whatever", k: 3) { id } }"#)
416 .await;
417 assert!(!res.errors.is_empty(), "expected an error");
418 let msg = res.errors[0].message.clone();
419 assert!(msg.contains("no TextEmbedder configured"), "got: {msg}");
420 }
421}