1use std::cmp::Reverse;
2use std::collections::BinaryHeap;
3use std::sync::OnceLock;
4
5use sphereql_core::*;
6use sphereql_index::SpatialItem;
7
8use crate::category::{
9 BridgeItem, CategoryLayer, CategoryPath, CategorySummary, DrillDownResult, InnerSphereReport,
10};
11use crate::confidence::{ProjectionWarning, QualityConfig, QualitySignal};
12use crate::config::{PipelineConfig, ProjectionKind};
13use crate::configured_projection::ConfiguredProjection;
14use crate::corpus_features::CorpusFeatures;
15use crate::domain_groups::DomainGroup;
16use crate::kernel_pca::KernelPcaProjection;
17use crate::laplacian::LaplacianEigenmapProjection;
18use crate::meta_model::MetaModel;
19use crate::projection::{PcaProjection, Projection};
20use crate::quality_metric::QualityMetric;
21use crate::query::{EmbeddingIndex, GlobResult, SlicingManifold};
22use crate::tuner::{SearchSpace, SearchStrategy, TuneReport, auto_tune};
23use crate::types::{Embedding, RadialStrategy};
24
25#[derive(Debug, Clone, thiserror::Error)]
29pub enum PipelineError {
30 #[error("categories length ({cat}) must equal embeddings length ({emb})")]
33 LengthMismatch { cat: usize, emb: usize },
34 #[error("need at least 3 embeddings, got {0}")]
36 TooFewEmbeddings(usize),
37 #[error("projection fit failed: {0}")]
41 Projection(#[from] crate::projection::ProjectionError),
42 #[error("unknown category: {0:?}")]
47 UnknownCategory(String),
48 #[error("unknown id: {0:?}")]
51 UnknownId(String),
52 #[error("auto_tune produced no successful trials ({} failures)", failures.len())]
59 AllTrialsFailed {
60 failures: Vec<(crate::config::PipelineConfig, String)>,
61 },
62 #[error("invalid search space: {0}")]
70 InvalidSearchSpace(String),
71 #[error("invalid input: {0}")]
75 InvalidInput(String),
76}
77
78pub struct PipelineInput {
86 pub categories: Vec<String>,
87 pub embeddings: Vec<Vec<f64>>,
88}
89
90pub struct PipelineQuery {
93 pub embedding: Vec<f64>,
94}
95
96#[derive(Debug, Clone)]
105pub struct NearestResult {
106 pub id: String,
108 pub category: String,
110 pub distance: f64,
113 pub certainty: f64,
115 pub intensity: f64,
117 pub quality: Option<QualitySignal>,
122}
123
124#[derive(Debug, Clone)]
127pub struct PathResult {
128 pub steps: Vec<PipelinePathStep>,
129 pub total_distance: f64,
130}
131
132#[derive(Debug, Clone)]
134pub struct PipelinePathStep {
135 pub id: String,
136 pub category: String,
137 pub cumulative_distance: f64,
138 pub hop_distance: f64,
140 pub bridge_strength: Option<f64>,
142}
143
144#[derive(Debug, Clone)]
146pub struct GlobSummary {
147 pub id: usize,
148 pub centroid: [f64; 3],
149 pub member_count: usize,
150 pub radius: f64,
151 pub top_categories: Vec<(String, usize)>,
152}
153
154#[derive(Debug, Clone)]
156pub struct ManifoldResult {
157 pub centroid: [f64; 3],
158 pub normal: [f64; 3],
159 pub variance_ratio: f64,
160}
161
162#[derive(Debug, Clone)]
164pub enum SphereQLOutput {
165 Nearest(Vec<NearestResult>),
166 KNearest(Vec<NearestResult>),
167 ConceptPath(Option<PathResult>),
168 Globs(Vec<GlobSummary>),
169 LocalManifold(ManifoldResult),
170 CategoryConceptPath(Option<CategoryPath>),
173 CategoryNeighbors(Vec<CategorySummary>),
175 DrillDown(Vec<DrillDownResult>),
177 CategoryStats {
179 summaries: Vec<CategorySummary>,
180 inner_sphere_reports: Vec<InnerSphereReport>,
181 },
182}
183
184pub enum SphereQLQuery<'a> {
186 Nearest { k: usize },
188 SimilarAbove { min_cosine: f64 },
190 ConceptPath {
192 source_id: &'a str,
193 target_id: &'a str,
194 graph_k: usize,
195 },
196 DetectGlobs { k: Option<usize>, max_k: usize },
198 LocalManifold { neighborhood_k: usize },
200 CategoryConceptPath {
203 source_category: &'a str,
204 target_category: &'a str,
205 },
206 CategoryNeighbors { category: &'a str, k: usize },
208 DrillDown { category: &'a str, k: usize },
211 CategoryStats,
213}
214
215#[derive(Debug, Clone, serde::Serialize)]
217pub struct ExportedPoint {
218 pub id: String,
219 pub category: String,
220 pub r: f64,
221 pub theta: f64,
222 pub phi: f64,
223 pub x: f64,
224 pub y: f64,
225 pub z: f64,
226 pub certainty: f64,
227 pub intensity: f64,
228}
229
230pub(crate) struct MetricPoints {
236 pub(crate) positions: Vec<SphericalPoint>,
237 pub(crate) category_indices: Vec<Option<usize>>,
238}
239
240const HIGH_EVR_ROUTING_BYPASS: f64 = 0.90;
250
251pub struct SphereQLPipeline {
259 projection: ConfiguredProjection,
260 index: EmbeddingIndex<ConfiguredProjection>,
261 categories: Vec<String>,
262 cart_points: Vec<[f64; 3]>,
263 ids: Vec<String>,
264 category_layer: CategoryLayer,
266 quality_config: QualityConfig,
268 projection_warnings: Vec<ProjectionWarning>,
270 config: PipelineConfig,
272 raw_embeddings: Option<Vec<Vec<f64>>>,
276 metric_points: OnceLock<MetricPoints>,
281}
282
283impl SphereQLPipeline {
284 pub fn new(input: PipelineInput) -> Result<Self, PipelineError> {
290 Self::new_with_config(input, PipelineConfig::default())
291 }
292
293 pub fn new_with_config(
297 input: PipelineInput,
298 config: PipelineConfig,
299 ) -> Result<Self, PipelineError> {
300 #[cfg(feature = "retain-embeddings")]
301 let raw = Some(input.embeddings.clone());
302 #[cfg(not(feature = "retain-embeddings"))]
303 let raw: Option<Vec<Vec<f64>>> = None;
304
305 let embeddings: Vec<Embedding> = input
306 .embeddings
307 .iter()
308 .map(|v| Embedding::new(v.clone()))
309 .collect();
310
311 let projection = fit_projection_for_config(&embeddings, &input.categories, &config)?;
312 let mut pipeline = Self::with_configured_projection_and_config(
313 input.categories,
314 embeddings,
315 projection,
316 config,
317 )?;
318 pipeline.raw_embeddings = raw;
319 Ok(pipeline)
320 }
321
322 pub fn new_from_metamodel<M: MetaModel>(
336 input: PipelineInput,
337 model: &M,
338 ) -> Result<(Self, CorpusFeatures, PipelineConfig), PipelineError> {
339 require_fitted(model)?;
340 let features = CorpusFeatures::extract(&input.categories, &input.embeddings)
341 .map_err(PipelineError::InvalidInput)?;
342 let predicted = model.predict(&features);
343 let pipeline = Self::new_with_config(input, predicted.clone())?;
344 Ok((pipeline, features, predicted))
345 }
346
347 pub fn new_from_metamodel_tuned<M, Q>(
364 input: PipelineInput,
365 model: &M,
366 space: &SearchSpace,
367 metric: &Q,
368 strategy: SearchStrategy,
369 ) -> Result<(Self, CorpusFeatures, TuneReport), PipelineError>
370 where
371 M: MetaModel,
372 Q: QualityMetric,
373 {
374 require_fitted(model)?;
375 let features = CorpusFeatures::extract(&input.categories, &input.embeddings)
376 .map_err(PipelineError::InvalidInput)?;
377 let predicted = model.predict(&features);
378 let (pipeline, report) = auto_tune(input, space, metric, strategy, &predicted)?;
379 Ok((pipeline, features, report))
380 }
381
382 pub fn with_projection(
389 categories: Vec<String>,
390 embeddings: Vec<Embedding>,
391 pca: PcaProjection,
392 ) -> Result<Self, PipelineError> {
393 Self::with_configured_projection_and_config(
394 categories,
395 embeddings,
396 ConfiguredProjection::Pca(pca),
397 PipelineConfig::default(),
398 )
399 }
400
401 pub fn with_projection_and_config(
404 categories: Vec<String>,
405 embeddings: Vec<Embedding>,
406 pca: PcaProjection,
407 config: PipelineConfig,
408 ) -> Result<Self, PipelineError> {
409 Self::with_configured_projection_and_config(
410 categories,
411 embeddings,
412 ConfiguredProjection::Pca(pca),
413 config,
414 )
415 }
416
417 pub fn with_configured_projection_and_config(
420 categories: Vec<String>,
421 embeddings: Vec<Embedding>,
422 projection: ConfiguredProjection,
423 config: PipelineConfig,
424 ) -> Result<Self, PipelineError> {
425 Self::with_projection_parts(categories, &embeddings, projection, config)
426 }
427
428 pub(crate) fn with_projection_parts(
435 categories: Vec<String>,
436 embeddings: &[Embedding],
437 projection: ConfiguredProjection,
438 config: PipelineConfig,
439 ) -> Result<Self, PipelineError> {
440 let n = embeddings.len();
441 if n != categories.len() {
442 return Err(PipelineError::LengthMismatch {
443 cat: categories.len(),
444 emb: n,
445 });
446 }
447 if n < 3 {
448 return Err(PipelineError::TooFewEmbeddings(n));
449 }
450
451 let mut index = EmbeddingIndex::builder(projection.clone())
452 .uniform_shells(10, 1.0)
453 .theta_divisions(12)
454 .phi_divisions(6)
455 .build();
456
457 let mut ids = Vec::with_capacity(n);
458 for (i, emb) in embeddings.iter().enumerate() {
459 let id = format!("s-{i:04}");
460 index.insert(&id, emb);
461 ids.push(id);
462 }
463
464 let projected_positions: Vec<SphericalPoint> =
471 embeddings.iter().map(|e| projection.project(e)).collect();
472 let cart_points: Vec<[f64; 3]> = projected_positions
473 .iter()
474 .map(|sp| {
475 let c = spherical_to_cartesian(sp);
476 [c.x, c.y, c.z]
477 })
478 .collect();
479
480 let evr = projection.explained_variance_ratio();
481 let category_layer = CategoryLayer::build_with_config(
482 &categories,
483 embeddings,
484 &projected_positions,
485 &projection,
486 evr,
487 &config,
488 );
489
490 let quality_config = QualityConfig::default();
491 let projection_warnings = ProjectionWarning::from_evr(evr, quality_config.warn_below_evr)
492 .into_iter()
493 .collect();
494
495 Ok(Self {
496 projection,
497 index,
498 categories,
499 cart_points,
500 ids,
501 category_layer,
502 quality_config,
503 projection_warnings,
504 config,
505 raw_embeddings: None,
506 metric_points: OnceLock::new(),
507 })
508 }
509
510 pub fn has_category(&self, name: &str) -> bool {
515 self.category_layer.name_to_index.contains_key(name)
516 }
517
518 pub fn has_id(&self, id: &str) -> bool {
520 self.index.get(id).is_some()
521 }
522
523 pub fn ids(&self) -> &[String] {
528 &self.ids
529 }
530
531 pub fn query(
540 &self,
541 q: SphereQLQuery<'_>,
542 query_embedding: &PipelineQuery,
543 ) -> Result<SphereQLOutput, PipelineError> {
544 let emb = Embedding::new(query_embedding.embedding.clone());
545
546 match q {
547 SphereQLQuery::Nearest { k } => {
548 Ok(SphereQLOutput::Nearest(self.default_nearest(&emb, k)))
549 }
550
551 SphereQLQuery::SimilarAbove { min_cosine } => {
552 let evr = self.projection.explained_variance_ratio();
553 let results = self.index.search_similar(&emb, min_cosine);
554 let sp_q = self.projection.project(&emb);
555 Ok(SphereQLOutput::KNearest(
556 results
557 .items
558 .iter()
559 .map(|item| {
560 let d = angular_distance(&sp_q, item.position());
561 let certainty = item.certainty();
562 let quality = QualitySignal::from_certainty(evr, certainty);
563 NearestResult {
564 id: item.id.clone(),
565 category: self.cat_for(&item.id),
566 distance: d,
567 certainty,
568 intensity: item.intensity(),
569 quality: Some(quality),
570 }
571 })
572 .filter(|r| self.passes_quality(r))
573 .collect(),
574 ))
575 }
576
577 SphereQLQuery::ConceptPath {
578 source_id,
579 target_id,
580 graph_k,
581 } => {
582 if !self.has_id(source_id) {
583 return Err(PipelineError::UnknownId(source_id.to_string()));
584 }
585 if !self.has_id(target_id) {
586 return Err(PipelineError::UnknownId(target_id.to_string()));
587 }
588 let path = self.index.concept_path(source_id, target_id, graph_k);
589 Ok(SphereQLOutput::ConceptPath(path.map(|p| {
590 PathResult {
591 total_distance: p.total_distance,
592 steps: p
593 .steps
594 .iter()
595 .map(|s| PipelinePathStep {
596 id: s.id.clone(),
597 category: self.cat_for(&s.id),
598 cumulative_distance: s.cumulative_distance,
599 hop_distance: s.hop_distance,
600 bridge_strength: s.bridge_strength,
601 })
602 .collect(),
603 }
604 })))
605 }
606
607 SphereQLQuery::DetectGlobs { k, max_k } => {
608 let result = GlobResult::detect(&self.cart_points, &self.ids, k, max_k);
609 Ok(SphereQLOutput::Globs(
610 result
611 .globs
612 .iter()
613 .map(|g| {
614 let mut cat_counts = std::collections::HashMap::<String, usize>::new();
615 for mid in &g.member_ids {
616 let cat = self.cat_for(mid);
617 *cat_counts.entry(cat).or_default() += 1;
618 }
619 let mut top: Vec<_> = cat_counts.into_iter().collect();
620 top.sort_by_key(|(_, c)| std::cmp::Reverse(*c));
621 top.truncate(3);
622
623 GlobSummary {
624 id: g.id,
625 centroid: g.centroid,
626 member_count: g.member_ids.len(),
627 radius: g.radius,
628 top_categories: top,
629 }
630 })
631 .collect(),
632 ))
633 }
634
635 SphereQLQuery::LocalManifold { neighborhood_k } => {
636 let sp = self.projection.project(&emb);
637 let c = spherical_to_cartesian(&sp);
638 let qpt = [c.x, c.y, c.z];
639 let m = SlicingManifold::fit_local(&qpt, &self.cart_points, neighborhood_k);
640 Ok(SphereQLOutput::LocalManifold(ManifoldResult {
641 centroid: m.centroid,
642 normal: m.normal,
643 variance_ratio: m.variance_ratio,
644 }))
645 }
646
647 SphereQLQuery::CategoryConceptPath {
649 source_category,
650 target_category,
651 } => {
652 if !self.has_category(source_category) {
653 return Err(PipelineError::UnknownCategory(source_category.to_string()));
654 }
655 if !self.has_category(target_category) {
656 return Err(PipelineError::UnknownCategory(target_category.to_string()));
657 }
658 let path = self
659 .category_layer
660 .category_path(source_category, target_category);
661 Ok(SphereQLOutput::CategoryConceptPath(path))
662 }
663
664 SphereQLQuery::CategoryNeighbors { category, k } => {
665 if !self.has_category(category) {
666 return Err(PipelineError::UnknownCategory(category.to_string()));
667 }
668 let neighbors = self.category_layer.category_neighbors(category, k);
669 Ok(SphereQLOutput::CategoryNeighbors(
670 neighbors.into_iter().cloned().collect(),
671 ))
672 }
673
674 SphereQLQuery::DrillDown { category, k } => {
675 if !self.has_category(category) {
676 return Err(PipelineError::UnknownCategory(category.to_string()));
677 }
678 let results = self.category_layer.drill_down_with_projection(
679 category,
680 &emb,
681 &self.projection,
682 k,
683 );
684 Ok(SphereQLOutput::DrillDown(results))
685 }
686
687 SphereQLQuery::CategoryStats => Ok(SphereQLOutput::CategoryStats {
688 summaries: self.category_layer.summaries.clone(),
689 inner_sphere_reports: self.category_layer.inner_sphere_stats(),
690 }),
691 }
692 }
693
694 fn cat_for(&self, id: &str) -> String {
696 if let Some(idx_str) = id.strip_prefix("s-")
697 && let Ok(idx) = idx_str.parse::<usize>()
698 && idx < self.categories.len()
699 {
700 return self.categories[idx].clone();
701 }
702 debug_assert!(
707 false,
708 "cat_for: id {id:?} does not match the generated `s-{{i:04}}` format \
709 or indexes past {} categories",
710 self.categories.len()
711 );
712 "unknown".into()
713 }
714
715 pub fn num_items(&self) -> usize {
717 self.ids.len()
718 }
719
720 pub fn categories(&self) -> &[String] {
722 &self.categories
723 }
724
725 pub fn projected_points(&self) -> Vec<(&str, &str, [f64; 3])> {
727 self.ids
728 .iter()
729 .enumerate()
730 .map(|(i, id)| {
731 let cat = self
732 .categories
733 .get(i)
734 .map(|s| s.as_str())
735 .unwrap_or("unknown");
736 (id.as_str(), cat, self.cart_points[i])
737 })
738 .collect()
739 }
740
741 pub fn projection(&self) -> &ConfiguredProjection {
750 &self.projection
751 }
752
753 pub fn projection_kind(&self) -> ProjectionKind {
755 self.projection.kind()
756 }
757
758 pub fn exported_points(&self) -> Vec<ExportedPoint> {
762 self.ids
763 .iter()
764 .enumerate()
765 .map(|(i, id)| {
766 let [x, y, z] = self.cart_points[i];
767 let category = self
768 .categories
769 .get(i)
770 .cloned()
771 .unwrap_or_else(|| "unknown".into());
772 let item = self.index.get(id);
773 let (r, theta, phi) = item
774 .map(|it| {
775 let pos = it.position();
776 (pos.r, pos.theta, pos.phi)
777 })
778 .unwrap_or((0.0, 0.0, 0.0));
779 let certainty = item.map_or(1.0, |it| it.certainty());
780 let intensity = item.map_or(1.0, |it| it.intensity());
781 ExportedPoint {
782 id: id.clone(),
783 category,
784 r,
785 theta,
786 phi,
787 x,
788 y,
789 z,
790 certainty,
791 intensity,
792 }
793 })
794 .collect()
795 }
796
797 pub(crate) fn metric_points(&self) -> &MetricPoints {
802 self.metric_points.get_or_init(|| {
803 let positions = self
804 .ids
805 .iter()
806 .map(|id| {
807 self.index
808 .get(id)
809 .map(|it| *it.position())
810 .unwrap_or_else(|| SphericalPoint::new_unchecked(0.0, 0.0, 0.0))
811 })
812 .collect();
813 let category_indices = self
814 .categories
815 .iter()
816 .map(|c| self.category_layer.name_to_index.get(c).copied())
817 .collect();
818 MetricPoints {
819 positions,
820 category_indices,
821 }
822 })
823 }
824
825 pub fn explained_variance_ratio(&self) -> f64 {
834 self.projection.explained_variance_ratio()
835 }
836
837 pub fn num_categories(&self) -> usize {
839 self.category_layer.num_categories()
840 }
841
842 pub fn unique_categories(&self) -> Vec<String> {
844 self.category_layer
845 .summaries
846 .iter()
847 .map(|s| s.name.clone())
848 .collect()
849 }
850
851 pub fn category_layer(&self) -> &CategoryLayer {
855 &self.category_layer
856 }
857
858 pub fn category_path(&self, source: &str, target: &str) -> Option<CategoryPath> {
860 self.category_layer.category_path(source, target)
861 }
862
863 pub fn bridge_items(&self, source: &str, target: &str, max: usize) -> Vec<&BridgeItem> {
865 self.category_layer.bridge_items(source, target, max)
866 }
867
868 pub fn has_inner_sphere(&self, category: &str) -> bool {
870 self.category_layer.has_inner_sphere(category)
871 }
872
873 pub fn num_inner_spheres(&self) -> usize {
875 self.category_layer.num_inner_spheres()
876 }
877
878 pub fn inner_sphere_stats(&self) -> Vec<InnerSphereReport> {
880 self.category_layer.inner_sphere_stats()
881 }
882
883 pub fn projection_warnings(&self) -> &[ProjectionWarning] {
885 &self.projection_warnings
886 }
887
888 pub fn raw_embeddings(&self) -> Option<&[Vec<f64>]> {
896 self.raw_embeddings.as_deref()
897 }
898
899 pub fn embedding_dim(&self) -> usize {
902 self.raw_embeddings
903 .as_ref()
904 .and_then(|e| e.first())
905 .map(|v| v.len())
906 .unwrap_or(0)
907 }
908
909 pub fn pairwise_similarities(&self) -> Option<Result<Vec<f64>, sphereql_core::SphereQlError>> {
919 self.raw_embeddings
920 .as_ref()
921 .map(|embeddings| sphereql_core::pairwise_cosine_similarities(embeddings))
922 }
923
924 pub fn nearest_by_embedding(
938 &self,
939 query_embedding: &[f64],
940 k: usize,
941 ) -> Option<Result<Vec<(usize, f64)>, sphereql_core::SphereQlError>> {
942 let embeddings = self.raw_embeddings.as_ref()?;
943 let dim = embeddings.first().map(|v| v.len()).unwrap_or(0);
944
945 if query_embedding.len() != dim {
946 return Some(Err(sphereql_core::SphereQlError::DimensionMismatch {
947 expected: dim,
948 actual: query_embedding.len(),
949 }));
950 }
951
952 let query_norm: f64 = query_embedding.iter().map(|x| x * x).sum::<f64>().sqrt();
953 if query_norm < f64::EPSILON || k == 0 {
954 return Some(Ok(Vec::new()));
955 }
956
957 struct Scored {
961 sim: f64,
962 index: usize,
963 }
964 impl Ord for Scored {
965 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
966 self.sim
967 .total_cmp(&other.sim)
968 .then_with(|| other.index.cmp(&self.index))
969 }
970 }
971 impl PartialOrd for Scored {
972 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
973 Some(self.cmp(other))
974 }
975 }
976 impl PartialEq for Scored {
977 fn eq(&self, other: &Self) -> bool {
978 self.cmp(other) == std::cmp::Ordering::Equal
979 }
980 }
981 impl Eq for Scored {}
982
983 let mut heap: BinaryHeap<Reverse<Scored>> = BinaryHeap::with_capacity(k + 1);
984 for (i, emb) in embeddings.iter().enumerate() {
985 let emb_norm: f64 = emb.iter().map(|x| x * x).sum::<f64>().sqrt();
986 let dot: f64 = query_embedding
987 .iter()
988 .zip(emb.iter())
989 .map(|(a, b)| a * b)
990 .sum();
991 let sim = if emb_norm < f64::EPSILON {
992 0.0
993 } else {
994 (dot / (query_norm * emb_norm)).clamp(-1.0, 1.0)
995 };
996 heap.push(Reverse(Scored { sim, index: i }));
997 if heap.len() > k {
998 heap.pop();
999 }
1000 }
1001
1002 let mut scored: Vec<(usize, f64)> = heap
1003 .into_iter()
1004 .map(|Reverse(s)| (s.index, s.sim))
1005 .collect();
1006 scored.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
1007
1008 Some(Ok(scored))
1009 }
1010
1011 pub fn domain_groups(&self) -> &[DomainGroup] {
1017 &self.category_layer.domain_groups
1018 }
1019
1020 pub fn route_to_group(&self, embedding: &Embedding) -> Option<&DomainGroup> {
1023 let groups = &self.category_layer.domain_groups;
1024 if groups.is_empty() {
1025 return None;
1026 }
1027 let pos = self.projection.project(embedding);
1028 groups.iter().min_by(|a, b| {
1029 let da = angular_distance(&pos, &a.centroid);
1030 let db = angular_distance(&pos, &b.centroid);
1031 da.total_cmp(&db)
1032 })
1033 }
1034
1035 pub fn hierarchical_nearest(&self, embedding: &Embedding, k: usize) -> Vec<NearestResult> {
1047 let evr = self.projection.explained_variance_ratio();
1048
1049 if evr >= self.config.routing.low_evr_threshold {
1050 return self.nearest_filtered(embedding, k, evr);
1051 }
1052
1053 let Some(group) = self.route_to_group(embedding) else {
1054 return self.nearest_filtered(embedding, k, evr);
1055 };
1056
1057 let mut candidates: Vec<NearestResult> = Vec::new();
1060 for &ci in &group.member_categories {
1061 let cat_name = &self.category_layer.summaries[ci].name;
1062 for r in self.category_layer.drill_down_with_projection(
1063 cat_name,
1064 embedding,
1065 &self.projection,
1066 k,
1067 ) {
1068 candidates.push(self.drill_result_to_nearest(&r, evr));
1069 }
1070 }
1071
1072 candidates.sort_by(|a, b| {
1073 a.distance
1074 .partial_cmp(&b.distance)
1075 .unwrap_or(std::cmp::Ordering::Equal)
1076 });
1077
1078 let truncated: Vec<NearestResult> = candidates.into_iter().take(k).collect();
1088
1089 if truncated.is_empty() {
1090 self.nearest_filtered(embedding, k, evr)
1091 } else {
1092 truncated
1093 }
1094 }
1095
1096 #[inline]
1105 fn passes_quality(&self, r: &NearestResult) -> bool {
1106 r.certainty >= self.quality_config.min_certainty
1107 && r.quality
1108 .is_none_or(|q| q.passes_threshold(self.quality_config.min_combined))
1109 }
1110
1111 pub fn default_nearest(&self, embedding: &Embedding, k: usize) -> Vec<NearestResult> {
1121 let evr = self.projection.explained_variance_ratio();
1122 let alpha = self.config.routing.group_routing_alpha;
1123
1124 let route = if alpha > 0.0 && evr < HIGH_EVR_ROUTING_BYPASS {
1125 let pos = self.projection.project(embedding);
1126 self.category_layer
1127 .nearest_group(&pos)
1128 .filter(|(gi, d_near, d_second)| {
1129 self.category_layer.group_inner_spheres.contains_key(gi)
1130 && (*d_second == f64::INFINITY || d_near / d_second < alpha)
1131 })
1132 .map(|(gi, _, _)| gi)
1133 } else {
1134 None
1135 };
1136
1137 if let Some(gi) = route {
1138 let drilled = self.category_layer.drill_down_group(gi, embedding, k);
1139 let mut results: Vec<NearestResult> = drilled
1140 .iter()
1141 .map(|r| self.drill_result_to_nearest(r, evr))
1142 .filter(|r| self.passes_quality(r))
1143 .collect();
1144 if !results.is_empty() {
1147 results.truncate(k);
1148 return results;
1149 }
1150 }
1151
1152 self.nearest_filtered(embedding, k, evr)
1153 }
1154
1155 fn nearest_filtered(&self, embedding: &Embedding, k: usize, evr: f64) -> Vec<NearestResult> {
1157 self.index
1158 .search_nearest(embedding, k)
1159 .iter()
1160 .map(|r| {
1161 let certainty = r.item.certainty();
1162 let quality = QualitySignal::from_certainty(evr, certainty);
1163 NearestResult {
1164 id: r.item.id.clone(),
1165 category: self.cat_for(&r.item.id),
1166 distance: r.distance,
1167 certainty,
1168 intensity: r.item.intensity(),
1169 quality: Some(quality),
1170 }
1171 })
1172 .filter(|r| self.passes_quality(r))
1173 .collect()
1174 }
1175
1176 fn drill_result_to_nearest(&self, r: &DrillDownResult, evr: f64) -> NearestResult {
1177 let id = self.ids[r.item_index].clone();
1178 let item = self.index.get(&id);
1179 let certainty = item.map_or(1.0, |it| it.certainty());
1180 let intensity = item.map_or(1.0, |it| it.intensity());
1181 let quality = QualitySignal::from_certainty(evr, certainty);
1182 NearestResult {
1183 id,
1184 category: self
1185 .categories
1186 .get(r.item_index)
1187 .cloned()
1188 .unwrap_or_else(|| "unknown".into()),
1189 distance: r.distance,
1190 certainty,
1191 intensity,
1192 quality: Some(quality),
1193 }
1194 }
1195
1196 pub fn quality_config(&self) -> &QualityConfig {
1198 &self.quality_config
1199 }
1200
1201 pub fn set_quality_config(&mut self, config: QualityConfig) {
1203 self.quality_config = config;
1204 }
1205
1206 pub fn annotate_relations(&mut self, labels: &[String]) {
1212 self.category_layer.annotate_bridge_relations(labels);
1213 }
1214
1215 pub fn config(&self) -> &PipelineConfig {
1217 &self.config
1218 }
1219
1220 pub fn to_json(&self) -> Result<String, serde_json::Error> {
1225 serde_json::to_string(&self.exported_points())
1226 }
1227
1228 pub fn to_csv(&self) -> String {
1233 let points = self.exported_points();
1234 let mut out = String::from("id,category,r,theta,phi,x,y,z,certainty,intensity\n");
1235 for p in &points {
1236 out.push_str(&format!(
1237 "\"{}\",\"{}\",{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6}\n",
1238 p.id.replace('"', "\"\""),
1239 p.category.replace('"', "\"\""),
1240 p.r,
1241 p.theta,
1242 p.phi,
1243 p.x,
1244 p.y,
1245 p.z,
1246 p.certainty,
1247 p.intensity,
1248 ));
1249 }
1250 out
1251 }
1252}
1253
1254fn require_fitted<M: MetaModel>(model: &M) -> Result<(), PipelineError> {
1258 if model.is_fitted() {
1259 Ok(())
1260 } else {
1261 Err(PipelineError::InvalidInput(format!(
1262 "meta-model {:?} is unfitted; call fit() with at least one record first",
1263 model.name()
1264 )))
1265 }
1266}
1267
1268pub fn fit_projection_for_config(
1280 embeddings: &[Embedding],
1281 categories: &[String],
1282 config: &PipelineConfig,
1283) -> Result<ConfiguredProjection, crate::projection::ProjectionError> {
1284 match config.projection_kind {
1285 ProjectionKind::Pca => {
1286 let mut cat_counts: std::collections::HashMap<&str, usize> =
1292 std::collections::HashMap::new();
1293 for c in categories {
1294 *cat_counts.entry(c.as_str()).or_default() += 1;
1295 }
1296 let weights: Vec<f64> = categories
1297 .iter()
1298 .map(|c| 1.0 / (cat_counts[c.as_str()] as f64).sqrt())
1299 .collect();
1300 Ok(ConfiguredProjection::Pca(
1301 PcaProjection::fit_weighted(embeddings, &weights, RadialStrategy::Magnitude)?
1302 .with_volumetric(true),
1303 ))
1304 }
1305 ProjectionKind::KernelPca => Ok(ConfiguredProjection::KernelPca(KernelPcaProjection::fit(
1306 embeddings,
1307 RadialStrategy::Magnitude,
1308 )?)),
1309 ProjectionKind::LaplacianEigenmap => {
1310 let lc = &config.laplacian;
1311 Ok(ConfiguredProjection::Laplacian(
1312 LaplacianEigenmapProjection::fit_with_params(
1313 embeddings,
1314 lc.k_neighbors,
1315 lc.active_threshold,
1316 RadialStrategy::Magnitude,
1317 )?,
1318 ))
1319 }
1320 ProjectionKind::UmapSphere => {
1321 let cat_indices = compact_category_indices(categories);
1322 Ok(ConfiguredProjection::UmapSphere(
1323 crate::umap::UmapSphereProjection::fit(
1324 embeddings,
1325 Some(&cat_indices),
1326 RadialStrategy::Magnitude,
1327 umap_fit_config(config),
1328 )?,
1329 ))
1330 }
1331 }
1332}
1333
1334pub fn fit_umap_from_graph(
1342 graph: &crate::umap::UmapGraph,
1343 categories: &[String],
1344 config: &PipelineConfig,
1345) -> Result<ConfiguredProjection, crate::projection::ProjectionError> {
1346 let cat_indices = compact_category_indices(categories);
1347 Ok(ConfiguredProjection::UmapSphere(
1348 crate::umap::UmapSphereProjection::fit_from_graph(
1349 graph,
1350 Some(&cat_indices),
1351 RadialStrategy::Magnitude,
1352 umap_fit_config(config),
1353 )?,
1354 ))
1355}
1356
1357fn compact_category_indices(categories: &[String]) -> Vec<u32> {
1362 let mut cat_map: std::collections::HashMap<&str, u32> = std::collections::HashMap::new();
1363 let mut next_id: u32 = 0;
1364 categories
1365 .iter()
1366 .map(|c| {
1367 *cat_map.entry(c.as_str()).or_insert_with(|| {
1368 let id = next_id;
1369 next_id += 1;
1370 id
1371 })
1372 })
1373 .collect()
1374}
1375
1376fn umap_fit_config(config: &PipelineConfig) -> crate::umap::UmapConfig {
1382 let uc = &config.umap;
1383 crate::umap::UmapConfig {
1384 n_neighbors: uc.n_neighbors,
1385 n_epochs: uc.n_epochs,
1386 learning_rate: 0.05,
1387 negative_sample_rate: 5,
1388 category_weight: uc.category_weight,
1389 min_dist: uc.min_dist,
1390 warm_start_anchor: uc.warm_start_anchor,
1391 seed: uc.seed,
1392 }
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397 use super::*;
1398
1399 fn make_input(n: usize, dim: usize) -> (PipelineInput, PipelineQuery) {
1400 let mut embeddings = Vec::with_capacity(n);
1401 let mut categories = Vec::with_capacity(n);
1402 for i in 0..n {
1403 let mut v = vec![0.0; dim];
1404 if i < n / 2 {
1405 v[0] = 1.0 + (i as f64 * 0.01);
1406 v[1] = 0.1;
1407 categories.push("group_a".into());
1408 } else {
1409 v[0] = 0.1;
1410 v[1] = 1.0 + (i as f64 * 0.01);
1411 categories.push("group_b".into());
1412 }
1413 v[2] = 0.05 * (i as f64);
1414 embeddings.push(v);
1415 }
1416 let query = PipelineQuery {
1417 embedding: vec![0.9; dim],
1418 };
1419 (
1420 PipelineInput {
1421 categories,
1422 embeddings,
1423 },
1424 query,
1425 )
1426 }
1427
1428 #[test]
1431 fn ids_are_insertion_order_aligned_with_categories_and_points() {
1432 let (input, _) = make_input(20, 10);
1433 let categories_in = input.categories.clone();
1434 let pipeline = SphereQLPipeline::new(input).unwrap();
1435
1436 let ids = pipeline.ids();
1437 assert_eq!(ids.len(), 20);
1438 for (i, id) in ids.iter().enumerate() {
1439 assert_eq!(id, &format!("s-{i:04}"));
1440 }
1441
1442 let cats = pipeline.categories();
1444 assert_eq!(cats.len(), ids.len());
1445 for (i, cat) in cats.iter().enumerate() {
1446 assert_eq!(cat, &categories_in[i]);
1447 }
1448
1449 let pts = pipeline.projected_points();
1453 assert_eq!(pts.len(), ids.len());
1454 for (i, (id, cat, _)) in pts.iter().enumerate() {
1455 assert_eq!(*id, ids[i].as_str());
1456 assert_eq!(*cat, cats[i].as_str());
1457 }
1458 }
1459
1460 #[test]
1461 fn fit_weighted_pca_handles_singleton_categories() {
1462 let mut embeddings = Vec::new();
1470 let mut categories = Vec::new();
1471 for i in 0..20 {
1472 let mut v = vec![0.0; 8];
1473 v[0] = 1.0 + (i as f64 * 0.01);
1474 v[1] = 0.1;
1475 embeddings.push(v);
1476 categories.push("big".to_string());
1477 }
1478 embeddings.push(vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]);
1479 categories.push("singleton_a".to_string());
1480 embeddings.push(vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
1481 categories.push("singleton_b".to_string());
1482
1483 let pipeline = SphereQLPipeline::new(PipelineInput {
1484 categories,
1485 embeddings,
1486 })
1487 .unwrap();
1488
1489 let evr = pipeline.explained_variance_ratio();
1490 assert!(
1491 evr > 0.0,
1492 "weighted PCA should produce nonzero EVR even with singletons, got {evr}"
1493 );
1494 assert_eq!(pipeline.num_categories(), 3);
1495 }
1496
1497 #[test]
1498 fn pipeline_nearest() {
1499 let (input, query) = make_input(20, 10);
1500 let pipeline = SphereQLPipeline::new(input).unwrap();
1501 let result = pipeline
1502 .query(SphereQLQuery::Nearest { k: 5 }, &query)
1503 .unwrap();
1504 match result {
1505 SphereQLOutput::Nearest(items) => {
1506 assert_eq!(items.len(), 5);
1507 assert!(items[0].distance <= items[1].distance);
1508 }
1509 _ => panic!("expected Nearest"),
1510 }
1511 }
1512
1513 #[test]
1514 fn pipeline_globs() {
1515 let (input, query) = make_input(30, 10);
1516 let pipeline = SphereQLPipeline::new(input).unwrap();
1517 let result = pipeline
1518 .query(
1519 SphereQLQuery::DetectGlobs {
1520 k: Some(2),
1521 max_k: 5,
1522 },
1523 &query,
1524 )
1525 .unwrap();
1526 match result {
1527 SphereQLOutput::Globs(globs) => {
1528 assert_eq!(globs.len(), 2);
1529 let total: usize = globs.iter().map(|g| g.member_count).sum();
1530 assert_eq!(total, 30);
1531 }
1532 _ => panic!("expected Globs"),
1533 }
1534 }
1535
1536 #[test]
1537 fn pipeline_concept_path() {
1538 let (input, query) = make_input(20, 10);
1539 let pipeline = SphereQLPipeline::new(input).unwrap();
1540 let result = pipeline
1541 .query(
1542 SphereQLQuery::ConceptPath {
1543 source_id: "s-0000",
1544 target_id: "s-0015",
1545 graph_k: 10,
1546 },
1547 &query,
1548 )
1549 .unwrap();
1550 match result {
1551 SphereQLOutput::ConceptPath(Some(path)) => {
1552 assert!(path.steps.len() >= 2);
1553 assert_eq!(path.steps.first().unwrap().id, "s-0000");
1554 assert_eq!(path.steps.last().unwrap().id, "s-0015");
1555 }
1556 _ => panic!("expected ConceptPath(Some)"),
1557 }
1558 }
1559
1560 #[test]
1561 fn pipeline_local_manifold() {
1562 let (input, query) = make_input(20, 10);
1563 let pipeline = SphereQLPipeline::new(input).unwrap();
1564 let result = pipeline
1565 .query(SphereQLQuery::LocalManifold { neighborhood_k: 10 }, &query)
1566 .unwrap();
1567 match result {
1568 SphereQLOutput::LocalManifold(m) => {
1569 assert!(m.variance_ratio > 0.0);
1570 assert!(m.variance_ratio <= 1.0);
1571 }
1572 _ => panic!("expected LocalManifold"),
1573 }
1574 }
1575
1576 #[test]
1577 fn test_exported_points_count() {
1578 let (input, _) = make_input(20, 10);
1579 let pipeline = SphereQLPipeline::new(input).unwrap();
1580 assert_eq!(pipeline.exported_points().len(), 20);
1581 }
1582
1583 #[test]
1584 fn test_exported_points_fields() {
1585 let (input, _) = make_input(20, 10);
1586 let pipeline = SphereQLPipeline::new(input).unwrap();
1587 for p in pipeline.exported_points() {
1588 assert!(p.r >= 0.0, "r must be non-negative");
1589 assert!(
1590 p.theta >= 0.0 && p.theta < std::f64::consts::TAU,
1591 "theta out of range"
1592 );
1593 assert!(
1594 p.phi >= 0.0 && p.phi <= std::f64::consts::PI,
1595 "phi out of range"
1596 );
1597 }
1598 }
1599
1600 #[test]
1601 fn test_exported_points_categories() {
1602 let (input, _) = make_input(20, 10);
1603 let pipeline = SphereQLPipeline::new(input).unwrap();
1604 let points = pipeline.exported_points();
1605 for (i, p) in points.iter().enumerate() {
1606 let expected = if i < 10 { "group_a" } else { "group_b" };
1607 assert_eq!(p.category, expected);
1608 }
1609 }
1610
1611 #[test]
1612 fn test_to_json_parseable() {
1613 let (input, _) = make_input(20, 10);
1614 let pipeline = SphereQLPipeline::new(input).unwrap();
1615 let json = pipeline.to_json().unwrap();
1616 let parsed: Vec<serde_json::Value> = serde_json::from_str(&json).expect("valid JSON");
1617 assert_eq!(parsed.len(), 20);
1618 }
1619
1620 #[test]
1621 fn test_to_csv_lines() {
1622 let (input, _) = make_input(20, 10);
1623 let pipeline = SphereQLPipeline::new(input).unwrap();
1624 let csv = pipeline.to_csv();
1625 let lines: Vec<&str> = csv.lines().collect();
1626 assert_eq!(
1627 lines[0],
1628 "id,category,r,theta,phi,x,y,z,certainty,intensity"
1629 );
1630 assert_eq!(lines.len(), 21);
1631 }
1632
1633 #[test]
1634 fn test_to_csv_quoted_fields() {
1635 let (input, _) = make_input(20, 10);
1636 let pipeline = SphereQLPipeline::new(input).unwrap();
1637 let csv = pipeline.to_csv();
1638 let data_line = csv.lines().nth(1).unwrap();
1639 assert!(data_line.starts_with('"'), "id field should be quoted");
1640 }
1641
1642 #[test]
1643 fn test_explained_variance() {
1644 let (input, _) = make_input(20, 10);
1645 let pipeline = SphereQLPipeline::new(input).unwrap();
1646 let ratio = pipeline.explained_variance_ratio();
1647 assert!(ratio > 0.0 && ratio <= 1.0);
1648 }
1649
1650 #[test]
1651 fn test_unique_categories() {
1652 let (input, _) = make_input(20, 10);
1653 let pipeline = SphereQLPipeline::new(input).unwrap();
1654 let cats = pipeline.unique_categories();
1655 assert_eq!(cats.len(), 2);
1656 assert_eq!(cats[0], "group_a");
1657 assert_eq!(cats[1], "group_b");
1658 assert_eq!(pipeline.num_categories(), 2);
1659 }
1660
1661 #[test]
1664 fn pipeline_builds_category_layer() {
1665 let (input, _) = make_input(20, 10);
1666 let pipeline = SphereQLPipeline::new(input).unwrap();
1667 assert_eq!(pipeline.category_layer().num_categories(), 2);
1668 }
1669
1670 #[test]
1671 fn pipeline_category_path_query() {
1672 let (input, query) = make_input(20, 10);
1673 let pipeline = SphereQLPipeline::new(input).unwrap();
1674 let result = pipeline
1675 .query(
1676 SphereQLQuery::CategoryConceptPath {
1677 source_category: "group_a",
1678 target_category: "group_b",
1679 },
1680 &query,
1681 )
1682 .unwrap();
1683 match result {
1684 SphereQLOutput::CategoryConceptPath(Some(path)) => {
1685 assert!(path.steps.len() >= 2);
1686 assert_eq!(path.steps.first().unwrap().category_name, "group_a");
1687 assert_eq!(path.steps.last().unwrap().category_name, "group_b");
1688 assert!(path.total_distance > 0.0);
1689 }
1690 _ => panic!("expected CategoryConceptPath(Some)"),
1691 }
1692 }
1693
1694 #[test]
1695 fn pipeline_category_path_shortcut() {
1696 let (input, _) = make_input(20, 10);
1697 let pipeline = SphereQLPipeline::new(input).unwrap();
1698 let path = pipeline.category_path("group_a", "group_b");
1699 assert!(path.is_some());
1700 let path = path.unwrap();
1701 assert_eq!(path.steps.first().unwrap().category_name, "group_a");
1702 assert_eq!(path.steps.last().unwrap().category_name, "group_b");
1703 }
1704
1705 #[test]
1706 fn pipeline_category_path_unknown() {
1707 let (input, _) = make_input(20, 10);
1708 let pipeline = SphereQLPipeline::new(input).unwrap();
1709 assert!(pipeline.category_path("group_a", "nonexistent").is_none());
1710 }
1711
1712 #[test]
1713 fn pipeline_category_neighbors_query() {
1714 let (input, query) = make_input(20, 10);
1715 let pipeline = SphereQLPipeline::new(input).unwrap();
1716 let result = pipeline
1717 .query(
1718 SphereQLQuery::CategoryNeighbors {
1719 category: "group_a",
1720 k: 5,
1721 },
1722 &query,
1723 )
1724 .unwrap();
1725 match result {
1726 SphereQLOutput::CategoryNeighbors(neighbors) => {
1727 assert_eq!(neighbors.len(), 1);
1728 assert_eq!(neighbors[0].name, "group_b");
1729 }
1730 _ => panic!("expected CategoryNeighbors"),
1731 }
1732 }
1733
1734 #[test]
1735 fn pipeline_drill_down_query() {
1736 let (input, query) = make_input(20, 10);
1737 let pipeline = SphereQLPipeline::new(input).unwrap();
1738 let result = pipeline
1739 .query(
1740 SphereQLQuery::DrillDown {
1741 category: "group_a",
1742 k: 5,
1743 },
1744 &query,
1745 )
1746 .unwrap();
1747 match result {
1748 SphereQLOutput::DrillDown(results) => {
1749 assert!(!results.is_empty());
1750 assert!(results.len() <= 5);
1751 for w in results.windows(2) {
1752 assert!(w[0].distance <= w[1].distance);
1753 }
1754 }
1755 _ => panic!("expected DrillDown"),
1756 }
1757 }
1758
1759 #[test]
1760 fn pipeline_category_stats_query() {
1761 let (input, query) = make_input(20, 10);
1762 let pipeline = SphereQLPipeline::new(input).unwrap();
1763 let result = pipeline
1764 .query(SphereQLQuery::CategoryStats, &query)
1765 .unwrap();
1766 match result {
1767 SphereQLOutput::CategoryStats {
1768 summaries,
1769 inner_sphere_reports,
1770 } => {
1771 assert_eq!(summaries.len(), 2);
1772 assert_eq!(inner_sphere_reports.len(), 0);
1773 }
1774 _ => panic!("expected CategoryStats"),
1775 }
1776 }
1777
1778 #[test]
1779 fn pipeline_bridge_items_shortcut() {
1780 let (input, _) = make_input(20, 10);
1781 let pipeline = SphereQLPipeline::new(input).unwrap();
1782 let _ = pipeline.bridge_items("group_a", "group_b", 5);
1783 }
1784
1785 #[test]
1786 fn pipeline_inner_sphere_shortcuts() {
1787 let (input, _) = make_input(20, 10);
1788 let pipeline = SphereQLPipeline::new(input).unwrap();
1789 assert!(!pipeline.has_inner_sphere("group_a"));
1790 assert_eq!(pipeline.num_inner_spheres(), 0);
1791 assert!(pipeline.inner_sphere_stats().is_empty());
1792 }
1793
1794 #[test]
1795 fn pipeline_category_layer_accessor() {
1796 let (input, _) = make_input(20, 10);
1797 let pipeline = SphereQLPipeline::new(input).unwrap();
1798 let layer = pipeline.category_layer();
1799 assert_eq!(layer.num_categories(), 2);
1800 assert!(layer.get_category("group_a").is_some());
1801 assert!(layer.get_category("group_b").is_some());
1802 }
1803
1804 #[test]
1807 fn domain_groups_detected() {
1808 let (input, _) = make_input(20, 10);
1809 let pipeline = SphereQLPipeline::new(input).unwrap();
1810 let groups = pipeline.domain_groups();
1811 assert!(!groups.is_empty());
1812 let total: usize = groups.iter().map(|g| g.total_items).sum();
1813 assert_eq!(total, pipeline.num_items());
1814 }
1815
1816 #[test]
1817 fn domain_groups_cover_all_categories() {
1818 let (input, _) = make_input(20, 10);
1819 let pipeline = SphereQLPipeline::new(input).unwrap();
1820 let groups = pipeline.domain_groups();
1821 let mut all_cats: Vec<usize> = groups
1822 .iter()
1823 .flat_map(|g| g.member_categories.iter().copied())
1824 .collect();
1825 all_cats.sort();
1826 all_cats.dedup();
1827 assert_eq!(all_cats.len(), pipeline.num_categories());
1828 }
1829
1830 #[test]
1831 fn route_to_group_returns_something() {
1832 let (input, _) = make_input(20, 10);
1833 let pipeline = SphereQLPipeline::new(input).unwrap();
1834 let emb = Embedding::new(vec![0.5; 10]);
1835 assert!(pipeline.route_to_group(&emb).is_some());
1836 }
1837
1838 #[test]
1839 fn hierarchical_nearest_matches_standard_when_evr_high() {
1840 let (input, query) = make_input(20, 10);
1844 let pipeline = SphereQLPipeline::new(input).unwrap();
1845 let hier = pipeline.hierarchical_nearest(&Embedding::new(query.embedding.clone()), 5);
1846 assert!(!hier.is_empty());
1847 assert!(hier.len() <= 5);
1848 for w in hier.windows(2) {
1849 assert!(w[0].distance <= w[1].distance);
1850 }
1851 }
1852
1853 #[test]
1854 fn hierarchical_nearest_falls_back_when_filter_kills_candidates() {
1855 let (input, query) = make_input(20, 10);
1862 let mut pipeline = SphereQLPipeline::new_with_config(
1863 input,
1864 PipelineConfig {
1865 routing: crate::config::RoutingConfig {
1866 num_domain_groups: 2,
1867 group_routing_alpha: 0.8,
1868 low_evr_threshold: 1.1, },
1870 ..Default::default()
1871 },
1872 )
1873 .unwrap();
1874 pipeline.set_quality_config(crate::confidence::QualityConfig {
1875 min_certainty: 1.1, ..Default::default()
1877 });
1878
1879 pipeline.set_quality_config(crate::confidence::QualityConfig::default());
1887 let hier = pipeline.hierarchical_nearest(&Embedding::new(query.embedding.clone()), 5);
1888 assert!(
1889 !hier.is_empty(),
1890 "low-EVR branch should return results with default filter"
1891 );
1892 }
1893
1894 #[test]
1895 fn feedback_aggregator_derive_and_save_load_round_trip() {
1896 use crate::feedback::{FeedbackAggregator, FeedbackEvent};
1901 let mut agg = FeedbackAggregator::default();
1902 agg.record(FeedbackEvent {
1903 corpus_id: "c".into(),
1904 query_id: "q".into(),
1905 score: 0.5,
1906 timestamp: "0".into(),
1907 });
1908
1909 let json_via_derive = serde_json::to_string(&agg).unwrap();
1910 assert!(json_via_derive.starts_with('['));
1912
1913 let dir = std::env::temp_dir();
1915 let path = dir.join(format!(
1916 "sphereql_serde_transparent_{}.json",
1917 std::process::id()
1918 ));
1919 std::fs::write(&path, &json_via_derive).unwrap();
1920 let loaded = FeedbackAggregator::load(&path).unwrap();
1921 assert_eq!(loaded.len(), 1);
1922 let _ = std::fs::remove_file(&path);
1923 }
1924
1925 #[test]
1926 fn new_from_metamodel_uses_predicted_config() {
1927 use crate::corpus_features::CorpusFeatures;
1928 use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
1929
1930 let (input, _) = make_input(20, 10);
1931 let features = CorpusFeatures::extract(&input.categories, &input.embeddings).unwrap();
1932
1933 let target_config = PipelineConfig {
1937 projection_kind: ProjectionKind::LaplacianEigenmap,
1938 ..Default::default()
1939 };
1940 let record = MetaTrainingRecord {
1941 corpus_id: "seed".into(),
1942 features: features.clone(),
1943 best_config: target_config.clone(),
1944 best_score: 0.5,
1945 score_lift: None,
1946 metric_name: "test".into(),
1947 strategy: "manual".into(),
1948 timestamp: "0".into(),
1949 };
1950
1951 let mut model = NearestNeighborMetaModel::new();
1952 model.fit(&[record]);
1953
1954 let (pipeline, _extracted, predicted) =
1955 SphereQLPipeline::new_from_metamodel(input, &model).unwrap();
1956 assert_eq!(predicted.projection_kind, ProjectionKind::LaplacianEigenmap);
1957 assert_eq!(
1958 pipeline.projection_kind(),
1959 ProjectionKind::LaplacianEigenmap
1960 );
1961 }
1962
1963 #[test]
1964 fn new_from_metamodel_tuned_runs_and_carries_prediction() {
1965 use crate::corpus_features::CorpusFeatures;
1966 use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
1967 use crate::quality_metric::TerritorialHealth;
1968 use crate::tuner::{SearchSpace, SearchStrategy};
1969
1970 let (input, _) = make_input(20, 10);
1976 let features = CorpusFeatures::extract(&input.categories, &input.embeddings).unwrap();
1977
1978 let mut predicted_cfg = PipelineConfig::default();
1979 predicted_cfg.bridges.overlap_artifact_territorial = 0.123; let record = MetaTrainingRecord {
1982 corpus_id: "seed".into(),
1983 features: features.clone(),
1984 best_config: predicted_cfg.clone(),
1985 best_score: 0.5,
1986 score_lift: None,
1987 metric_name: "test".into(),
1988 strategy: "manual".into(),
1989 timestamp: "0".into(),
1990 };
1991 let mut model = NearestNeighborMetaModel::new();
1992 model.fit(&[record]);
1993
1994 let space = SearchSpace {
1996 projection_kinds: vec![ProjectionKind::Pca],
1997 laplacian_k_neighbors: vec![15],
1998 laplacian_active_threshold: vec![0.05],
1999 umap_n_neighbors: vec![15],
2000 umap_n_epochs: vec![200],
2001 umap_category_weight: vec![1.5],
2002 umap_min_dist: vec![0.1],
2003 num_domain_groups: vec![3, 5],
2004 low_evr_threshold: vec![0.35],
2005 overlap_artifact_territorial: vec![0.3], threshold_base: vec![0.5],
2007 threshold_evr_penalty: vec![0.4],
2008 min_evr_improvement: vec![0.10],
2009 };
2010
2011 let metric = TerritorialHealth;
2012 let (pipeline, _feats, report) = SphereQLPipeline::new_from_metamodel_tuned(
2013 input,
2014 &model,
2015 &space,
2016 &metric,
2017 SearchStrategy::Grid,
2018 )
2019 .unwrap();
2020
2021 assert_eq!(report.trials.len(), 2);
2027 for t in &report.trials {
2028 assert!((t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-9);
2029 }
2030 assert_eq!(pipeline.projection_kind(), ProjectionKind::Pca);
2031 }
2032
2033 #[test]
2034 fn new_from_metamodel_unfitted_model_returns_err() {
2035 use crate::meta_model::NearestNeighborMetaModel;
2036
2037 let (input, _) = make_input(20, 10);
2038 let model = NearestNeighborMetaModel::new();
2039 assert!(!model.is_fitted());
2040 match SphereQLPipeline::new_from_metamodel(input, &model) {
2041 Err(PipelineError::InvalidInput(msg)) => {
2042 assert!(msg.contains("unfitted"), "msg = {msg:?}");
2043 }
2044 Err(other) => panic!("expected InvalidInput, got {other:?}"),
2045 Ok(_) => panic!("expected Err for unfitted model"),
2046 }
2047 }
2048
2049 #[test]
2050 fn new_from_metamodel_tuned_unfitted_model_returns_err() {
2051 use crate::meta_model::DistanceWeightedMetaModel;
2052 use crate::quality_metric::TerritorialHealth;
2053 use crate::tuner::{SearchSpace, SearchStrategy};
2054
2055 let (input, _) = make_input(20, 10);
2056 let model = DistanceWeightedMetaModel::new();
2057 match SphereQLPipeline::new_from_metamodel_tuned(
2058 input,
2059 &model,
2060 &SearchSpace::default(),
2061 &TerritorialHealth,
2062 SearchStrategy::Grid,
2063 ) {
2064 Err(PipelineError::InvalidInput(msg)) => {
2065 assert!(msg.contains("unfitted"), "msg = {msg:?}");
2066 }
2067 Err(other) => panic!("expected InvalidInput, got {other:?}"),
2068 Ok(_) => panic!("expected Err for unfitted model"),
2069 }
2070 }
2071
2072 #[test]
2073 fn new_from_metamodel_empty_input_maps_to_invalid_input() {
2074 use crate::corpus_features::CorpusFeatures;
2075 use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
2076
2077 let (seed_input, _) = make_input(20, 10);
2078 let features =
2079 CorpusFeatures::extract(&seed_input.categories, &seed_input.embeddings).unwrap();
2080 let record = MetaTrainingRecord {
2081 corpus_id: "seed".into(),
2082 features,
2083 best_config: PipelineConfig::default(),
2084 best_score: 0.5,
2085 score_lift: None,
2086 metric_name: "test".into(),
2087 strategy: "manual".into(),
2088 timestamp: "0".into(),
2089 };
2090 let mut model = NearestNeighborMetaModel::new();
2091 model.fit(&[record]);
2092
2093 let empty = PipelineInput {
2094 categories: Vec::new(),
2095 embeddings: Vec::new(),
2096 };
2097 match SphereQLPipeline::new_from_metamodel(empty, &model) {
2098 Err(PipelineError::InvalidInput(msg)) => {
2099 assert!(msg.contains("empty"), "msg = {msg:?}");
2100 }
2101 Err(other) => panic!("expected InvalidInput, got {other:?}"),
2102 Ok(_) => panic!("expected Err for empty input"),
2103 }
2104 }
2105
2106 fn two_cluster_input(n_per: usize, dim: usize) -> (PipelineInput, PipelineQuery) {
2112 let mut embeddings = Vec::with_capacity(2 * n_per);
2113 let mut categories = Vec::with_capacity(2 * n_per);
2114 for i in 0..n_per {
2115 let t = i as f64 * 0.001;
2116 let mut v = vec![0.0; dim];
2117 v[0] = 1.0 + t;
2118 v[1] = 0.1 + t;
2119 embeddings.push(v);
2120 categories.push("a".into());
2121 }
2122 for i in 0..n_per {
2123 let t = i as f64 * 0.001;
2124 let mut v = vec![0.0; dim];
2125 v[3] = 1.0 + t;
2126 v[4] = 0.1 + t;
2127 embeddings.push(v);
2128 categories.push("b".into());
2129 }
2130 let query = PipelineQuery {
2131 embedding: {
2132 let mut q = vec![0.0; dim];
2133 q[0] = 1.0;
2134 q[1] = 0.05;
2135 q
2136 },
2137 };
2138 (
2139 PipelineInput {
2140 categories,
2141 embeddings,
2142 },
2143 query,
2144 )
2145 }
2146
2147 #[test]
2148 fn group_inner_spheres_built_when_two_or_more_groups() {
2149 let (input, _) = two_cluster_input(30, 8);
2150 let pipeline = SphereQLPipeline::new_with_config(
2151 input,
2152 PipelineConfig {
2153 routing: crate::config::RoutingConfig {
2154 num_domain_groups: 2,
2155 group_routing_alpha: 0.8,
2156 low_evr_threshold: 0.35,
2157 },
2158 inner_sphere: crate::config::InnerSphereConfig {
2164 min_evr_improvement: -1.0,
2165 ..Default::default()
2166 },
2167 ..Default::default()
2168 },
2169 )
2170 .unwrap();
2171 let layer = pipeline.category_layer();
2172 assert_eq!(layer.domain_groups.len(), 2);
2173 assert!(
2174 !layer.group_inner_spheres.is_empty(),
2175 "at least one group should produce a group-level inner sphere"
2176 );
2177 }
2178
2179 #[test]
2180 fn default_nearest_routes_through_group_inner_sphere() {
2181 let (input, query) = two_cluster_input(30, 8);
2182 let pipeline = SphereQLPipeline::new_with_config(
2183 input,
2184 PipelineConfig {
2185 routing: crate::config::RoutingConfig {
2186 num_domain_groups: 2,
2187 group_routing_alpha: 0.99,
2188 low_evr_threshold: 0.35,
2189 },
2190 ..Default::default()
2191 },
2192 )
2193 .unwrap();
2194
2195 let results = pipeline.default_nearest(&Embedding::new(query.embedding.clone()), 5);
2196 assert!(!results.is_empty());
2197 for r in &results {
2199 assert_eq!(r.category, "a", "got {r:?}");
2200 }
2201 }
2202
2203 #[test]
2204 fn default_nearest_bypasses_group_routing_at_high_evr() {
2205 let (input, query) = two_cluster_input(30, 8);
2210 let pipeline = SphereQLPipeline::new_with_config(
2211 input,
2212 PipelineConfig {
2213 routing: crate::config::RoutingConfig {
2214 num_domain_groups: 2,
2215 group_routing_alpha: 0.99, low_evr_threshold: 0.35,
2217 },
2218 ..Default::default()
2219 },
2220 )
2221 .unwrap();
2222
2223 let evr = pipeline.explained_variance_ratio();
2226 assert!(
2227 evr >= HIGH_EVR_ROUTING_BYPASS,
2228 "fixture no longer exercises the high-EVR bypass path (evr={evr})"
2229 );
2230
2231 let results = pipeline.default_nearest(&Embedding::new(query.embedding.clone()), 5);
2234 assert!(!results.is_empty());
2235 for r in &results {
2237 assert_eq!(
2238 r.category, "a",
2239 "high-EVR outer-sphere should find cluster-a items, got category={} id={}",
2240 r.category, r.id
2241 );
2242 }
2243 }
2244
2245 #[test]
2246 fn default_nearest_falls_back_when_alpha_zero() {
2247 let (input, query) = two_cluster_input(30, 8);
2251 let pipeline = SphereQLPipeline::new_with_config(
2252 input,
2253 PipelineConfig {
2254 routing: crate::config::RoutingConfig {
2255 num_domain_groups: 2,
2256 group_routing_alpha: 0.0,
2257 low_evr_threshold: 0.35,
2258 },
2259 ..Default::default()
2260 },
2261 )
2262 .unwrap();
2263 let results = pipeline.default_nearest(&Embedding::new(query.embedding.clone()), 5);
2264 assert!(!results.is_empty());
2265 }
2266
2267 #[test]
2270 #[cfg(feature = "retain-embeddings")]
2271 fn raw_embeddings_retained() {
2272 let input = PipelineInput {
2273 categories: vec!["a".into(), "a".into(), "b".into()],
2274 embeddings: vec![
2275 vec![1.0, 0.0, 0.0, 0.0],
2276 vec![0.9, 0.1, 0.0, 0.0],
2277 vec![0.0, 0.0, 1.0, 0.0],
2278 ],
2279 };
2280 let pipeline = SphereQLPipeline::new(input).unwrap();
2281 let raw = pipeline
2282 .raw_embeddings()
2283 .expect("should be Some with feature");
2284 assert_eq!(raw.len(), 3);
2285 assert_eq!(raw[0], vec![1.0, 0.0, 0.0, 0.0]);
2286 }
2287
2288 #[test]
2289 #[cfg(feature = "retain-embeddings")]
2290 fn pairwise_similarities_bounded() {
2291 let input = PipelineInput {
2292 categories: vec!["a".into(), "b".into(), "c".into()],
2293 embeddings: vec![
2294 vec![1.0, 0.0, 0.0],
2295 vec![0.0, 1.0, 0.0],
2296 vec![1.0, 1.0, 0.0],
2297 ],
2298 };
2299 let pipeline = SphereQLPipeline::new(input).unwrap();
2300 let sims = pipeline.pairwise_similarities().unwrap().unwrap();
2301 assert_eq!(sims.len(), 3);
2302 for &s in &sims {
2303 assert!((-1.0..=1.0).contains(&s));
2304 }
2305 let sim_01 = sims[sphereql_core::upper_triangle_index(0, 1, 3)];
2306 let sim_02 = sims[sphereql_core::upper_triangle_index(0, 2, 3)];
2307 assert!(sim_02 > sim_01);
2308 }
2309
2310 #[test]
2311 #[cfg(feature = "retain-embeddings")]
2312 fn nearest_by_embedding_finds_closest() {
2313 let input = PipelineInput {
2314 categories: vec!["a".into(), "b".into(), "c".into()],
2315 embeddings: vec![
2316 vec![1.0, 0.0, 0.0],
2317 vec![0.0, 1.0, 0.0],
2318 vec![0.0, 0.0, 1.0],
2319 ],
2320 };
2321 let pipeline = SphereQLPipeline::new(input).unwrap();
2322 let results = pipeline
2323 .nearest_by_embedding(&[0.9, 0.1, 0.0], 2)
2324 .unwrap()
2325 .unwrap();
2326 assert_eq!(results.len(), 2);
2327 assert_eq!(results[0].0, 0);
2328 assert!(results[0].1 > results[1].1);
2329 }
2330
2331 #[test]
2332 #[cfg(feature = "retain-embeddings")]
2333 fn nearest_by_embedding_handles_ties_and_large_k() {
2334 let input = PipelineInput {
2338 categories: vec!["a".into(), "b".into(), "b".into(), "c".into()],
2339 embeddings: vec![
2340 vec![0.0, 0.0, 1.0],
2341 vec![1.0, 0.0, 0.0],
2342 vec![1.0, 0.0, 0.0],
2343 vec![0.5, 0.0, 0.5],
2344 ],
2345 };
2346 let pipeline = SphereQLPipeline::new(input).unwrap();
2347 let results = pipeline
2348 .nearest_by_embedding(&[1.0, 0.0, 0.0], 10)
2349 .unwrap()
2350 .unwrap();
2351 assert_eq!(results.len(), 4);
2352 assert_eq!(results[0].0, 1);
2353 assert_eq!(results[1].0, 2);
2354 for w in results.windows(2) {
2355 assert!(w[0].1 >= w[1].1);
2356 }
2357
2358 let top2 = pipeline
2359 .nearest_by_embedding(&[1.0, 0.0, 0.0], 2)
2360 .unwrap()
2361 .unwrap();
2362 assert_eq!(
2363 top2.iter().map(|r| r.0).collect::<Vec<_>>(),
2364 vec![1, 2],
2365 "bounded heap must keep the lower index on exact ties"
2366 );
2367
2368 let none = pipeline
2369 .nearest_by_embedding(&[1.0, 0.0, 0.0], 0)
2370 .unwrap()
2371 .unwrap();
2372 assert!(none.is_empty());
2373 }
2374
2375 #[test]
2376 #[cfg(feature = "retain-embeddings")]
2377 fn nearest_by_embedding_dimension_mismatch() {
2378 let input = PipelineInput {
2379 categories: vec!["a".into(), "b".into(), "c".into()],
2380 embeddings: vec![
2381 vec![1.0, 0.0, 0.0],
2382 vec![0.0, 1.0, 0.0],
2383 vec![0.0, 0.0, 1.0],
2384 ],
2385 };
2386 let pipeline = SphereQLPipeline::new(input).unwrap();
2387 let result = pipeline.nearest_by_embedding(&[1.0, 0.0], 1).unwrap();
2388 assert!(result.is_err());
2389 }
2390
2391 #[test]
2392 fn pipeline_with_min_category_size_still_indexes_all_items() {
2393 let mut embeddings = Vec::new();
2394 let mut categories = Vec::new();
2395 for i in 0..15 {
2396 let mut v = vec![0.0; 8];
2397 v[0] = 1.0 + (i as f64 * 0.01);
2398 embeddings.push(v);
2399 categories.push("big".into());
2400 }
2401 embeddings.push(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]);
2402 categories.push("tiny_a".into());
2403 embeddings.push(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
2404 categories.push("tiny_b".into());
2405
2406 let pipeline = SphereQLPipeline::new_with_config(
2407 PipelineInput {
2408 categories,
2409 embeddings,
2410 },
2411 PipelineConfig {
2412 min_category_size: 5,
2413 ..Default::default()
2414 },
2415 )
2416 .unwrap();
2417
2418 assert_eq!(pipeline.num_items(), 17);
2419 assert!(pipeline.has_id("s-0015"));
2420 assert!(pipeline.has_id("s-0016"));
2421
2422 assert_eq!(pipeline.num_categories(), 1);
2423 let cats = pipeline.unique_categories();
2424 assert_eq!(cats, vec!["big".to_string()]);
2425 }
2426
2427 #[test]
2428 fn nearest_group_returns_consistent_order() {
2429 let (input, _) = two_cluster_input(30, 8);
2430 let pipeline = SphereQLPipeline::new(input).unwrap();
2431 let layer = pipeline.category_layer();
2432 if let Some(g0) = layer.domain_groups.first() {
2434 let pos = g0.centroid;
2435 let (gi, d_near, d_second) = layer.nearest_group(&pos).unwrap();
2436 assert_eq!(gi, 0);
2437 assert!(d_near <= d_second);
2438 }
2439 }
2440}