1use std::collections::HashMap;
2
3use sphereql_core::{SphericalPoint, angular_distance};
4
5use crate::kernel_pca::KernelPcaProjection;
6use crate::projection::{PcaProjection, Projection};
7use crate::types::{Embedding, RadialStrategy};
8
9const MIN_INNER_SPHERE_SIZE: usize = 20;
13
14const MIN_EVR_IMPROVEMENT: f64 = 0.10;
16
17const KERNEL_PCA_MIN_SIZE: usize = 80;
19
20const MIN_KERNEL_IMPROVEMENT: f64 = 0.05;
22
23#[derive(Debug, Clone)]
31pub struct CategorySummary {
32 pub name: String,
34 pub member_indices: Vec<usize>,
36 pub centroid_embedding: Vec<f64>,
39 pub centroid_position: SphericalPoint,
41 pub angular_spread: f64,
44 pub cohesion: f64,
47 pub member_count: usize,
49}
50
51#[derive(Debug, Clone)]
59pub struct BridgeItem {
60 pub item_index: usize,
62 pub source_category: usize,
64 pub target_category: usize,
66 pub affinity_to_source: f64,
68 pub affinity_to_target: f64,
70 pub bridge_strength: f64,
73}
74
75#[derive(Debug, Clone)]
79pub struct CategoryEdge {
80 pub target: usize,
82 pub centroid_distance: f64,
84 pub bridge_count: usize,
86 pub weight: f64,
89}
90
91#[derive(Debug, Clone)]
93pub struct CategoryGraph {
94 pub adjacency: Vec<Vec<CategoryEdge>>,
96 pub bridges: HashMap<(usize, usize), Vec<BridgeItem>>,
99}
100
101#[derive(Debug, Clone)]
105pub struct CategoryPathStep {
106 pub category_index: usize,
108 pub category_name: String,
110 pub cumulative_distance: f64,
112 pub bridges_to_next: Vec<BridgeItem>,
114}
115
116#[derive(Debug, Clone)]
118pub struct CategoryPath {
119 pub steps: Vec<CategoryPathStep>,
121 pub total_distance: f64,
123}
124
125#[derive(Clone)]
133pub enum InnerProjection {
134 LinearPca(PcaProjection),
137 KernelPca(KernelPcaProjection),
140}
141
142impl std::fmt::Debug for InnerProjection {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 match self {
145 Self::LinearPca(_) => write!(f, "LinearPca"),
146 Self::KernelPca(_) => write!(f, "KernelPca"),
147 }
148 }
149}
150
151impl Projection for InnerProjection {
152 fn project(&self, embedding: &Embedding) -> SphericalPoint {
153 match self {
154 Self::LinearPca(p) => p.project(embedding),
155 Self::KernelPca(p) => p.project(embedding),
156 }
157 }
158 fn project_rich(&self, embedding: &Embedding) -> crate::types::ProjectedPoint {
159 match self {
160 Self::LinearPca(p) => p.project_rich(embedding),
161 Self::KernelPca(p) => p.project_rich(embedding),
162 }
163 }
164 fn dimensionality(&self) -> usize {
165 match self {
166 Self::LinearPca(p) => p.dimensionality(),
167 Self::KernelPca(p) => p.dimensionality(),
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
181pub struct InnerSphere {
182 pub projection: InnerProjection,
184 pub inner_positions: Vec<SphericalPoint>,
187 pub member_indices: Vec<usize>,
189 pub explained_variance_ratio: f64,
191 pub global_subset_evr: f64,
194 pub evr_improvement: f64,
196}
197
198#[derive(Debug, Clone)]
200pub struct DrillDownResult {
201 pub item_index: usize,
203 pub distance: f64,
205 pub used_inner_sphere: bool,
208}
209
210#[derive(Debug, Clone)]
213pub struct InnerSphereReport {
214 pub category_name: String,
216 pub category_index: usize,
218 pub member_count: usize,
220 pub projection_type: &'static str,
222 pub inner_evr: f64,
224 pub global_subset_evr: f64,
226 pub evr_improvement: f64,
228}
229
230#[derive(Debug, Clone)]
240pub struct CategoryLayer {
241 pub summaries: Vec<CategorySummary>,
243 pub name_to_index: HashMap<String, usize>,
245 pub graph: CategoryGraph,
247 outer_positions: Vec<SphericalPoint>,
249 pub inner_spheres: HashMap<usize, InnerSphere>,
252}
253
254impl CategoryLayer {
255 pub fn build<P: Projection>(
272 categories: &[String],
273 embeddings: &[Embedding],
274 projected_positions: &[SphericalPoint],
275 projection: &P,
276 ) -> Self {
277 let n = categories.len();
278 assert_eq!(n, embeddings.len());
279 assert_eq!(n, projected_positions.len());
280
281 let mut name_to_index: HashMap<String, usize> = HashMap::new();
283 let mut cat_names: Vec<String> = Vec::new();
284 let mut cat_members: Vec<Vec<usize>> = Vec::new();
285
286 for (i, cat) in categories.iter().enumerate() {
287 let idx = if let Some(&idx) = name_to_index.get(cat) {
288 idx
289 } else {
290 let idx = cat_names.len();
291 name_to_index.insert(cat.clone(), idx);
292 cat_names.push(cat.clone());
293 cat_members.push(Vec::new());
294 idx
295 };
296 cat_members[idx].push(i);
297 }
298
299 let num_cats = cat_names.len();
300 let dim = if n > 0 { embeddings[0].dimension() } else { 0 };
301
302 let mut summaries: Vec<CategorySummary> = Vec::with_capacity(num_cats);
304
305 for (ci, name) in cat_names.iter().enumerate() {
306 let members = &cat_members[ci];
307 let count = members.len();
308
309 let mut centroid_emb = vec![0.0; dim];
311 for &mi in members {
312 for (j, &v) in embeddings[mi].values.iter().enumerate() {
313 centroid_emb[j] += v;
314 }
315 }
316 if count > 0 {
317 for v in &mut centroid_emb {
318 *v /= count as f64;
319 }
320 }
321
322 let centroid_embedding_obj = Embedding::new(centroid_emb.clone());
324 let centroid_position = projection.project(¢roid_embedding_obj);
325
326 let angular_spread = if count > 1 {
328 let total: f64 = members
329 .iter()
330 .map(|&mi| angular_distance(&projected_positions[mi], ¢roid_position))
331 .sum();
332 total / count as f64
333 } else {
334 0.0
335 };
336
337 let cohesion = 1.0 / (1.0 + angular_spread);
338
339 summaries.push(CategorySummary {
340 name: name.clone(),
341 member_indices: members.clone(),
342 centroid_embedding: centroid_emb,
343 centroid_position,
344 angular_spread,
345 cohesion,
346 member_count: count,
347 });
348 }
349
350 let graph = Self::build_graph(&summaries, embeddings, num_cats);
352
353 let inner_spheres = Self::build_inner_spheres(&summaries, embeddings, projection);
355
356 CategoryLayer {
357 summaries,
358 name_to_index,
359 graph,
360 outer_positions: projected_positions.to_vec(),
361 inner_spheres,
362 }
363 }
364
365 fn build_graph(
367 summaries: &[CategorySummary],
368 embeddings: &[Embedding],
369 num_cats: usize,
370 ) -> CategoryGraph {
371 let mut centroid_dists = vec![vec![0.0; num_cats]; num_cats];
373 for i in 0..num_cats {
374 for j in (i + 1)..num_cats {
375 let d = angular_distance(
376 &summaries[i].centroid_position,
377 &summaries[j].centroid_position,
378 );
379 centroid_dists[i][j] = d;
380 centroid_dists[j][i] = d;
381 }
382 }
383
384 let mut bridges: HashMap<(usize, usize), Vec<BridgeItem>> = HashMap::new();
386
387 for (ci, summary) in summaries.iter().enumerate() {
388 let centroid_a = &summary.centroid_embedding;
389
390 for &mi in &summary.member_indices {
391 let item_emb = &embeddings[mi];
392 let sim_to_own = cosine_similarity(&item_emb.values, centroid_a);
393
394 for (cj, other_summary) in summaries.iter().enumerate() {
395 if ci == cj {
396 continue;
397 }
398
399 let sim_to_other =
400 cosine_similarity(&item_emb.values, &other_summary.centroid_embedding);
401
402 if sim_to_other > 0.0 && sim_to_other > sim_to_own * 0.5 {
403 let bridge_strength = if sim_to_own + sim_to_other > f64::EPSILON {
404 2.0 * sim_to_own * sim_to_other / (sim_to_own + sim_to_other)
405 } else {
406 0.0
407 };
408
409 bridges.entry((ci, cj)).or_default().push(BridgeItem {
410 item_index: mi,
411 source_category: ci,
412 target_category: cj,
413 affinity_to_source: sim_to_own,
414 affinity_to_target: sim_to_other,
415 bridge_strength,
416 });
417 }
418 }
419 }
420 }
421
422 for list in bridges.values_mut() {
423 list.sort_by(|a, b| {
424 b.bridge_strength
425 .partial_cmp(&a.bridge_strength)
426 .unwrap_or(std::cmp::Ordering::Equal)
427 });
428 }
429
430 let mut adjacency: Vec<Vec<CategoryEdge>> = vec![Vec::new(); num_cats];
431 for i in 0..num_cats {
432 for (j, &cd) in centroid_dists[i].iter().enumerate() {
433 if i == j {
434 continue;
435 }
436 let bridge_count = bridges.get(&(i, j)).map_or(0, |b| b.len());
437 let weight = cd / (1.0 + bridge_count as f64);
438
439 adjacency[i].push(CategoryEdge {
440 target: j,
441 centroid_distance: cd,
442 bridge_count,
443 weight,
444 });
445 }
446 adjacency[i].sort_by(|a, b| {
447 a.weight
448 .partial_cmp(&b.weight)
449 .unwrap_or(std::cmp::Ordering::Equal)
450 });
451 }
452
453 CategoryGraph { adjacency, bridges }
454 }
455
456 fn build_inner_spheres<P: Projection>(
458 summaries: &[CategorySummary],
459 embeddings: &[Embedding],
460 projection: &P,
461 ) -> HashMap<usize, InnerSphere> {
462 let mut result = HashMap::new();
463
464 for (ci, summary) in summaries.iter().enumerate() {
465 if summary.member_count < MIN_INNER_SPHERE_SIZE {
466 continue;
467 }
468
469 let member_embs: Vec<Embedding> = summary
470 .member_indices
471 .iter()
472 .map(|&i| embeddings[i].clone())
473 .collect();
474
475 let global_subset_evr: f64 = member_embs
477 .iter()
478 .map(|e| projection.project_rich(e).certainty)
479 .sum::<f64>()
480 / member_embs.len() as f64;
481
482 let inner_pca = PcaProjection::fit(&member_embs, RadialStrategy::Fixed(1.0));
484 let inner_linear_evr = inner_pca.explained_variance_ratio();
485
486 if inner_linear_evr - global_subset_evr < MIN_EVR_IMPROVEMENT {
487 continue;
488 }
489
490 let (inner_proj, inner_evr) = if summary.member_count >= KERNEL_PCA_MIN_SIZE {
491 let inner_kpca = KernelPcaProjection::fit(&member_embs, RadialStrategy::Fixed(1.0));
492 let kernel_evr = inner_kpca.explained_variance_ratio();
493
494 if kernel_evr > inner_linear_evr + MIN_KERNEL_IMPROVEMENT {
495 (InnerProjection::KernelPca(inner_kpca), kernel_evr)
496 } else {
497 (InnerProjection::LinearPca(inner_pca), inner_linear_evr)
498 }
499 } else {
500 (InnerProjection::LinearPca(inner_pca), inner_linear_evr)
501 };
502
503 let inner_positions: Vec<SphericalPoint> =
504 member_embs.iter().map(|e| inner_proj.project(e)).collect();
505
506 result.insert(
507 ci,
508 InnerSphere {
509 projection: inner_proj,
510 inner_positions,
511 member_indices: summary.member_indices.clone(),
512 explained_variance_ratio: inner_evr,
513 global_subset_evr,
514 evr_improvement: inner_evr - global_subset_evr,
515 },
516 );
517 }
518
519 result
520 }
521
522 pub fn num_categories(&self) -> usize {
526 self.summaries.len()
527 }
528
529 pub fn get_category(&self, name: &str) -> Option<&CategorySummary> {
531 self.name_to_index
532 .get(name)
533 .map(|&idx| &self.summaries[idx])
534 }
535
536 pub fn category_neighbors(&self, category_name: &str, k: usize) -> Vec<&CategorySummary> {
538 let Some(&ci) = self.name_to_index.get(category_name) else {
539 return Vec::new();
540 };
541 self.graph.adjacency[ci]
542 .iter()
543 .take(k)
544 .map(|edge| &self.summaries[edge.target])
545 .collect()
546 }
547
548 pub fn bridge_items(
550 &self,
551 source_category: &str,
552 target_category: &str,
553 max_bridges: usize,
554 ) -> Vec<&BridgeItem> {
555 let Some(&si) = self.name_to_index.get(source_category) else {
556 return Vec::new();
557 };
558 let Some(&ti) = self.name_to_index.get(target_category) else {
559 return Vec::new();
560 };
561 self.graph
562 .bridges
563 .get(&(si, ti))
564 .map(|list| list.iter().take(max_bridges).collect())
565 .unwrap_or_default()
566 }
567
568 pub fn category_path(
570 &self,
571 source_category: &str,
572 target_category: &str,
573 ) -> Option<CategoryPath> {
574 let &si = self.name_to_index.get(source_category)?;
575 let &ti = self.name_to_index.get(target_category)?;
576 if si == ti {
577 return Some(CategoryPath {
578 steps: vec![CategoryPathStep {
579 category_index: si,
580 category_name: self.summaries[si].name.clone(),
581 cumulative_distance: 0.0,
582 bridges_to_next: Vec::new(),
583 }],
584 total_distance: 0.0,
585 });
586 }
587
588 let n = self.summaries.len();
589 let mut dist = vec![f64::INFINITY; n];
590 let mut prev: Vec<Option<usize>> = vec![None; n];
591 let mut visited = vec![false; n];
592
593 dist[si] = 0.0;
594
595 for _ in 0..n {
596 let mut u = None;
597 let mut best = f64::INFINITY;
598 for (i, (&d, &v)) in dist.iter().zip(visited.iter()).enumerate() {
599 if !v && d < best {
600 best = d;
601 u = Some(i);
602 }
603 }
604 let Some(u) = u else { break };
605 if u == ti {
606 break;
607 }
608 visited[u] = true;
609
610 for edge in &self.graph.adjacency[u] {
611 let nd = dist[u] + edge.weight;
612 if nd < dist[edge.target] {
613 dist[edge.target] = nd;
614 prev[edge.target] = Some(u);
615 }
616 }
617 }
618
619 if dist[ti].is_infinite() {
620 return None;
621 }
622
623 let mut path_indices = Vec::new();
624 let mut cur = ti;
625 loop {
626 path_indices.push(cur);
627 match prev[cur] {
628 Some(p) => cur = p,
629 None => break,
630 }
631 }
632 path_indices.reverse();
633
634 let mut steps = Vec::with_capacity(path_indices.len());
635 for (step_idx, &ci) in path_indices.iter().enumerate() {
636 let bridges_to_next = if step_idx + 1 < path_indices.len() {
637 let next_ci = path_indices[step_idx + 1];
638 self.graph
639 .bridges
640 .get(&(ci, next_ci))
641 .map(|list| list.iter().take(3).cloned().collect())
642 .unwrap_or_default()
643 } else {
644 Vec::new()
645 };
646
647 steps.push(CategoryPathStep {
648 category_index: ci,
649 category_name: self.summaries[ci].name.clone(),
650 cumulative_distance: dist[ci],
651 bridges_to_next,
652 });
653 }
654
655 Some(CategoryPath {
656 total_distance: dist[ti],
657 steps,
658 })
659 }
660
661 pub fn categories_near_embedding<P: Projection>(
664 &self,
665 embedding: &Embedding,
666 projection: &P,
667 max_angle: f64,
668 ) -> Vec<(usize, f64)> {
669 let pos = projection.project(embedding);
670 let mut results: Vec<(usize, f64)> = self
671 .summaries
672 .iter()
673 .enumerate()
674 .map(|(i, s)| (i, angular_distance(&pos, &s.centroid_position)))
675 .filter(|&(_, d)| d <= max_angle)
676 .collect();
677 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
678 results
679 }
680
681 pub fn has_inner_sphere(&self, category_name: &str) -> bool {
685 self.name_to_index
686 .get(category_name)
687 .is_some_and(|&ci| self.inner_spheres.contains_key(&ci))
688 }
689
690 pub fn get_inner_sphere(&self, category_name: &str) -> Option<&InnerSphere> {
692 self.name_to_index
693 .get(category_name)
694 .and_then(|&ci| self.inner_spheres.get(&ci))
695 }
696
697 pub fn num_inner_spheres(&self) -> usize {
699 self.inner_spheres.len()
700 }
701
702 pub fn drill_down(
708 &self,
709 category_name: &str,
710 embedding: &Embedding,
711 k: usize,
712 ) -> Vec<DrillDownResult> {
713 let Some(&ci) = self.name_to_index.get(category_name) else {
714 return Vec::new();
715 };
716 let summary = &self.summaries[ci];
717
718 if let Some(inner) = self.inner_spheres.get(&ci) {
719 let query_pos = inner.projection.project(embedding);
720 let mut results: Vec<DrillDownResult> = inner
721 .inner_positions
722 .iter()
723 .enumerate()
724 .map(|(local_idx, pos)| DrillDownResult {
725 item_index: inner.member_indices[local_idx],
726 distance: angular_distance(&query_pos, pos),
727 used_inner_sphere: true,
728 })
729 .collect();
730 results.sort_by(|a, b| {
731 a.distance
732 .partial_cmp(&b.distance)
733 .unwrap_or(std::cmp::Ordering::Equal)
734 });
735 results.truncate(k);
736 results
737 } else {
738 let centroid = &summary.centroid_position;
740 let mut results: Vec<DrillDownResult> = summary
741 .member_indices
742 .iter()
743 .map(|&mi| DrillDownResult {
744 item_index: mi,
745 distance: angular_distance(&self.outer_positions[mi], centroid),
746 used_inner_sphere: false,
747 })
748 .collect();
749 results.sort_by(|a, b| {
750 a.distance
751 .partial_cmp(&b.distance)
752 .unwrap_or(std::cmp::Ordering::Equal)
753 });
754 results.truncate(k);
755 results
756 }
757 }
758
759 pub fn drill_down_with_projection<P: Projection>(
764 &self,
765 category_name: &str,
766 embedding: &Embedding,
767 projection: &P,
768 k: usize,
769 ) -> Vec<DrillDownResult> {
770 let Some(&ci) = self.name_to_index.get(category_name) else {
771 return Vec::new();
772 };
773 let summary = &self.summaries[ci];
774
775 if let Some(inner) = self.inner_spheres.get(&ci) {
776 let query_pos = inner.projection.project(embedding);
777 let mut results: Vec<DrillDownResult> = inner
778 .inner_positions
779 .iter()
780 .enumerate()
781 .map(|(local_idx, pos)| DrillDownResult {
782 item_index: inner.member_indices[local_idx],
783 distance: angular_distance(&query_pos, pos),
784 used_inner_sphere: true,
785 })
786 .collect();
787 results.sort_by(|a, b| {
788 a.distance
789 .partial_cmp(&b.distance)
790 .unwrap_or(std::cmp::Ordering::Equal)
791 });
792 results.truncate(k);
793 results
794 } else {
795 let query_pos = projection.project(embedding);
796 let mut results: Vec<DrillDownResult> = summary
797 .member_indices
798 .iter()
799 .map(|&mi| DrillDownResult {
800 item_index: mi,
801 distance: angular_distance(&self.outer_positions[mi], &query_pos),
802 used_inner_sphere: false,
803 })
804 .collect();
805 results.sort_by(|a, b| {
806 a.distance
807 .partial_cmp(&b.distance)
808 .unwrap_or(std::cmp::Ordering::Equal)
809 });
810 results.truncate(k);
811 results
812 }
813 }
814
815 pub fn inner_sphere_stats(&self) -> Vec<InnerSphereReport> {
818 let mut reports: Vec<InnerSphereReport> = self
819 .inner_spheres
820 .iter()
821 .map(|(&ci, inner)| {
822 let proj_type = match &inner.projection {
823 InnerProjection::LinearPca(_) => "LinearPca",
824 InnerProjection::KernelPca(_) => "KernelPca",
825 };
826 InnerSphereReport {
827 category_name: self.summaries[ci].name.clone(),
828 category_index: ci,
829 member_count: inner.member_indices.len(),
830 projection_type: proj_type,
831 inner_evr: inner.explained_variance_ratio,
832 global_subset_evr: inner.global_subset_evr,
833 evr_improvement: inner.evr_improvement,
834 }
835 })
836 .collect();
837 reports.sort_by_key(|r| r.category_index);
838 reports
839 }
840}
841
842fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
845 let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
846 let mag_a = a.iter().map(|x| x * x).sum::<f64>().sqrt();
847 let mag_b = b.iter().map(|x| x * x).sum::<f64>().sqrt();
848 let denom = mag_a * mag_b;
849 if denom < f64::EPSILON {
850 return 0.0;
851 }
852 (dot / denom).clamp(-1.0, 1.0)
853}
854
855#[cfg(test)]
858mod tests {
859 use super::*;
860
861 fn emb(vals: &[f64]) -> Embedding {
862 Embedding::new(vals.to_vec())
863 }
864
865 fn test_corpus() -> (Vec<String>, Vec<Embedding>) {
868 let categories = vec![
869 "science".into(),
870 "science".into(),
871 "science".into(),
872 "science".into(),
873 "cooking".into(),
874 "cooking".into(),
875 "cooking".into(),
876 "cooking".into(),
877 "music".into(),
878 "music".into(),
879 "music".into(),
880 "music".into(),
881 ];
882 let embeddings = vec![
883 emb(&[1.0, 0.1, 0.0, 0.05, 0.02]),
884 emb(&[0.9, 0.15, 0.05, 0.03, 0.01]),
885 emb(&[0.95, 0.05, 0.1, 0.04, 0.03]),
886 emb(&[0.85, 0.2, 0.0, 0.06, 0.01]),
887 emb(&[0.1, 1.0, 0.0, 0.02, 0.05]),
888 emb(&[0.15, 0.9, 0.05, 0.03, 0.04]),
889 emb(&[0.05, 0.95, 0.1, 0.01, 0.06]),
890 emb(&[0.2, 0.85, 0.0, 0.04, 0.03]),
891 emb(&[0.0, 0.1, 1.0, 0.05, 0.02]),
892 emb(&[0.05, 0.15, 0.9, 0.03, 0.01]),
893 emb(&[0.1, 0.05, 0.95, 0.04, 0.03]),
894 emb(&[0.0, 0.2, 0.85, 0.06, 0.01]),
895 ];
896 (categories, embeddings)
897 }
898
899 fn build_test_layer() -> (CategoryLayer, Vec<Embedding>, PcaProjection) {
900 let (categories, embeddings) = test_corpus();
901 let pca = PcaProjection::fit(&embeddings, RadialStrategy::Fixed(1.0));
902 let projected: Vec<SphericalPoint> = embeddings.iter().map(|e| pca.project(e)).collect();
903 let layer = CategoryLayer::build(&categories, &embeddings, &projected, &pca);
904 (layer, embeddings, pca)
905 }
906
907 fn large_category_corpus() -> (Vec<String>, Vec<Embedding>) {
910 let mut categories = Vec::new();
911 let mut embeddings = Vec::new();
912
913 for i in 0..25 {
914 categories.push("big".into());
915 let t = i as f64 / 25.0;
916 let mut v = vec![0.0; 10];
917 v[0] = 1.0 + 0.3 * (t * std::f64::consts::TAU).sin();
918 v[1] = 0.5 + 0.3 * (t * std::f64::consts::TAU).cos();
919 v[2] = 0.2 * t;
920 for (d, slot) in v.iter_mut().enumerate().take(10).skip(3) {
921 *slot = 0.01 * ((i * 7 + d) as f64 % 1.0);
922 }
923 embeddings.push(emb(&v));
924 }
925
926 for i in 0..4 {
927 categories.push("small_a".into());
928 let mut v = vec![0.0; 10];
929 v[5] = 1.0 + 0.1 * i as f64;
930 v[6] = 0.05;
931 embeddings.push(emb(&v));
932 }
933
934 for i in 0..4 {
935 categories.push("small_b".into());
936 let mut v = vec![0.0; 10];
937 v[8] = 1.0 + 0.1 * i as f64;
938 v[9] = 0.05;
939 embeddings.push(emb(&v));
940 }
941
942 (categories, embeddings)
943 }
944
945 fn build_large_test_layer() -> (CategoryLayer, Vec<Embedding>, PcaProjection) {
946 let (categories, embeddings) = large_category_corpus();
947 let pca = PcaProjection::fit(&embeddings, RadialStrategy::Fixed(1.0));
948 let projected: Vec<SphericalPoint> = embeddings.iter().map(|e| pca.project(e)).collect();
949 let layer = CategoryLayer::build(&categories, &embeddings, &projected, &pca);
950 (layer, embeddings, pca)
951 }
952
953 #[test]
956 fn builds_correct_number_of_categories() {
957 let (layer, _, _) = build_test_layer();
958 assert_eq!(layer.num_categories(), 3);
959 }
960
961 #[test]
962 fn category_names_correct() {
963 let (layer, _, _) = build_test_layer();
964 let names: Vec<&str> = layer.summaries.iter().map(|s| s.name.as_str()).collect();
965 assert!(names.contains(&"science"));
966 assert!(names.contains(&"cooking"));
967 assert!(names.contains(&"music"));
968 }
969
970 #[test]
971 fn member_counts_correct() {
972 let (layer, _, _) = build_test_layer();
973 for summary in &layer.summaries {
974 assert_eq!(summary.member_count, 4);
975 assert_eq!(summary.member_indices.len(), 4);
976 }
977 }
978
979 #[test]
980 fn centroid_embedding_is_mean() {
981 let (layer, embeddings, _) = build_test_layer();
982 let science = layer.get_category("science").unwrap();
983 let mut expected = vec![0.0; 5];
984 for emb in embeddings.iter().take(4) {
985 for (j, &v) in emb.values.iter().enumerate() {
986 expected[j] += v;
987 }
988 }
989 for v in &mut expected {
990 *v /= 4.0;
991 }
992 for (j, (&actual, &exp)) in science
993 .centroid_embedding
994 .iter()
995 .zip(expected.iter())
996 .enumerate()
997 {
998 assert!(
999 (actual - exp).abs() < 1e-10,
1000 "centroid dim {j}: {actual} != {exp}"
1001 );
1002 }
1003 }
1004
1005 #[test]
1006 fn angular_spread_is_nonnegative() {
1007 let (layer, _, _) = build_test_layer();
1008 for s in &layer.summaries {
1009 assert!(s.angular_spread >= 0.0);
1010 }
1011 }
1012
1013 #[test]
1014 fn cohesion_in_range() {
1015 let (layer, _, _) = build_test_layer();
1016 for s in &layer.summaries {
1017 assert!(s.cohesion > 0.0 && s.cohesion <= 1.0);
1018 }
1019 }
1020
1021 #[test]
1022 fn graph_has_edges_for_all_pairs() {
1023 let (layer, _, _) = build_test_layer();
1024 for (i, edges) in layer.graph.adjacency.iter().enumerate() {
1025 assert_eq!(edges.len(), layer.num_categories() - 1, "cat {i}");
1026 }
1027 }
1028
1029 #[test]
1030 fn edge_weights_positive() {
1031 let (layer, _, _) = build_test_layer();
1032 for edges in &layer.graph.adjacency {
1033 for e in edges {
1034 assert!(e.weight > 0.0);
1035 assert!(e.centroid_distance > 0.0);
1036 }
1037 }
1038 }
1039
1040 #[test]
1041 fn edges_sorted_by_weight() {
1042 let (layer, _, _) = build_test_layer();
1043 for edges in &layer.graph.adjacency {
1044 for w in edges.windows(2) {
1045 assert!(w[0].weight <= w[1].weight);
1046 }
1047 }
1048 }
1049
1050 #[test]
1051 fn get_category_by_name() {
1052 let (layer, _, _) = build_test_layer();
1053 assert!(layer.get_category("science").is_some());
1054 assert!(layer.get_category("astrology").is_none());
1055 }
1056
1057 #[test]
1058 fn category_neighbors_returns_sorted() {
1059 let (layer, _, _) = build_test_layer();
1060 assert_eq!(layer.category_neighbors("science", 2).len(), 2);
1061 }
1062
1063 #[test]
1064 fn category_neighbors_k_larger_than_available() {
1065 let (layer, _, _) = build_test_layer();
1066 assert_eq!(layer.category_neighbors("science", 100).len(), 2);
1067 }
1068
1069 #[test]
1070 fn category_neighbors_unknown_returns_empty() {
1071 let (layer, _, _) = build_test_layer();
1072 assert!(layer.category_neighbors("nonexistent", 5).is_empty());
1073 }
1074
1075 #[test]
1076 fn bridge_items_detected() {
1077 let (layer, _, _) = build_test_layer();
1078 let _ = layer.bridge_items("science", "cooking", 10);
1079 }
1080
1081 #[test]
1082 fn bridge_items_unknown_category_returns_empty() {
1083 let (layer, _, _) = build_test_layer();
1084 assert!(layer.bridge_items("science", "nonexistent", 10).is_empty());
1085 }
1086
1087 #[test]
1088 fn bridge_strength_in_valid_range() {
1089 let (layer, _, _) = build_test_layer();
1090 for list in layer.graph.bridges.values() {
1091 for b in list {
1092 assert!(b.bridge_strength >= 0.0 && b.bridge_strength <= 1.0);
1093 }
1094 }
1095 }
1096
1097 #[test]
1098 fn bridges_sorted_by_strength() {
1099 let (layer, _, _) = build_test_layer();
1100 for list in layer.graph.bridges.values() {
1101 for w in list.windows(2) {
1102 assert!(w[0].bridge_strength >= w[1].bridge_strength);
1103 }
1104 }
1105 }
1106
1107 #[test]
1108 fn category_path_same_category() {
1109 let (layer, _, _) = build_test_layer();
1110 let path = layer.category_path("science", "science").unwrap();
1111 assert_eq!(path.steps.len(), 1);
1112 assert!(path.total_distance.abs() < 1e-12);
1113 }
1114
1115 #[test]
1116 fn category_path_adjacent() {
1117 let (layer, _, _) = build_test_layer();
1118 let path = layer.category_path("science", "cooking").unwrap();
1119 assert!(path.steps.len() >= 2);
1120 assert_eq!(path.steps.first().unwrap().category_name, "science");
1121 assert_eq!(path.steps.last().unwrap().category_name, "cooking");
1122 assert!(path.total_distance > 0.0);
1123 }
1124
1125 #[test]
1126 fn category_path_unknown_returns_none() {
1127 let (layer, _, _) = build_test_layer();
1128 assert!(layer.category_path("science", "nonexistent").is_none());
1129 }
1130
1131 #[test]
1132 fn category_path_distances_monotonic() {
1133 let (layer, _, _) = build_test_layer();
1134 let path = layer.category_path("science", "music").unwrap();
1135 for w in path.steps.windows(2) {
1136 assert!(w[1].cumulative_distance >= w[0].cumulative_distance);
1137 }
1138 }
1139
1140 #[test]
1141 fn categories_near_embedding_finds_correct() {
1142 let (layer, _, pca) = build_test_layer();
1143 let near = layer.categories_near_embedding(
1144 &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]),
1145 &pca,
1146 std::f64::consts::PI,
1147 );
1148 assert!(!near.is_empty());
1149 assert_eq!(layer.summaries[near[0].0].name, "science");
1150 }
1151
1152 #[test]
1153 fn categories_near_embedding_sorted_by_distance() {
1154 let (layer, _, pca) = build_test_layer();
1155 let near = layer.categories_near_embedding(
1156 &emb(&[0.5, 0.5, 0.5, 0.0, 0.0]),
1157 &pca,
1158 std::f64::consts::PI,
1159 );
1160 for w in near.windows(2) {
1161 assert!(w[0].1 <= w[1].1);
1162 }
1163 }
1164
1165 #[test]
1166 fn categories_near_embedding_respects_threshold() {
1167 let (layer, _, pca) = build_test_layer();
1168 let near = layer.categories_near_embedding(&emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), &pca, 0.01);
1169 for &(_, d) in &near {
1170 assert!(d <= 0.01);
1171 }
1172 }
1173
1174 #[test]
1175 fn cosine_similarity_identical() {
1176 assert!((cosine_similarity(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-12);
1177 }
1178
1179 #[test]
1180 fn cosine_similarity_orthogonal() {
1181 assert!(cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).abs() < 1e-12);
1182 }
1183
1184 #[test]
1185 fn cosine_similarity_opposite() {
1186 assert!((cosine_similarity(&[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0]) + 1.0).abs() < 1e-12);
1187 }
1188
1189 #[test]
1190 fn cosine_similarity_zero_vector() {
1191 assert!(cosine_similarity(&[0.0, 0.0, 0.0], &[1.0, 0.0, 0.0]).abs() < 1e-12);
1192 }
1193
1194 #[test]
1197 fn small_categories_get_no_inner_sphere() {
1198 let (layer, _, _) = build_test_layer();
1199 assert_eq!(layer.num_inner_spheres(), 0);
1200 assert!(!layer.has_inner_sphere("science"));
1201 }
1202
1203 #[test]
1204 fn large_category_may_get_inner_sphere() {
1205 let (layer, _, _) = build_large_test_layer();
1206 assert!(!layer.has_inner_sphere("small_a"));
1207 assert!(!layer.has_inner_sphere("small_b"));
1208 let _ = layer.has_inner_sphere("big");
1209 }
1210
1211 #[test]
1212 fn inner_sphere_stats_count_matches() {
1213 let (layer, _, _) = build_large_test_layer();
1214 assert_eq!(layer.inner_sphere_stats().len(), layer.num_inner_spheres());
1215 }
1216
1217 #[test]
1218 fn inner_sphere_stats_sorted_by_index() {
1219 let (layer, _, _) = build_large_test_layer();
1220 let stats = layer.inner_sphere_stats();
1221 for w in stats.windows(2) {
1222 assert!(w[0].category_index <= w[1].category_index);
1223 }
1224 }
1225
1226 #[test]
1227 fn inner_sphere_evr_improvement_positive() {
1228 let (layer, _, _) = build_large_test_layer();
1229 for inner in layer.inner_spheres.values() {
1230 assert!(inner.evr_improvement >= MIN_EVR_IMPROVEMENT);
1231 }
1232 }
1233
1234 #[test]
1235 fn inner_sphere_positions_match_member_count() {
1236 let (layer, _, _) = build_large_test_layer();
1237 for (&ci, inner) in &layer.inner_spheres {
1238 assert_eq!(inner.inner_positions.len(), inner.member_indices.len());
1239 assert_eq!(inner.member_indices.len(), layer.summaries[ci].member_count);
1240 }
1241 }
1242
1243 #[test]
1244 fn inner_sphere_member_indices_valid() {
1245 let (layer, _, _) = build_large_test_layer();
1246 let total = layer.outer_positions.len();
1247 for inner in layer.inner_spheres.values() {
1248 for &mi in &inner.member_indices {
1249 assert!(mi < total);
1250 }
1251 }
1252 }
1253
1254 #[test]
1255 fn inner_sphere_report_projection_type_valid() {
1256 let (layer, _, _) = build_large_test_layer();
1257 for r in layer.inner_sphere_stats() {
1258 assert!(r.projection_type == "LinearPca" || r.projection_type == "KernelPca");
1259 }
1260 }
1261
1262 #[test]
1263 fn inner_sphere_evr_in_range() {
1264 let (layer, _, _) = build_large_test_layer();
1265 for inner in layer.inner_spheres.values() {
1266 assert!(inner.explained_variance_ratio >= 0.0 && inner.explained_variance_ratio <= 1.0);
1267 assert!(inner.global_subset_evr >= 0.0 && inner.global_subset_evr <= 1.0);
1268 }
1269 }
1270
1271 #[test]
1272 fn has_inner_sphere_unknown_category() {
1273 let (layer, _, _) = build_test_layer();
1274 assert!(!layer.has_inner_sphere("nonexistent"));
1275 }
1276
1277 #[test]
1278 fn get_inner_sphere_returns_none_for_small() {
1279 let (layer, _, _) = build_test_layer();
1280 assert!(layer.get_inner_sphere("science").is_none());
1281 }
1282
1283 #[test]
1284 fn drill_down_returns_results() {
1285 let (layer, _, pca) = build_large_test_layer();
1286 let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1287 let results = layer.drill_down_with_projection("big", &q, &pca, 5);
1288 assert!(!results.is_empty());
1289 assert!(results.len() <= 5);
1290 }
1291
1292 #[test]
1293 fn drill_down_sorted_by_distance() {
1294 let (layer, _, pca) = build_large_test_layer();
1295 let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1296 let results = layer.drill_down_with_projection("big", &q, &pca, 10);
1297 for w in results.windows(2) {
1298 assert!(w[0].distance <= w[1].distance);
1299 }
1300 }
1301
1302 #[test]
1303 fn drill_down_unknown_category_empty() {
1304 let (layer, _, pca) = build_large_test_layer();
1305 assert!(
1306 layer
1307 .drill_down_with_projection("nonexistent", &emb(&[1.0; 10]), &pca, 5)
1308 .is_empty()
1309 );
1310 }
1311
1312 #[test]
1313 fn drill_down_item_indices_valid() {
1314 let (layer, _, pca) = build_large_test_layer();
1315 let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1316 let total = layer.outer_positions.len();
1317 for r in layer.drill_down_with_projection("big", &q, &pca, 25) {
1318 assert!(r.item_index < total);
1319 }
1320 }
1321
1322 #[test]
1323 fn drill_down_small_category_uses_outer() {
1324 let (layer, _, pca) = build_large_test_layer();
1325 let q = emb(&[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]);
1326 for r in layer.drill_down_with_projection("small_a", &q, &pca, 4) {
1327 assert!(!r.used_inner_sphere);
1328 }
1329 }
1330
1331 #[test]
1332 fn drill_down_distances_nonnegative() {
1333 let (layer, _, pca) = build_large_test_layer();
1334 for r in layer.drill_down_with_projection("big", &emb(&[1.0; 10]), &pca, 10) {
1335 assert!(r.distance >= 0.0);
1336 }
1337 }
1338
1339 #[test]
1340 fn drill_down_without_projection_works() {
1341 let (layer, _, _) = build_large_test_layer();
1342 let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1343 assert!(layer.drill_down("big", &q, 5).len() <= 5);
1344 }
1345
1346 #[test]
1347 fn inner_projection_enum_debug() {
1348 let corpus: Vec<Embedding> = (0..5)
1349 .map(|i| emb(&[i as f64, 0.0, 0.0, 0.0, 0.0]))
1350 .collect();
1351 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1352 assert_eq!(
1353 format!("{:?}", InnerProjection::LinearPca(pca)),
1354 "LinearPca"
1355 );
1356 }
1357
1358 #[test]
1359 fn inner_projection_projects_correctly() {
1360 let corpus: Vec<Embedding> = (0..5)
1361 .map(|i| emb(&[i as f64, 0.0, 0.0, 0.0, 0.0]))
1362 .collect();
1363 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1364 let proj = InnerProjection::LinearPca(pca.clone());
1365 let e = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
1366 let sp_enum = proj.project(&e);
1367 let sp_direct = pca.project(&e);
1368 assert!((sp_enum.theta - sp_direct.theta).abs() < 1e-12);
1369 assert!((sp_enum.phi - sp_direct.phi).abs() < 1e-12);
1370 }
1371
1372 #[test]
1373 fn inner_projection_dimensionality() {
1374 let corpus: Vec<Embedding> = (0..5)
1375 .map(|i| emb(&[i as f64, 0.0, 0.0, 0.0, 0.0]))
1376 .collect();
1377 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1378 assert_eq!(InnerProjection::LinearPca(pca).dimensionality(), 5);
1379 }
1380}