1use std::collections::{BinaryHeap, HashMap};
2use std::sync::{Arc, Mutex};
3
4use sphereql_core::*;
5use sphereql_index::*;
6
7use crate::category::BridgeClassification;
8use crate::projection::Projection;
9use crate::types::{Embedding, ProjectedPoint};
10
11struct KnnCache {
17 k: usize,
18 adj: Arc<Vec<Vec<(usize, f64)>>>,
19}
20
21#[derive(Debug, Clone)]
22pub struct EmbeddingItem {
23 pub id: String,
24 pub position: SphericalPoint,
25 pub original_magnitude: f64,
26 pub projected: Option<ProjectedPoint>,
28}
29
30impl SpatialItem for EmbeddingItem {
31 type Id = String;
32 fn id(&self) -> &String {
33 &self.id
34 }
35 fn position(&self) -> &SphericalPoint {
36 &self.position
37 }
38}
39
40impl EmbeddingItem {
41 pub fn certainty(&self) -> f64 {
43 self.projected.map_or(1.0, |p| p.certainty)
44 }
45
46 pub fn intensity(&self) -> f64 {
48 self.projected
49 .map_or(self.original_magnitude, |p| p.intensity)
50 }
51
52 pub fn projection_magnitude(&self) -> f64 {
55 self.projected.map_or(1.0, |p| p.projection_magnitude)
56 }
57}
58
59#[must_use = "builders must be terminated with `.build()` to construct the index"]
60pub struct EmbeddingIndexBuilder<P> {
61 projection: P,
62 inner: SpatialIndexBuilder,
63}
64
65impl<P: Projection> EmbeddingIndexBuilder<P> {
66 pub fn new(projection: P) -> Self {
67 Self {
68 projection,
69 inner: SpatialIndexBuilder::new(),
70 }
71 }
72
73 pub fn shell_boundary(mut self, r: f64) -> Self {
74 self.inner = self.inner.shell_boundary(r);
75 self
76 }
77
78 pub fn uniform_shells(mut self, count: usize, max_r: f64) -> Self {
79 self.inner = self.inner.uniform_shells(count, max_r);
80 self
81 }
82
83 pub fn theta_divisions(mut self, n: usize) -> Self {
84 self.inner = self.inner.theta_divisions(n);
85 self
86 }
87
88 pub fn phi_divisions(mut self, n: usize) -> Self {
89 self.inner = self.inner.phi_divisions(n);
90 self
91 }
92
93 pub fn build(self) -> EmbeddingIndex<P> {
94 EmbeddingIndex {
95 projection: self.projection,
96 index: self.inner.build(),
97 knn_cache: Mutex::new(None),
98 }
99 }
100}
101
102pub struct EmbeddingIndex<P> {
103 projection: P,
104 index: SpatialIndex<EmbeddingItem>,
105 knn_cache: Mutex<Option<KnnCache>>,
110}
111
112impl<P: Projection> EmbeddingIndex<P> {
113 pub fn builder(projection: P) -> EmbeddingIndexBuilder<P> {
114 EmbeddingIndexBuilder::new(projection)
115 }
116
117 pub fn insert(&mut self, id: impl Into<String>, embedding: &Embedding) {
118 let rich = self.projection.project_rich(embedding);
119 self.index.insert(EmbeddingItem {
120 id: id.into(),
121 position: rich.position,
122 original_magnitude: embedding.magnitude(),
123 projected: Some(rich),
124 });
125 self.invalidate_knn_cache();
126 }
127
128 pub fn insert_with_radius(&mut self, id: impl Into<String>, embedding: &Embedding, r: f64) {
132 let rich = self.projection.project_rich(embedding);
133 let position = SphericalPoint::new_unchecked(r, rich.position.theta, rich.position.phi);
134 self.index.insert(EmbeddingItem {
135 id: id.into(),
136 position,
137 original_magnitude: embedding.magnitude(),
138 projected: Some(ProjectedPoint { position, ..rich }),
139 });
140 self.invalidate_knn_cache();
141 }
142
143 fn invalidate_knn_cache(&mut self) {
147 if let Ok(slot) = self.knn_cache.get_mut() {
148 *slot = None;
149 }
150 }
151
152 fn knn_adjacency(&self, items: &[&EmbeddingItem], k: usize) -> Arc<Vec<Vec<(usize, f64)>>> {
159 {
160 let cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
161 if let Some(cached) = cache.as_ref()
162 && cached.k == k
163 && cached.adj.len() == items.len()
164 {
165 return Arc::clone(&cached.adj);
166 }
167 }
168
169 let n = items.len();
173 let id_to_idx: HashMap<&str, usize> = items
174 .iter()
175 .enumerate()
176 .map(|(i, item)| (item.id.as_str(), i))
177 .collect();
178 let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::with_capacity(k); n];
179 let mut seen: std::collections::HashSet<(usize, usize)> =
180 std::collections::HashSet::with_capacity(n * k);
181 for (i, item) in items.iter().enumerate() {
182 let nearest = self.index.nearest(item.position(), k + 1);
183 for result in &nearest {
184 let Some(&j) = id_to_idx.get(result.item.id.as_str()) else {
185 continue;
186 };
187 if i == j {
188 continue;
189 }
190 let key = if i < j { (i, j) } else { (j, i) };
191 if seen.insert(key) {
192 adj[i].push((j, result.distance));
193 adj[j].push((i, result.distance));
194 }
195 }
196 }
197
198 let adj = Arc::new(adj);
199 let mut cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
200 *cache = Some(KnnCache {
201 k,
202 adj: Arc::clone(&adj),
203 });
204 adj
205 }
206
207 pub fn search_nearest(&self, query: &Embedding, k: usize) -> Vec<NearestResult<EmbeddingItem>> {
209 let projected = self.projection.project(query);
210 self.index.nearest(&projected, k)
211 }
212
213 pub fn search_similar(
218 &self,
219 query: &Embedding,
220 min_cosine_similarity: f64,
221 ) -> SpatialQueryResult<EmbeddingItem> {
222 let projected = self.projection.project(query);
223 let max_angle = min_cosine_similarity.clamp(-1.0, 1.0).acos();
224 self.index.within_distance(&projected, max_angle)
225 }
226
227 pub fn search_region(&self, region: &Region) -> SpatialQueryResult<EmbeddingItem> {
228 self.index.query_region(region)
229 }
230
231 pub fn remove(&mut self, id: &str) -> Option<EmbeddingItem> {
232 let removed = self.index.remove(&id.to_string());
233 if removed.is_some() {
234 self.invalidate_knn_cache();
235 }
236 removed
237 }
238
239 pub fn get(&self, id: &str) -> Option<&EmbeddingItem> {
240 self.index.get(&id.to_string())
241 }
242
243 pub fn len(&self) -> usize {
244 self.index.len()
245 }
246
247 pub fn is_empty(&self) -> bool {
248 self.index.is_empty()
249 }
250
251 pub fn projection(&self) -> &P {
252 &self.projection
253 }
254
255 pub fn all_items(&self) -> Vec<&EmbeddingItem> {
256 self.index.all_items()
257 }
258
259 pub fn concept_path(&self, source_id: &str, target_id: &str, k: usize) -> Option<ConceptPath> {
271 let items = self.index.all_items();
272 let n = items.len();
273 if n < 2 {
274 return None;
275 }
276
277 let id_to_idx: HashMap<&str, usize> = items
278 .iter()
279 .enumerate()
280 .map(|(i, item)| (item.id.as_str(), i))
281 .collect();
282
283 let source_idx = *id_to_idx.get(source_id)?;
284 let target_idx = *id_to_idx.get(target_id)?;
285
286 let adj = self.knn_adjacency(&items, k);
287
288 let mut dist = vec![f64::INFINITY; n];
290 let mut prev: Vec<Option<usize>> = vec![None; n];
291 let mut heap = BinaryHeap::new();
292
293 dist[source_idx] = 0.0;
294 heap.push(DijkstraEntry {
295 dist: 0.0,
296 node: source_idx,
297 });
298
299 while let Some(entry) = heap.pop() {
300 let u = entry.node;
301 if entry.dist > dist[u] {
302 continue;
303 }
304 if u == target_idx {
305 break;
306 }
307 for &(v, w) in &adj[u] {
308 let nd = dist[u] + w;
309 if nd < dist[v] {
310 dist[v] = nd;
311 prev[v] = Some(u);
312 heap.push(DijkstraEntry { dist: nd, node: v });
313 }
314 }
315 }
316
317 if dist[target_idx].is_infinite() {
318 return None;
319 }
320
321 let mut path = Vec::new();
323 let mut cur = target_idx;
324 loop {
325 let hop_distance = prev[cur]
326 .and_then(|p| adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, d)| d))
327 .unwrap_or(0.0);
328 path.push(PathStep {
329 id: items[cur].id.clone(),
330 cumulative_distance: dist[cur],
331 hop_distance,
332 category: None,
333 bridge_strength: None,
334 });
335 match prev[cur] {
336 Some(p) => cur = p,
337 None => break,
338 }
339 }
340 path.reverse();
341
342 Some(ConceptPath {
343 total_distance: dist[target_idx],
344 steps: path,
345 })
346 }
347
348 pub fn concept_path_bridged(
365 &self,
366 source_id: &str,
367 target_id: &str,
368 k: usize,
369 categories: &HashMap<&str, usize>,
370 bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
371 ) -> Option<ConceptPath> {
372 let items = self.index.all_items();
373 let n = items.len();
374 if n < 2 {
375 return None;
376 }
377
378 let id_to_idx: HashMap<&str, usize> = items
379 .iter()
380 .enumerate()
381 .map(|(i, item)| (item.id.as_str(), i))
382 .collect();
383
384 let source_idx = *id_to_idx.get(source_id)?;
385 let target_idx = *id_to_idx.get(target_id)?;
386
387 let item_cats: Vec<Option<usize>> = items
389 .iter()
390 .map(|item| categories.get(item.id.as_str()).copied())
391 .collect();
392
393 let adj = self.knn_adjacency(&items, k);
399
400 let mut dist = vec![f64::INFINITY; n];
402 let mut prev: Vec<Option<usize>> = vec![None; n];
403 let mut heap = BinaryHeap::new();
404
405 dist[source_idx] = 0.0;
406 heap.push(DijkstraEntry {
407 dist: 0.0,
408 node: source_idx,
409 });
410
411 while let Some(entry) = heap.pop() {
412 let u = entry.node;
413 if entry.dist > dist[u] {
414 continue;
415 }
416 if u == target_idx {
417 break;
418 }
419 for &(v, raw_d) in &adj[u] {
420 let (w, _) = cross_category_weight(raw_d, &item_cats, u, v, bridge_strengths);
421 let nd = dist[u] + w;
422 if nd < dist[v] {
423 dist[v] = nd;
424 prev[v] = Some(u);
425 heap.push(DijkstraEntry { dist: nd, node: v });
426 }
427 }
428 }
429
430 if dist[target_idx].is_infinite() {
431 return None;
432 }
433
434 let mut path = Vec::new();
436 let mut cur = target_idx;
437 loop {
438 let edge_info = prev[cur].and_then(|p| {
439 adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, raw_d)| {
440 let (_, bs) =
441 cross_category_weight(raw_d, &item_cats, p, cur, bridge_strengths);
442 (raw_d, bs)
443 })
444 });
445 let hop_distance = edge_info.map_or(0.0, |(d, _)| d);
446 let bridge_str = edge_info.and_then(|(_, bs)| bs);
447
448 path.push(PathStep {
449 id: items[cur].id.clone(),
450 cumulative_distance: dist[cur],
451 hop_distance,
452 category: item_cats[cur],
453 bridge_strength: bridge_str,
454 });
455 match prev[cur] {
456 Some(p) => cur = p,
457 None => break,
458 }
459 }
460 path.reverse();
461
462 Some(ConceptPath {
463 total_distance: dist[target_idx],
464 steps: path,
465 })
466 }
467}
468
469#[derive(Debug, Clone)]
472pub struct ConceptPath {
473 pub steps: Vec<PathStep>,
474 pub total_distance: f64,
475}
476
477#[derive(Debug, Clone)]
478pub struct PathStep {
479 pub id: String,
480 pub cumulative_distance: f64,
481 pub hop_distance: f64,
483 pub category: Option<usize>,
485 pub bridge_strength: Option<f64>,
488}
489
490#[derive(PartialEq)]
491struct DijkstraEntry {
492 dist: f64,
493 node: usize,
494}
495
496impl Eq for DijkstraEntry {}
499
500impl PartialOrd for DijkstraEntry {
501 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
502 Some(self.cmp(other))
503 }
504}
505
506impl Ord for DijkstraEntry {
507 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
508 other
510 .dist
511 .partial_cmp(&self.dist)
512 .unwrap_or(std::cmp::Ordering::Equal)
513 }
514}
515
516#[derive(Debug, Clone)]
527pub struct SlicingManifold {
528 pub centroid: [f64; 3],
529 pub normal: [f64; 3],
530 pub basis_u: [f64; 3],
531 pub basis_v: [f64; 3],
532 pub variance_ratio: f64,
533}
534
535impl SlicingManifold {
536 pub fn fit(points: &[[f64; 3]]) -> Self {
542 let n = points.len() as f64;
543 debug_assert!(n >= 3.0, "need at least 3 points to fit a plane");
544
545 let mut c = [0.0; 3];
547 for p in points {
548 for i in 0..3 {
549 c[i] += p[i];
550 }
551 }
552 for ci in &mut c {
553 *ci /= n;
554 }
555
556 let mut cov = [[0.0f64; 3]; 3];
558 for p in points {
559 let d = [p[0] - c[0], p[1] - c[1], p[2] - c[2]];
560 for i in 0..3 {
561 for j in 0..3 {
562 cov[i][j] += d[i] * d[j];
563 }
564 }
565 }
566 for row in &mut cov {
567 for v in row.iter_mut() {
568 *v /= n;
569 }
570 }
571
572 let (eigenvalues, eigenvectors) = eigen_symmetric_3x3(&cov);
574
575 let total_var = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
578 let variance_ratio = if total_var > 0.0 {
579 (eigenvalues[0] + eigenvalues[1]) / total_var
580 } else {
581 1.0
582 };
583
584 Self {
585 centroid: c,
586 normal: eigenvectors[2],
587 basis_u: eigenvectors[0],
588 basis_v: eigenvectors[1],
589 variance_ratio,
590 }
591 }
592
593 pub fn project_2d(&self, point: &[f64; 3]) -> (f64, f64) {
595 let d = [
596 point[0] - self.centroid[0],
597 point[1] - self.centroid[1],
598 point[2] - self.centroid[2],
599 ];
600 let u = d[0] * self.basis_u[0] + d[1] * self.basis_u[1] + d[2] * self.basis_u[2];
601 let v = d[0] * self.basis_v[0] + d[1] * self.basis_v[1] + d[2] * self.basis_v[2];
602 (u, v)
603 }
604
605 pub fn distance(&self, point: &[f64; 3]) -> f64 {
607 let d = [
608 point[0] - self.centroid[0],
609 point[1] - self.centroid[1],
610 point[2] - self.centroid[2],
611 ];
612 d[0] * self.normal[0] + d[1] * self.normal[1] + d[2] * self.normal[2]
613 }
614
615 pub fn fit_local(query: &[f64; 3], all_points: &[[f64; 3]], k: usize) -> Self {
626 let mut dists: Vec<(usize, f64)> = all_points
627 .iter()
628 .enumerate()
629 .map(|(i, p)| (i, dist3(query, p)))
630 .collect();
631 dists.sort_by(|a, b| a.1.total_cmp(&b.1));
637
638 let neighborhood: Vec<[f64; 3]> = dists
639 .iter()
640 .take(k.max(3))
641 .map(|&(i, _)| all_points[i])
642 .collect();
643
644 Self::fit(&neighborhood)
645 }
646}
647
648fn eigen_symmetric_3x3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
651 let mut a = *m;
652 let mut v = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; #[allow(clippy::needless_range_loop)]
655 for _ in 0..50 {
656 let mut p = 0;
658 let mut q = 1;
659 let mut max_val = a[0][1].abs();
660 for i in 0..3 {
661 for j in (i + 1)..3 {
662 if a[i][j].abs() > max_val {
663 max_val = a[i][j].abs();
664 p = i;
665 q = j;
666 }
667 }
668 }
669 if max_val < 1e-15 {
670 break;
671 }
672
673 let theta = if (a[p][p] - a[q][q]).abs() < 1e-30 {
675 std::f64::consts::FRAC_PI_4
676 } else {
677 0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
678 };
679 let c = theta.cos();
680 let s = theta.sin();
681
682 let mut new_a = a;
684 for i in 0..3 {
685 new_a[i][p] = c * a[i][p] + s * a[i][q];
686 new_a[i][q] = -s * a[i][p] + c * a[i][q];
687 }
688 let snapshot = new_a;
689 for j in 0..3 {
690 new_a[p][j] = c * snapshot[p][j] + s * snapshot[q][j];
691 new_a[q][j] = -s * snapshot[p][j] + c * snapshot[q][j];
692 }
693 new_a[p][q] = 0.0;
694 new_a[q][p] = 0.0;
695 a = new_a;
696
697 let mut new_v = v;
699 for i in 0..3 {
700 new_v[i][p] = c * v[i][p] + s * v[i][q];
701 new_v[i][q] = -s * v[i][p] + c * v[i][q];
702 }
703 v = new_v;
704 }
705
706 let eigenvalues = [a[0][0], a[1][1], a[2][2]];
707
708 let mut order = [0usize, 1, 2];
712 order.sort_by(|&a, &b| eigenvalues[b].total_cmp(&eigenvalues[a]));
713
714 let sorted_vals = [
715 eigenvalues[order[0]],
716 eigenvalues[order[1]],
717 eigenvalues[order[2]],
718 ];
719 let sorted_vecs = [
721 [v[0][order[0]], v[1][order[0]], v[2][order[0]]],
722 [v[0][order[1]], v[1][order[1]], v[2][order[1]]],
723 [v[0][order[2]], v[1][order[2]], v[2][order[2]]],
724 ];
725
726 (sorted_vals, sorted_vecs)
727}
728
729#[derive(Debug, Clone)]
733pub struct ConceptGlob {
734 pub id: usize,
735 pub centroid: [f64; 3],
736 pub member_ids: Vec<String>,
737 pub member_distances: Vec<f64>,
738 pub radius: f64,
739}
740
741#[derive(Debug, Clone)]
743pub struct GlobResult {
744 pub globs: Vec<ConceptGlob>,
745 pub k: usize,
746 pub silhouette: f64,
747}
748
749impl GlobResult {
750 pub fn detect(points: &[[f64; 3]], ids: &[String], k: Option<usize>, max_k: usize) -> Self {
755 let n = points.len();
756 debug_assert_eq!(n, ids.len());
759 debug_assert!(n >= 2, "need at least 2 points for clustering");
760
761 let max_k = max_k.max(2).min(n);
765
766 if let Some(k) = k {
767 let k = k.clamp(2, max_k);
768 let (assignments, silhouette) = kmeans_3d(points, k);
769 let globs = build_globs(points, ids, &assignments, k);
770 return Self {
771 globs,
772 k,
773 silhouette,
774 };
775 }
776
777 let mut best_k = 2;
779 let mut best_sil = f64::NEG_INFINITY;
780 let mut best_assignments = vec![0usize; n];
781
782 for trial_k in 2..=max_k {
783 let (assignments, sil) = kmeans_3d(points, trial_k);
784 if sil > best_sil {
785 best_sil = sil;
786 best_k = trial_k;
787 best_assignments = assignments;
788 }
789 }
790
791 let globs = build_globs(points, ids, &best_assignments, best_k);
792 Self {
793 globs,
794 k: best_k,
795 silhouette: best_sil,
796 }
797 }
798}
799
800fn kmeans_3d(points: &[[f64; 3]], k: usize) -> (Vec<usize>, f64) {
801 let n = points.len();
802 let max_iter = 50;
803
804 let mut centers: Vec<[f64; 3]> = (0..k).map(|i| points[i * n / k]).collect();
806
807 let mut assignments = vec![0usize; n];
808
809 for _ in 0..max_iter {
810 let mut changed = false;
811
812 for (i, p) in points.iter().enumerate() {
814 let mut best = 0;
815 let mut best_d = f64::MAX;
816 for (j, c) in centers.iter().enumerate() {
817 let d = angular_dist3(p, c);
818 if d < best_d {
819 best_d = d;
820 best = j;
821 }
822 }
823 if assignments[i] != best {
824 assignments[i] = best;
825 changed = true;
826 }
827 }
828
829 if !changed {
830 break;
831 }
832
833 let mut sums = vec![[0.0f64; 3]; k];
836 let mut counts = vec![0usize; k];
837 for (i, &a) in assignments.iter().enumerate() {
838 let norm_p = normalize3(&points[i]);
839 for (d, &np) in norm_p.iter().enumerate() {
840 sums[a][d] += np;
841 }
842 counts[a] += 1;
843 }
844 for j in 0..k {
845 if counts[j] > 0 {
846 centers[j] = normalize3(&sums[j]);
847 }
848 }
849 }
850
851 let sil = silhouette_score(points, &assignments, k);
852 (assignments, sil)
853}
854
855fn silhouette_score(points: &[[f64; 3]], assignments: &[usize], k: usize) -> f64 {
856 let n = points.len();
857 if n <= 1 || k <= 1 {
858 return 0.0;
859 }
860
861 let mut total = 0.0;
862 for i in 0..n {
863 let ci = assignments[i];
864
865 let mut a_sum = 0.0;
867 let mut a_cnt = 0;
868 for j in 0..n {
869 if j != i && assignments[j] == ci {
870 a_sum += angular_dist3(&points[i], &points[j]);
871 a_cnt += 1;
872 }
873 }
874 let a = if a_cnt > 0 { a_sum / a_cnt as f64 } else { 0.0 };
875
876 let mut b = f64::MAX;
878 for ck in 0..k {
879 if ck == ci {
880 continue;
881 }
882 let mut b_sum = 0.0;
883 let mut b_cnt = 0;
884 for j in 0..n {
885 if assignments[j] == ck {
886 b_sum += angular_dist3(&points[i], &points[j]);
887 b_cnt += 1;
888 }
889 }
890 if b_cnt > 0 {
891 b = b.min(b_sum / b_cnt as f64);
892 }
893 }
894 if b == f64::MAX {
895 b = 0.0;
896 }
897
898 let denom = a.max(b);
899 total += if denom > 0.0 { (b - a) / denom } else { 0.0 };
900 }
901
902 total / n as f64
903}
904
905fn build_globs(
906 points: &[[f64; 3]],
907 ids: &[String],
908 assignments: &[usize],
909 k: usize,
910) -> Vec<ConceptGlob> {
911 let mut globs = Vec::with_capacity(k);
912
913 for cluster_id in 0..k {
914 let member_indices: Vec<usize> = assignments
915 .iter()
916 .enumerate()
917 .filter(|&(_, &a)| a == cluster_id)
918 .map(|(i, _)| i)
919 .collect();
920
921 if member_indices.is_empty() {
922 continue;
923 }
924
925 let mut centroid = [0.0; 3];
927 for &i in &member_indices {
928 let norm_p = normalize3(&points[i]);
929 for (d, c) in centroid.iter_mut().enumerate() {
930 *c += norm_p[d];
931 }
932 }
933 centroid = normalize3(¢roid);
934
935 let member_distances: Vec<f64> = member_indices
937 .iter()
938 .map(|&i| angular_dist3(&points[i], ¢roid))
939 .collect();
940
941 let radius = member_distances.iter().cloned().fold(0.0f64, f64::max);
942
943 let member_ids: Vec<String> = member_indices.iter().map(|&i| ids[i].clone()).collect();
944
945 globs.push(ConceptGlob {
946 id: cluster_id,
947 centroid,
948 member_ids,
949 member_distances,
950 radius,
951 });
952 }
953
954 globs
955}
956
957fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
958 let dx = a[0] - b[0];
959 let dy = a[1] - b[1];
960 let dz = a[2] - b[2];
961 (dx * dx + dy * dy + dz * dz).sqrt()
962}
963
964fn angular_dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
967 let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
968 let ma = (a[0] * a[0] + a[1] * a[1] + a[2] * a[2]).sqrt();
969 let mb = (b[0] * b[0] + b[1] * b[1] + b[2] * b[2]).sqrt();
970 let denom = ma * mb;
971 if denom < f64::EPSILON {
972 return 0.0;
973 }
974 (dot / denom).clamp(-1.0, 1.0).acos()
975}
976
977fn normalize3(v: &[f64; 3]) -> [f64; 3] {
979 let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
980 if mag < f64::EPSILON {
981 return [0.0; 3];
982 }
983 [v[0] / mag, v[1] / mag, v[2] / mag]
984}
985
986fn cross_category_weight(
1001 angular_dist: f64,
1002 item_cats: &[Option<usize>],
1003 i: usize,
1004 j: usize,
1005 bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
1006) -> (f64, Option<f64>) {
1007 match (item_cats[i], item_cats[j]) {
1008 (Some(ci), Some(cj)) if ci != cj => {
1009 let (strength, classification) = bridge_strengths
1010 .get(&(ci, cj))
1011 .or_else(|| bridge_strengths.get(&(cj, ci)))
1012 .copied()
1013 .unwrap_or((0.0, BridgeClassification::Weak));
1014 let weight = match classification {
1015 BridgeClassification::Genuine => angular_dist / (strength + 0.1),
1016 BridgeClassification::OverlapArtifact => angular_dist * 2.0,
1017 BridgeClassification::Weak => angular_dist / (strength + 0.01),
1018 };
1019 (weight, Some(strength))
1020 }
1021 _ => (angular_dist, None),
1022 }
1023}
1024
1025pub struct SemanticQuery;
1027
1028impl SemanticQuery {
1029 pub fn within_angle<P: Projection>(
1031 query: &Embedding,
1032 projection: &P,
1033 max_angular_distance: f64,
1034 ) -> Region {
1035 let point = projection.project(query);
1036 let half_angle = max_angular_distance.clamp(1e-10, std::f64::consts::PI);
1038 Region::Cap(
1039 Cap::new(
1040 SphericalPoint::new_unchecked(1.0, point.theta, point.phi),
1041 half_angle,
1042 )
1043 .expect("clamped half_angle is always a valid Cap angle"),
1045 )
1046 }
1047
1048 pub fn above_similarity<P: Projection>(
1051 query: &Embedding,
1052 projection: &P,
1053 min_similarity: f64,
1054 ) -> Region {
1055 let half_angle = min_similarity.clamp(-1.0, 1.0).acos();
1056 Self::within_angle(query, projection, half_angle)
1057 }
1058
1059 pub fn in_shell(inner: f64, outer: f64) -> Region {
1066 Region::Shell(
1067 Shell::new(inner, outer).expect("inner must be <= outer and both non-negative"),
1068 )
1069 }
1070
1071 pub fn similar_in_shell<P: Projection>(
1074 query: &Embedding,
1075 projection: &P,
1076 min_similarity: f64,
1077 shell_inner: f64,
1078 shell_outer: f64,
1079 ) -> Region {
1080 Region::intersection(vec![
1081 Self::above_similarity(query, projection, min_similarity),
1082 Self::in_shell(shell_inner, shell_outer),
1083 ])
1084 }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089 use super::*;
1090 use crate::projection::{PcaProjection, RandomProjection};
1091 use crate::types::RadialStrategy;
1092 use sphereql_core::angular_distance;
1093
1094 fn emb(vals: &[f64]) -> Embedding {
1095 Embedding::new(vals.to_vec())
1096 }
1097
1098 fn test_corpus() -> Vec<Embedding> {
1099 vec![
1100 emb(&[1.0, 0.0, 0.0, 0.1, 0.0]),
1101 emb(&[0.0, 1.0, 0.0, 0.0, 0.1]),
1102 emb(&[0.0, 0.0, 1.0, 0.1, 0.0]),
1103 emb(&[1.0, 1.0, 0.0, 0.05, 0.05]),
1104 emb(&[-1.0, 0.0, 0.0, -0.1, 0.0]),
1105 emb(&[0.0, -1.0, 0.0, 0.0, -0.1]),
1106 ]
1107 }
1108
1109 #[test]
1112 fn insert_and_get() {
1113 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1114 let mut idx = EmbeddingIndex::builder(rp)
1115 .theta_divisions(4)
1116 .phi_divisions(3)
1117 .build();
1118
1119 idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1120 idx.insert("b", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1121
1122 assert_eq!(idx.len(), 2);
1123 assert!(!idx.is_empty());
1124 assert!(idx.get("a").is_some());
1125 assert!(idx.get("b").is_some());
1126 assert!(idx.get("c").is_none());
1127 }
1128
1129 #[test]
1130 fn remove() {
1131 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1132 let mut idx = EmbeddingIndex::builder(rp).build();
1133
1134 idx.insert("a", &emb(&[1.0; 5]));
1135 assert_eq!(idx.len(), 1);
1136
1137 let removed = idx.remove("a");
1138 assert!(removed.is_some());
1139 assert_eq!(removed.unwrap().id, "a");
1140 assert_eq!(idx.len(), 0);
1141 assert!(idx.get("a").is_none());
1142 }
1143
1144 #[test]
1145 fn remove_nonexistent() {
1146 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1147 let mut idx = EmbeddingIndex::builder(rp).build();
1148 assert!(idx.remove("nope").is_none());
1149 }
1150
1151 #[test]
1152 fn search_nearest_returns_sorted() {
1153 let corpus = test_corpus();
1154 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1155 let mut idx = EmbeddingIndex::builder(pca)
1156 .theta_divisions(4)
1157 .phi_divisions(3)
1158 .build();
1159
1160 for (i, e) in corpus.iter().enumerate() {
1161 idx.insert(format!("item-{i}"), e);
1162 }
1163
1164 let query = emb(&[0.95, 0.1, 0.0, 0.05, 0.0]);
1165 let results = idx.search_nearest(&query, 3);
1166
1167 assert_eq!(results.len(), 3);
1168 assert!(results[0].distance <= results[1].distance);
1169 assert!(results[1].distance <= results[2].distance);
1170 }
1171
1172 #[test]
1173 fn search_similar_respects_threshold() {
1174 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1175 let mut idx = EmbeddingIndex::builder(rp)
1176 .theta_divisions(4)
1177 .phi_divisions(3)
1178 .build();
1179
1180 idx.insert("close_a", &emb(&[1.0, 0.1, 0.0, 0.0, 0.0]));
1181 idx.insert("close_b", &emb(&[0.9, 0.2, 0.0, 0.0, 0.0]));
1182 idx.insert("far", &emb(&[-1.0, 0.0, 0.0, 0.0, 0.0]));
1183
1184 let query = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
1185 let projected_query = idx.projection().project(&query);
1186 let result = idx.search_similar(&query, 0.5);
1187
1188 let max_angle = 0.5_f64.acos();
1189 for item in &result.items {
1190 let d = angular_distance(&projected_query, item.position());
1191 assert!(d <= max_angle + 1e-10, "item {} too far: {d}", item.id);
1192 }
1193 }
1194
1195 #[test]
1196 fn insert_with_radius_overrides() {
1197 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1198 let mut idx = EmbeddingIndex::builder(rp).build();
1199
1200 idx.insert_with_radius("custom", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), 42.0);
1201 let item = idx.get("custom").unwrap();
1202 assert!((item.position.r - 42.0).abs() < 1e-12);
1203 }
1204
1205 #[test]
1206 fn original_magnitude_stored() {
1207 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1208 let mut idx = EmbeddingIndex::builder(rp).build();
1209
1210 let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0]);
1211 idx.insert("vec", &e);
1212 let item = idx.get("vec").unwrap();
1213 assert!((item.original_magnitude - 5.0).abs() < 1e-10);
1214 }
1215
1216 #[test]
1217 fn magnitude_radial_with_shell_query() {
1218 let corpus = test_corpus();
1219 let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
1220 let mut idx = EmbeddingIndex::builder(pca)
1221 .uniform_shells(5, 10.0)
1222 .theta_divisions(4)
1223 .phi_divisions(3)
1224 .build();
1225
1226 idx.insert("small", &emb(&[0.1, 0.0, 0.0, 0.0, 0.0]));
1227 idx.insert("medium", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1228 idx.insert("large", &emb(&[5.0, 0.0, 0.0, 0.0, 0.0]));
1229
1230 let shell = Shell::new(0.5, 2.0).unwrap();
1231 let result = idx.search_region(&Region::Shell(shell));
1232
1233 let ids: Vec<&str> = result.items.iter().map(|i| i.id.as_str()).collect();
1234 assert!(
1235 ids.contains(&"medium"),
1236 "medium (mag=1.0) should be in [0.5, 2.0]"
1237 );
1238 assert!(
1239 !ids.contains(&"large"),
1240 "large (mag=5.0) should not be in [0.5, 2.0]"
1241 );
1242 }
1243
1244 #[test]
1245 fn empty_index() {
1246 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1247 let idx = EmbeddingIndex::builder(rp).build();
1248
1249 assert!(idx.is_empty());
1250 assert_eq!(idx.len(), 0);
1251 assert!(idx.get("x").is_none());
1252
1253 let results = idx.search_nearest(&emb(&[1.0; 5]), 5);
1254 assert!(results.is_empty());
1255 }
1256
1257 #[test]
1258 fn projection_accessor() {
1259 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1260 let idx = EmbeddingIndex::builder(rp).build();
1261 assert_eq!(idx.projection().dimensionality(), 5);
1262 }
1263
1264 #[test]
1267 fn above_similarity_creates_cap() {
1268 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1269 let region = SemanticQuery::above_similarity(&emb(&[1.0; 5]), &rp, 0.8);
1270 assert!(matches!(region, Region::Cap(_)));
1271 }
1272
1273 #[test]
1274 fn within_angle_creates_cap() {
1275 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1276 let region = SemanticQuery::within_angle(&emb(&[1.0; 5]), &rp, 0.5);
1277 assert!(matches!(region, Region::Cap(_)));
1278 }
1279
1280 #[test]
1281 fn in_shell_creates_shell() {
1282 let region = SemanticQuery::in_shell(1.0, 5.0);
1283 assert!(matches!(region, Region::Shell(_)));
1284 }
1285
1286 #[test]
1287 fn similar_in_shell_creates_intersection() {
1288 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1289 let region = SemanticQuery::similar_in_shell(&emb(&[1.0; 5]), &rp, 0.7, 1.0, 5.0);
1290
1291 match region {
1292 Region::Intersection(parts) => {
1293 assert_eq!(parts.len(), 2);
1294 assert!(matches!(parts[0], Region::Cap(_)));
1295 assert!(matches!(parts[1], Region::Shell(_)));
1296 }
1297 other => panic!("expected Intersection, got {other:?}"),
1298 }
1299 }
1300
1301 #[test]
1302 fn semantic_query_region_used_in_index() {
1303 let corpus = test_corpus();
1304 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1305 let projection_clone = pca.clone();
1306 let mut idx = EmbeddingIndex::builder(pca)
1307 .theta_divisions(4)
1308 .phi_divisions(3)
1309 .build();
1310
1311 for (i, e) in corpus.iter().enumerate() {
1312 idx.insert(format!("item-{i}"), e);
1313 }
1314
1315 let query = emb(&[1.0, 0.0, 0.0, 0.05, 0.0]);
1316 let region = SemanticQuery::above_similarity(&query, &projection_clone, 0.5);
1317 let result = idx.search_region(®ion);
1318
1319 for item in &result.items {
1320 assert!(region.contains(item.position()));
1321 }
1322 }
1323
1324 #[test]
1327 fn concept_path_populates_hop_distance() {
1328 let corpus = test_corpus();
1329 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1330 let mut idx = EmbeddingIndex::builder(pca)
1331 .theta_divisions(4)
1332 .phi_divisions(3)
1333 .build();
1334
1335 for (i, e) in corpus.iter().enumerate() {
1336 idx.insert(format!("item-{i}"), e);
1337 }
1338
1339 if let Some(path) = idx.concept_path("item-0", "item-4", 3) {
1340 assert!(path.steps[0].hop_distance == 0.0, "first step has no hop");
1341 for step in &path.steps[1..] {
1342 assert!(
1343 step.hop_distance > 0.0,
1344 "subsequent steps should have a hop distance"
1345 );
1346 }
1347 assert!(path.steps.iter().all(|s| s.category.is_none()));
1348 assert!(path.steps.iter().all(|s| s.bridge_strength.is_none()));
1349 }
1350 }
1351
1352 #[test]
1355 fn concept_path_bridged_same_category_equals_unbridged() {
1356 let corpus = test_corpus();
1357 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1358 let mut idx = EmbeddingIndex::builder(pca)
1359 .theta_divisions(4)
1360 .phi_divisions(3)
1361 .build();
1362
1363 for (i, e) in corpus.iter().enumerate() {
1364 idx.insert(format!("item-{i}"), e);
1365 }
1366
1367 let categories: HashMap<&str, usize> = (0..6)
1369 .map(|i| {
1370 (
1371 ["item-0", "item-1", "item-2", "item-3", "item-4", "item-5"][i],
1372 0,
1373 )
1374 })
1375 .collect();
1376 let bridges = HashMap::new();
1377
1378 let unbridged = idx.concept_path("item-0", "item-3", 3);
1379 let bridged = idx.concept_path_bridged("item-0", "item-3", 3, &categories, &bridges);
1380
1381 match (unbridged, bridged) {
1382 (Some(u), Some(b)) => {
1383 assert_eq!(u.steps.len(), b.steps.len());
1384 assert!((u.total_distance - b.total_distance).abs() < 1e-10);
1385 for step in &b.steps {
1386 assert_eq!(step.category, Some(0));
1387 assert!(step.bridge_strength.is_none());
1388 }
1389 }
1390 (None, None) => {} _ => panic!("bridged and unbridged should agree on reachability"),
1392 }
1393 }
1394
1395 #[test]
1396 fn concept_path_bridged_penalizes_weak_bridges() {
1397 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1398 let mut idx = EmbeddingIndex::builder(rp)
1399 .theta_divisions(4)
1400 .phi_divisions(3)
1401 .build();
1402
1403 idx.insert("a0", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1406 idx.insert("a1", &emb(&[0.9, 0.1, 0.0, 0.0, 0.0]));
1407 idx.insert("b0", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1409 idx.insert("b1", &emb(&[0.1, 0.9, 0.0, 0.0, 0.0]));
1410
1411 let mut categories: HashMap<&str, usize> = HashMap::new();
1412 categories.insert("a0", 0);
1413 categories.insert("a1", 0);
1414 categories.insert("b0", 1);
1415 categories.insert("b1", 1);
1416
1417 let mut weak_bridges = HashMap::new();
1419 weak_bridges.insert((0, 1), (0.05, BridgeClassification::Weak));
1420
1421 let mut strong_bridges = HashMap::new();
1423 strong_bridges.insert((0, 1), (0.95, BridgeClassification::Genuine));
1424
1425 let weak_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &weak_bridges);
1426 let strong_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &strong_bridges);
1427
1428 if let (Some(weak), Some(strong)) = (weak_path, strong_path) {
1431 assert!(
1432 weak.total_distance > strong.total_distance,
1433 "weak bridge ({:.4}) should produce higher cost than strong ({:.4})",
1434 weak.total_distance,
1435 strong.total_distance
1436 );
1437 }
1438 }
1439
1440 #[test]
1441 fn concept_path_bridged_populates_bridge_metadata() {
1442 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1443 let mut idx = EmbeddingIndex::builder(rp)
1444 .theta_divisions(4)
1445 .phi_divisions(3)
1446 .build();
1447
1448 idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1449 idx.insert("b", &emb(&[0.5, 0.5, 0.0, 0.0, 0.0]));
1450 idx.insert("c", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1451
1452 let mut categories: HashMap<&str, usize> = HashMap::new();
1453 categories.insert("a", 0);
1454 categories.insert("b", 0);
1455 categories.insert("c", 1);
1456
1457 let mut bridges = HashMap::new();
1458 bridges.insert((0, 1), (0.7, BridgeClassification::Genuine));
1459
1460 if let Some(path) = idx.concept_path_bridged("a", "c", 3, &categories, &bridges) {
1461 for step in &path.steps {
1463 assert!(step.category.is_some());
1464 }
1465 let has_bridge = path.steps.iter().any(|s| s.bridge_strength.is_some());
1467 assert!(
1468 has_bridge,
1469 "should record bridge strength on cross-category hop"
1470 );
1471 }
1472 }
1473
1474 #[test]
1477 fn cross_category_weight_same_category() {
1478 let cats = vec![Some(0), Some(0)];
1479 let bridges = HashMap::new();
1480 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1481 assert!((weight - 0.5).abs() < 1e-10);
1482 assert!(bs.is_none());
1483 }
1484
1485 #[test]
1486 fn cross_category_weight_different_categories_no_bridge() {
1487 let cats = vec![Some(0), Some(1)];
1488 let bridges = HashMap::new();
1489 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1490 assert!((weight - 50.0).abs() < 1e-10);
1492 assert_eq!(bs, Some(0.0));
1493 }
1494
1495 #[test]
1496 fn cross_category_weight_genuine_bridge() {
1497 let cats = vec![Some(0), Some(1)];
1498 let mut bridges = HashMap::new();
1499 bridges.insert((0, 1), (0.9, BridgeClassification::Genuine));
1500 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1501 assert!((weight - 0.5).abs() < 1e-10);
1503 assert_eq!(bs, Some(0.9));
1504 }
1505
1506 #[test]
1507 fn cross_category_weight_weak_bridge() {
1508 let cats = vec![Some(0), Some(1)];
1509 let mut bridges = HashMap::new();
1510 bridges.insert((0, 1), (0.3, BridgeClassification::Weak));
1511 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1512 assert!((weight - 0.5 / 0.31).abs() < 1e-10);
1514 assert_eq!(bs, Some(0.3));
1515 }
1516
1517 #[test]
1518 fn cross_category_weight_overlap_artifact_discouraged() {
1519 let cats = vec![Some(0), Some(1)];
1520 let mut bridges = HashMap::new();
1521 bridges.insert((0, 1), (0.8, BridgeClassification::OverlapArtifact));
1522 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1523 assert!((weight - 1.0).abs() < 1e-10);
1525 assert_eq!(bs, Some(0.8));
1526 }
1527
1528 #[test]
1529 fn cross_category_weight_no_category_info() {
1530 let cats = vec![None, Some(1)];
1531 let bridges = HashMap::new();
1532 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1533 assert!((weight - 0.5).abs() < 1e-10);
1534 assert!(bs.is_none());
1535 }
1536}