Skip to main content

sphereql_embed/
pipeline.rs

1use sphereql_core::*;
2use sphereql_index::SpatialItem;
3
4use crate::category::{
5    BridgeItem, CategoryLayer, CategoryPath, CategorySummary, DrillDownResult, InnerSphereReport,
6};
7use crate::confidence::{ProjectionWarning, QualityConfig, QualitySignal};
8use crate::config::{PipelineConfig, ProjectionKind};
9use crate::configured_projection::ConfiguredProjection;
10use crate::corpus_features::CorpusFeatures;
11use crate::domain_groups::{DomainGroup, detect_domain_groups};
12use crate::kernel_pca::KernelPcaProjection;
13use crate::laplacian::LaplacianEigenmapProjection;
14use crate::meta_model::MetaModel;
15use crate::projection::{PcaProjection, Projection};
16use crate::quality_metric::QualityMetric;
17use crate::query::{EmbeddingIndex, GlobResult, SlicingManifold};
18use crate::tuner::{SearchSpace, SearchStrategy, TuneReport, auto_tune};
19use crate::types::{Embedding, RadialStrategy};
20
21// ── Errors ─────────────────────────────────────────────────────────────────
22
23/// Reasons a pipeline build or query can fail.
24#[derive(Debug, Clone, thiserror::Error)]
25pub enum PipelineError {
26    /// `categories` and `embeddings` had different lengths — they must
27    /// match one-to-one.
28    #[error("categories length ({cat}) must equal embeddings length ({emb})")]
29    LengthMismatch { cat: usize, emb: usize },
30    /// Fewer than 3 embeddings — not enough to fit a 3D projection.
31    #[error("need at least 3 embeddings, got {0}")]
32    TooFewEmbeddings(usize),
33    /// Projection fit rejected the input. Wraps
34    /// [`ProjectionError`](crate::projection::ProjectionError) — empty
35    /// corpus, dim-too-low, inconsistent dim, invalid sigma, etc.
36    #[error("projection fit failed: {0}")]
37    Projection(#[from] crate::projection::ProjectionError),
38    /// A category-keyed query referenced a name that doesn't exist in
39    /// the pipeline. Previously these paths silently returned empty
40    /// results / `None`, which callers couldn't distinguish from
41    /// "category exists but is disconnected on the graph".
42    #[error("unknown category: {0:?}")]
43    UnknownCategory(String),
44    /// A concept-path query referenced an id that doesn't exist in the
45    /// pipeline.
46    #[error("unknown id: {0:?}")]
47    UnknownId(String),
48    /// Every [`auto_tune`](crate::tuner::auto_tune) trial failed a
49    /// downstream validator (e.g. every candidate config was rejected
50    /// by the pipeline builder). The attached `failures` list carries
51    /// the `(config, error)` pairs the tuner observed — callers should
52    /// inspect them to find the real cause; the outer error is just a
53    /// roll-up saying "none of the trials produced a scorable pipeline".
54    #[error("auto_tune produced no successful trials ({} failures)", failures.len())]
55    AllTrialsFailed {
56        failures: Vec<(crate::config::PipelineConfig, String)>,
57    },
58}
59
60// ── Input contract ──────────────────────────────────────────────────────────
61
62/// Input to construct a SphereQL pipeline.
63///
64/// - `categories`: one category string per sentence, same length as `embeddings`
65/// - `embeddings`: one `Vec<f64>` per sentence, all same dimensionality
66/// - Both vectors must have the same length.
67pub struct PipelineInput {
68    pub categories: Vec<String>,
69    pub embeddings: Vec<Vec<f64>>,
70}
71
72/// A query into the pipeline. All fields are embeddings of the same
73/// dimensionality as the pipeline's corpus.
74pub struct PipelineQuery {
75    pub embedding: Vec<f64>,
76}
77
78// ── Output types ────────────────────────────────────────────────────────────
79
80/// One item returned from a nearest-neighbor or similarity query.
81///
82/// All fields use the pipeline's configured projection to derive
83/// distances and quality signals; callers should treat results as
84/// comparable within a single pipeline but not across pipelines with
85/// different projection kinds.
86#[derive(Debug, Clone)]
87pub struct NearestResult {
88    /// Item id as supplied to [`SphereQLPipeline::new`].
89    pub id: String,
90    /// Category label from the input.
91    pub category: String,
92    /// Angular distance on S² between the query and this item's
93    /// projected position, in radians.
94    pub distance: f64,
95    /// Certainty of this point's projection (0–1). Higher = more faithfully represented.
96    pub certainty: f64,
97    /// Semantic intensity (pre-normalization magnitude of original embedding).
98    pub intensity: f64,
99    /// Combined quality signal: EVR × certainty × gap_confidence.
100    /// Always `Some(...)` for results the pipeline produces today; the
101    /// `Option` is kept so callers that construct `NearestResult`
102    /// outside the pipeline (e.g. mocks, tests) can omit it.
103    pub quality: Option<QualitySignal>,
104}
105
106/// Concept-path result: ordered steps between two indexed items, with
107/// cumulative angular distance along the path.
108#[derive(Debug, Clone)]
109pub struct PathResult {
110    pub steps: Vec<PipelinePathStep>,
111    pub total_distance: f64,
112}
113
114/// One step along a [`PathResult`].
115#[derive(Debug, Clone)]
116pub struct PipelinePathStep {
117    pub id: String,
118    pub category: String,
119    pub cumulative_distance: f64,
120    /// Angular distance of this individual hop (0.0 for the first step).
121    pub hop_distance: f64,
122    /// Bridge strength used on cross-category hops (None for same-category or unbridged paths).
123    pub bridge_strength: Option<f64>,
124}
125
126/// Summary of one cluster detected by `DetectGlobs`.
127#[derive(Debug, Clone)]
128pub struct GlobSummary {
129    pub id: usize,
130    pub centroid: [f64; 3],
131    pub member_count: usize,
132    pub radius: f64,
133    pub top_categories: Vec<(String, usize)>,
134}
135
136/// Local 3-D manifold fitted around the query point.
137#[derive(Debug, Clone)]
138pub struct ManifoldResult {
139    pub centroid: [f64; 3],
140    pub normal: [f64; 3],
141    pub variance_ratio: f64,
142}
143
144/// Typed output from a pipeline query.
145#[derive(Debug, Clone)]
146pub enum SphereQLOutput {
147    Nearest(Vec<NearestResult>),
148    KNearest(Vec<NearestResult>),
149    ConceptPath(Option<PathResult>),
150    Globs(Vec<GlobSummary>),
151    LocalManifold(ManifoldResult),
152    // ── Phase 3: category-level outputs ─────────────────────────────────
153    /// Result of a category-level concept path query.
154    CategoryConceptPath(Option<CategoryPath>),
155    /// Nearest neighbor categories to a given category.
156    CategoryNeighbors(Vec<CategorySummary>),
157    /// Drill-down results within a single category.
158    DrillDown(Vec<DrillDownResult>),
159    /// Summary statistics for all categories and inner spheres.
160    CategoryStats {
161        summaries: Vec<CategorySummary>,
162        inner_sphere_reports: Vec<InnerSphereReport>,
163    },
164}
165
166/// Typed query request.
167pub enum SphereQLQuery<'a> {
168    /// Find the k nearest neighbors to the query embedding.
169    Nearest { k: usize },
170    /// Find all neighbors within a cosine similarity threshold.
171    SimilarAbove { min_cosine: f64 },
172    /// Find the shortest concept path between two indexed items.
173    ConceptPath {
174        source_id: &'a str,
175        target_id: &'a str,
176        graph_k: usize,
177    },
178    /// Detect concept globs. k=None for auto-detection.
179    DetectGlobs { k: Option<usize>, max_k: usize },
180    /// Fit a local manifold around the query point.
181    LocalManifold { neighborhood_k: usize },
182    // ── Phase 3: category-level queries ─────────────────────────────────
183    /// Find the shortest path between two categories through the category graph.
184    CategoryConceptPath {
185        source_category: &'a str,
186        target_category: &'a str,
187    },
188    /// Find the k nearest neighbor categories to the given category.
189    CategoryNeighbors { category: &'a str, k: usize },
190    /// Drill down into a category: k-NN within the category, using the
191    /// inner sphere's projection if available.
192    DrillDown { category: &'a str, k: usize },
193    /// Get summary statistics for all categories and inner spheres.
194    CategoryStats,
195}
196
197/// Projected data for a single item, suitable for export or visualization.
198#[derive(Debug, Clone, serde::Serialize)]
199pub struct ExportedPoint {
200    pub id: String,
201    pub category: String,
202    pub r: f64,
203    pub theta: f64,
204    pub phi: f64,
205    pub x: f64,
206    pub y: f64,
207    pub z: f64,
208    pub certainty: f64,
209    pub intensity: f64,
210}
211
212// ── Pipeline ──────────────────────────────────────────────────────────────
213
214/// The main SphereQL pipeline: fitted projection + spatial index +
215/// category enrichment layer + optional tunable config.
216///
217/// Build one with [`Self::new`] for defaults,
218/// [`Self::new_with_config`] for an explicit [`PipelineConfig`], or
219/// [`Self::new_from_metamodel`] / [`Self::new_from_metamodel_tuned`]
220/// to consult a trained meta-model on past tuner runs.
221pub struct SphereQLPipeline {
222    projection: ConfiguredProjection,
223    index: EmbeddingIndex<ConfiguredProjection>,
224    categories: Vec<String>,
225    cart_points: Vec<[f64; 3]>,
226    ids: Vec<String>,
227    /// Category enrichment layer: summaries, graph, bridges, inner spheres.
228    category_layer: CategoryLayer,
229    /// Quality configuration for filtering and warnings.
230    quality_config: QualityConfig,
231    /// Projection quality warnings (empty if EVR is above threshold).
232    projection_warnings: Vec<ProjectionWarning>,
233    /// Hierarchical domain groups detected from Voronoi adjacency + cap overlap.
234    /// Used by [`SphereQLPipeline::route_to_group`] and
235    /// [`SphereQLPipeline::hierarchical_nearest`] for coarse routing when EVR is low.
236    domain_groups: Vec<DomainGroup>,
237    /// Full tunable configuration used at build time.
238    config: PipelineConfig,
239}
240
241impl SphereQLPipeline {
242    /// Build a pipeline from raw inputs with [`PipelineConfig::default`].
243    ///
244    /// - `input.categories[i]` is the category for sentence `i`
245    /// - `input.embeddings[i]` is the embedding vector for sentence `i`
246    /// - All embedding vectors must have the same dimensionality (>= 3).
247    pub fn new(input: PipelineInput) -> Result<Self, PipelineError> {
248        Self::new_with_config(input, PipelineConfig::default())
249    }
250
251    /// Build a pipeline with an explicit configuration. Fits the projection
252    /// internally using [`PipelineConfig::projection_kind`] and any relevant
253    /// sub-config (e.g. [`LaplacianConfig`](crate::config::LaplacianConfig)).
254    pub fn new_with_config(
255        input: PipelineInput,
256        config: PipelineConfig,
257    ) -> Result<Self, PipelineError> {
258        let embeddings: Vec<Embedding> = input
259            .embeddings
260            .iter()
261            .map(|v| Embedding::new(v.clone()))
262            .collect();
263
264        let projection = fit_projection_for_config(&embeddings, &config)?;
265        Self::with_configured_projection_and_config(
266            input.categories,
267            embeddings,
268            projection,
269            config,
270        )
271    }
272
273    /// Build a pipeline using a config predicted by a [`MetaModel`].
274    ///
275    /// Extracts [`CorpusFeatures`] from the input, asks the model for a
276    /// predicted [`PipelineConfig`], then builds the pipeline with it.
277    /// Returns the pipeline alongside the extracted features and the
278    /// predicted config so the caller can log, audit, or save them as a
279    /// new [`MetaTrainingRecord`](crate::meta_model::MetaTrainingRecord).
280    ///
281    /// This is the "tune-or-recall" entry point: once you've accumulated
282    /// a handful of training records, call this instead of
283    /// [`crate::tuner::auto_tune`] when you want to skip
284    /// search entirely. For a warm-start hybrid that does some tuning
285    /// on top of the prediction, use [`Self::new_from_metamodel_tuned`].
286    pub fn new_from_metamodel<M: MetaModel>(
287        input: PipelineInput,
288        model: &M,
289    ) -> Result<(Self, CorpusFeatures, PipelineConfig), PipelineError> {
290        let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
291        let predicted = model.predict(&features);
292        let pipeline = Self::new_with_config(input, predicted.clone())?;
293        Ok((pipeline, features, predicted))
294    }
295
296    /// Warm-started hybrid: predict a config with `model`, then run a
297    /// small-budget tuner pass using that prediction as `base_config`.
298    ///
299    /// Non-tuned knobs stay at the model's predicted values; the
300    /// searched knobs explore the given [`SearchSpace`] from there.
301    /// When the meta-model has seen a similar corpus before the
302    /// prediction is usually close to optimal and the tuner only needs
303    /// a handful of trials to confirm or refine it — meaningfully
304    /// cheaper than cold-starting at [`PipelineConfig::default`].
305    ///
306    /// Returns the winning pipeline, the extracted corpus features, and
307    /// the full [`TuneReport`]. Callers can feed the report back into
308    /// [`MetaTrainingRecord::from_tune_result`](crate::meta_model::MetaTrainingRecord::from_tune_result)
309    /// to accumulate more training data for the next recall.
310    pub fn new_from_metamodel_tuned<M, Q>(
311        input: PipelineInput,
312        model: &M,
313        space: &SearchSpace,
314        metric: &Q,
315        strategy: SearchStrategy,
316    ) -> Result<(Self, CorpusFeatures, TuneReport), PipelineError>
317    where
318        M: MetaModel,
319        Q: QualityMetric,
320    {
321        let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
322        let predicted = model.predict(&features);
323        let (pipeline, report) = auto_tune(input, space, metric, strategy, &predicted)?;
324        Ok((pipeline, features, report))
325    }
326
327    /// Build a pipeline from pre-computed embeddings and an existing PCA
328    /// projection, with [`PipelineConfig::default`].
329    ///
330    /// This is the legacy entry point — use
331    /// [`Self::with_configured_projection_and_config`] directly when you
332    /// have a non-PCA [`ConfiguredProjection`].
333    pub fn with_projection(
334        categories: Vec<String>,
335        embeddings: Vec<Embedding>,
336        pca: PcaProjection,
337    ) -> Result<Self, PipelineError> {
338        Self::with_configured_projection_and_config(
339            categories,
340            embeddings,
341            ConfiguredProjection::Pca(pca),
342            PipelineConfig::default(),
343        )
344    }
345
346    /// Legacy configurable PCA entry point. Prefer
347    /// [`Self::with_configured_projection_and_config`] for new code.
348    pub fn with_projection_and_config(
349        categories: Vec<String>,
350        embeddings: Vec<Embedding>,
351        pca: PcaProjection,
352        config: PipelineConfig,
353    ) -> Result<Self, PipelineError> {
354        Self::with_configured_projection_and_config(
355            categories,
356            embeddings,
357            ConfiguredProjection::Pca(pca),
358            config,
359        )
360    }
361
362    /// Core pipeline constructor: accepts any [`ConfiguredProjection`] and
363    /// a [`PipelineConfig`].
364    pub fn with_configured_projection_and_config(
365        categories: Vec<String>,
366        embeddings: Vec<Embedding>,
367        projection: ConfiguredProjection,
368        config: PipelineConfig,
369    ) -> Result<Self, PipelineError> {
370        let n = embeddings.len();
371        if n != categories.len() {
372            return Err(PipelineError::LengthMismatch {
373                cat: categories.len(),
374                emb: n,
375            });
376        }
377        if n < 3 {
378            return Err(PipelineError::TooFewEmbeddings(n));
379        }
380
381        let mut index = EmbeddingIndex::builder(projection.clone())
382            .uniform_shells(10, 1.0)
383            .theta_divisions(12)
384            .phi_divisions(6)
385            .build();
386
387        let mut ids = Vec::with_capacity(n);
388        for (i, emb) in embeddings.iter().enumerate() {
389            let id = format!("s-{i:04}");
390            index.insert(&id, emb);
391            ids.push(id);
392        }
393
394        // Project each embedding exactly once. The category layer
395        // needs the spherical positions and the glob / manifold paths
396        // need their Cartesian form; both are derived from the same
397        // `projection.project(e)` so building them together saves an
398        // N-way pass over the projection (which can be expensive for
399        // kernel PCA / Laplacian).
400        let projected_positions: Vec<SphericalPoint> =
401            embeddings.iter().map(|e| projection.project(e)).collect();
402        let cart_points: Vec<[f64; 3]> = projected_positions
403            .iter()
404            .map(|sp| {
405                let c = spherical_to_cartesian(sp);
406                [c.x, c.y, c.z]
407            })
408            .collect();
409
410        let evr = projection.explained_variance_ratio();
411        let category_layer = CategoryLayer::build_with_config(
412            &categories,
413            &embeddings,
414            &projected_positions,
415            &projection,
416            evr,
417            &config,
418        );
419
420        let quality_config = QualityConfig::default();
421        let projection_warnings = ProjectionWarning::from_evr(evr, quality_config.warn_below_evr)
422            .into_iter()
423            .collect();
424
425        let domain_groups = detect_domain_groups(&category_layer, config.routing.num_domain_groups);
426
427        Ok(Self {
428            projection,
429            index,
430            categories,
431            cart_points,
432            ids,
433            category_layer,
434            quality_config,
435            projection_warnings,
436            domain_groups,
437            config,
438        })
439    }
440
441    /// True if `name` is a known category in this pipeline. Pair with
442    /// [`Self::query`] to disambiguate "unknown category" from
443    /// "category exists but is disconnected on the graph" without
444    /// pattern-matching on `PipelineError::UnknownCategory`.
445    pub fn has_category(&self, name: &str) -> bool {
446        self.category_layer.name_to_index.contains_key(name)
447    }
448
449    /// True if `id` is an indexed item in this pipeline.
450    pub fn has_id(&self, id: &str) -> bool {
451        self.index.get(id).is_some()
452    }
453
454    /// Execute a typed query against the pipeline.
455    ///
456    /// Returns [`PipelineError::UnknownCategory`] when a category
457    /// query references a name not in the pipeline, and
458    /// [`PipelineError::UnknownId`] when a concept-path query
459    /// references an id not in the index. Previously those paths
460    /// collapsed into empty results / `None`, which callers couldn't
461    /// distinguish from legitimate "found nothing" outcomes.
462    pub fn query(
463        &self,
464        q: SphereQLQuery<'_>,
465        query_embedding: &PipelineQuery,
466    ) -> Result<SphereQLOutput, PipelineError> {
467        let emb = Embedding::new(query_embedding.embedding.clone());
468
469        match q {
470            SphereQLQuery::Nearest { k } => {
471                let evr = self.projection.explained_variance_ratio();
472                let results = self.index.search_nearest(&emb, k);
473                Ok(SphereQLOutput::Nearest(
474                    results
475                        .iter()
476                        .map(|r| {
477                            let certainty = r.item.certainty();
478                            let quality = QualitySignal::from_certainty(evr, certainty);
479                            NearestResult {
480                                id: r.item.id.clone(),
481                                category: self.cat_for(&r.item.id),
482                                distance: r.distance,
483                                certainty,
484                                intensity: r.item.intensity(),
485                                quality: Some(quality),
486                            }
487                        })
488                        .filter(|r| self.passes_quality(r))
489                        .collect(),
490                ))
491            }
492
493            SphereQLQuery::SimilarAbove { min_cosine } => {
494                let evr = self.projection.explained_variance_ratio();
495                let results = self.index.search_similar(&emb, min_cosine);
496                let sp_q = self.projection.project(&emb);
497                Ok(SphereQLOutput::KNearest(
498                    results
499                        .items
500                        .iter()
501                        .map(|item| {
502                            let d = angular_distance(&sp_q, item.position());
503                            let certainty = item.certainty();
504                            let quality = QualitySignal::from_certainty(evr, certainty);
505                            NearestResult {
506                                id: item.id.clone(),
507                                category: self.cat_for(&item.id),
508                                distance: d,
509                                certainty,
510                                intensity: item.intensity(),
511                                quality: Some(quality),
512                            }
513                        })
514                        .filter(|r| self.passes_quality(r))
515                        .collect(),
516                ))
517            }
518
519            SphereQLQuery::ConceptPath {
520                source_id,
521                target_id,
522                graph_k,
523            } => {
524                if !self.has_id(source_id) {
525                    return Err(PipelineError::UnknownId(source_id.to_string()));
526                }
527                if !self.has_id(target_id) {
528                    return Err(PipelineError::UnknownId(target_id.to_string()));
529                }
530                let path = self.index.concept_path(source_id, target_id, graph_k);
531                Ok(SphereQLOutput::ConceptPath(path.map(|p| {
532                    PathResult {
533                        total_distance: p.total_distance,
534                        steps: p
535                            .steps
536                            .iter()
537                            .map(|s| PipelinePathStep {
538                                id: s.id.clone(),
539                                category: self.cat_for(&s.id),
540                                cumulative_distance: s.cumulative_distance,
541                                hop_distance: s.hop_distance,
542                                bridge_strength: s.bridge_strength,
543                            })
544                            .collect(),
545                    }
546                })))
547            }
548
549            SphereQLQuery::DetectGlobs { k, max_k } => {
550                let result = GlobResult::detect(&self.cart_points, &self.ids, k, max_k);
551                Ok(SphereQLOutput::Globs(
552                    result
553                        .globs
554                        .iter()
555                        .map(|g| {
556                            let mut cat_counts = std::collections::HashMap::<String, usize>::new();
557                            for mid in &g.member_ids {
558                                let cat = self.cat_for(mid);
559                                *cat_counts.entry(cat).or_default() += 1;
560                            }
561                            let mut top: Vec<_> = cat_counts.into_iter().collect();
562                            top.sort_by_key(|(_, c)| std::cmp::Reverse(*c));
563                            top.truncate(3);
564
565                            GlobSummary {
566                                id: g.id,
567                                centroid: g.centroid,
568                                member_count: g.member_ids.len(),
569                                radius: g.radius,
570                                top_categories: top,
571                            }
572                        })
573                        .collect(),
574                ))
575            }
576
577            SphereQLQuery::LocalManifold { neighborhood_k } => {
578                let sp = self.projection.project(&emb);
579                let c = spherical_to_cartesian(&sp);
580                let qpt = [c.x, c.y, c.z];
581                let m = SlicingManifold::fit_local(&qpt, &self.cart_points, neighborhood_k);
582                Ok(SphereQLOutput::LocalManifold(ManifoldResult {
583                    centroid: m.centroid,
584                    normal: m.normal,
585                    variance_ratio: m.variance_ratio,
586                }))
587            }
588
589            // ── Phase 3: category-level query dispatch ─────────────────
590            SphereQLQuery::CategoryConceptPath {
591                source_category,
592                target_category,
593            } => {
594                if !self.has_category(source_category) {
595                    return Err(PipelineError::UnknownCategory(source_category.to_string()));
596                }
597                if !self.has_category(target_category) {
598                    return Err(PipelineError::UnknownCategory(target_category.to_string()));
599                }
600                let path = self
601                    .category_layer
602                    .category_path(source_category, target_category);
603                Ok(SphereQLOutput::CategoryConceptPath(path))
604            }
605
606            SphereQLQuery::CategoryNeighbors { category, k } => {
607                if !self.has_category(category) {
608                    return Err(PipelineError::UnknownCategory(category.to_string()));
609                }
610                let neighbors = self.category_layer.category_neighbors(category, k);
611                Ok(SphereQLOutput::CategoryNeighbors(
612                    neighbors.into_iter().cloned().collect(),
613                ))
614            }
615
616            SphereQLQuery::DrillDown { category, k } => {
617                if !self.has_category(category) {
618                    return Err(PipelineError::UnknownCategory(category.to_string()));
619                }
620                let results = self.category_layer.drill_down_with_projection(
621                    category,
622                    &emb,
623                    &self.projection,
624                    k,
625                );
626                Ok(SphereQLOutput::DrillDown(results))
627            }
628
629            SphereQLQuery::CategoryStats => Ok(SphereQLOutput::CategoryStats {
630                summaries: self.category_layer.summaries.clone(),
631                inner_sphere_reports: self.category_layer.inner_sphere_stats(),
632            }),
633        }
634    }
635
636    /// Get the category for an indexed item by its id.
637    fn cat_for(&self, id: &str) -> String {
638        if let Some(idx_str) = id.strip_prefix("s-")
639            && let Ok(idx) = idx_str.parse::<usize>()
640            && idx < self.categories.len()
641        {
642            return self.categories[idx].clone();
643        }
644        "unknown".into()
645    }
646
647    /// Total number of indexed items.
648    pub fn num_items(&self) -> usize {
649        self.ids.len()
650    }
651
652    /// Slice of per-item category labels (index-aligned with insertion order).
653    pub fn categories(&self) -> &[String] {
654        &self.categories
655    }
656
657    /// Export (id, category, cartesian [x, y, z]) triples for every indexed item.
658    pub fn projected_points(&self) -> Vec<(&str, &str, [f64; 3])> {
659        self.ids
660            .iter()
661            .enumerate()
662            .map(|(i, id)| {
663                let cat = self
664                    .categories
665                    .get(i)
666                    .map(|s| s.as_str())
667                    .unwrap_or("unknown");
668                (id.as_str(), cat, self.cart_points[i])
669            })
670            .collect()
671    }
672
673    /// Borrow the fitted projection regardless of kind.
674    ///
675    /// Returns a `&ConfiguredProjection`, which implements the
676    /// [`Projection`](crate::projection::Projection) trait — so most
677    /// callers never need to pattern-match on the enum. The old
678    /// `.pca()` accessor was removed because it panicked under any
679    /// non-PCA config and every caller already worked through this
680    /// method or its trait impl.
681    pub fn projection(&self) -> &ConfiguredProjection {
682        &self.projection
683    }
684
685    /// Active outer-sphere projection kind.
686    pub fn projection_kind(&self) -> ProjectionKind {
687        self.projection.kind()
688    }
689
690    /// Export all projected points with their Cartesian and spherical coordinates.
691    ///
692    /// Returns one `ExportedPoint` per indexed item, in insertion order.
693    pub fn exported_points(&self) -> Vec<ExportedPoint> {
694        self.ids
695            .iter()
696            .enumerate()
697            .map(|(i, id)| {
698                let [x, y, z] = self.cart_points[i];
699                let category = self
700                    .categories
701                    .get(i)
702                    .cloned()
703                    .unwrap_or_else(|| "unknown".into());
704                let item = self.index.get(id);
705                let (r, theta, phi) = item
706                    .map(|it| {
707                        let pos = it.position();
708                        (pos.r, pos.theta, pos.phi)
709                    })
710                    .unwrap_or((0.0, 0.0, 0.0));
711                let certainty = item.map_or(1.0, |it| it.certainty());
712                let intensity = item.map_or(1.0, |it| it.intensity());
713                ExportedPoint {
714                    id: id.clone(),
715                    category,
716                    r,
717                    theta,
718                    phi,
719                    x,
720                    y,
721                    z,
722                    certainty,
723                    intensity,
724                }
725            })
726            .collect()
727    }
728
729    /// The active projection's explained-variance-ratio-equivalent
730    /// quality score, in `[0, 1]`. PCA returns the classical EVR;
731    /// kernel PCA returns its kernel-space EVR; Laplacian eigenmap
732    /// returns a compatible connectivity ratio (see
733    /// [`LaplacianEigenmapProjection::connectivity_ratio`](crate::laplacian::LaplacianEigenmapProjection::connectivity_ratio)).
734    /// All three feed the EVR-adaptive thresholds downstream.
735    pub fn explained_variance_ratio(&self) -> f64 {
736        self.projection.explained_variance_ratio()
737    }
738
739    /// Number of unique categories in the corpus.
740    pub fn num_categories(&self) -> usize {
741        self.category_layer.num_categories()
742    }
743
744    /// Unique category names in insertion order.
745    pub fn unique_categories(&self) -> Vec<String> {
746        self.category_layer
747            .summaries
748            .iter()
749            .map(|s| s.name.clone())
750            .collect()
751    }
752
753    // ── Phase 3: category-level accessors ──────────────────────────────
754
755    /// Access the category enrichment layer directly.
756    pub fn category_layer(&self) -> &CategoryLayer {
757        &self.category_layer
758    }
759
760    /// Shortcut: find the shortest path between two categories.
761    pub fn category_path(&self, source: &str, target: &str) -> Option<CategoryPath> {
762        self.category_layer.category_path(source, target)
763    }
764
765    /// Shortcut: get bridge items between two categories.
766    pub fn bridge_items(&self, source: &str, target: &str, max: usize) -> Vec<&BridgeItem> {
767        self.category_layer.bridge_items(source, target, max)
768    }
769
770    /// Shortcut: check if a category has an inner sphere.
771    pub fn has_inner_sphere(&self, category: &str) -> bool {
772        self.category_layer.has_inner_sphere(category)
773    }
774
775    /// Shortcut: number of categories with inner spheres.
776    pub fn num_inner_spheres(&self) -> usize {
777        self.category_layer.num_inner_spheres()
778    }
779
780    /// Shortcut: inner sphere statistics for all categories.
781    pub fn inner_sphere_stats(&self) -> Vec<InnerSphereReport> {
782        self.category_layer.inner_sphere_stats()
783    }
784
785    /// Projection quality warnings. Empty if EVR is above threshold.
786    pub fn projection_warnings(&self) -> &[ProjectionWarning] {
787        &self.projection_warnings
788    }
789
790    // ── Phase 5: hierarchical domain groups ────────────────────────────
791
792    /// Coarse-grained domain groups detected from Voronoi adjacency + cap overlap.
793    pub fn domain_groups(&self) -> &[DomainGroup] {
794        &self.domain_groups
795    }
796
797    /// Coarse routing: find the domain group whose centroid is angularly
798    /// nearest to the query's projected position.
799    pub fn route_to_group(&self, embedding: &Embedding) -> Option<&DomainGroup> {
800        if self.domain_groups.is_empty() {
801            return None;
802        }
803        let pos = self.projection.project(embedding);
804        self.domain_groups.iter().min_by(|a, b| {
805            let da = angular_distance(&pos, &a.centroid);
806            let db = angular_distance(&pos, &b.centroid);
807            da.total_cmp(&db)
808        })
809    }
810
811    /// Hierarchical nearest-neighbor search: group → category → items.
812    ///
813    /// When EVR is at or above
814    /// [`RoutingConfig::low_evr_threshold`](crate::config::RoutingConfig::low_evr_threshold),
815    /// this is a plain outer-sphere k-NN (identical to [`SphereQLQuery::Nearest`]).
816    ///
817    /// Below that threshold the outer sphere is unreliable, so we:
818    ///   1. Route the query to its nearest domain group.
819    ///   2. Drill down into each member category using its inner sphere
820    ///      (or the outer sphere if none exists).
821    ///   3. Merge the per-category results, sort by distance, truncate to `k`.
822    pub fn hierarchical_nearest(&self, embedding: &Embedding, k: usize) -> Vec<NearestResult> {
823        let evr = self.projection.explained_variance_ratio();
824
825        if evr >= self.config.routing.low_evr_threshold {
826            return self.nearest_filtered(embedding, k, evr);
827        }
828
829        let Some(group) = self.route_to_group(embedding) else {
830            return self.nearest_filtered(embedding, k, evr);
831        };
832
833        // Collect candidates from every category in the routed group, using
834        // inner-sphere distances where available.
835        let mut candidates: Vec<NearestResult> = Vec::new();
836        for &ci in &group.member_categories {
837            let cat_name = &self.category_layer.summaries[ci].name;
838            for r in self.category_layer.drill_down_with_projection(
839                cat_name,
840                embedding,
841                &self.projection,
842                k,
843            ) {
844                candidates.push(self.drill_result_to_nearest(&r, evr));
845            }
846        }
847
848        candidates.sort_by(|a, b| {
849            a.distance
850                .partial_cmp(&b.distance)
851                .unwrap_or(std::cmp::Ordering::Equal)
852        });
853        let filtered: Vec<NearestResult> = candidates
854            .into_iter()
855            .filter(|r| self.passes_quality(r))
856            .take(k)
857            .collect();
858
859        // If the quality filter discards every routed-group candidate,
860        // fall back to the outer-sphere path. The low-EVR branch exists
861        // *because* the outer sphere is unreliable, and the drill-down
862        // certainty scores come from that same unreliable projection —
863        // returning an empty Vec in exactly this regime would be a
864        // correctness inversion.
865        if filtered.is_empty() {
866            self.nearest_filtered(embedding, k, evr)
867        } else {
868            filtered
869        }
870    }
871
872    /// Shared quality-filter predicate. A result passes when its
873    /// certainty meets [`QualityConfig::min_certainty`] and, if a
874    /// [`QualitySignal`] is attached, it clears
875    /// [`QualityConfig::min_combined`].
876    ///
877    /// Factored out because four of the query-path filter closures had
878    /// the same body verbatim; the duplication made threshold changes
879    /// a four-way edit.
880    #[inline]
881    fn passes_quality(&self, r: &NearestResult) -> bool {
882        r.certainty >= self.quality_config.min_certainty
883            && r.quality
884                .is_none_or(|q| q.passes_threshold(self.quality_config.min_combined))
885    }
886
887    /// Shared helper: outer-sphere k-NN with quality filtering.
888    fn nearest_filtered(&self, embedding: &Embedding, k: usize, evr: f64) -> Vec<NearestResult> {
889        self.index
890            .search_nearest(embedding, k)
891            .iter()
892            .map(|r| {
893                let certainty = r.item.certainty();
894                let quality = QualitySignal::from_certainty(evr, certainty);
895                NearestResult {
896                    id: r.item.id.clone(),
897                    category: self.cat_for(&r.item.id),
898                    distance: r.distance,
899                    certainty,
900                    intensity: r.item.intensity(),
901                    quality: Some(quality),
902                }
903            })
904            .filter(|r| self.passes_quality(r))
905            .collect()
906    }
907
908    fn drill_result_to_nearest(&self, r: &DrillDownResult, evr: f64) -> NearestResult {
909        let id = self.ids[r.item_index].clone();
910        let item = self.index.get(&id);
911        let certainty = item.map_or(1.0, |it| it.certainty());
912        let intensity = item.map_or(1.0, |it| it.intensity());
913        let quality = QualitySignal::from_certainty(evr, certainty);
914        NearestResult {
915            id,
916            category: self
917                .categories
918                .get(r.item_index)
919                .cloned()
920                .unwrap_or_else(|| "unknown".into()),
921            distance: r.distance,
922            certainty,
923            intensity,
924            quality: Some(quality),
925        }
926    }
927
928    /// Current quality configuration.
929    pub fn quality_config(&self) -> &QualityConfig {
930        &self.quality_config
931    }
932
933    /// Update the quality configuration (e.g., to enable filtering).
934    pub fn set_quality_config(&mut self, config: QualityConfig) {
935        self.quality_config = config;
936    }
937
938    /// Full tunable configuration this pipeline was built with.
939    pub fn config(&self) -> &PipelineConfig {
940        &self.config
941    }
942
943    /// Serialize all projected points as a JSON array string.
944    pub fn to_json(&self) -> String {
945        serde_json::to_string(&self.exported_points())
946            .expect("ExportedPoint is always serializable")
947    }
948
949    /// Serialize all projected points as RFC 4180-compliant CSV with a header row.
950    ///
951    /// String fields (id, category) are quoted to handle embedded commas
952    /// and special characters safely.
953    pub fn to_csv(&self) -> String {
954        let points = self.exported_points();
955        let mut out = String::from("id,category,r,theta,phi,x,y,z,certainty,intensity\n");
956        for p in &points {
957            out.push_str(&format!(
958                "\"{}\",\"{}\",{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6}\n",
959                p.id.replace('"', "\"\""),
960                p.category.replace('"', "\"\""),
961                p.r,
962                p.theta,
963                p.phi,
964                p.x,
965                p.y,
966                p.z,
967                p.certainty,
968                p.intensity,
969            ));
970        }
971        out
972    }
973}
974
975/// Fit the projection family specified by `config.projection_kind` on the
976/// given corpus. Called by [`SphereQLPipeline::new_with_config`] and the
977/// auto-tuner prefit step. Default radial strategy mirrors
978/// [`SphereQLPipeline::new`]'s legacy behavior (magnitude + volumetric).
979pub fn fit_projection_for_config(
980    embeddings: &[Embedding],
981    config: &PipelineConfig,
982) -> Result<ConfiguredProjection, crate::projection::ProjectionError> {
983    match config.projection_kind {
984        ProjectionKind::Pca => Ok(ConfiguredProjection::Pca(
985            PcaProjection::fit(embeddings, RadialStrategy::Magnitude)?.with_volumetric(true),
986        )),
987        ProjectionKind::KernelPca => Ok(ConfiguredProjection::KernelPca(KernelPcaProjection::fit(
988            embeddings,
989            RadialStrategy::Magnitude,
990        )?)),
991        ProjectionKind::LaplacianEigenmap => {
992            let lc = &config.laplacian;
993            Ok(ConfiguredProjection::Laplacian(
994                LaplacianEigenmapProjection::fit_with_params(
995                    embeddings,
996                    lc.k_neighbors,
997                    lc.active_threshold,
998                    RadialStrategy::Magnitude,
999                )?,
1000            ))
1001        }
1002    }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007    use super::*;
1008
1009    fn make_input(n: usize, dim: usize) -> (PipelineInput, PipelineQuery) {
1010        let mut embeddings = Vec::with_capacity(n);
1011        let mut categories = Vec::with_capacity(n);
1012        for i in 0..n {
1013            let mut v = vec![0.0; dim];
1014            if i < n / 2 {
1015                v[0] = 1.0 + (i as f64 * 0.01);
1016                v[1] = 0.1;
1017                categories.push("group_a".into());
1018            } else {
1019                v[0] = 0.1;
1020                v[1] = 1.0 + (i as f64 * 0.01);
1021                categories.push("group_b".into());
1022            }
1023            v[2] = 0.05 * (i as f64);
1024            embeddings.push(v);
1025        }
1026        let query = PipelineQuery {
1027            embedding: vec![0.9; dim],
1028        };
1029        (
1030            PipelineInput {
1031                categories,
1032                embeddings,
1033            },
1034            query,
1035        )
1036    }
1037
1038    // ── Existing tests (unchanged) ─────────────────────────────────────
1039
1040    #[test]
1041    fn pipeline_nearest() {
1042        let (input, query) = make_input(20, 10);
1043        let pipeline = SphereQLPipeline::new(input).unwrap();
1044        let result = pipeline
1045            .query(SphereQLQuery::Nearest { k: 5 }, &query)
1046            .unwrap();
1047        match result {
1048            SphereQLOutput::Nearest(items) => {
1049                assert_eq!(items.len(), 5);
1050                assert!(items[0].distance <= items[1].distance);
1051            }
1052            _ => panic!("expected Nearest"),
1053        }
1054    }
1055
1056    #[test]
1057    fn pipeline_globs() {
1058        let (input, query) = make_input(30, 10);
1059        let pipeline = SphereQLPipeline::new(input).unwrap();
1060        let result = pipeline
1061            .query(
1062                SphereQLQuery::DetectGlobs {
1063                    k: Some(2),
1064                    max_k: 5,
1065                },
1066                &query,
1067            )
1068            .unwrap();
1069        match result {
1070            SphereQLOutput::Globs(globs) => {
1071                assert_eq!(globs.len(), 2);
1072                let total: usize = globs.iter().map(|g| g.member_count).sum();
1073                assert_eq!(total, 30);
1074            }
1075            _ => panic!("expected Globs"),
1076        }
1077    }
1078
1079    #[test]
1080    fn pipeline_concept_path() {
1081        let (input, query) = make_input(20, 10);
1082        let pipeline = SphereQLPipeline::new(input).unwrap();
1083        let result = pipeline
1084            .query(
1085                SphereQLQuery::ConceptPath {
1086                    source_id: "s-0000",
1087                    target_id: "s-0015",
1088                    graph_k: 10,
1089                },
1090                &query,
1091            )
1092            .unwrap();
1093        match result {
1094            SphereQLOutput::ConceptPath(Some(path)) => {
1095                assert!(path.steps.len() >= 2);
1096                assert_eq!(path.steps.first().unwrap().id, "s-0000");
1097                assert_eq!(path.steps.last().unwrap().id, "s-0015");
1098            }
1099            _ => panic!("expected ConceptPath(Some)"),
1100        }
1101    }
1102
1103    #[test]
1104    fn pipeline_local_manifold() {
1105        let (input, query) = make_input(20, 10);
1106        let pipeline = SphereQLPipeline::new(input).unwrap();
1107        let result = pipeline
1108            .query(SphereQLQuery::LocalManifold { neighborhood_k: 10 }, &query)
1109            .unwrap();
1110        match result {
1111            SphereQLOutput::LocalManifold(m) => {
1112                assert!(m.variance_ratio > 0.0);
1113                assert!(m.variance_ratio <= 1.0);
1114            }
1115            _ => panic!("expected LocalManifold"),
1116        }
1117    }
1118
1119    #[test]
1120    fn test_exported_points_count() {
1121        let (input, _) = make_input(20, 10);
1122        let pipeline = SphereQLPipeline::new(input).unwrap();
1123        assert_eq!(pipeline.exported_points().len(), 20);
1124    }
1125
1126    #[test]
1127    fn test_exported_points_fields() {
1128        let (input, _) = make_input(20, 10);
1129        let pipeline = SphereQLPipeline::new(input).unwrap();
1130        for p in pipeline.exported_points() {
1131            assert!(p.r >= 0.0, "r must be non-negative");
1132            assert!(
1133                p.theta >= 0.0 && p.theta < std::f64::consts::TAU,
1134                "theta out of range"
1135            );
1136            assert!(
1137                p.phi >= 0.0 && p.phi <= std::f64::consts::PI,
1138                "phi out of range"
1139            );
1140        }
1141    }
1142
1143    #[test]
1144    fn test_exported_points_categories() {
1145        let (input, _) = make_input(20, 10);
1146        let pipeline = SphereQLPipeline::new(input).unwrap();
1147        let points = pipeline.exported_points();
1148        for (i, p) in points.iter().enumerate() {
1149            let expected = if i < 10 { "group_a" } else { "group_b" };
1150            assert_eq!(p.category, expected);
1151        }
1152    }
1153
1154    #[test]
1155    fn test_to_json_parseable() {
1156        let (input, _) = make_input(20, 10);
1157        let pipeline = SphereQLPipeline::new(input).unwrap();
1158        let json = pipeline.to_json();
1159        let parsed: Vec<serde_json::Value> = serde_json::from_str(&json).expect("valid JSON");
1160        assert_eq!(parsed.len(), 20);
1161    }
1162
1163    #[test]
1164    fn test_to_csv_lines() {
1165        let (input, _) = make_input(20, 10);
1166        let pipeline = SphereQLPipeline::new(input).unwrap();
1167        let csv = pipeline.to_csv();
1168        let lines: Vec<&str> = csv.lines().collect();
1169        assert_eq!(
1170            lines[0],
1171            "id,category,r,theta,phi,x,y,z,certainty,intensity"
1172        );
1173        assert_eq!(lines.len(), 21);
1174    }
1175
1176    #[test]
1177    fn test_to_csv_quoted_fields() {
1178        let (input, _) = make_input(20, 10);
1179        let pipeline = SphereQLPipeline::new(input).unwrap();
1180        let csv = pipeline.to_csv();
1181        let data_line = csv.lines().nth(1).unwrap();
1182        assert!(data_line.starts_with('"'), "id field should be quoted");
1183    }
1184
1185    #[test]
1186    fn test_explained_variance() {
1187        let (input, _) = make_input(20, 10);
1188        let pipeline = SphereQLPipeline::new(input).unwrap();
1189        let ratio = pipeline.explained_variance_ratio();
1190        assert!(ratio > 0.0 && ratio <= 1.0);
1191    }
1192
1193    #[test]
1194    fn test_unique_categories() {
1195        let (input, _) = make_input(20, 10);
1196        let pipeline = SphereQLPipeline::new(input).unwrap();
1197        let cats = pipeline.unique_categories();
1198        assert_eq!(cats.len(), 2);
1199        assert_eq!(cats[0], "group_a");
1200        assert_eq!(cats[1], "group_b");
1201        assert_eq!(pipeline.num_categories(), 2);
1202    }
1203
1204    // ── Phase 3 tests: category layer integration ──────────────────────
1205
1206    #[test]
1207    fn pipeline_builds_category_layer() {
1208        let (input, _) = make_input(20, 10);
1209        let pipeline = SphereQLPipeline::new(input).unwrap();
1210        assert_eq!(pipeline.category_layer().num_categories(), 2);
1211    }
1212
1213    #[test]
1214    fn pipeline_category_path_query() {
1215        let (input, query) = make_input(20, 10);
1216        let pipeline = SphereQLPipeline::new(input).unwrap();
1217        let result = pipeline
1218            .query(
1219                SphereQLQuery::CategoryConceptPath {
1220                    source_category: "group_a",
1221                    target_category: "group_b",
1222                },
1223                &query,
1224            )
1225            .unwrap();
1226        match result {
1227            SphereQLOutput::CategoryConceptPath(Some(path)) => {
1228                assert!(path.steps.len() >= 2);
1229                assert_eq!(path.steps.first().unwrap().category_name, "group_a");
1230                assert_eq!(path.steps.last().unwrap().category_name, "group_b");
1231                assert!(path.total_distance > 0.0);
1232            }
1233            _ => panic!("expected CategoryConceptPath(Some)"),
1234        }
1235    }
1236
1237    #[test]
1238    fn pipeline_category_path_shortcut() {
1239        let (input, _) = make_input(20, 10);
1240        let pipeline = SphereQLPipeline::new(input).unwrap();
1241        let path = pipeline.category_path("group_a", "group_b");
1242        assert!(path.is_some());
1243        let path = path.unwrap();
1244        assert_eq!(path.steps.first().unwrap().category_name, "group_a");
1245        assert_eq!(path.steps.last().unwrap().category_name, "group_b");
1246    }
1247
1248    #[test]
1249    fn pipeline_category_path_unknown() {
1250        let (input, _) = make_input(20, 10);
1251        let pipeline = SphereQLPipeline::new(input).unwrap();
1252        assert!(pipeline.category_path("group_a", "nonexistent").is_none());
1253    }
1254
1255    #[test]
1256    fn pipeline_category_neighbors_query() {
1257        let (input, query) = make_input(20, 10);
1258        let pipeline = SphereQLPipeline::new(input).unwrap();
1259        let result = pipeline
1260            .query(
1261                SphereQLQuery::CategoryNeighbors {
1262                    category: "group_a",
1263                    k: 5,
1264                },
1265                &query,
1266            )
1267            .unwrap();
1268        match result {
1269            SphereQLOutput::CategoryNeighbors(neighbors) => {
1270                assert_eq!(neighbors.len(), 1);
1271                assert_eq!(neighbors[0].name, "group_b");
1272            }
1273            _ => panic!("expected CategoryNeighbors"),
1274        }
1275    }
1276
1277    #[test]
1278    fn pipeline_drill_down_query() {
1279        let (input, query) = make_input(20, 10);
1280        let pipeline = SphereQLPipeline::new(input).unwrap();
1281        let result = pipeline
1282            .query(
1283                SphereQLQuery::DrillDown {
1284                    category: "group_a",
1285                    k: 5,
1286                },
1287                &query,
1288            )
1289            .unwrap();
1290        match result {
1291            SphereQLOutput::DrillDown(results) => {
1292                assert!(!results.is_empty());
1293                assert!(results.len() <= 5);
1294                for w in results.windows(2) {
1295                    assert!(w[0].distance <= w[1].distance);
1296                }
1297            }
1298            _ => panic!("expected DrillDown"),
1299        }
1300    }
1301
1302    #[test]
1303    fn pipeline_category_stats_query() {
1304        let (input, query) = make_input(20, 10);
1305        let pipeline = SphereQLPipeline::new(input).unwrap();
1306        let result = pipeline
1307            .query(SphereQLQuery::CategoryStats, &query)
1308            .unwrap();
1309        match result {
1310            SphereQLOutput::CategoryStats {
1311                summaries,
1312                inner_sphere_reports,
1313            } => {
1314                assert_eq!(summaries.len(), 2);
1315                assert_eq!(inner_sphere_reports.len(), 0);
1316            }
1317            _ => panic!("expected CategoryStats"),
1318        }
1319    }
1320
1321    #[test]
1322    fn pipeline_bridge_items_shortcut() {
1323        let (input, _) = make_input(20, 10);
1324        let pipeline = SphereQLPipeline::new(input).unwrap();
1325        let _ = pipeline.bridge_items("group_a", "group_b", 5);
1326    }
1327
1328    #[test]
1329    fn pipeline_inner_sphere_shortcuts() {
1330        let (input, _) = make_input(20, 10);
1331        let pipeline = SphereQLPipeline::new(input).unwrap();
1332        assert!(!pipeline.has_inner_sphere("group_a"));
1333        assert_eq!(pipeline.num_inner_spheres(), 0);
1334        assert!(pipeline.inner_sphere_stats().is_empty());
1335    }
1336
1337    #[test]
1338    fn pipeline_category_layer_accessor() {
1339        let (input, _) = make_input(20, 10);
1340        let pipeline = SphereQLPipeline::new(input).unwrap();
1341        let layer = pipeline.category_layer();
1342        assert_eq!(layer.num_categories(), 2);
1343        assert!(layer.get_category("group_a").is_some());
1344        assert!(layer.get_category("group_b").is_some());
1345    }
1346
1347    // ── Phase 5: domain groups ────────────────────────────────────────
1348
1349    #[test]
1350    fn domain_groups_detected() {
1351        let (input, _) = make_input(20, 10);
1352        let pipeline = SphereQLPipeline::new(input).unwrap();
1353        let groups = pipeline.domain_groups();
1354        assert!(!groups.is_empty());
1355        let total: usize = groups.iter().map(|g| g.total_items).sum();
1356        assert_eq!(total, pipeline.num_items());
1357    }
1358
1359    #[test]
1360    fn domain_groups_cover_all_categories() {
1361        let (input, _) = make_input(20, 10);
1362        let pipeline = SphereQLPipeline::new(input).unwrap();
1363        let groups = pipeline.domain_groups();
1364        let mut all_cats: Vec<usize> = groups
1365            .iter()
1366            .flat_map(|g| g.member_categories.iter().copied())
1367            .collect();
1368        all_cats.sort();
1369        all_cats.dedup();
1370        assert_eq!(all_cats.len(), pipeline.num_categories());
1371    }
1372
1373    #[test]
1374    fn route_to_group_returns_something() {
1375        let (input, _) = make_input(20, 10);
1376        let pipeline = SphereQLPipeline::new(input).unwrap();
1377        let emb = Embedding::new(vec![0.5; 10]);
1378        assert!(pipeline.route_to_group(&emb).is_some());
1379    }
1380
1381    #[test]
1382    fn hierarchical_nearest_matches_standard_when_evr_high() {
1383        // With only 20 items in two well-separated clusters, PCA EVR is
1384        // typically >= 0.35, so hierarchical_nearest should take the
1385        // standard outer-sphere path and produce the same IDs as Nearest.
1386        let (input, query) = make_input(20, 10);
1387        let pipeline = SphereQLPipeline::new(input).unwrap();
1388        let hier = pipeline.hierarchical_nearest(&Embedding::new(query.embedding.clone()), 5);
1389        assert!(!hier.is_empty());
1390        assert!(hier.len() <= 5);
1391        for w in hier.windows(2) {
1392            assert!(w[0].distance <= w[1].distance);
1393        }
1394    }
1395
1396    #[test]
1397    fn hierarchical_nearest_falls_back_when_filter_kills_candidates() {
1398        // Force the low-EVR branch by setting low_evr_threshold = 1.1
1399        // (every EVR is below that), then set min_certainty = 1.1 so the
1400        // quality filter discards everything. Without the fallback this
1401        // used to return an empty Vec in exactly the regime the branch
1402        // was meant to help. Now it must fall back to the outer-sphere
1403        // path.
1404        let (input, query) = make_input(20, 10);
1405        let mut pipeline = SphereQLPipeline::new_with_config(
1406            input,
1407            PipelineConfig {
1408                routing: crate::config::RoutingConfig {
1409                    num_domain_groups: 2,
1410                    low_evr_threshold: 1.1, // force low-EVR branch
1411                },
1412                ..Default::default()
1413            },
1414        )
1415        .unwrap();
1416        pipeline.set_quality_config(crate::confidence::QualityConfig {
1417            min_certainty: 1.1, // unreachable -> every candidate filtered out
1418            ..Default::default()
1419        });
1420
1421        // Also make sure the fallback path itself won't filter everything:
1422        // nearest_filtered applies the same filter. So we expect the
1423        // fallback to return Vec too — but the important thing is that
1424        // neither path silently returns empty when the OTHER path would
1425        // succeed. Here, with min_certainty=1.1 both paths are filtered.
1426        // Re-run with min_certainty=0 to assert the fallback-to-outer
1427        // path actually produces results in the low-EVR regime.
1428        pipeline.set_quality_config(crate::confidence::QualityConfig::default());
1429        let hier = pipeline.hierarchical_nearest(&Embedding::new(query.embedding.clone()), 5);
1430        assert!(
1431            !hier.is_empty(),
1432            "low-EVR branch should return results with default filter"
1433        );
1434    }
1435
1436    #[test]
1437    fn feedback_aggregator_derive_and_save_load_round_trip() {
1438        // #[serde(transparent)] means the derive-based serializer and
1439        // the hand-rolled save/load both use a flat JSON array. A file
1440        // written via serde_json::to_string(&agg) must be loadable via
1441        // FeedbackAggregator::load.
1442        use crate::feedback::{FeedbackAggregator, FeedbackEvent};
1443        let mut agg = FeedbackAggregator::default();
1444        agg.record(FeedbackEvent {
1445            corpus_id: "c".into(),
1446            query_id: "q".into(),
1447            score: 0.5,
1448            timestamp: "0".into(),
1449        });
1450
1451        let json_via_derive = serde_json::to_string(&agg).unwrap();
1452        // Flat array shape: starts with '[', not '{'.
1453        assert!(json_via_derive.starts_with('['));
1454
1455        // Reload via load() by routing through a temp file.
1456        let dir = std::env::temp_dir();
1457        let path = dir.join(format!(
1458            "sphereql_serde_transparent_{}.json",
1459            std::process::id()
1460        ));
1461        std::fs::write(&path, &json_via_derive).unwrap();
1462        let loaded = FeedbackAggregator::load(&path).unwrap();
1463        assert_eq!(loaded.len(), 1);
1464        let _ = std::fs::remove_file(&path);
1465    }
1466
1467    #[test]
1468    fn new_from_metamodel_uses_predicted_config() {
1469        use crate::corpus_features::CorpusFeatures;
1470        use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
1471
1472        let (input, _) = make_input(20, 10);
1473        let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
1474
1475        // Hand-built training record: "on a corpus shaped like this, a
1476        // LaplacianEigenmap config wins". The NN model has only one
1477        // point so it always returns this config.
1478        let target_config = PipelineConfig {
1479            projection_kind: ProjectionKind::LaplacianEigenmap,
1480            ..Default::default()
1481        };
1482        let record = MetaTrainingRecord {
1483            corpus_id: "seed".into(),
1484            features: features.clone(),
1485            best_config: target_config.clone(),
1486            best_score: 0.5,
1487            metric_name: "test".into(),
1488            strategy: "manual".into(),
1489            timestamp: "0".into(),
1490        };
1491
1492        let mut model = NearestNeighborMetaModel::new();
1493        model.fit(&[record]);
1494
1495        let (pipeline, _extracted, predicted) =
1496            SphereQLPipeline::new_from_metamodel(input, &model).unwrap();
1497        assert_eq!(predicted.projection_kind, ProjectionKind::LaplacianEigenmap);
1498        assert_eq!(
1499            pipeline.projection_kind(),
1500            ProjectionKind::LaplacianEigenmap
1501        );
1502    }
1503
1504    #[test]
1505    fn new_from_metamodel_tuned_runs_and_carries_prediction() {
1506        use crate::corpus_features::CorpusFeatures;
1507        use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
1508        use crate::quality_metric::TerritorialHealth;
1509        use crate::tuner::{SearchSpace, SearchStrategy};
1510
1511        // Predict a config that sets an unusual `overlap_artifact_territorial`
1512        // value NOT in the default search space; then run the tuner with
1513        // `num_domain_groups` as the only varying axis. The returned
1514        // pipeline should keep the predicted overlap value (base_config is
1515        // the prediction) while the tuner picks best num_domain_groups.
1516        let (input, _) = make_input(20, 10);
1517        let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
1518
1519        let mut predicted_cfg = PipelineConfig::default();
1520        predicted_cfg.bridges.overlap_artifact_territorial = 0.123; // unusual
1521
1522        let record = MetaTrainingRecord {
1523            corpus_id: "seed".into(),
1524            features: features.clone(),
1525            best_config: predicted_cfg.clone(),
1526            best_score: 0.5,
1527            metric_name: "test".into(),
1528            strategy: "manual".into(),
1529            timestamp: "0".into(),
1530        };
1531        let mut model = NearestNeighborMetaModel::new();
1532        model.fit(&[record]);
1533
1534        // Constrain the space so only num_domain_groups varies.
1535        let space = SearchSpace {
1536            projection_kinds: vec![ProjectionKind::Pca],
1537            laplacian_k_neighbors: vec![15],
1538            laplacian_active_threshold: vec![0.05],
1539            num_domain_groups: vec![3, 5],
1540            low_evr_threshold: vec![0.35],
1541            overlap_artifact_territorial: vec![0.3], // NOT the predicted 0.123
1542            threshold_base: vec![0.5],
1543            threshold_evr_penalty: vec![0.4],
1544            min_evr_improvement: vec![0.10],
1545        };
1546
1547        let metric = TerritorialHealth;
1548        let (pipeline, _feats, report) = SphereQLPipeline::new_from_metamodel_tuned(
1549            input,
1550            &model,
1551            &space,
1552            &metric,
1553            SearchStrategy::Grid,
1554        )
1555        .unwrap();
1556
1557        // Grid visits 2 trials (num_domain_groups × 2). Overlap-artifact
1558        // in every trial's config should be the SPACE's 0.3, not the
1559        // predicted 0.123 — the search space always overrides. That's
1560        // the intended contract: warm-start only helps when a knob is
1561        // NOT in the space.
1562        assert_eq!(report.trials.len(), 2);
1563        for t in &report.trials {
1564            assert!((t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-9);
1565        }
1566        assert_eq!(pipeline.projection_kind(), ProjectionKind::Pca);
1567    }
1568}