1use async_graphql::{Context, Object, Result};
22
23use sphereql_embed::pipeline::{SphereQLOutput, SphereQLQuery};
24use sphereql_embed::text_embedder::TextEmbedder;
25use sphereql_embed::types::Embedding;
26
27use crate::category_types::{
28 CategoryNearestResultOutput, CategoryPathOutput, CategoryStatsOutput, CategorySummaryOutput,
29 ConceptPathOutput, DomainGroupOutput, DrillDownOutput,
30};
31use crate::context::{CategoryPipelineHandle, EmbedderHandle};
32
33pub struct CategoryQueryRoot;
34
35#[Object]
36impl CategoryQueryRoot {
37 async fn concept_path(
40 &self,
41 ctx: &Context<'_>,
42 source_id: String,
43 target_id: String,
44 graph_k: i32,
45 query_text: String,
46 ) -> Result<Option<ConceptPathOutput>> {
47 let embedding = embed_query_text(ctx, &query_text)?;
48 let pipeline = pipeline_handle(ctx)?;
49 let guard = pipeline.read().await;
50 let query = sphereql_embed::pipeline::PipelineQuery {
51 embedding: embedding.values,
52 };
53 let out = guard
54 .query(
55 SphereQLQuery::ConceptPath {
56 source_id: &source_id,
57 target_id: &target_id,
58 graph_k: graph_k.max(0) as usize,
59 },
60 &query,
61 )
62 .map_err(gql_err)?;
63 match out {
64 SphereQLOutput::ConceptPath(path) => Ok(path.as_ref().map(ConceptPathOutput::from)),
65 _ => Err(unexpected("ConceptPath")),
66 }
67 }
68
69 async fn category_concept_path(
71 &self,
72 ctx: &Context<'_>,
73 source_category: String,
74 target_category: String,
75 ) -> Result<Option<CategoryPathOutput>> {
76 let pipeline = pipeline_handle(ctx)?;
77 let guard = pipeline.read().await;
78 let query = zero_query(&guard);
82 let out = guard
83 .query(
84 SphereQLQuery::CategoryConceptPath {
85 source_category: &source_category,
86 target_category: &target_category,
87 },
88 &query,
89 )
90 .map_err(gql_err)?;
91 match out {
92 SphereQLOutput::CategoryConceptPath(path) => {
93 Ok(path.as_ref().map(CategoryPathOutput::from))
94 }
95 _ => Err(unexpected("CategoryConceptPath")),
96 }
97 }
98
99 async fn category_neighbors(
101 &self,
102 ctx: &Context<'_>,
103 category: String,
104 k: i32,
105 ) -> Result<Vec<CategorySummaryOutput>> {
106 let pipeline = pipeline_handle(ctx)?;
107 let guard = pipeline.read().await;
108 let query = zero_query(&guard);
109 let out = guard
110 .query(
111 SphereQLQuery::CategoryNeighbors {
112 category: &category,
113 k: k.max(0) as usize,
114 },
115 &query,
116 )
117 .map_err(gql_err)?;
118 match out {
119 SphereQLOutput::CategoryNeighbors(items) => {
120 Ok(items.iter().map(CategorySummaryOutput::from).collect())
121 }
122 _ => Err(unexpected("CategoryNeighbors")),
123 }
124 }
125
126 async fn drill_down(
129 &self,
130 ctx: &Context<'_>,
131 category: String,
132 query_text: String,
133 k: i32,
134 ) -> Result<Vec<DrillDownOutput>> {
135 let embedding = embed_query_text(ctx, &query_text)?;
136 let pipeline = pipeline_handle(ctx)?;
137 let guard = pipeline.read().await;
138 let query = sphereql_embed::pipeline::PipelineQuery {
139 embedding: embedding.values,
140 };
141 let out = guard
142 .query(
143 SphereQLQuery::DrillDown {
144 category: &category,
145 k: k.max(0) as usize,
146 },
147 &query,
148 )
149 .map_err(gql_err)?;
150 match out {
151 SphereQLOutput::DrillDown(items) => {
152 Ok(items.iter().map(DrillDownOutput::from).collect())
153 }
154 _ => Err(unexpected("DrillDown")),
155 }
156 }
157
158 async fn hierarchical_nearest(
162 &self,
163 ctx: &Context<'_>,
164 query_text: String,
165 k: i32,
166 ) -> Result<Vec<CategoryNearestResultOutput>> {
167 let embedding = embed_query_text(ctx, &query_text)?;
168 let pipeline = pipeline_handle(ctx)?;
169 let guard = pipeline.read().await;
170 let items = guard.hierarchical_nearest(&embedding, k.max(0) as usize);
171 Ok(items
172 .iter()
173 .map(CategoryNearestResultOutput::from)
174 .collect())
175 }
176
177 async fn category_stats(&self, ctx: &Context<'_>) -> Result<CategoryStatsOutput> {
179 let pipeline = pipeline_handle(ctx)?;
180 let guard = pipeline.read().await;
181 let query = zero_query(&guard);
182 let out = guard
183 .query(SphereQLQuery::CategoryStats, &query)
184 .map_err(gql_err)?;
185 match out {
186 SphereQLOutput::CategoryStats {
187 summaries,
188 inner_sphere_reports,
189 } => Ok(CategoryStatsOutput {
190 summaries: summaries.iter().map(CategorySummaryOutput::from).collect(),
191 inner_sphere_reports: inner_sphere_reports
192 .iter()
193 .map(crate::category_types::InnerSphereReportOutput::from)
194 .collect(),
195 }),
196 _ => Err(unexpected("CategoryStats")),
197 }
198 }
199
200 async fn domain_groups(&self, ctx: &Context<'_>) -> Result<Vec<DomainGroupOutput>> {
202 let pipeline = pipeline_handle(ctx)?;
203 let guard = pipeline.read().await;
204 Ok(guard
205 .domain_groups()
206 .iter()
207 .map(DomainGroupOutput::from)
208 .collect())
209 }
210}
211
212fn pipeline_handle<'c>(ctx: &Context<'c>) -> Result<&'c CategoryPipelineHandle> {
215 ctx.data::<CategoryPipelineHandle>()
216 .map_err(|_| async_graphql::Error::new("SphereQLPipeline not found in context"))
217}
218
219fn embed_query_text(ctx: &Context<'_>, text: &str) -> Result<Embedding> {
220 let embedder = ctx
221 .data::<EmbedderHandle>()
222 .map_err(|_| async_graphql::Error::new("TextEmbedder not found in context"))?;
223 embedder
224 .embed(text)
225 .map_err(|e| async_graphql::Error::new(e.to_string()))
226}
227
228fn zero_query(
233 pipeline: &sphereql_embed::pipeline::SphereQLPipeline,
234) -> sphereql_embed::pipeline::PipelineQuery {
235 use sphereql_embed::projection::Projection;
236 let dim = pipeline.projection().dimensionality();
237 sphereql_embed::pipeline::PipelineQuery {
238 embedding: vec![0.0; dim],
239 }
240}
241
242fn gql_err<E: std::fmt::Display>(e: E) -> async_graphql::Error {
243 async_graphql::Error::new(e.to_string())
244}
245
246fn unexpected(expected: &str) -> async_graphql::Error {
247 async_graphql::Error::new(format!(
248 "unexpected SphereQLOutput variant (expected {expected})"
249 ))
250}
251
252#[cfg(test)]
255mod tests {
256 use super::*;
257 use std::sync::Arc;
258
259 use async_graphql::{EmptyMutation, EmptySubscription, Schema};
260 use sphereql_embed::text_embedder::{EmbedderError, FnEmbedder};
261 use sphereql_embed::types::Embedding;
262
263 use crate::context::{build_pipeline_handle_from_items, default_no_embedder_handle};
264 use crate::types::CategorizedItemInput;
265
266 fn items() -> Vec<CategorizedItemInput> {
267 let rows: &[(&str, &str, [f64; 4])] = &[
271 ("a1", "alpha", [1.0, 0.1, 0.0, 0.2]),
272 ("a2", "alpha", [0.9, 0.2, 0.0, 0.3]),
273 ("a3", "alpha", [1.0, 0.15, 0.05, 0.25]),
274 ("a4", "alpha", [0.95, 0.1, 0.05, 0.2]),
275 ("b1", "beta", [0.1, 1.0, 0.0, 0.2]),
276 ("b2", "beta", [0.2, 0.9, 0.1, 0.3]),
277 ("b3", "beta", [0.15, 0.95, 0.05, 0.25]),
278 ("b4", "beta", [0.05, 0.85, 0.05, 0.2]),
279 ("g1", "gamma", [0.3, 0.3, 0.9, 0.1]),
280 ("g2", "gamma", [0.25, 0.35, 0.85, 0.15]),
281 ("g3", "gamma", [0.35, 0.25, 0.9, 0.1]),
282 ("g4", "gamma", [0.3, 0.3, 0.95, 0.05]),
283 ];
284 rows.iter()
285 .map(|(id, cat, emb)| CategorizedItemInput {
286 id: (*id).into(),
287 category: (*cat).into(),
288 embedding: emb.to_vec(),
289 })
290 .collect()
291 }
292
293 fn test_embedder() -> EmbedderHandle {
294 Arc::new(FnEmbedder::new(|text: &str| {
299 let b = text.as_bytes();
300 let mut v = [0.0f64; 4];
301 for (i, slot) in v.iter_mut().enumerate() {
302 *slot = *b.get(i).unwrap_or(&1) as f64 / 128.0;
303 }
304 Ok::<_, EmbedderError>(Embedding::new(v.to_vec()))
305 }))
306 }
307
308 fn build_schema_for_tests(
309 embedder: EmbedderHandle,
310 ) -> Schema<CategoryQueryRoot, EmptyMutation, EmptySubscription> {
311 let pipeline = build_pipeline_handle_from_items(&items()).expect("pipeline build failed");
312 Schema::build(CategoryQueryRoot, EmptyMutation, EmptySubscription)
313 .data(pipeline)
314 .data(embedder)
315 .finish()
316 }
317
318 #[tokio::test]
319 async fn category_neighbors_returns_results() {
320 let schema = build_schema_for_tests(test_embedder());
321 let res = schema
322 .execute(r#"{ categoryNeighbors(category: "alpha", k: 2) { name memberCount } }"#)
323 .await;
324 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
325 let data = res.data.into_json().unwrap();
326 let arr = data["categoryNeighbors"].as_array().unwrap();
327 assert!(!arr.is_empty());
328 }
329
330 #[tokio::test]
331 async fn category_stats_returns_summaries() {
332 let schema = build_schema_for_tests(test_embedder());
333 let res = schema
334 .execute(r#"{ categoryStats { summaries { name memberCount } innerSphereReports { categoryName } } }"#)
335 .await;
336 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
337 let data = res.data.into_json().unwrap();
338 let summaries = data["categoryStats"]["summaries"].as_array().unwrap();
339 assert_eq!(summaries.len(), 3);
340 }
341
342 #[tokio::test]
343 async fn category_concept_path_finds_path() {
344 let schema = build_schema_for_tests(test_embedder());
345 let res = schema
346 .execute(
347 r#"{ categoryConceptPath(sourceCategory: "alpha", targetCategory: "gamma") { totalDistance steps { categoryName } } }"#,
348 )
349 .await;
350 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
351 let data = res.data.into_json().unwrap();
352 assert!(data["categoryConceptPath"].is_object());
353 }
354
355 #[tokio::test]
356 async fn drill_down_uses_embedder() {
357 let schema = build_schema_for_tests(test_embedder());
358 let res = schema
359 .execute(
360 r#"{ drillDown(category: "alpha", queryText: "alpha query", k: 3) { itemIndex distance usedInnerSphere } }"#,
361 )
362 .await;
363 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
364 let data = res.data.into_json().unwrap();
365 let arr = data["drillDown"].as_array().unwrap();
366 assert!(!arr.is_empty());
367 }
368
369 #[tokio::test]
370 async fn hierarchical_nearest_uses_embedder() {
371 let schema = build_schema_for_tests(test_embedder());
372 let res = schema
373 .execute(
374 r#"{ hierarchicalNearest(queryText: "something", k: 3) { id category distance } }"#,
375 )
376 .await;
377 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
378 let data = res.data.into_json().unwrap();
379 let arr = data["hierarchicalNearest"].as_array().unwrap();
380 assert!(!arr.is_empty());
381 }
382
383 #[tokio::test]
384 async fn domain_groups_returns_groups() {
385 let schema = build_schema_for_tests(test_embedder());
386 let res = schema
387 .execute(r#"{ domainGroups { categoryNames totalItems } }"#)
388 .await;
389 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
390 let data = res.data.into_json().unwrap();
391 assert!(data["domainGroups"].is_array());
392 }
393
394 #[tokio::test]
395 async fn concept_path_uses_embedder() {
396 let schema = build_schema_for_tests(test_embedder());
400 let res = schema
401 .execute(
402 r#"{ conceptPath(sourceId: "s-0000", targetId: "s-0008", graphK: 4, queryText: "bridge") { totalDistance steps { id category } } }"#,
403 )
404 .await;
405 assert!(res.errors.is_empty(), "errors: {:?}", res.errors);
408 }
409
410 #[tokio::test]
411 async fn text_query_without_embedder_errors_descriptively() {
412 let schema = build_schema_for_tests(default_no_embedder_handle());
413 let res = schema
414 .execute(r#"{ hierarchicalNearest(queryText: "whatever", k: 3) { id } }"#)
415 .await;
416 assert!(!res.errors.is_empty(), "expected an error");
417 let msg = res.errors[0].message.clone();
418 assert!(msg.contains("no TextEmbedder configured"), "got: {msg}");
419 }
420}