1use sphereql_core::*;
2use sphereql_index::SpatialItem;
3
4use crate::category::{
5 BridgeItem, CategoryLayer, CategoryPath, CategorySummary, DrillDownResult, InnerSphereReport,
6};
7use crate::confidence::{ProjectionWarning, QualityConfig, QualitySignal};
8use crate::config::{PipelineConfig, ProjectionKind};
9use crate::configured_projection::ConfiguredProjection;
10use crate::corpus_features::CorpusFeatures;
11use crate::domain_groups::{DomainGroup, detect_domain_groups};
12use crate::kernel_pca::KernelPcaProjection;
13use crate::laplacian::LaplacianEigenmapProjection;
14use crate::meta_model::MetaModel;
15use crate::projection::{PcaProjection, Projection};
16use crate::quality_metric::QualityMetric;
17use crate::query::{EmbeddingIndex, GlobResult, SlicingManifold};
18use crate::tuner::{SearchSpace, SearchStrategy, TuneReport, auto_tune};
19use crate::types::{Embedding, RadialStrategy};
20
21#[derive(Debug, Clone, thiserror::Error)]
25pub enum PipelineError {
26 #[error("categories length ({cat}) must equal embeddings length ({emb})")]
29 LengthMismatch { cat: usize, emb: usize },
30 #[error("need at least 3 embeddings, got {0}")]
32 TooFewEmbeddings(usize),
33 #[error("projection fit failed: {0}")]
37 Projection(#[from] crate::projection::ProjectionError),
38 #[error("unknown category: {0:?}")]
43 UnknownCategory(String),
44 #[error("unknown id: {0:?}")]
47 UnknownId(String),
48 #[error("auto_tune produced no successful trials ({} failures)", failures.len())]
55 AllTrialsFailed {
56 failures: Vec<(crate::config::PipelineConfig, String)>,
57 },
58}
59
60pub struct PipelineInput {
68 pub categories: Vec<String>,
69 pub embeddings: Vec<Vec<f64>>,
70}
71
72pub struct PipelineQuery {
75 pub embedding: Vec<f64>,
76}
77
78#[derive(Debug, Clone)]
87pub struct NearestResult {
88 pub id: String,
90 pub category: String,
92 pub distance: f64,
95 pub certainty: f64,
97 pub intensity: f64,
99 pub quality: Option<QualitySignal>,
104}
105
106#[derive(Debug, Clone)]
109pub struct PathResult {
110 pub steps: Vec<PipelinePathStep>,
111 pub total_distance: f64,
112}
113
114#[derive(Debug, Clone)]
116pub struct PipelinePathStep {
117 pub id: String,
118 pub category: String,
119 pub cumulative_distance: f64,
120 pub hop_distance: f64,
122 pub bridge_strength: Option<f64>,
124}
125
126#[derive(Debug, Clone)]
128pub struct GlobSummary {
129 pub id: usize,
130 pub centroid: [f64; 3],
131 pub member_count: usize,
132 pub radius: f64,
133 pub top_categories: Vec<(String, usize)>,
134}
135
136#[derive(Debug, Clone)]
138pub struct ManifoldResult {
139 pub centroid: [f64; 3],
140 pub normal: [f64; 3],
141 pub variance_ratio: f64,
142}
143
144#[derive(Debug, Clone)]
146pub enum SphereQLOutput {
147 Nearest(Vec<NearestResult>),
148 KNearest(Vec<NearestResult>),
149 ConceptPath(Option<PathResult>),
150 Globs(Vec<GlobSummary>),
151 LocalManifold(ManifoldResult),
152 CategoryConceptPath(Option<CategoryPath>),
155 CategoryNeighbors(Vec<CategorySummary>),
157 DrillDown(Vec<DrillDownResult>),
159 CategoryStats {
161 summaries: Vec<CategorySummary>,
162 inner_sphere_reports: Vec<InnerSphereReport>,
163 },
164}
165
166pub enum SphereQLQuery<'a> {
168 Nearest { k: usize },
170 SimilarAbove { min_cosine: f64 },
172 ConceptPath {
174 source_id: &'a str,
175 target_id: &'a str,
176 graph_k: usize,
177 },
178 DetectGlobs { k: Option<usize>, max_k: usize },
180 LocalManifold { neighborhood_k: usize },
182 CategoryConceptPath {
185 source_category: &'a str,
186 target_category: &'a str,
187 },
188 CategoryNeighbors { category: &'a str, k: usize },
190 DrillDown { category: &'a str, k: usize },
193 CategoryStats,
195}
196
197#[derive(Debug, Clone, serde::Serialize)]
199pub struct ExportedPoint {
200 pub id: String,
201 pub category: String,
202 pub r: f64,
203 pub theta: f64,
204 pub phi: f64,
205 pub x: f64,
206 pub y: f64,
207 pub z: f64,
208 pub certainty: f64,
209 pub intensity: f64,
210}
211
212pub struct SphereQLPipeline {
222 projection: ConfiguredProjection,
223 index: EmbeddingIndex<ConfiguredProjection>,
224 categories: Vec<String>,
225 cart_points: Vec<[f64; 3]>,
226 ids: Vec<String>,
227 category_layer: CategoryLayer,
229 quality_config: QualityConfig,
231 projection_warnings: Vec<ProjectionWarning>,
233 domain_groups: Vec<DomainGroup>,
237 config: PipelineConfig,
239}
240
241impl SphereQLPipeline {
242 pub fn new(input: PipelineInput) -> Result<Self, PipelineError> {
248 Self::new_with_config(input, PipelineConfig::default())
249 }
250
251 pub fn new_with_config(
255 input: PipelineInput,
256 config: PipelineConfig,
257 ) -> Result<Self, PipelineError> {
258 let embeddings: Vec<Embedding> = input
259 .embeddings
260 .iter()
261 .map(|v| Embedding::new(v.clone()))
262 .collect();
263
264 let projection = fit_projection_for_config(&embeddings, &config)?;
265 Self::with_configured_projection_and_config(
266 input.categories,
267 embeddings,
268 projection,
269 config,
270 )
271 }
272
273 pub fn new_from_metamodel<M: MetaModel>(
287 input: PipelineInput,
288 model: &M,
289 ) -> Result<(Self, CorpusFeatures, PipelineConfig), PipelineError> {
290 let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
291 let predicted = model.predict(&features);
292 let pipeline = Self::new_with_config(input, predicted.clone())?;
293 Ok((pipeline, features, predicted))
294 }
295
296 pub fn new_from_metamodel_tuned<M, Q>(
311 input: PipelineInput,
312 model: &M,
313 space: &SearchSpace,
314 metric: &Q,
315 strategy: SearchStrategy,
316 ) -> Result<(Self, CorpusFeatures, TuneReport), PipelineError>
317 where
318 M: MetaModel,
319 Q: QualityMetric,
320 {
321 let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
322 let predicted = model.predict(&features);
323 let (pipeline, report) = auto_tune(input, space, metric, strategy, &predicted)?;
324 Ok((pipeline, features, report))
325 }
326
327 pub fn with_projection(
334 categories: Vec<String>,
335 embeddings: Vec<Embedding>,
336 pca: PcaProjection,
337 ) -> Result<Self, PipelineError> {
338 Self::with_configured_projection_and_config(
339 categories,
340 embeddings,
341 ConfiguredProjection::Pca(pca),
342 PipelineConfig::default(),
343 )
344 }
345
346 pub fn with_projection_and_config(
349 categories: Vec<String>,
350 embeddings: Vec<Embedding>,
351 pca: PcaProjection,
352 config: PipelineConfig,
353 ) -> Result<Self, PipelineError> {
354 Self::with_configured_projection_and_config(
355 categories,
356 embeddings,
357 ConfiguredProjection::Pca(pca),
358 config,
359 )
360 }
361
362 pub fn with_configured_projection_and_config(
365 categories: Vec<String>,
366 embeddings: Vec<Embedding>,
367 projection: ConfiguredProjection,
368 config: PipelineConfig,
369 ) -> Result<Self, PipelineError> {
370 let n = embeddings.len();
371 if n != categories.len() {
372 return Err(PipelineError::LengthMismatch {
373 cat: categories.len(),
374 emb: n,
375 });
376 }
377 if n < 3 {
378 return Err(PipelineError::TooFewEmbeddings(n));
379 }
380
381 let mut index = EmbeddingIndex::builder(projection.clone())
382 .uniform_shells(10, 1.0)
383 .theta_divisions(12)
384 .phi_divisions(6)
385 .build();
386
387 let mut ids = Vec::with_capacity(n);
388 for (i, emb) in embeddings.iter().enumerate() {
389 let id = format!("s-{i:04}");
390 index.insert(&id, emb);
391 ids.push(id);
392 }
393
394 let projected_positions: Vec<SphericalPoint> =
401 embeddings.iter().map(|e| projection.project(e)).collect();
402 let cart_points: Vec<[f64; 3]> = projected_positions
403 .iter()
404 .map(|sp| {
405 let c = spherical_to_cartesian(sp);
406 [c.x, c.y, c.z]
407 })
408 .collect();
409
410 let evr = projection.explained_variance_ratio();
411 let category_layer = CategoryLayer::build_with_config(
412 &categories,
413 &embeddings,
414 &projected_positions,
415 &projection,
416 evr,
417 &config,
418 );
419
420 let quality_config = QualityConfig::default();
421 let projection_warnings = ProjectionWarning::from_evr(evr, quality_config.warn_below_evr)
422 .into_iter()
423 .collect();
424
425 let domain_groups = detect_domain_groups(&category_layer, config.routing.num_domain_groups);
426
427 Ok(Self {
428 projection,
429 index,
430 categories,
431 cart_points,
432 ids,
433 category_layer,
434 quality_config,
435 projection_warnings,
436 domain_groups,
437 config,
438 })
439 }
440
441 pub fn has_category(&self, name: &str) -> bool {
446 self.category_layer.name_to_index.contains_key(name)
447 }
448
449 pub fn has_id(&self, id: &str) -> bool {
451 self.index.get(id).is_some()
452 }
453
454 pub fn query(
463 &self,
464 q: SphereQLQuery<'_>,
465 query_embedding: &PipelineQuery,
466 ) -> Result<SphereQLOutput, PipelineError> {
467 let emb = Embedding::new(query_embedding.embedding.clone());
468
469 match q {
470 SphereQLQuery::Nearest { k } => {
471 let evr = self.projection.explained_variance_ratio();
472 let results = self.index.search_nearest(&emb, k);
473 Ok(SphereQLOutput::Nearest(
474 results
475 .iter()
476 .map(|r| {
477 let certainty = r.item.certainty();
478 let quality = QualitySignal::from_certainty(evr, certainty);
479 NearestResult {
480 id: r.item.id.clone(),
481 category: self.cat_for(&r.item.id),
482 distance: r.distance,
483 certainty,
484 intensity: r.item.intensity(),
485 quality: Some(quality),
486 }
487 })
488 .filter(|r| self.passes_quality(r))
489 .collect(),
490 ))
491 }
492
493 SphereQLQuery::SimilarAbove { min_cosine } => {
494 let evr = self.projection.explained_variance_ratio();
495 let results = self.index.search_similar(&emb, min_cosine);
496 let sp_q = self.projection.project(&emb);
497 Ok(SphereQLOutput::KNearest(
498 results
499 .items
500 .iter()
501 .map(|item| {
502 let d = angular_distance(&sp_q, item.position());
503 let certainty = item.certainty();
504 let quality = QualitySignal::from_certainty(evr, certainty);
505 NearestResult {
506 id: item.id.clone(),
507 category: self.cat_for(&item.id),
508 distance: d,
509 certainty,
510 intensity: item.intensity(),
511 quality: Some(quality),
512 }
513 })
514 .filter(|r| self.passes_quality(r))
515 .collect(),
516 ))
517 }
518
519 SphereQLQuery::ConceptPath {
520 source_id,
521 target_id,
522 graph_k,
523 } => {
524 if !self.has_id(source_id) {
525 return Err(PipelineError::UnknownId(source_id.to_string()));
526 }
527 if !self.has_id(target_id) {
528 return Err(PipelineError::UnknownId(target_id.to_string()));
529 }
530 let path = self.index.concept_path(source_id, target_id, graph_k);
531 Ok(SphereQLOutput::ConceptPath(path.map(|p| {
532 PathResult {
533 total_distance: p.total_distance,
534 steps: p
535 .steps
536 .iter()
537 .map(|s| PipelinePathStep {
538 id: s.id.clone(),
539 category: self.cat_for(&s.id),
540 cumulative_distance: s.cumulative_distance,
541 hop_distance: s.hop_distance,
542 bridge_strength: s.bridge_strength,
543 })
544 .collect(),
545 }
546 })))
547 }
548
549 SphereQLQuery::DetectGlobs { k, max_k } => {
550 let result = GlobResult::detect(&self.cart_points, &self.ids, k, max_k);
551 Ok(SphereQLOutput::Globs(
552 result
553 .globs
554 .iter()
555 .map(|g| {
556 let mut cat_counts = std::collections::HashMap::<String, usize>::new();
557 for mid in &g.member_ids {
558 let cat = self.cat_for(mid);
559 *cat_counts.entry(cat).or_default() += 1;
560 }
561 let mut top: Vec<_> = cat_counts.into_iter().collect();
562 top.sort_by_key(|(_, c)| std::cmp::Reverse(*c));
563 top.truncate(3);
564
565 GlobSummary {
566 id: g.id,
567 centroid: g.centroid,
568 member_count: g.member_ids.len(),
569 radius: g.radius,
570 top_categories: top,
571 }
572 })
573 .collect(),
574 ))
575 }
576
577 SphereQLQuery::LocalManifold { neighborhood_k } => {
578 let sp = self.projection.project(&emb);
579 let c = spherical_to_cartesian(&sp);
580 let qpt = [c.x, c.y, c.z];
581 let m = SlicingManifold::fit_local(&qpt, &self.cart_points, neighborhood_k);
582 Ok(SphereQLOutput::LocalManifold(ManifoldResult {
583 centroid: m.centroid,
584 normal: m.normal,
585 variance_ratio: m.variance_ratio,
586 }))
587 }
588
589 SphereQLQuery::CategoryConceptPath {
591 source_category,
592 target_category,
593 } => {
594 if !self.has_category(source_category) {
595 return Err(PipelineError::UnknownCategory(source_category.to_string()));
596 }
597 if !self.has_category(target_category) {
598 return Err(PipelineError::UnknownCategory(target_category.to_string()));
599 }
600 let path = self
601 .category_layer
602 .category_path(source_category, target_category);
603 Ok(SphereQLOutput::CategoryConceptPath(path))
604 }
605
606 SphereQLQuery::CategoryNeighbors { category, k } => {
607 if !self.has_category(category) {
608 return Err(PipelineError::UnknownCategory(category.to_string()));
609 }
610 let neighbors = self.category_layer.category_neighbors(category, k);
611 Ok(SphereQLOutput::CategoryNeighbors(
612 neighbors.into_iter().cloned().collect(),
613 ))
614 }
615
616 SphereQLQuery::DrillDown { category, k } => {
617 if !self.has_category(category) {
618 return Err(PipelineError::UnknownCategory(category.to_string()));
619 }
620 let results = self.category_layer.drill_down_with_projection(
621 category,
622 &emb,
623 &self.projection,
624 k,
625 );
626 Ok(SphereQLOutput::DrillDown(results))
627 }
628
629 SphereQLQuery::CategoryStats => Ok(SphereQLOutput::CategoryStats {
630 summaries: self.category_layer.summaries.clone(),
631 inner_sphere_reports: self.category_layer.inner_sphere_stats(),
632 }),
633 }
634 }
635
636 fn cat_for(&self, id: &str) -> String {
638 if let Some(idx_str) = id.strip_prefix("s-")
639 && let Ok(idx) = idx_str.parse::<usize>()
640 && idx < self.categories.len()
641 {
642 return self.categories[idx].clone();
643 }
644 "unknown".into()
645 }
646
647 pub fn num_items(&self) -> usize {
649 self.ids.len()
650 }
651
652 pub fn categories(&self) -> &[String] {
654 &self.categories
655 }
656
657 pub fn projected_points(&self) -> Vec<(&str, &str, [f64; 3])> {
659 self.ids
660 .iter()
661 .enumerate()
662 .map(|(i, id)| {
663 let cat = self
664 .categories
665 .get(i)
666 .map(|s| s.as_str())
667 .unwrap_or("unknown");
668 (id.as_str(), cat, self.cart_points[i])
669 })
670 .collect()
671 }
672
673 pub fn projection(&self) -> &ConfiguredProjection {
682 &self.projection
683 }
684
685 pub fn projection_kind(&self) -> ProjectionKind {
687 self.projection.kind()
688 }
689
690 pub fn exported_points(&self) -> Vec<ExportedPoint> {
694 self.ids
695 .iter()
696 .enumerate()
697 .map(|(i, id)| {
698 let [x, y, z] = self.cart_points[i];
699 let category = self
700 .categories
701 .get(i)
702 .cloned()
703 .unwrap_or_else(|| "unknown".into());
704 let item = self.index.get(id);
705 let (r, theta, phi) = item
706 .map(|it| {
707 let pos = it.position();
708 (pos.r, pos.theta, pos.phi)
709 })
710 .unwrap_or((0.0, 0.0, 0.0));
711 let certainty = item.map_or(1.0, |it| it.certainty());
712 let intensity = item.map_or(1.0, |it| it.intensity());
713 ExportedPoint {
714 id: id.clone(),
715 category,
716 r,
717 theta,
718 phi,
719 x,
720 y,
721 z,
722 certainty,
723 intensity,
724 }
725 })
726 .collect()
727 }
728
729 pub fn explained_variance_ratio(&self) -> f64 {
736 self.projection.explained_variance_ratio()
737 }
738
739 pub fn num_categories(&self) -> usize {
741 self.category_layer.num_categories()
742 }
743
744 pub fn unique_categories(&self) -> Vec<String> {
746 self.category_layer
747 .summaries
748 .iter()
749 .map(|s| s.name.clone())
750 .collect()
751 }
752
753 pub fn category_layer(&self) -> &CategoryLayer {
757 &self.category_layer
758 }
759
760 pub fn category_path(&self, source: &str, target: &str) -> Option<CategoryPath> {
762 self.category_layer.category_path(source, target)
763 }
764
765 pub fn bridge_items(&self, source: &str, target: &str, max: usize) -> Vec<&BridgeItem> {
767 self.category_layer.bridge_items(source, target, max)
768 }
769
770 pub fn has_inner_sphere(&self, category: &str) -> bool {
772 self.category_layer.has_inner_sphere(category)
773 }
774
775 pub fn num_inner_spheres(&self) -> usize {
777 self.category_layer.num_inner_spheres()
778 }
779
780 pub fn inner_sphere_stats(&self) -> Vec<InnerSphereReport> {
782 self.category_layer.inner_sphere_stats()
783 }
784
785 pub fn projection_warnings(&self) -> &[ProjectionWarning] {
787 &self.projection_warnings
788 }
789
790 pub fn domain_groups(&self) -> &[DomainGroup] {
794 &self.domain_groups
795 }
796
797 pub fn route_to_group(&self, embedding: &Embedding) -> Option<&DomainGroup> {
800 if self.domain_groups.is_empty() {
801 return None;
802 }
803 let pos = self.projection.project(embedding);
804 self.domain_groups.iter().min_by(|a, b| {
805 let da = angular_distance(&pos, &a.centroid);
806 let db = angular_distance(&pos, &b.centroid);
807 da.total_cmp(&db)
808 })
809 }
810
811 pub fn hierarchical_nearest(&self, embedding: &Embedding, k: usize) -> Vec<NearestResult> {
823 let evr = self.projection.explained_variance_ratio();
824
825 if evr >= self.config.routing.low_evr_threshold {
826 return self.nearest_filtered(embedding, k, evr);
827 }
828
829 let Some(group) = self.route_to_group(embedding) else {
830 return self.nearest_filtered(embedding, k, evr);
831 };
832
833 let mut candidates: Vec<NearestResult> = Vec::new();
836 for &ci in &group.member_categories {
837 let cat_name = &self.category_layer.summaries[ci].name;
838 for r in self.category_layer.drill_down_with_projection(
839 cat_name,
840 embedding,
841 &self.projection,
842 k,
843 ) {
844 candidates.push(self.drill_result_to_nearest(&r, evr));
845 }
846 }
847
848 candidates.sort_by(|a, b| {
849 a.distance
850 .partial_cmp(&b.distance)
851 .unwrap_or(std::cmp::Ordering::Equal)
852 });
853 let filtered: Vec<NearestResult> = candidates
854 .into_iter()
855 .filter(|r| self.passes_quality(r))
856 .take(k)
857 .collect();
858
859 if filtered.is_empty() {
866 self.nearest_filtered(embedding, k, evr)
867 } else {
868 filtered
869 }
870 }
871
872 #[inline]
881 fn passes_quality(&self, r: &NearestResult) -> bool {
882 r.certainty >= self.quality_config.min_certainty
883 && r.quality
884 .is_none_or(|q| q.passes_threshold(self.quality_config.min_combined))
885 }
886
887 fn nearest_filtered(&self, embedding: &Embedding, k: usize, evr: f64) -> Vec<NearestResult> {
889 self.index
890 .search_nearest(embedding, k)
891 .iter()
892 .map(|r| {
893 let certainty = r.item.certainty();
894 let quality = QualitySignal::from_certainty(evr, certainty);
895 NearestResult {
896 id: r.item.id.clone(),
897 category: self.cat_for(&r.item.id),
898 distance: r.distance,
899 certainty,
900 intensity: r.item.intensity(),
901 quality: Some(quality),
902 }
903 })
904 .filter(|r| self.passes_quality(r))
905 .collect()
906 }
907
908 fn drill_result_to_nearest(&self, r: &DrillDownResult, evr: f64) -> NearestResult {
909 let id = self.ids[r.item_index].clone();
910 let item = self.index.get(&id);
911 let certainty = item.map_or(1.0, |it| it.certainty());
912 let intensity = item.map_or(1.0, |it| it.intensity());
913 let quality = QualitySignal::from_certainty(evr, certainty);
914 NearestResult {
915 id,
916 category: self
917 .categories
918 .get(r.item_index)
919 .cloned()
920 .unwrap_or_else(|| "unknown".into()),
921 distance: r.distance,
922 certainty,
923 intensity,
924 quality: Some(quality),
925 }
926 }
927
928 pub fn quality_config(&self) -> &QualityConfig {
930 &self.quality_config
931 }
932
933 pub fn set_quality_config(&mut self, config: QualityConfig) {
935 self.quality_config = config;
936 }
937
938 pub fn config(&self) -> &PipelineConfig {
940 &self.config
941 }
942
943 pub fn to_json(&self) -> String {
945 serde_json::to_string(&self.exported_points())
946 .expect("ExportedPoint is always serializable")
947 }
948
949 pub fn to_csv(&self) -> String {
954 let points = self.exported_points();
955 let mut out = String::from("id,category,r,theta,phi,x,y,z,certainty,intensity\n");
956 for p in &points {
957 out.push_str(&format!(
958 "\"{}\",\"{}\",{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6},{:.6}\n",
959 p.id.replace('"', "\"\""),
960 p.category.replace('"', "\"\""),
961 p.r,
962 p.theta,
963 p.phi,
964 p.x,
965 p.y,
966 p.z,
967 p.certainty,
968 p.intensity,
969 ));
970 }
971 out
972 }
973}
974
975pub fn fit_projection_for_config(
980 embeddings: &[Embedding],
981 config: &PipelineConfig,
982) -> Result<ConfiguredProjection, crate::projection::ProjectionError> {
983 match config.projection_kind {
984 ProjectionKind::Pca => Ok(ConfiguredProjection::Pca(
985 PcaProjection::fit(embeddings, RadialStrategy::Magnitude)?.with_volumetric(true),
986 )),
987 ProjectionKind::KernelPca => Ok(ConfiguredProjection::KernelPca(KernelPcaProjection::fit(
988 embeddings,
989 RadialStrategy::Magnitude,
990 )?)),
991 ProjectionKind::LaplacianEigenmap => {
992 let lc = &config.laplacian;
993 Ok(ConfiguredProjection::Laplacian(
994 LaplacianEigenmapProjection::fit_with_params(
995 embeddings,
996 lc.k_neighbors,
997 lc.active_threshold,
998 RadialStrategy::Magnitude,
999 )?,
1000 ))
1001 }
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::*;
1008
1009 fn make_input(n: usize, dim: usize) -> (PipelineInput, PipelineQuery) {
1010 let mut embeddings = Vec::with_capacity(n);
1011 let mut categories = Vec::with_capacity(n);
1012 for i in 0..n {
1013 let mut v = vec![0.0; dim];
1014 if i < n / 2 {
1015 v[0] = 1.0 + (i as f64 * 0.01);
1016 v[1] = 0.1;
1017 categories.push("group_a".into());
1018 } else {
1019 v[0] = 0.1;
1020 v[1] = 1.0 + (i as f64 * 0.01);
1021 categories.push("group_b".into());
1022 }
1023 v[2] = 0.05 * (i as f64);
1024 embeddings.push(v);
1025 }
1026 let query = PipelineQuery {
1027 embedding: vec![0.9; dim],
1028 };
1029 (
1030 PipelineInput {
1031 categories,
1032 embeddings,
1033 },
1034 query,
1035 )
1036 }
1037
1038 #[test]
1041 fn pipeline_nearest() {
1042 let (input, query) = make_input(20, 10);
1043 let pipeline = SphereQLPipeline::new(input).unwrap();
1044 let result = pipeline
1045 .query(SphereQLQuery::Nearest { k: 5 }, &query)
1046 .unwrap();
1047 match result {
1048 SphereQLOutput::Nearest(items) => {
1049 assert_eq!(items.len(), 5);
1050 assert!(items[0].distance <= items[1].distance);
1051 }
1052 _ => panic!("expected Nearest"),
1053 }
1054 }
1055
1056 #[test]
1057 fn pipeline_globs() {
1058 let (input, query) = make_input(30, 10);
1059 let pipeline = SphereQLPipeline::new(input).unwrap();
1060 let result = pipeline
1061 .query(
1062 SphereQLQuery::DetectGlobs {
1063 k: Some(2),
1064 max_k: 5,
1065 },
1066 &query,
1067 )
1068 .unwrap();
1069 match result {
1070 SphereQLOutput::Globs(globs) => {
1071 assert_eq!(globs.len(), 2);
1072 let total: usize = globs.iter().map(|g| g.member_count).sum();
1073 assert_eq!(total, 30);
1074 }
1075 _ => panic!("expected Globs"),
1076 }
1077 }
1078
1079 #[test]
1080 fn pipeline_concept_path() {
1081 let (input, query) = make_input(20, 10);
1082 let pipeline = SphereQLPipeline::new(input).unwrap();
1083 let result = pipeline
1084 .query(
1085 SphereQLQuery::ConceptPath {
1086 source_id: "s-0000",
1087 target_id: "s-0015",
1088 graph_k: 10,
1089 },
1090 &query,
1091 )
1092 .unwrap();
1093 match result {
1094 SphereQLOutput::ConceptPath(Some(path)) => {
1095 assert!(path.steps.len() >= 2);
1096 assert_eq!(path.steps.first().unwrap().id, "s-0000");
1097 assert_eq!(path.steps.last().unwrap().id, "s-0015");
1098 }
1099 _ => panic!("expected ConceptPath(Some)"),
1100 }
1101 }
1102
1103 #[test]
1104 fn pipeline_local_manifold() {
1105 let (input, query) = make_input(20, 10);
1106 let pipeline = SphereQLPipeline::new(input).unwrap();
1107 let result = pipeline
1108 .query(SphereQLQuery::LocalManifold { neighborhood_k: 10 }, &query)
1109 .unwrap();
1110 match result {
1111 SphereQLOutput::LocalManifold(m) => {
1112 assert!(m.variance_ratio > 0.0);
1113 assert!(m.variance_ratio <= 1.0);
1114 }
1115 _ => panic!("expected LocalManifold"),
1116 }
1117 }
1118
1119 #[test]
1120 fn test_exported_points_count() {
1121 let (input, _) = make_input(20, 10);
1122 let pipeline = SphereQLPipeline::new(input).unwrap();
1123 assert_eq!(pipeline.exported_points().len(), 20);
1124 }
1125
1126 #[test]
1127 fn test_exported_points_fields() {
1128 let (input, _) = make_input(20, 10);
1129 let pipeline = SphereQLPipeline::new(input).unwrap();
1130 for p in pipeline.exported_points() {
1131 assert!(p.r >= 0.0, "r must be non-negative");
1132 assert!(
1133 p.theta >= 0.0 && p.theta < std::f64::consts::TAU,
1134 "theta out of range"
1135 );
1136 assert!(
1137 p.phi >= 0.0 && p.phi <= std::f64::consts::PI,
1138 "phi out of range"
1139 );
1140 }
1141 }
1142
1143 #[test]
1144 fn test_exported_points_categories() {
1145 let (input, _) = make_input(20, 10);
1146 let pipeline = SphereQLPipeline::new(input).unwrap();
1147 let points = pipeline.exported_points();
1148 for (i, p) in points.iter().enumerate() {
1149 let expected = if i < 10 { "group_a" } else { "group_b" };
1150 assert_eq!(p.category, expected);
1151 }
1152 }
1153
1154 #[test]
1155 fn test_to_json_parseable() {
1156 let (input, _) = make_input(20, 10);
1157 let pipeline = SphereQLPipeline::new(input).unwrap();
1158 let json = pipeline.to_json();
1159 let parsed: Vec<serde_json::Value> = serde_json::from_str(&json).expect("valid JSON");
1160 assert_eq!(parsed.len(), 20);
1161 }
1162
1163 #[test]
1164 fn test_to_csv_lines() {
1165 let (input, _) = make_input(20, 10);
1166 let pipeline = SphereQLPipeline::new(input).unwrap();
1167 let csv = pipeline.to_csv();
1168 let lines: Vec<&str> = csv.lines().collect();
1169 assert_eq!(
1170 lines[0],
1171 "id,category,r,theta,phi,x,y,z,certainty,intensity"
1172 );
1173 assert_eq!(lines.len(), 21);
1174 }
1175
1176 #[test]
1177 fn test_to_csv_quoted_fields() {
1178 let (input, _) = make_input(20, 10);
1179 let pipeline = SphereQLPipeline::new(input).unwrap();
1180 let csv = pipeline.to_csv();
1181 let data_line = csv.lines().nth(1).unwrap();
1182 assert!(data_line.starts_with('"'), "id field should be quoted");
1183 }
1184
1185 #[test]
1186 fn test_explained_variance() {
1187 let (input, _) = make_input(20, 10);
1188 let pipeline = SphereQLPipeline::new(input).unwrap();
1189 let ratio = pipeline.explained_variance_ratio();
1190 assert!(ratio > 0.0 && ratio <= 1.0);
1191 }
1192
1193 #[test]
1194 fn test_unique_categories() {
1195 let (input, _) = make_input(20, 10);
1196 let pipeline = SphereQLPipeline::new(input).unwrap();
1197 let cats = pipeline.unique_categories();
1198 assert_eq!(cats.len(), 2);
1199 assert_eq!(cats[0], "group_a");
1200 assert_eq!(cats[1], "group_b");
1201 assert_eq!(pipeline.num_categories(), 2);
1202 }
1203
1204 #[test]
1207 fn pipeline_builds_category_layer() {
1208 let (input, _) = make_input(20, 10);
1209 let pipeline = SphereQLPipeline::new(input).unwrap();
1210 assert_eq!(pipeline.category_layer().num_categories(), 2);
1211 }
1212
1213 #[test]
1214 fn pipeline_category_path_query() {
1215 let (input, query) = make_input(20, 10);
1216 let pipeline = SphereQLPipeline::new(input).unwrap();
1217 let result = pipeline
1218 .query(
1219 SphereQLQuery::CategoryConceptPath {
1220 source_category: "group_a",
1221 target_category: "group_b",
1222 },
1223 &query,
1224 )
1225 .unwrap();
1226 match result {
1227 SphereQLOutput::CategoryConceptPath(Some(path)) => {
1228 assert!(path.steps.len() >= 2);
1229 assert_eq!(path.steps.first().unwrap().category_name, "group_a");
1230 assert_eq!(path.steps.last().unwrap().category_name, "group_b");
1231 assert!(path.total_distance > 0.0);
1232 }
1233 _ => panic!("expected CategoryConceptPath(Some)"),
1234 }
1235 }
1236
1237 #[test]
1238 fn pipeline_category_path_shortcut() {
1239 let (input, _) = make_input(20, 10);
1240 let pipeline = SphereQLPipeline::new(input).unwrap();
1241 let path = pipeline.category_path("group_a", "group_b");
1242 assert!(path.is_some());
1243 let path = path.unwrap();
1244 assert_eq!(path.steps.first().unwrap().category_name, "group_a");
1245 assert_eq!(path.steps.last().unwrap().category_name, "group_b");
1246 }
1247
1248 #[test]
1249 fn pipeline_category_path_unknown() {
1250 let (input, _) = make_input(20, 10);
1251 let pipeline = SphereQLPipeline::new(input).unwrap();
1252 assert!(pipeline.category_path("group_a", "nonexistent").is_none());
1253 }
1254
1255 #[test]
1256 fn pipeline_category_neighbors_query() {
1257 let (input, query) = make_input(20, 10);
1258 let pipeline = SphereQLPipeline::new(input).unwrap();
1259 let result = pipeline
1260 .query(
1261 SphereQLQuery::CategoryNeighbors {
1262 category: "group_a",
1263 k: 5,
1264 },
1265 &query,
1266 )
1267 .unwrap();
1268 match result {
1269 SphereQLOutput::CategoryNeighbors(neighbors) => {
1270 assert_eq!(neighbors.len(), 1);
1271 assert_eq!(neighbors[0].name, "group_b");
1272 }
1273 _ => panic!("expected CategoryNeighbors"),
1274 }
1275 }
1276
1277 #[test]
1278 fn pipeline_drill_down_query() {
1279 let (input, query) = make_input(20, 10);
1280 let pipeline = SphereQLPipeline::new(input).unwrap();
1281 let result = pipeline
1282 .query(
1283 SphereQLQuery::DrillDown {
1284 category: "group_a",
1285 k: 5,
1286 },
1287 &query,
1288 )
1289 .unwrap();
1290 match result {
1291 SphereQLOutput::DrillDown(results) => {
1292 assert!(!results.is_empty());
1293 assert!(results.len() <= 5);
1294 for w in results.windows(2) {
1295 assert!(w[0].distance <= w[1].distance);
1296 }
1297 }
1298 _ => panic!("expected DrillDown"),
1299 }
1300 }
1301
1302 #[test]
1303 fn pipeline_category_stats_query() {
1304 let (input, query) = make_input(20, 10);
1305 let pipeline = SphereQLPipeline::new(input).unwrap();
1306 let result = pipeline
1307 .query(SphereQLQuery::CategoryStats, &query)
1308 .unwrap();
1309 match result {
1310 SphereQLOutput::CategoryStats {
1311 summaries,
1312 inner_sphere_reports,
1313 } => {
1314 assert_eq!(summaries.len(), 2);
1315 assert_eq!(inner_sphere_reports.len(), 0);
1316 }
1317 _ => panic!("expected CategoryStats"),
1318 }
1319 }
1320
1321 #[test]
1322 fn pipeline_bridge_items_shortcut() {
1323 let (input, _) = make_input(20, 10);
1324 let pipeline = SphereQLPipeline::new(input).unwrap();
1325 let _ = pipeline.bridge_items("group_a", "group_b", 5);
1326 }
1327
1328 #[test]
1329 fn pipeline_inner_sphere_shortcuts() {
1330 let (input, _) = make_input(20, 10);
1331 let pipeline = SphereQLPipeline::new(input).unwrap();
1332 assert!(!pipeline.has_inner_sphere("group_a"));
1333 assert_eq!(pipeline.num_inner_spheres(), 0);
1334 assert!(pipeline.inner_sphere_stats().is_empty());
1335 }
1336
1337 #[test]
1338 fn pipeline_category_layer_accessor() {
1339 let (input, _) = make_input(20, 10);
1340 let pipeline = SphereQLPipeline::new(input).unwrap();
1341 let layer = pipeline.category_layer();
1342 assert_eq!(layer.num_categories(), 2);
1343 assert!(layer.get_category("group_a").is_some());
1344 assert!(layer.get_category("group_b").is_some());
1345 }
1346
1347 #[test]
1350 fn domain_groups_detected() {
1351 let (input, _) = make_input(20, 10);
1352 let pipeline = SphereQLPipeline::new(input).unwrap();
1353 let groups = pipeline.domain_groups();
1354 assert!(!groups.is_empty());
1355 let total: usize = groups.iter().map(|g| g.total_items).sum();
1356 assert_eq!(total, pipeline.num_items());
1357 }
1358
1359 #[test]
1360 fn domain_groups_cover_all_categories() {
1361 let (input, _) = make_input(20, 10);
1362 let pipeline = SphereQLPipeline::new(input).unwrap();
1363 let groups = pipeline.domain_groups();
1364 let mut all_cats: Vec<usize> = groups
1365 .iter()
1366 .flat_map(|g| g.member_categories.iter().copied())
1367 .collect();
1368 all_cats.sort();
1369 all_cats.dedup();
1370 assert_eq!(all_cats.len(), pipeline.num_categories());
1371 }
1372
1373 #[test]
1374 fn route_to_group_returns_something() {
1375 let (input, _) = make_input(20, 10);
1376 let pipeline = SphereQLPipeline::new(input).unwrap();
1377 let emb = Embedding::new(vec![0.5; 10]);
1378 assert!(pipeline.route_to_group(&emb).is_some());
1379 }
1380
1381 #[test]
1382 fn hierarchical_nearest_matches_standard_when_evr_high() {
1383 let (input, query) = make_input(20, 10);
1387 let pipeline = SphereQLPipeline::new(input).unwrap();
1388 let hier = pipeline.hierarchical_nearest(&Embedding::new(query.embedding.clone()), 5);
1389 assert!(!hier.is_empty());
1390 assert!(hier.len() <= 5);
1391 for w in hier.windows(2) {
1392 assert!(w[0].distance <= w[1].distance);
1393 }
1394 }
1395
1396 #[test]
1397 fn hierarchical_nearest_falls_back_when_filter_kills_candidates() {
1398 let (input, query) = make_input(20, 10);
1405 let mut pipeline = SphereQLPipeline::new_with_config(
1406 input,
1407 PipelineConfig {
1408 routing: crate::config::RoutingConfig {
1409 num_domain_groups: 2,
1410 low_evr_threshold: 1.1, },
1412 ..Default::default()
1413 },
1414 )
1415 .unwrap();
1416 pipeline.set_quality_config(crate::confidence::QualityConfig {
1417 min_certainty: 1.1, ..Default::default()
1419 });
1420
1421 pipeline.set_quality_config(crate::confidence::QualityConfig::default());
1429 let hier = pipeline.hierarchical_nearest(&Embedding::new(query.embedding.clone()), 5);
1430 assert!(
1431 !hier.is_empty(),
1432 "low-EVR branch should return results with default filter"
1433 );
1434 }
1435
1436 #[test]
1437 fn feedback_aggregator_derive_and_save_load_round_trip() {
1438 use crate::feedback::{FeedbackAggregator, FeedbackEvent};
1443 let mut agg = FeedbackAggregator::default();
1444 agg.record(FeedbackEvent {
1445 corpus_id: "c".into(),
1446 query_id: "q".into(),
1447 score: 0.5,
1448 timestamp: "0".into(),
1449 });
1450
1451 let json_via_derive = serde_json::to_string(&agg).unwrap();
1452 assert!(json_via_derive.starts_with('['));
1454
1455 let dir = std::env::temp_dir();
1457 let path = dir.join(format!(
1458 "sphereql_serde_transparent_{}.json",
1459 std::process::id()
1460 ));
1461 std::fs::write(&path, &json_via_derive).unwrap();
1462 let loaded = FeedbackAggregator::load(&path).unwrap();
1463 assert_eq!(loaded.len(), 1);
1464 let _ = std::fs::remove_file(&path);
1465 }
1466
1467 #[test]
1468 fn new_from_metamodel_uses_predicted_config() {
1469 use crate::corpus_features::CorpusFeatures;
1470 use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
1471
1472 let (input, _) = make_input(20, 10);
1473 let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
1474
1475 let target_config = PipelineConfig {
1479 projection_kind: ProjectionKind::LaplacianEigenmap,
1480 ..Default::default()
1481 };
1482 let record = MetaTrainingRecord {
1483 corpus_id: "seed".into(),
1484 features: features.clone(),
1485 best_config: target_config.clone(),
1486 best_score: 0.5,
1487 metric_name: "test".into(),
1488 strategy: "manual".into(),
1489 timestamp: "0".into(),
1490 };
1491
1492 let mut model = NearestNeighborMetaModel::new();
1493 model.fit(&[record]);
1494
1495 let (pipeline, _extracted, predicted) =
1496 SphereQLPipeline::new_from_metamodel(input, &model).unwrap();
1497 assert_eq!(predicted.projection_kind, ProjectionKind::LaplacianEigenmap);
1498 assert_eq!(
1499 pipeline.projection_kind(),
1500 ProjectionKind::LaplacianEigenmap
1501 );
1502 }
1503
1504 #[test]
1505 fn new_from_metamodel_tuned_runs_and_carries_prediction() {
1506 use crate::corpus_features::CorpusFeatures;
1507 use crate::meta_model::{MetaTrainingRecord, NearestNeighborMetaModel};
1508 use crate::quality_metric::TerritorialHealth;
1509 use crate::tuner::{SearchSpace, SearchStrategy};
1510
1511 let (input, _) = make_input(20, 10);
1517 let features = CorpusFeatures::extract(&input.categories, &input.embeddings);
1518
1519 let mut predicted_cfg = PipelineConfig::default();
1520 predicted_cfg.bridges.overlap_artifact_territorial = 0.123; let record = MetaTrainingRecord {
1523 corpus_id: "seed".into(),
1524 features: features.clone(),
1525 best_config: predicted_cfg.clone(),
1526 best_score: 0.5,
1527 metric_name: "test".into(),
1528 strategy: "manual".into(),
1529 timestamp: "0".into(),
1530 };
1531 let mut model = NearestNeighborMetaModel::new();
1532 model.fit(&[record]);
1533
1534 let space = SearchSpace {
1536 projection_kinds: vec![ProjectionKind::Pca],
1537 laplacian_k_neighbors: vec![15],
1538 laplacian_active_threshold: vec![0.05],
1539 num_domain_groups: vec![3, 5],
1540 low_evr_threshold: vec![0.35],
1541 overlap_artifact_territorial: vec![0.3], threshold_base: vec![0.5],
1543 threshold_evr_penalty: vec![0.4],
1544 min_evr_improvement: vec![0.10],
1545 };
1546
1547 let metric = TerritorialHealth;
1548 let (pipeline, _feats, report) = SphereQLPipeline::new_from_metamodel_tuned(
1549 input,
1550 &model,
1551 &space,
1552 &metric,
1553 SearchStrategy::Grid,
1554 )
1555 .unwrap();
1556
1557 assert_eq!(report.trials.len(), 2);
1563 for t in &report.trials {
1564 assert!((t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-9);
1565 }
1566 assert_eq!(pipeline.projection_kind(), ProjectionKind::Pca);
1567 }
1568}