1use std::collections::{BinaryHeap, HashMap};
2use std::sync::{Arc, Mutex};
3
4use sphereql_core::*;
5use sphereql_index::*;
6
7use crate::category::BridgeClassification;
8use crate::projection::Projection;
9use crate::types::{Embedding, ProjectedPoint};
10
11struct KnnCache {
17 k: usize,
18 adj: Arc<Vec<Vec<(usize, f64)>>>,
19}
20
21#[derive(Debug, Clone)]
22pub struct EmbeddingItem {
23 pub id: String,
24 pub position: SphericalPoint,
25 pub original_magnitude: f64,
26 pub projected: Option<ProjectedPoint>,
28}
29
30impl SpatialItem for EmbeddingItem {
31 type Id = String;
32 fn id(&self) -> &String {
33 &self.id
34 }
35 fn position(&self) -> &SphericalPoint {
36 &self.position
37 }
38}
39
40impl EmbeddingItem {
41 pub fn certainty(&self) -> f64 {
43 self.projected.map_or(1.0, |p| p.certainty)
44 }
45
46 pub fn intensity(&self) -> f64 {
48 self.projected
49 .map_or(self.original_magnitude, |p| p.intensity)
50 }
51
52 pub fn projection_magnitude(&self) -> f64 {
55 self.projected.map_or(1.0, |p| p.projection_magnitude)
56 }
57}
58
59pub struct EmbeddingIndexBuilder<P> {
60 projection: P,
61 inner: SpatialIndexBuilder,
62}
63
64impl<P: Projection> EmbeddingIndexBuilder<P> {
65 pub fn new(projection: P) -> Self {
66 Self {
67 projection,
68 inner: SpatialIndexBuilder::new(),
69 }
70 }
71
72 pub fn shell_boundary(mut self, r: f64) -> Self {
73 self.inner = self.inner.shell_boundary(r);
74 self
75 }
76
77 pub fn uniform_shells(mut self, count: usize, max_r: f64) -> Self {
78 self.inner = self.inner.uniform_shells(count, max_r);
79 self
80 }
81
82 pub fn theta_divisions(mut self, n: usize) -> Self {
83 self.inner = self.inner.theta_divisions(n);
84 self
85 }
86
87 pub fn phi_divisions(mut self, n: usize) -> Self {
88 self.inner = self.inner.phi_divisions(n);
89 self
90 }
91
92 pub fn build(self) -> EmbeddingIndex<P> {
93 EmbeddingIndex {
94 projection: self.projection,
95 index: self.inner.build(),
96 knn_cache: Mutex::new(None),
97 }
98 }
99}
100
101pub struct EmbeddingIndex<P> {
102 projection: P,
103 index: SpatialIndex<EmbeddingItem>,
104 knn_cache: Mutex<Option<KnnCache>>,
109}
110
111impl<P: Projection> EmbeddingIndex<P> {
112 pub fn builder(projection: P) -> EmbeddingIndexBuilder<P> {
113 EmbeddingIndexBuilder::new(projection)
114 }
115
116 pub fn insert(&mut self, id: impl Into<String>, embedding: &Embedding) {
117 let rich = self.projection.project_rich(embedding);
118 self.index.insert(EmbeddingItem {
119 id: id.into(),
120 position: rich.position,
121 original_magnitude: embedding.magnitude(),
122 projected: Some(rich),
123 });
124 self.invalidate_knn_cache();
125 }
126
127 pub fn insert_with_radius(&mut self, id: impl Into<String>, embedding: &Embedding, r: f64) {
131 let rich = self.projection.project_rich(embedding);
132 let position = SphericalPoint::new_unchecked(r, rich.position.theta, rich.position.phi);
133 self.index.insert(EmbeddingItem {
134 id: id.into(),
135 position,
136 original_magnitude: embedding.magnitude(),
137 projected: Some(ProjectedPoint { position, ..rich }),
138 });
139 self.invalidate_knn_cache();
140 }
141
142 fn invalidate_knn_cache(&mut self) {
146 if let Ok(slot) = self.knn_cache.get_mut() {
147 *slot = None;
148 }
149 }
150
151 fn knn_adjacency(&self, items: &[&EmbeddingItem], k: usize) -> Arc<Vec<Vec<(usize, f64)>>> {
158 {
159 let cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
160 if let Some(cached) = cache.as_ref()
161 && cached.k == k
162 && cached.adj.len() == items.len()
163 {
164 return Arc::clone(&cached.adj);
165 }
166 }
167
168 let n = items.len();
172 let id_to_idx: HashMap<&str, usize> = items
173 .iter()
174 .enumerate()
175 .map(|(i, item)| (item.id.as_str(), i))
176 .collect();
177 let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::with_capacity(k); n];
178 let mut seen: std::collections::HashSet<(usize, usize)> =
179 std::collections::HashSet::with_capacity(n * k);
180 for (i, item) in items.iter().enumerate() {
181 let nearest = self.index.nearest(item.position(), k + 1);
182 for result in &nearest {
183 let Some(&j) = id_to_idx.get(result.item.id.as_str()) else {
184 continue;
185 };
186 if i == j {
187 continue;
188 }
189 let key = if i < j { (i, j) } else { (j, i) };
190 if seen.insert(key) {
191 adj[i].push((j, result.distance));
192 adj[j].push((i, result.distance));
193 }
194 }
195 }
196
197 let adj = Arc::new(adj);
198 let mut cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
199 *cache = Some(KnnCache {
200 k,
201 adj: Arc::clone(&adj),
202 });
203 adj
204 }
205
206 pub fn search_nearest(&self, query: &Embedding, k: usize) -> Vec<NearestResult<EmbeddingItem>> {
208 let projected = self.projection.project(query);
209 self.index.nearest(&projected, k)
210 }
211
212 pub fn search_similar(
217 &self,
218 query: &Embedding,
219 min_cosine_similarity: f64,
220 ) -> SpatialQueryResult<EmbeddingItem> {
221 let projected = self.projection.project(query);
222 let max_angle = min_cosine_similarity.clamp(-1.0, 1.0).acos();
223 self.index.within_distance(&projected, max_angle)
224 }
225
226 pub fn search_region(&self, region: &Region) -> SpatialQueryResult<EmbeddingItem> {
227 self.index.query_region(region)
228 }
229
230 pub fn remove(&mut self, id: &str) -> Option<EmbeddingItem> {
231 let removed = self.index.remove(&id.to_string());
232 if removed.is_some() {
233 self.invalidate_knn_cache();
234 }
235 removed
236 }
237
238 pub fn get(&self, id: &str) -> Option<&EmbeddingItem> {
239 self.index.get(&id.to_string())
240 }
241
242 pub fn len(&self) -> usize {
243 self.index.len()
244 }
245
246 pub fn is_empty(&self) -> bool {
247 self.index.is_empty()
248 }
249
250 pub fn projection(&self) -> &P {
251 &self.projection
252 }
253
254 pub fn all_items(&self) -> Vec<&EmbeddingItem> {
255 self.index.all_items()
256 }
257
258 pub fn concept_path(&self, source_id: &str, target_id: &str, k: usize) -> Option<ConceptPath> {
270 let items = self.index.all_items();
271 let n = items.len();
272 if n < 2 {
273 return None;
274 }
275
276 let id_to_idx: HashMap<&str, usize> = items
277 .iter()
278 .enumerate()
279 .map(|(i, item)| (item.id.as_str(), i))
280 .collect();
281
282 let source_idx = *id_to_idx.get(source_id)?;
283 let target_idx = *id_to_idx.get(target_id)?;
284
285 let adj = self.knn_adjacency(&items, k);
286
287 let mut dist = vec![f64::INFINITY; n];
289 let mut prev: Vec<Option<usize>> = vec![None; n];
290 let mut heap = BinaryHeap::new();
291
292 dist[source_idx] = 0.0;
293 heap.push(DijkstraEntry {
294 dist: 0.0,
295 node: source_idx,
296 });
297
298 while let Some(entry) = heap.pop() {
299 let u = entry.node;
300 if entry.dist > dist[u] {
301 continue;
302 }
303 if u == target_idx {
304 break;
305 }
306 for &(v, w) in &adj[u] {
307 let nd = dist[u] + w;
308 if nd < dist[v] {
309 dist[v] = nd;
310 prev[v] = Some(u);
311 heap.push(DijkstraEntry { dist: nd, node: v });
312 }
313 }
314 }
315
316 if dist[target_idx].is_infinite() {
317 return None;
318 }
319
320 let mut path = Vec::new();
322 let mut cur = target_idx;
323 loop {
324 let hop_distance = prev[cur]
325 .and_then(|p| adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, d)| d))
326 .unwrap_or(0.0);
327 path.push(PathStep {
328 id: items[cur].id.clone(),
329 cumulative_distance: dist[cur],
330 hop_distance,
331 category: None,
332 bridge_strength: None,
333 });
334 match prev[cur] {
335 Some(p) => cur = p,
336 None => break,
337 }
338 }
339 path.reverse();
340
341 Some(ConceptPath {
342 total_distance: dist[target_idx],
343 steps: path,
344 })
345 }
346
347 pub fn concept_path_bridged(
364 &self,
365 source_id: &str,
366 target_id: &str,
367 k: usize,
368 categories: &HashMap<&str, usize>,
369 bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
370 ) -> Option<ConceptPath> {
371 let items = self.index.all_items();
372 let n = items.len();
373 if n < 2 {
374 return None;
375 }
376
377 let id_to_idx: HashMap<&str, usize> = items
378 .iter()
379 .enumerate()
380 .map(|(i, item)| (item.id.as_str(), i))
381 .collect();
382
383 let source_idx = *id_to_idx.get(source_id)?;
384 let target_idx = *id_to_idx.get(target_id)?;
385
386 let item_cats: Vec<Option<usize>> = items
388 .iter()
389 .map(|item| categories.get(item.id.as_str()).copied())
390 .collect();
391
392 let adj = self.knn_adjacency(&items, k);
398
399 let mut dist = vec![f64::INFINITY; n];
401 let mut prev: Vec<Option<usize>> = vec![None; n];
402 let mut heap = BinaryHeap::new();
403
404 dist[source_idx] = 0.0;
405 heap.push(DijkstraEntry {
406 dist: 0.0,
407 node: source_idx,
408 });
409
410 while let Some(entry) = heap.pop() {
411 let u = entry.node;
412 if entry.dist > dist[u] {
413 continue;
414 }
415 if u == target_idx {
416 break;
417 }
418 for &(v, raw_d) in &adj[u] {
419 let (w, _) = cross_category_weight(raw_d, &item_cats, u, v, bridge_strengths);
420 let nd = dist[u] + w;
421 if nd < dist[v] {
422 dist[v] = nd;
423 prev[v] = Some(u);
424 heap.push(DijkstraEntry { dist: nd, node: v });
425 }
426 }
427 }
428
429 if dist[target_idx].is_infinite() {
430 return None;
431 }
432
433 let mut path = Vec::new();
435 let mut cur = target_idx;
436 loop {
437 let edge_info = prev[cur].and_then(|p| {
438 adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, raw_d)| {
439 let (_, bs) =
440 cross_category_weight(raw_d, &item_cats, p, cur, bridge_strengths);
441 (raw_d, bs)
442 })
443 });
444 let hop_distance = edge_info.map_or(0.0, |(d, _)| d);
445 let bridge_str = edge_info.and_then(|(_, bs)| bs);
446
447 path.push(PathStep {
448 id: items[cur].id.clone(),
449 cumulative_distance: dist[cur],
450 hop_distance,
451 category: item_cats[cur],
452 bridge_strength: bridge_str,
453 });
454 match prev[cur] {
455 Some(p) => cur = p,
456 None => break,
457 }
458 }
459 path.reverse();
460
461 Some(ConceptPath {
462 total_distance: dist[target_idx],
463 steps: path,
464 })
465 }
466}
467
468#[derive(Debug, Clone)]
471pub struct ConceptPath {
472 pub steps: Vec<PathStep>,
473 pub total_distance: f64,
474}
475
476#[derive(Debug, Clone)]
477pub struct PathStep {
478 pub id: String,
479 pub cumulative_distance: f64,
480 pub hop_distance: f64,
482 pub category: Option<usize>,
484 pub bridge_strength: Option<f64>,
487}
488
489#[derive(PartialEq)]
490struct DijkstraEntry {
491 dist: f64,
492 node: usize,
493}
494
495impl Eq for DijkstraEntry {}
498
499impl PartialOrd for DijkstraEntry {
500 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
501 Some(self.cmp(other))
502 }
503}
504
505impl Ord for DijkstraEntry {
506 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
507 other
509 .dist
510 .partial_cmp(&self.dist)
511 .unwrap_or(std::cmp::Ordering::Equal)
512 }
513}
514
515#[derive(Debug, Clone)]
526pub struct SlicingManifold {
527 pub centroid: [f64; 3],
528 pub normal: [f64; 3],
529 pub basis_u: [f64; 3],
530 pub basis_v: [f64; 3],
531 pub variance_ratio: f64,
532}
533
534impl SlicingManifold {
535 pub fn fit(points: &[[f64; 3]]) -> Self {
538 let n = points.len() as f64;
539 assert!(n >= 3.0, "need at least 3 points to fit a plane");
540
541 let mut c = [0.0; 3];
543 for p in points {
544 for i in 0..3 {
545 c[i] += p[i];
546 }
547 }
548 for ci in &mut c {
549 *ci /= n;
550 }
551
552 let mut cov = [[0.0f64; 3]; 3];
554 for p in points {
555 let d = [p[0] - c[0], p[1] - c[1], p[2] - c[2]];
556 for i in 0..3 {
557 for j in 0..3 {
558 cov[i][j] += d[i] * d[j];
559 }
560 }
561 }
562 for row in &mut cov {
563 for v in row.iter_mut() {
564 *v /= n;
565 }
566 }
567
568 let (eigenvalues, eigenvectors) = eigen_symmetric_3x3(&cov);
570
571 let total_var = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
574 let variance_ratio = if total_var > 0.0 {
575 (eigenvalues[0] + eigenvalues[1]) / total_var
576 } else {
577 1.0
578 };
579
580 Self {
581 centroid: c,
582 normal: eigenvectors[2],
583 basis_u: eigenvectors[0],
584 basis_v: eigenvectors[1],
585 variance_ratio,
586 }
587 }
588
589 pub fn project_2d(&self, point: &[f64; 3]) -> (f64, f64) {
591 let d = [
592 point[0] - self.centroid[0],
593 point[1] - self.centroid[1],
594 point[2] - self.centroid[2],
595 ];
596 let u = d[0] * self.basis_u[0] + d[1] * self.basis_u[1] + d[2] * self.basis_u[2];
597 let v = d[0] * self.basis_v[0] + d[1] * self.basis_v[1] + d[2] * self.basis_v[2];
598 (u, v)
599 }
600
601 pub fn distance(&self, point: &[f64; 3]) -> f64 {
603 let d = [
604 point[0] - self.centroid[0],
605 point[1] - self.centroid[1],
606 point[2] - self.centroid[2],
607 ];
608 d[0] * self.normal[0] + d[1] * self.normal[1] + d[2] * self.normal[2]
609 }
610
611 pub fn fit_local(query: &[f64; 3], all_points: &[[f64; 3]], k: usize) -> Self {
622 let mut dists: Vec<(usize, f64)> = all_points
623 .iter()
624 .enumerate()
625 .map(|(i, p)| (i, dist3(query, p)))
626 .collect();
627 dists.sort_by(|a, b| a.1.total_cmp(&b.1));
633
634 let neighborhood: Vec<[f64; 3]> = dists
635 .iter()
636 .take(k.max(3))
637 .map(|&(i, _)| all_points[i])
638 .collect();
639
640 Self::fit(&neighborhood)
641 }
642}
643
644fn eigen_symmetric_3x3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
647 let mut a = *m;
648 let mut v = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; #[allow(clippy::needless_range_loop)]
651 for _ in 0..50 {
652 let mut p = 0;
654 let mut q = 1;
655 let mut max_val = a[0][1].abs();
656 for i in 0..3 {
657 for j in (i + 1)..3 {
658 if a[i][j].abs() > max_val {
659 max_val = a[i][j].abs();
660 p = i;
661 q = j;
662 }
663 }
664 }
665 if max_val < 1e-15 {
666 break;
667 }
668
669 let theta = if (a[p][p] - a[q][q]).abs() < 1e-30 {
671 std::f64::consts::FRAC_PI_4
672 } else {
673 0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
674 };
675 let c = theta.cos();
676 let s = theta.sin();
677
678 let mut new_a = a;
680 for i in 0..3 {
681 new_a[i][p] = c * a[i][p] + s * a[i][q];
682 new_a[i][q] = -s * a[i][p] + c * a[i][q];
683 }
684 let snapshot = new_a;
685 for j in 0..3 {
686 new_a[p][j] = c * snapshot[p][j] + s * snapshot[q][j];
687 new_a[q][j] = -s * snapshot[p][j] + c * snapshot[q][j];
688 }
689 new_a[p][q] = 0.0;
690 new_a[q][p] = 0.0;
691 a = new_a;
692
693 let mut new_v = v;
695 for i in 0..3 {
696 new_v[i][p] = c * v[i][p] + s * v[i][q];
697 new_v[i][q] = -s * v[i][p] + c * v[i][q];
698 }
699 v = new_v;
700 }
701
702 let eigenvalues = [a[0][0], a[1][1], a[2][2]];
703
704 let mut order = [0usize, 1, 2];
706 order.sort_by(|&a, &b| eigenvalues[b].partial_cmp(&eigenvalues[a]).unwrap());
707
708 let sorted_vals = [
709 eigenvalues[order[0]],
710 eigenvalues[order[1]],
711 eigenvalues[order[2]],
712 ];
713 let sorted_vecs = [
715 [v[0][order[0]], v[1][order[0]], v[2][order[0]]],
716 [v[0][order[1]], v[1][order[1]], v[2][order[1]]],
717 [v[0][order[2]], v[1][order[2]], v[2][order[2]]],
718 ];
719
720 (sorted_vals, sorted_vecs)
721}
722
723#[derive(Debug, Clone)]
727pub struct ConceptGlob {
728 pub id: usize,
729 pub centroid: [f64; 3],
730 pub member_ids: Vec<String>,
731 pub member_distances: Vec<f64>,
732 pub radius: f64,
733}
734
735#[derive(Debug, Clone)]
737pub struct GlobResult {
738 pub globs: Vec<ConceptGlob>,
739 pub k: usize,
740 pub silhouette: f64,
741}
742
743impl GlobResult {
744 pub fn detect(points: &[[f64; 3]], ids: &[String], k: Option<usize>, max_k: usize) -> Self {
749 let n = points.len();
750 assert_eq!(n, ids.len());
751 assert!(n >= 2, "need at least 2 points for clustering");
752
753 let max_k = max_k.min(n);
754
755 if let Some(k) = k {
756 let k = k.clamp(2, max_k);
757 let (assignments, silhouette) = kmeans_3d(points, k);
758 let globs = build_globs(points, ids, &assignments, k);
759 return Self {
760 globs,
761 k,
762 silhouette,
763 };
764 }
765
766 let mut best_k = 2;
768 let mut best_sil = f64::NEG_INFINITY;
769 let mut best_assignments = vec![0usize; n];
770
771 for trial_k in 2..=max_k {
772 let (assignments, sil) = kmeans_3d(points, trial_k);
773 if sil > best_sil {
774 best_sil = sil;
775 best_k = trial_k;
776 best_assignments = assignments;
777 }
778 }
779
780 let globs = build_globs(points, ids, &best_assignments, best_k);
781 Self {
782 globs,
783 k: best_k,
784 silhouette: best_sil,
785 }
786 }
787}
788
789fn kmeans_3d(points: &[[f64; 3]], k: usize) -> (Vec<usize>, f64) {
790 let n = points.len();
791 let max_iter = 50;
792
793 let mut centers: Vec<[f64; 3]> = (0..k).map(|i| points[i * n / k]).collect();
795
796 let mut assignments = vec![0usize; n];
797
798 for _ in 0..max_iter {
799 let mut changed = false;
800
801 for (i, p) in points.iter().enumerate() {
803 let mut best = 0;
804 let mut best_d = f64::MAX;
805 for (j, c) in centers.iter().enumerate() {
806 let d = angular_dist3(p, c);
807 if d < best_d {
808 best_d = d;
809 best = j;
810 }
811 }
812 if assignments[i] != best {
813 assignments[i] = best;
814 changed = true;
815 }
816 }
817
818 if !changed {
819 break;
820 }
821
822 let mut sums = vec![[0.0f64; 3]; k];
825 let mut counts = vec![0usize; k];
826 for (i, &a) in assignments.iter().enumerate() {
827 let norm_p = normalize3(&points[i]);
828 for (d, &np) in norm_p.iter().enumerate() {
829 sums[a][d] += np;
830 }
831 counts[a] += 1;
832 }
833 for j in 0..k {
834 if counts[j] > 0 {
835 centers[j] = normalize3(&sums[j]);
836 }
837 }
838 }
839
840 let sil = silhouette_score(points, &assignments, k);
841 (assignments, sil)
842}
843
844fn silhouette_score(points: &[[f64; 3]], assignments: &[usize], k: usize) -> f64 {
845 let n = points.len();
846 if n <= 1 || k <= 1 {
847 return 0.0;
848 }
849
850 let mut total = 0.0;
851 for i in 0..n {
852 let ci = assignments[i];
853
854 let mut a_sum = 0.0;
856 let mut a_cnt = 0;
857 for j in 0..n {
858 if j != i && assignments[j] == ci {
859 a_sum += angular_dist3(&points[i], &points[j]);
860 a_cnt += 1;
861 }
862 }
863 let a = if a_cnt > 0 { a_sum / a_cnt as f64 } else { 0.0 };
864
865 let mut b = f64::MAX;
867 for ck in 0..k {
868 if ck == ci {
869 continue;
870 }
871 let mut b_sum = 0.0;
872 let mut b_cnt = 0;
873 for j in 0..n {
874 if assignments[j] == ck {
875 b_sum += angular_dist3(&points[i], &points[j]);
876 b_cnt += 1;
877 }
878 }
879 if b_cnt > 0 {
880 b = b.min(b_sum / b_cnt as f64);
881 }
882 }
883 if b == f64::MAX {
884 b = 0.0;
885 }
886
887 let denom = a.max(b);
888 total += if denom > 0.0 { (b - a) / denom } else { 0.0 };
889 }
890
891 total / n as f64
892}
893
894fn build_globs(
895 points: &[[f64; 3]],
896 ids: &[String],
897 assignments: &[usize],
898 k: usize,
899) -> Vec<ConceptGlob> {
900 let mut globs = Vec::with_capacity(k);
901
902 for cluster_id in 0..k {
903 let member_indices: Vec<usize> = assignments
904 .iter()
905 .enumerate()
906 .filter(|&(_, &a)| a == cluster_id)
907 .map(|(i, _)| i)
908 .collect();
909
910 if member_indices.is_empty() {
911 continue;
912 }
913
914 let mut centroid = [0.0; 3];
916 for &i in &member_indices {
917 let norm_p = normalize3(&points[i]);
918 for (d, c) in centroid.iter_mut().enumerate() {
919 *c += norm_p[d];
920 }
921 }
922 centroid = normalize3(¢roid);
923
924 let member_distances: Vec<f64> = member_indices
926 .iter()
927 .map(|&i| angular_dist3(&points[i], ¢roid))
928 .collect();
929
930 let radius = member_distances.iter().cloned().fold(0.0f64, f64::max);
931
932 let member_ids: Vec<String> = member_indices.iter().map(|&i| ids[i].clone()).collect();
933
934 globs.push(ConceptGlob {
935 id: cluster_id,
936 centroid,
937 member_ids,
938 member_distances,
939 radius,
940 });
941 }
942
943 globs
944}
945
946fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
947 let dx = a[0] - b[0];
948 let dy = a[1] - b[1];
949 let dz = a[2] - b[2];
950 (dx * dx + dy * dy + dz * dz).sqrt()
951}
952
953fn angular_dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
956 let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
957 let ma = (a[0] * a[0] + a[1] * a[1] + a[2] * a[2]).sqrt();
958 let mb = (b[0] * b[0] + b[1] * b[1] + b[2] * b[2]).sqrt();
959 let denom = ma * mb;
960 if denom < f64::EPSILON {
961 return 0.0;
962 }
963 (dot / denom).clamp(-1.0, 1.0).acos()
964}
965
966fn normalize3(v: &[f64; 3]) -> [f64; 3] {
968 let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
969 if mag < f64::EPSILON {
970 return [0.0; 3];
971 }
972 [v[0] / mag, v[1] / mag, v[2] / mag]
973}
974
975fn cross_category_weight(
990 angular_dist: f64,
991 item_cats: &[Option<usize>],
992 i: usize,
993 j: usize,
994 bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
995) -> (f64, Option<f64>) {
996 match (item_cats[i], item_cats[j]) {
997 (Some(ci), Some(cj)) if ci != cj => {
998 let (strength, classification) = bridge_strengths
999 .get(&(ci, cj))
1000 .or_else(|| bridge_strengths.get(&(cj, ci)))
1001 .copied()
1002 .unwrap_or((0.0, BridgeClassification::Weak));
1003 let weight = match classification {
1004 BridgeClassification::Genuine => angular_dist / (strength + 0.1),
1005 BridgeClassification::OverlapArtifact => angular_dist * 2.0,
1006 BridgeClassification::Weak => angular_dist / (strength + 0.01),
1007 };
1008 (weight, Some(strength))
1009 }
1010 _ => (angular_dist, None),
1011 }
1012}
1013
1014pub struct SemanticQuery;
1016
1017impl SemanticQuery {
1018 pub fn within_angle<P: Projection>(
1020 query: &Embedding,
1021 projection: &P,
1022 max_angular_distance: f64,
1023 ) -> Region {
1024 let point = projection.project(query);
1025 let half_angle = max_angular_distance.clamp(1e-10, std::f64::consts::PI);
1026 Region::Cap(
1027 Cap::new(
1028 SphericalPoint::new_unchecked(1.0, point.theta, point.phi),
1029 half_angle,
1030 )
1031 .unwrap(),
1032 )
1033 }
1034
1035 pub fn above_similarity<P: Projection>(
1038 query: &Embedding,
1039 projection: &P,
1040 min_similarity: f64,
1041 ) -> Region {
1042 let half_angle = min_similarity.clamp(-1.0, 1.0).acos();
1043 Self::within_angle(query, projection, half_angle)
1044 }
1045
1046 pub fn in_shell(inner: f64, outer: f64) -> Region {
1048 Region::Shell(Shell::new(inner, outer).expect("invalid shell bounds"))
1049 }
1050
1051 pub fn similar_in_shell<P: Projection>(
1054 query: &Embedding,
1055 projection: &P,
1056 min_similarity: f64,
1057 shell_inner: f64,
1058 shell_outer: f64,
1059 ) -> Region {
1060 Region::intersection(vec![
1061 Self::above_similarity(query, projection, min_similarity),
1062 Self::in_shell(shell_inner, shell_outer),
1063 ])
1064 }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use crate::projection::{PcaProjection, RandomProjection};
1071 use crate::types::RadialStrategy;
1072 use sphereql_core::angular_distance;
1073
1074 fn emb(vals: &[f64]) -> Embedding {
1075 Embedding::new(vals.to_vec())
1076 }
1077
1078 fn test_corpus() -> Vec<Embedding> {
1079 vec![
1080 emb(&[1.0, 0.0, 0.0, 0.1, 0.0]),
1081 emb(&[0.0, 1.0, 0.0, 0.0, 0.1]),
1082 emb(&[0.0, 0.0, 1.0, 0.1, 0.0]),
1083 emb(&[1.0, 1.0, 0.0, 0.05, 0.05]),
1084 emb(&[-1.0, 0.0, 0.0, -0.1, 0.0]),
1085 emb(&[0.0, -1.0, 0.0, 0.0, -0.1]),
1086 ]
1087 }
1088
1089 #[test]
1092 fn insert_and_get() {
1093 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1094 let mut idx = EmbeddingIndex::builder(rp)
1095 .theta_divisions(4)
1096 .phi_divisions(3)
1097 .build();
1098
1099 idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1100 idx.insert("b", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1101
1102 assert_eq!(idx.len(), 2);
1103 assert!(!idx.is_empty());
1104 assert!(idx.get("a").is_some());
1105 assert!(idx.get("b").is_some());
1106 assert!(idx.get("c").is_none());
1107 }
1108
1109 #[test]
1110 fn remove() {
1111 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1112 let mut idx = EmbeddingIndex::builder(rp).build();
1113
1114 idx.insert("a", &emb(&[1.0; 5]));
1115 assert_eq!(idx.len(), 1);
1116
1117 let removed = idx.remove("a");
1118 assert!(removed.is_some());
1119 assert_eq!(removed.unwrap().id, "a");
1120 assert_eq!(idx.len(), 0);
1121 assert!(idx.get("a").is_none());
1122 }
1123
1124 #[test]
1125 fn remove_nonexistent() {
1126 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1127 let mut idx = EmbeddingIndex::builder(rp).build();
1128 assert!(idx.remove("nope").is_none());
1129 }
1130
1131 #[test]
1132 fn search_nearest_returns_sorted() {
1133 let corpus = test_corpus();
1134 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1135 let mut idx = EmbeddingIndex::builder(pca)
1136 .theta_divisions(4)
1137 .phi_divisions(3)
1138 .build();
1139
1140 for (i, e) in corpus.iter().enumerate() {
1141 idx.insert(format!("item-{i}"), e);
1142 }
1143
1144 let query = emb(&[0.95, 0.1, 0.0, 0.05, 0.0]);
1145 let results = idx.search_nearest(&query, 3);
1146
1147 assert_eq!(results.len(), 3);
1148 assert!(results[0].distance <= results[1].distance);
1149 assert!(results[1].distance <= results[2].distance);
1150 }
1151
1152 #[test]
1153 fn search_similar_respects_threshold() {
1154 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1155 let mut idx = EmbeddingIndex::builder(rp)
1156 .theta_divisions(4)
1157 .phi_divisions(3)
1158 .build();
1159
1160 idx.insert("close_a", &emb(&[1.0, 0.1, 0.0, 0.0, 0.0]));
1161 idx.insert("close_b", &emb(&[0.9, 0.2, 0.0, 0.0, 0.0]));
1162 idx.insert("far", &emb(&[-1.0, 0.0, 0.0, 0.0, 0.0]));
1163
1164 let query = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
1165 let projected_query = idx.projection().project(&query);
1166 let result = idx.search_similar(&query, 0.5);
1167
1168 let max_angle = 0.5_f64.acos();
1169 for item in &result.items {
1170 let d = angular_distance(&projected_query, item.position());
1171 assert!(d <= max_angle + 1e-10, "item {} too far: {d}", item.id);
1172 }
1173 }
1174
1175 #[test]
1176 fn insert_with_radius_overrides() {
1177 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1178 let mut idx = EmbeddingIndex::builder(rp).build();
1179
1180 idx.insert_with_radius("custom", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), 42.0);
1181 let item = idx.get("custom").unwrap();
1182 assert!((item.position.r - 42.0).abs() < 1e-12);
1183 }
1184
1185 #[test]
1186 fn original_magnitude_stored() {
1187 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1188 let mut idx = EmbeddingIndex::builder(rp).build();
1189
1190 let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0]);
1191 idx.insert("vec", &e);
1192 let item = idx.get("vec").unwrap();
1193 assert!((item.original_magnitude - 5.0).abs() < 1e-10);
1194 }
1195
1196 #[test]
1197 fn magnitude_radial_with_shell_query() {
1198 let corpus = test_corpus();
1199 let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
1200 let mut idx = EmbeddingIndex::builder(pca)
1201 .uniform_shells(5, 10.0)
1202 .theta_divisions(4)
1203 .phi_divisions(3)
1204 .build();
1205
1206 idx.insert("small", &emb(&[0.1, 0.0, 0.0, 0.0, 0.0]));
1207 idx.insert("medium", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1208 idx.insert("large", &emb(&[5.0, 0.0, 0.0, 0.0, 0.0]));
1209
1210 let shell = Shell::new(0.5, 2.0).unwrap();
1211 let result = idx.search_region(&Region::Shell(shell));
1212
1213 let ids: Vec<&str> = result.items.iter().map(|i| i.id.as_str()).collect();
1214 assert!(
1215 ids.contains(&"medium"),
1216 "medium (mag=1.0) should be in [0.5, 2.0]"
1217 );
1218 assert!(
1219 !ids.contains(&"large"),
1220 "large (mag=5.0) should not be in [0.5, 2.0]"
1221 );
1222 }
1223
1224 #[test]
1225 fn empty_index() {
1226 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1227 let idx = EmbeddingIndex::builder(rp).build();
1228
1229 assert!(idx.is_empty());
1230 assert_eq!(idx.len(), 0);
1231 assert!(idx.get("x").is_none());
1232
1233 let results = idx.search_nearest(&emb(&[1.0; 5]), 5);
1234 assert!(results.is_empty());
1235 }
1236
1237 #[test]
1238 fn projection_accessor() {
1239 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1240 let idx = EmbeddingIndex::builder(rp).build();
1241 assert_eq!(idx.projection().dimensionality(), 5);
1242 }
1243
1244 #[test]
1247 fn above_similarity_creates_cap() {
1248 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1249 let region = SemanticQuery::above_similarity(&emb(&[1.0; 5]), &rp, 0.8);
1250 assert!(matches!(region, Region::Cap(_)));
1251 }
1252
1253 #[test]
1254 fn within_angle_creates_cap() {
1255 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1256 let region = SemanticQuery::within_angle(&emb(&[1.0; 5]), &rp, 0.5);
1257 assert!(matches!(region, Region::Cap(_)));
1258 }
1259
1260 #[test]
1261 fn in_shell_creates_shell() {
1262 let region = SemanticQuery::in_shell(1.0, 5.0);
1263 assert!(matches!(region, Region::Shell(_)));
1264 }
1265
1266 #[test]
1267 fn similar_in_shell_creates_intersection() {
1268 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1269 let region = SemanticQuery::similar_in_shell(&emb(&[1.0; 5]), &rp, 0.7, 1.0, 5.0);
1270
1271 match region {
1272 Region::Intersection(parts) => {
1273 assert_eq!(parts.len(), 2);
1274 assert!(matches!(parts[0], Region::Cap(_)));
1275 assert!(matches!(parts[1], Region::Shell(_)));
1276 }
1277 other => panic!("expected Intersection, got {other:?}"),
1278 }
1279 }
1280
1281 #[test]
1282 fn semantic_query_region_used_in_index() {
1283 let corpus = test_corpus();
1284 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1285 let projection_clone = pca.clone();
1286 let mut idx = EmbeddingIndex::builder(pca)
1287 .theta_divisions(4)
1288 .phi_divisions(3)
1289 .build();
1290
1291 for (i, e) in corpus.iter().enumerate() {
1292 idx.insert(format!("item-{i}"), e);
1293 }
1294
1295 let query = emb(&[1.0, 0.0, 0.0, 0.05, 0.0]);
1296 let region = SemanticQuery::above_similarity(&query, &projection_clone, 0.5);
1297 let result = idx.search_region(®ion);
1298
1299 for item in &result.items {
1300 assert!(region.contains(item.position()));
1301 }
1302 }
1303
1304 #[test]
1307 fn concept_path_populates_hop_distance() {
1308 let corpus = test_corpus();
1309 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1310 let mut idx = EmbeddingIndex::builder(pca)
1311 .theta_divisions(4)
1312 .phi_divisions(3)
1313 .build();
1314
1315 for (i, e) in corpus.iter().enumerate() {
1316 idx.insert(format!("item-{i}"), e);
1317 }
1318
1319 if let Some(path) = idx.concept_path("item-0", "item-4", 3) {
1320 assert!(path.steps[0].hop_distance == 0.0, "first step has no hop");
1321 for step in &path.steps[1..] {
1322 assert!(
1323 step.hop_distance > 0.0,
1324 "subsequent steps should have a hop distance"
1325 );
1326 }
1327 assert!(path.steps.iter().all(|s| s.category.is_none()));
1328 assert!(path.steps.iter().all(|s| s.bridge_strength.is_none()));
1329 }
1330 }
1331
1332 #[test]
1335 fn concept_path_bridged_same_category_equals_unbridged() {
1336 let corpus = test_corpus();
1337 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1338 let mut idx = EmbeddingIndex::builder(pca)
1339 .theta_divisions(4)
1340 .phi_divisions(3)
1341 .build();
1342
1343 for (i, e) in corpus.iter().enumerate() {
1344 idx.insert(format!("item-{i}"), e);
1345 }
1346
1347 let categories: HashMap<&str, usize> = (0..6)
1349 .map(|i| {
1350 (
1351 ["item-0", "item-1", "item-2", "item-3", "item-4", "item-5"][i],
1352 0,
1353 )
1354 })
1355 .collect();
1356 let bridges = HashMap::new();
1357
1358 let unbridged = idx.concept_path("item-0", "item-3", 3);
1359 let bridged = idx.concept_path_bridged("item-0", "item-3", 3, &categories, &bridges);
1360
1361 match (unbridged, bridged) {
1362 (Some(u), Some(b)) => {
1363 assert_eq!(u.steps.len(), b.steps.len());
1364 assert!((u.total_distance - b.total_distance).abs() < 1e-10);
1365 for step in &b.steps {
1366 assert_eq!(step.category, Some(0));
1367 assert!(step.bridge_strength.is_none());
1368 }
1369 }
1370 (None, None) => {} _ => panic!("bridged and unbridged should agree on reachability"),
1372 }
1373 }
1374
1375 #[test]
1376 fn concept_path_bridged_penalizes_weak_bridges() {
1377 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1378 let mut idx = EmbeddingIndex::builder(rp)
1379 .theta_divisions(4)
1380 .phi_divisions(3)
1381 .build();
1382
1383 idx.insert("a0", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1386 idx.insert("a1", &emb(&[0.9, 0.1, 0.0, 0.0, 0.0]));
1387 idx.insert("b0", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1389 idx.insert("b1", &emb(&[0.1, 0.9, 0.0, 0.0, 0.0]));
1390
1391 let mut categories: HashMap<&str, usize> = HashMap::new();
1392 categories.insert("a0", 0);
1393 categories.insert("a1", 0);
1394 categories.insert("b0", 1);
1395 categories.insert("b1", 1);
1396
1397 let mut weak_bridges = HashMap::new();
1399 weak_bridges.insert((0, 1), (0.05, BridgeClassification::Weak));
1400
1401 let mut strong_bridges = HashMap::new();
1403 strong_bridges.insert((0, 1), (0.95, BridgeClassification::Genuine));
1404
1405 let weak_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &weak_bridges);
1406 let strong_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &strong_bridges);
1407
1408 if let (Some(weak), Some(strong)) = (weak_path, strong_path) {
1411 assert!(
1412 weak.total_distance > strong.total_distance,
1413 "weak bridge ({:.4}) should produce higher cost than strong ({:.4})",
1414 weak.total_distance,
1415 strong.total_distance
1416 );
1417 }
1418 }
1419
1420 #[test]
1421 fn concept_path_bridged_populates_bridge_metadata() {
1422 let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1423 let mut idx = EmbeddingIndex::builder(rp)
1424 .theta_divisions(4)
1425 .phi_divisions(3)
1426 .build();
1427
1428 idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1429 idx.insert("b", &emb(&[0.5, 0.5, 0.0, 0.0, 0.0]));
1430 idx.insert("c", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1431
1432 let mut categories: HashMap<&str, usize> = HashMap::new();
1433 categories.insert("a", 0);
1434 categories.insert("b", 0);
1435 categories.insert("c", 1);
1436
1437 let mut bridges = HashMap::new();
1438 bridges.insert((0, 1), (0.7, BridgeClassification::Genuine));
1439
1440 if let Some(path) = idx.concept_path_bridged("a", "c", 3, &categories, &bridges) {
1441 for step in &path.steps {
1443 assert!(step.category.is_some());
1444 }
1445 let has_bridge = path.steps.iter().any(|s| s.bridge_strength.is_some());
1447 assert!(
1448 has_bridge,
1449 "should record bridge strength on cross-category hop"
1450 );
1451 }
1452 }
1453
1454 #[test]
1457 fn cross_category_weight_same_category() {
1458 let cats = vec![Some(0), Some(0)];
1459 let bridges = HashMap::new();
1460 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1461 assert!((weight - 0.5).abs() < 1e-10);
1462 assert!(bs.is_none());
1463 }
1464
1465 #[test]
1466 fn cross_category_weight_different_categories_no_bridge() {
1467 let cats = vec![Some(0), Some(1)];
1468 let bridges = HashMap::new();
1469 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1470 assert!((weight - 50.0).abs() < 1e-10);
1472 assert_eq!(bs, Some(0.0));
1473 }
1474
1475 #[test]
1476 fn cross_category_weight_genuine_bridge() {
1477 let cats = vec![Some(0), Some(1)];
1478 let mut bridges = HashMap::new();
1479 bridges.insert((0, 1), (0.9, BridgeClassification::Genuine));
1480 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1481 assert!((weight - 0.5).abs() < 1e-10);
1483 assert_eq!(bs, Some(0.9));
1484 }
1485
1486 #[test]
1487 fn cross_category_weight_weak_bridge() {
1488 let cats = vec![Some(0), Some(1)];
1489 let mut bridges = HashMap::new();
1490 bridges.insert((0, 1), (0.3, BridgeClassification::Weak));
1491 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1492 assert!((weight - 0.5 / 0.31).abs() < 1e-10);
1494 assert_eq!(bs, Some(0.3));
1495 }
1496
1497 #[test]
1498 fn cross_category_weight_overlap_artifact_discouraged() {
1499 let cats = vec![Some(0), Some(1)];
1500 let mut bridges = HashMap::new();
1501 bridges.insert((0, 1), (0.8, BridgeClassification::OverlapArtifact));
1502 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1503 assert!((weight - 1.0).abs() < 1e-10);
1505 assert_eq!(bs, Some(0.8));
1506 }
1507
1508 #[test]
1509 fn cross_category_weight_no_category_info() {
1510 let cats = vec![None, Some(1)];
1511 let bridges = HashMap::new();
1512 let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1513 assert!((weight - 0.5).abs() < 1e-10);
1514 assert!(bs.is_none());
1515 }
1516}