Skip to main content

sphereql_embed/
pipeline.rs

1use sphereql_core::*;
2use sphereql_index::SpatialItem;
3
4use crate::category::{
5    BridgeItem, CategoryLayer, CategoryPath, CategorySummary, DrillDownResult, InnerSphereReport,
6};
7use crate::projection::{PcaProjection, Projection};
8use crate::query::{EmbeddingIndex, GlobResult, SlicingManifold};
9use crate::types::{Embedding, RadialStrategy};
10
11// ── Errors ─────────────────────────────────────────────────────────────────
12
13#[derive(Debug, Clone, thiserror::Error)]
14pub enum PipelineError {
15    #[error("categories length ({cat}) must equal embeddings length ({emb})")]
16    LengthMismatch { cat: usize, emb: usize },
17    #[error("need at least 3 embeddings, got {0}")]
18    TooFewEmbeddings(usize),
19}
20
21// ── Input contract ──────────────────────────────────────────────────────────
22
23/// Input to construct a SphereQL pipeline.
24///
25/// - `categories`: one category string per sentence, same length as `embeddings`
26/// - `embeddings`: one `Vec<f64>` per sentence, all same dimensionality
27/// - Both vectors must have the same length.
28pub struct PipelineInput {
29    pub categories: Vec<String>,
30    pub embeddings: Vec<Vec<f64>>,
31}
32
33/// A query into the pipeline. All fields are embeddings of the same
34/// dimensionality as the pipeline's corpus.
35pub struct PipelineQuery {
36    pub embedding: Vec<f64>,
37}
38
39// ── Output types ────────────────────────────────────────────────────────────
40
41#[derive(Debug, Clone)]
42pub struct NearestResult {
43    pub id: String,
44    pub category: String,
45    pub distance: f64,
46    /// Certainty of this point's projection (0–1). Higher = more faithfully represented.
47    pub certainty: f64,
48    /// Semantic intensity (pre-normalization magnitude of original embedding).
49    pub intensity: f64,
50}
51
52#[derive(Debug, Clone)]
53pub struct PathResult {
54    pub steps: Vec<PipelinePathStep>,
55    pub total_distance: f64,
56}
57
58#[derive(Debug, Clone)]
59pub struct PipelinePathStep {
60    pub id: String,
61    pub category: String,
62    pub cumulative_distance: f64,
63}
64
65#[derive(Debug, Clone)]
66pub struct GlobSummary {
67    pub id: usize,
68    pub centroid: [f64; 3],
69    pub member_count: usize,
70    pub radius: f64,
71    pub top_categories: Vec<(String, usize)>,
72}
73
74#[derive(Debug, Clone)]
75pub struct ManifoldResult {
76    pub centroid: [f64; 3],
77    pub normal: [f64; 3],
78    pub variance_ratio: f64,
79}
80
81/// Typed output from a pipeline query.
82#[derive(Debug, Clone)]
83pub enum SphereQLOutput {
84    Nearest(Vec<NearestResult>),
85    KNearest(Vec<NearestResult>),
86    ConceptPath(Option<PathResult>),
87    Globs(Vec<GlobSummary>),
88    LocalManifold(ManifoldResult),
89    // ── Phase 3: category-level outputs ─────────────────────────────────
90    /// Result of a category-level concept path query.
91    CategoryConceptPath(Option<CategoryPath>),
92    /// Nearest neighbor categories to a given category.
93    CategoryNeighbors(Vec<CategorySummary>),
94    /// Drill-down results within a single category.
95    DrillDown(Vec<DrillDownResult>),
96    /// Summary statistics for all categories and inner spheres.
97    CategoryStats {
98        summaries: Vec<CategorySummary>,
99        inner_sphere_reports: Vec<InnerSphereReport>,
100    },
101}
102
103/// Typed query request.
104pub enum SphereQLQuery<'a> {
105    /// Find the k nearest neighbors to the query embedding.
106    Nearest { k: usize },
107    /// Find all neighbors within a cosine similarity threshold.
108    SimilarAbove { min_cosine: f64 },
109    /// Find the shortest concept path between two indexed items.
110    ConceptPath {
111        source_id: &'a str,
112        target_id: &'a str,
113        graph_k: usize,
114    },
115    /// Detect concept globs. k=None for auto-detection.
116    DetectGlobs { k: Option<usize>, max_k: usize },
117    /// Fit a local manifold around the query point.
118    LocalManifold { neighborhood_k: usize },
119    // ── Phase 3: category-level queries ─────────────────────────────────
120    /// Find the shortest path between two categories through the category graph.
121    CategoryConceptPath {
122        source_category: &'a str,
123        target_category: &'a str,
124    },
125    /// Find the k nearest neighbor categories to the given category.
126    CategoryNeighbors { category: &'a str, k: usize },
127    /// Drill down into a category: k-NN within the category, using the
128    /// inner sphere's projection if available.
129    DrillDown { category: &'a str, k: usize },
130    /// Get summary statistics for all categories and inner spheres.
131    CategoryStats,
132}
133
134/// Projected data for a single item, suitable for export or visualization.
135#[derive(Debug, Clone, serde::Serialize)]
136pub struct ExportedPoint {
137    pub id: String,
138    pub category: String,
139    pub r: f64,
140    pub theta: f64,
141    pub phi: f64,
142    pub x: f64,
143    pub y: f64,
144    pub z: f64,
145    pub certainty: f64,
146    pub intensity: f64,
147}
148
149// ── Pipeline ──────────────────────────────────────────────────────────────
150
151pub struct SphereQLPipeline {
152    pca: PcaProjection,
153    index: EmbeddingIndex<PcaProjection>,
154    categories: Vec<String>,
155    cart_points: Vec<[f64; 3]>,
156    ids: Vec<String>,
157    /// Stored embeddings for category layer queries (drill-down, etc.).
158    _embeddings: Vec<Embedding>,
159    /// Category enrichment layer: summaries, graph, bridges, inner spheres.
160    category_layer: CategoryLayer,
161}
162
163impl SphereQLPipeline {
164    /// Build a pipeline from raw inputs, fitting a new PCA internally.
165    ///
166    /// - `input.categories[i]` is the category for sentence `i`
167    /// - `input.embeddings[i]` is the embedding vector for sentence `i`
168    /// - All embedding vectors must have the same dimensionality (>= 3).
169    pub fn new(input: PipelineInput) -> Result<Self, PipelineError> {
170        let embeddings: Vec<Embedding> = input
171            .embeddings
172            .iter()
173            .map(|v| Embedding::new(v.clone()))
174            .collect();
175
176        let pca = PcaProjection::fit(&embeddings, RadialStrategy::Magnitude).with_volumetric(true);
177        Self::with_projection(input.categories, embeddings, pca)
178    }
179
180    /// Build a pipeline from pre-computed embeddings and an existing PCA projection.
181    ///
182    /// Use this when the projection has already been fitted externally (e.g.,
183    /// by `VectorStoreBridge`) to avoid fitting a second PCA on the same data.
184    pub fn with_projection(
185        categories: Vec<String>,
186        embeddings: Vec<Embedding>,
187        pca: PcaProjection,
188    ) -> Result<Self, PipelineError> {
189        let n = embeddings.len();
190        if n != categories.len() {
191            return Err(PipelineError::LengthMismatch {
192                cat: categories.len(),
193                emb: n,
194            });
195        }
196        if n < 3 {
197            return Err(PipelineError::TooFewEmbeddings(n));
198        }
199
200        let mut index = EmbeddingIndex::builder(pca.clone())
201            .uniform_shells(10, 1.0)
202            .theta_divisions(12)
203            .phi_divisions(6)
204            .build();
205
206        let mut ids = Vec::with_capacity(n);
207        for (i, emb) in embeddings.iter().enumerate() {
208            let id = format!("s-{i:04}");
209            index.insert(&id, emb);
210            ids.push(id);
211        }
212
213        let cart_points: Vec<[f64; 3]> = embeddings
214            .iter()
215            .map(|e| {
216                let sp = pca.project(e);
217                let c = spherical_to_cartesian(&sp);
218                [c.x, c.y, c.z]
219            })
220            .collect();
221
222        // Build the category enrichment layer (Phase 1+2)
223        let projected_positions: Vec<SphericalPoint> =
224            embeddings.iter().map(|e| pca.project(e)).collect();
225
226        let category_layer =
227            CategoryLayer::build(&categories, &embeddings, &projected_positions, &pca);
228
229        Ok(Self {
230            pca,
231            index,
232            categories,
233            cart_points,
234            ids,
235            _embeddings: embeddings,
236            category_layer,
237        })
238    }
239
240    /// Execute a typed query against the pipeline.
241    pub fn query(&self, q: SphereQLQuery<'_>, query_embedding: &PipelineQuery) -> SphereQLOutput {
242        let emb = Embedding::new(query_embedding.embedding.clone());
243
244        match q {
245            SphereQLQuery::Nearest { k } => {
246                let results = self.index.search_nearest(&emb, k);
247                SphereQLOutput::Nearest(
248                    results
249                        .iter()
250                        .map(|r| NearestResult {
251                            id: r.item.id.clone(),
252                            category: self.cat_for(&r.item.id),
253                            distance: r.distance,
254                            certainty: r.item.certainty(),
255                            intensity: r.item.intensity(),
256                        })
257                        .collect(),
258                )
259            }
260
261            SphereQLQuery::SimilarAbove { min_cosine } => {
262                let results = self.index.search_similar(&emb, min_cosine);
263                let sp_q = self.pca.project(&emb);
264                SphereQLOutput::KNearest(
265                    results
266                        .items
267                        .iter()
268                        .map(|item| {
269                            let d = angular_distance(&sp_q, item.position());
270                            NearestResult {
271                                id: item.id.clone(),
272                                category: self.cat_for(&item.id),
273                                distance: d,
274                                certainty: item.certainty(),
275                                intensity: item.intensity(),
276                            }
277                        })
278                        .collect(),
279                )
280            }
281
282            SphereQLQuery::ConceptPath {
283                source_id,
284                target_id,
285                graph_k,
286            } => {
287                let path = self.index.concept_path(source_id, target_id, graph_k);
288                SphereQLOutput::ConceptPath(path.map(|p| {
289                    PathResult {
290                        total_distance: p.total_distance,
291                        steps: p
292                            .steps
293                            .iter()
294                            .map(|s| PipelinePathStep {
295                                id: s.id.clone(),
296                                category: self.cat_for(&s.id),
297                                cumulative_distance: s.cumulative_distance,
298                            })
299                            .collect(),
300                    }
301                }))
302            }
303
304            SphereQLQuery::DetectGlobs { k, max_k } => {
305                let result = GlobResult::detect(&self.cart_points, &self.ids, k, max_k);
306                SphereQLOutput::Globs(
307                    result
308                        .globs
309                        .iter()
310                        .map(|g| {
311                            let mut cat_counts = std::collections::HashMap::<String, usize>::new();
312                            for mid in &g.member_ids {
313                                let cat = self.cat_for(mid);
314                                *cat_counts.entry(cat).or_default() += 1;
315                            }
316                            let mut top: Vec<_> = cat_counts.into_iter().collect();
317                            top.sort_by_key(|(_, c)| std::cmp::Reverse(*c));
318                            top.truncate(3);
319
320                            GlobSummary {
321                                id: g.id,
322                                centroid: g.centroid,
323                                member_count: g.member_ids.len(),
324                                radius: g.radius,
325                                top_categories: top,
326                            }
327                        })
328                        .collect(),
329                )
330            }
331
332            SphereQLQuery::LocalManifold { neighborhood_k } => {
333                let sp = self.pca.project(&emb);
334                let c = spherical_to_cartesian(&sp);
335                let qpt = [c.x, c.y, c.z];
336                let m = SlicingManifold::fit_local(&qpt, &self.cart_points, neighborhood_k);
337                SphereQLOutput::LocalManifold(ManifoldResult {
338                    centroid: m.centroid,
339                    normal: m.normal,
340                    variance_ratio: m.variance_ratio,
341                })
342            }
343
344            // ── Phase 3: category-level query dispatch ─────────────────
345            SphereQLQuery::CategoryConceptPath {
346                source_category,
347                target_category,
348            } => {
349                let path = self
350                    .category_layer
351                    .category_path(source_category, target_category);
352                SphereQLOutput::CategoryConceptPath(path)
353            }
354
355            SphereQLQuery::CategoryNeighbors { category, k } => {
356                let neighbors = self.category_layer.category_neighbors(category, k);
357                SphereQLOutput::CategoryNeighbors(neighbors.into_iter().cloned().collect())
358            }
359
360            SphereQLQuery::DrillDown { category, k } => {
361                let results = self
362                    .category_layer
363                    .drill_down_with_projection(category, &emb, &self.pca, k);
364                SphereQLOutput::DrillDown(results)
365            }
366
367            SphereQLQuery::CategoryStats => SphereQLOutput::CategoryStats {
368                summaries: self.category_layer.summaries.clone(),
369                inner_sphere_reports: self.category_layer.inner_sphere_stats(),
370            },
371        }
372    }
373
374    /// Get the category for an indexed item by its id.
375    fn cat_for(&self, id: &str) -> String {
376        if let Some(idx_str) = id.strip_prefix("s-")
377            && let Ok(idx) = idx_str.parse::<usize>()
378            && idx < self.categories.len()
379        {
380            return self.categories[idx].clone();
381        }
382        "unknown".into()
383    }
384
385    pub fn num_items(&self) -> usize {
386        self.ids.len()
387    }
388
389    pub fn categories(&self) -> &[String] {
390        &self.categories
391    }
392
393    /// Export (id, category, cartesian [x, y, z]) triples for every indexed item.
394    pub fn projected_points(&self) -> Vec<(&str, &str, [f64; 3])> {
395        self.ids
396            .iter()
397            .enumerate()
398            .map(|(i, id)| {
399                let cat = self
400                    .categories
401                    .get(i)
402                    .map(|s| s.as_str())
403                    .unwrap_or("unknown");
404                (id.as_str(), cat, self.cart_points[i])
405            })
406            .collect()
407    }
408
409    /// Access the fitted PCA projection.
410    pub fn pca(&self) -> &PcaProjection {
411        &self.pca
412    }
413
414    /// Export all projected points with their Cartesian and spherical coordinates.
415    ///
416    /// Returns one `ExportedPoint` per indexed item, in insertion order.
417    pub fn exported_points(&self) -> Vec<ExportedPoint> {
418        self.ids
419            .iter()
420            .enumerate()
421            .map(|(i, id)| {
422                let [x, y, z] = self.cart_points[i];
423                let category = self
424                    .categories
425                    .get(i)
426                    .cloned()
427                    .unwrap_or_else(|| "unknown".into());
428                let item = self.index.get(id);
429                let (r, theta, phi) = item
430                    .map(|it| {
431                        let pos = it.position();
432                        (pos.r, pos.theta, pos.phi)
433                    })
434                    .unwrap_or((0.0, 0.0, 0.0));
435                let certainty = item.map_or(1.0, |it| it.certainty());
436                let intensity = item.map_or(1.0, |it| it.intensity());
437                ExportedPoint {
438                    id: id.clone(),
439                    category,
440                    r,
441                    theta,
442                    phi,
443                    x,
444                    y,
445                    z,
446                    certainty,
447                    intensity,
448                }
449            })
450            .collect()
451    }
452
453    /// The PCA projection's explained variance ratio (0.0–1.0).
454    pub fn explained_variance_ratio(&self) -> f64 {
455        self.pca.explained_variance_ratio()
456    }
457
458    /// Number of unique categories in the corpus.
459    pub fn num_categories(&self) -> usize {
460        self.category_layer.num_categories()
461    }
462
463    /// Unique category names in insertion order.
464    pub fn unique_categories(&self) -> Vec<String> {
465        self.category_layer
466            .summaries
467            .iter()
468            .map(|s| s.name.clone())
469            .collect()
470    }
471
472    // ── Phase 3: category-level accessors ──────────────────────────────
473
474    /// Access the category enrichment layer directly.
475    pub fn category_layer(&self) -> &CategoryLayer {
476        &self.category_layer
477    }
478
479    /// Shortcut: find the shortest path between two categories.
480    pub fn category_path(&self, source: &str, target: &str) -> Option<CategoryPath> {
481        self.category_layer.category_path(source, target)
482    }
483
484    /// Shortcut: get bridge items between two categories.
485    pub fn bridge_items(&self, source: &str, target: &str, max: usize) -> Vec<&BridgeItem> {
486        self.category_layer.bridge_items(source, target, max)
487    }
488
489    /// Shortcut: check if a category has an inner sphere.
490    pub fn has_inner_sphere(&self, category: &str) -> bool {
491        self.category_layer.has_inner_sphere(category)
492    }
493
494    /// Shortcut: number of categories with inner spheres.
495    pub fn num_inner_spheres(&self) -> usize {
496        self.category_layer.num_inner_spheres()
497    }
498
499    /// Shortcut: inner sphere statistics for all categories.
500    pub fn inner_sphere_stats(&self) -> Vec<InnerSphereReport> {
501        self.category_layer.inner_sphere_stats()
502    }
503
504    /// Serialize all projected points as a JSON array string.
505    pub fn to_json(&self) -> String {
506        serde_json::to_string(&self.exported_points())
507            .expect("ExportedPoint is always serializable")
508    }
509
510    /// Serialize all projected points as RFC 4180-compliant CSV with a header row.
511    ///
512    /// String fields (id, category) are quoted to handle embedded commas
513    /// and special characters safely.
514    pub fn to_csv(&self) -> String {
515        let points = self.exported_points();
516        let mut out = String::from("id,category,r,theta,phi,x,y,z,certainty,intensity\n");
517        for p in &points {
518            out.push_str(&format!(
519                "\"{}\",\"{}\",{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6}\n",
520                p.id.replace('"', "\"\""),
521                p.category.replace('"', "\"\""),
522                p.r,
523                p.theta,
524                p.phi,
525                p.x,
526                p.y,
527                p.z,
528                p.certainty,
529                p.intensity,
530            ));
531        }
532        out
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    fn make_input(n: usize, dim: usize) -> (PipelineInput, PipelineQuery) {
541        let mut embeddings = Vec::with_capacity(n);
542        let mut categories = Vec::with_capacity(n);
543        for i in 0..n {
544            let mut v = vec![0.0; dim];
545            if i < n / 2 {
546                v[0] = 1.0 + (i as f64 * 0.01);
547                v[1] = 0.1;
548                categories.push("group_a".into());
549            } else {
550                v[0] = 0.1;
551                v[1] = 1.0 + (i as f64 * 0.01);
552                categories.push("group_b".into());
553            }
554            v[2] = 0.05 * (i as f64);
555            embeddings.push(v);
556        }
557        let query = PipelineQuery {
558            embedding: vec![0.9; dim],
559        };
560        (
561            PipelineInput {
562                categories,
563                embeddings,
564            },
565            query,
566        )
567    }
568
569    // ── Existing tests (unchanged) ─────────────────────────────────────
570
571    #[test]
572    fn pipeline_nearest() {
573        let (input, query) = make_input(20, 10);
574        let pipeline = SphereQLPipeline::new(input).unwrap();
575        let result = pipeline.query(SphereQLQuery::Nearest { k: 5 }, &query);
576        match result {
577            SphereQLOutput::Nearest(items) => {
578                assert_eq!(items.len(), 5);
579                assert!(items[0].distance <= items[1].distance);
580            }
581            _ => panic!("expected Nearest"),
582        }
583    }
584
585    #[test]
586    fn pipeline_globs() {
587        let (input, query) = make_input(30, 10);
588        let pipeline = SphereQLPipeline::new(input).unwrap();
589        let result = pipeline.query(
590            SphereQLQuery::DetectGlobs {
591                k: Some(2),
592                max_k: 5,
593            },
594            &query,
595        );
596        match result {
597            SphereQLOutput::Globs(globs) => {
598                assert_eq!(globs.len(), 2);
599                let total: usize = globs.iter().map(|g| g.member_count).sum();
600                assert_eq!(total, 30);
601            }
602            _ => panic!("expected Globs"),
603        }
604    }
605
606    #[test]
607    fn pipeline_concept_path() {
608        let (input, query) = make_input(20, 10);
609        let pipeline = SphereQLPipeline::new(input).unwrap();
610        let result = pipeline.query(
611            SphereQLQuery::ConceptPath {
612                source_id: "s-0000",
613                target_id: "s-0015",
614                graph_k: 10,
615            },
616            &query,
617        );
618        match result {
619            SphereQLOutput::ConceptPath(Some(path)) => {
620                assert!(path.steps.len() >= 2);
621                assert_eq!(path.steps.first().unwrap().id, "s-0000");
622                assert_eq!(path.steps.last().unwrap().id, "s-0015");
623            }
624            _ => panic!("expected ConceptPath(Some)"),
625        }
626    }
627
628    #[test]
629    fn pipeline_local_manifold() {
630        let (input, query) = make_input(20, 10);
631        let pipeline = SphereQLPipeline::new(input).unwrap();
632        let result = pipeline.query(SphereQLQuery::LocalManifold { neighborhood_k: 10 }, &query);
633        match result {
634            SphereQLOutput::LocalManifold(m) => {
635                assert!(m.variance_ratio > 0.0);
636                assert!(m.variance_ratio <= 1.0);
637            }
638            _ => panic!("expected LocalManifold"),
639        }
640    }
641
642    #[test]
643    fn test_exported_points_count() {
644        let (input, _) = make_input(20, 10);
645        let pipeline = SphereQLPipeline::new(input).unwrap();
646        assert_eq!(pipeline.exported_points().len(), 20);
647    }
648
649    #[test]
650    fn test_exported_points_fields() {
651        let (input, _) = make_input(20, 10);
652        let pipeline = SphereQLPipeline::new(input).unwrap();
653        for p in pipeline.exported_points() {
654            assert!(p.r >= 0.0, "r must be non-negative");
655            assert!(
656                p.theta >= 0.0 && p.theta < std::f64::consts::TAU,
657                "theta out of range"
658            );
659            assert!(
660                p.phi >= 0.0 && p.phi <= std::f64::consts::PI,
661                "phi out of range"
662            );
663        }
664    }
665
666    #[test]
667    fn test_exported_points_categories() {
668        let (input, _) = make_input(20, 10);
669        let pipeline = SphereQLPipeline::new(input).unwrap();
670        let points = pipeline.exported_points();
671        for (i, p) in points.iter().enumerate() {
672            let expected = if i < 10 { "group_a" } else { "group_b" };
673            assert_eq!(p.category, expected);
674        }
675    }
676
677    #[test]
678    fn test_to_json_parseable() {
679        let (input, _) = make_input(20, 10);
680        let pipeline = SphereQLPipeline::new(input).unwrap();
681        let json = pipeline.to_json();
682        let parsed: Vec<serde_json::Value> = serde_json::from_str(&json).expect("valid JSON");
683        assert_eq!(parsed.len(), 20);
684    }
685
686    #[test]
687    fn test_to_csv_lines() {
688        let (input, _) = make_input(20, 10);
689        let pipeline = SphereQLPipeline::new(input).unwrap();
690        let csv = pipeline.to_csv();
691        let lines: Vec<&str> = csv.lines().collect();
692        assert_eq!(
693            lines[0],
694            "id,category,r,theta,phi,x,y,z,certainty,intensity"
695        );
696        assert_eq!(lines.len(), 21);
697    }
698
699    #[test]
700    fn test_to_csv_quoted_fields() {
701        let (input, _) = make_input(20, 10);
702        let pipeline = SphereQLPipeline::new(input).unwrap();
703        let csv = pipeline.to_csv();
704        let data_line = csv.lines().nth(1).unwrap();
705        assert!(data_line.starts_with('"'), "id field should be quoted");
706    }
707
708    #[test]
709    fn test_explained_variance() {
710        let (input, _) = make_input(20, 10);
711        let pipeline = SphereQLPipeline::new(input).unwrap();
712        let ratio = pipeline.explained_variance_ratio();
713        assert!(ratio > 0.0 && ratio <= 1.0);
714    }
715
716    #[test]
717    fn test_unique_categories() {
718        let (input, _) = make_input(20, 10);
719        let pipeline = SphereQLPipeline::new(input).unwrap();
720        let cats = pipeline.unique_categories();
721        assert_eq!(cats.len(), 2);
722        assert_eq!(cats[0], "group_a");
723        assert_eq!(cats[1], "group_b");
724        assert_eq!(pipeline.num_categories(), 2);
725    }
726
727    // ── Phase 3 tests: category layer integration ──────────────────────
728
729    #[test]
730    fn pipeline_builds_category_layer() {
731        let (input, _) = make_input(20, 10);
732        let pipeline = SphereQLPipeline::new(input).unwrap();
733        assert_eq!(pipeline.category_layer().num_categories(), 2);
734    }
735
736    #[test]
737    fn pipeline_category_path_query() {
738        let (input, query) = make_input(20, 10);
739        let pipeline = SphereQLPipeline::new(input).unwrap();
740        let result = pipeline.query(
741            SphereQLQuery::CategoryConceptPath {
742                source_category: "group_a",
743                target_category: "group_b",
744            },
745            &query,
746        );
747        match result {
748            SphereQLOutput::CategoryConceptPath(Some(path)) => {
749                assert!(path.steps.len() >= 2);
750                assert_eq!(path.steps.first().unwrap().category_name, "group_a");
751                assert_eq!(path.steps.last().unwrap().category_name, "group_b");
752                assert!(path.total_distance > 0.0);
753            }
754            _ => panic!("expected CategoryConceptPath(Some)"),
755        }
756    }
757
758    #[test]
759    fn pipeline_category_path_shortcut() {
760        let (input, _) = make_input(20, 10);
761        let pipeline = SphereQLPipeline::new(input).unwrap();
762        let path = pipeline.category_path("group_a", "group_b");
763        assert!(path.is_some());
764        let path = path.unwrap();
765        assert_eq!(path.steps.first().unwrap().category_name, "group_a");
766        assert_eq!(path.steps.last().unwrap().category_name, "group_b");
767    }
768
769    #[test]
770    fn pipeline_category_path_unknown() {
771        let (input, _) = make_input(20, 10);
772        let pipeline = SphereQLPipeline::new(input).unwrap();
773        assert!(pipeline.category_path("group_a", "nonexistent").is_none());
774    }
775
776    #[test]
777    fn pipeline_category_neighbors_query() {
778        let (input, query) = make_input(20, 10);
779        let pipeline = SphereQLPipeline::new(input).unwrap();
780        let result = pipeline.query(
781            SphereQLQuery::CategoryNeighbors {
782                category: "group_a",
783                k: 5,
784            },
785            &query,
786        );
787        match result {
788            SphereQLOutput::CategoryNeighbors(neighbors) => {
789                assert_eq!(neighbors.len(), 1);
790                assert_eq!(neighbors[0].name, "group_b");
791            }
792            _ => panic!("expected CategoryNeighbors"),
793        }
794    }
795
796    #[test]
797    fn pipeline_drill_down_query() {
798        let (input, query) = make_input(20, 10);
799        let pipeline = SphereQLPipeline::new(input).unwrap();
800        let result = pipeline.query(
801            SphereQLQuery::DrillDown {
802                category: "group_a",
803                k: 5,
804            },
805            &query,
806        );
807        match result {
808            SphereQLOutput::DrillDown(results) => {
809                assert!(!results.is_empty());
810                assert!(results.len() <= 5);
811                for w in results.windows(2) {
812                    assert!(w[0].distance <= w[1].distance);
813                }
814            }
815            _ => panic!("expected DrillDown"),
816        }
817    }
818
819    #[test]
820    fn pipeline_category_stats_query() {
821        let (input, query) = make_input(20, 10);
822        let pipeline = SphereQLPipeline::new(input).unwrap();
823        let result = pipeline.query(SphereQLQuery::CategoryStats, &query);
824        match result {
825            SphereQLOutput::CategoryStats {
826                summaries,
827                inner_sphere_reports,
828            } => {
829                assert_eq!(summaries.len(), 2);
830                assert_eq!(inner_sphere_reports.len(), 0);
831            }
832            _ => panic!("expected CategoryStats"),
833        }
834    }
835
836    #[test]
837    fn pipeline_bridge_items_shortcut() {
838        let (input, _) = make_input(20, 10);
839        let pipeline = SphereQLPipeline::new(input).unwrap();
840        let _ = pipeline.bridge_items("group_a", "group_b", 5);
841    }
842
843    #[test]
844    fn pipeline_inner_sphere_shortcuts() {
845        let (input, _) = make_input(20, 10);
846        let pipeline = SphereQLPipeline::new(input).unwrap();
847        assert!(!pipeline.has_inner_sphere("group_a"));
848        assert_eq!(pipeline.num_inner_spheres(), 0);
849        assert!(pipeline.inner_sphere_stats().is_empty());
850    }
851
852    #[test]
853    fn pipeline_category_layer_accessor() {
854        let (input, _) = make_input(20, 10);
855        let pipeline = SphereQLPipeline::new(input).unwrap();
856        let layer = pipeline.category_layer();
857        assert_eq!(layer.num_categories(), 2);
858        assert!(layer.get_category("group_a").is_some());
859        assert!(layer.get_category("group_b").is_some());
860    }
861}