Skip to main content

sphereql_embed/
query.rs

1use std::collections::{BinaryHeap, HashMap};
2use std::sync::{Arc, Mutex};
3
4use sphereql_core::*;
5use sphereql_index::*;
6
7use crate::category::BridgeClassification;
8use crate::projection::Projection;
9use crate::types::{Embedding, ProjectedPoint};
10
11/// k-NN adjacency snapshot cached between calls to
12/// [`EmbeddingIndex::concept_path`] and its bridged variant.
13///
14/// The graph is undirected and built for a specific `k`. Different `k`
15/// invalidates; so does any `&mut self` mutation on the index.
16struct KnnCache {
17    k: usize,
18    adj: Arc<Vec<Vec<(usize, f64)>>>,
19}
20
21#[derive(Debug, Clone)]
22pub struct EmbeddingItem {
23    pub id: String,
24    pub position: SphericalPoint,
25    pub original_magnitude: f64,
26    /// Rich projection metadata. `None` for items inserted via `insert()` (legacy path).
27    pub projected: Option<ProjectedPoint>,
28}
29
30impl SpatialItem for EmbeddingItem {
31    type Id = String;
32    fn id(&self) -> &String {
33        &self.id
34    }
35    fn position(&self) -> &SphericalPoint {
36        &self.position
37    }
38}
39
40impl EmbeddingItem {
41    /// Certainty of this point's projection. Falls back to 1.0 if no rich metadata.
42    pub fn certainty(&self) -> f64 {
43        self.projected.map_or(1.0, |p| p.certainty)
44    }
45
46    /// Intensity (pre-normalization magnitude) of the original embedding.
47    pub fn intensity(&self) -> f64 {
48        self.projected
49            .map_or(self.original_magnitude, |p| p.intensity)
50    }
51
52    /// PCA projection magnitude — how strongly this point projects onto
53    /// the 3 principal components. Low values indicate ambiguous points.
54    pub fn projection_magnitude(&self) -> f64 {
55        self.projected.map_or(1.0, |p| p.projection_magnitude)
56    }
57}
58
59pub struct EmbeddingIndexBuilder<P> {
60    projection: P,
61    inner: SpatialIndexBuilder,
62}
63
64impl<P: Projection> EmbeddingIndexBuilder<P> {
65    pub fn new(projection: P) -> Self {
66        Self {
67            projection,
68            inner: SpatialIndexBuilder::new(),
69        }
70    }
71
72    pub fn shell_boundary(mut self, r: f64) -> Self {
73        self.inner = self.inner.shell_boundary(r);
74        self
75    }
76
77    pub fn uniform_shells(mut self, count: usize, max_r: f64) -> Self {
78        self.inner = self.inner.uniform_shells(count, max_r);
79        self
80    }
81
82    pub fn theta_divisions(mut self, n: usize) -> Self {
83        self.inner = self.inner.theta_divisions(n);
84        self
85    }
86
87    pub fn phi_divisions(mut self, n: usize) -> Self {
88        self.inner = self.inner.phi_divisions(n);
89        self
90    }
91
92    pub fn build(self) -> EmbeddingIndex<P> {
93        EmbeddingIndex {
94            projection: self.projection,
95            index: self.inner.build(),
96            knn_cache: Mutex::new(None),
97        }
98    }
99}
100
101pub struct EmbeddingIndex<P> {
102    projection: P,
103    index: SpatialIndex<EmbeddingItem>,
104    /// k-NN adjacency cache for `concept_path` and friends. Shared
105    /// behind `Mutex<Option<_>>` so the `&self` query methods can
106    /// memoize across calls; `&mut self` mutations clear it via
107    /// `get_mut`, bypassing the lock.
108    knn_cache: Mutex<Option<KnnCache>>,
109}
110
111impl<P: Projection> EmbeddingIndex<P> {
112    pub fn builder(projection: P) -> EmbeddingIndexBuilder<P> {
113        EmbeddingIndexBuilder::new(projection)
114    }
115
116    pub fn insert(&mut self, id: impl Into<String>, embedding: &Embedding) {
117        let rich = self.projection.project_rich(embedding);
118        self.index.insert(EmbeddingItem {
119            id: id.into(),
120            position: rich.position,
121            original_magnitude: embedding.magnitude(),
122            projected: Some(rich),
123        });
124        self.invalidate_knn_cache();
125    }
126
127    /// Insert with an explicit radial value, overriding the projection's RadialStrategy.
128    /// The angular coordinates (theta, phi) are still determined by the projection.
129    /// Use this for metadata-driven radius: recency scores, importance weights, etc.
130    pub fn insert_with_radius(&mut self, id: impl Into<String>, embedding: &Embedding, r: f64) {
131        let rich = self.projection.project_rich(embedding);
132        let position = SphericalPoint::new_unchecked(r, rich.position.theta, rich.position.phi);
133        self.index.insert(EmbeddingItem {
134            id: id.into(),
135            position,
136            original_magnitude: embedding.magnitude(),
137            projected: Some(ProjectedPoint { position, ..rich }),
138        });
139        self.invalidate_knn_cache();
140    }
141
142    /// Drop any cached k-NN adjacency. Called by every `&mut self`
143    /// mutation that could change the graph. Uses `get_mut` to skip
144    /// the lock when we already hold `&mut self`.
145    fn invalidate_knn_cache(&mut self) {
146        if let Ok(slot) = self.knn_cache.get_mut() {
147            *slot = None;
148        }
149    }
150
151    /// Return the k-NN adjacency snapshot for the given `k`, rebuilding
152    /// only on cache miss.
153    ///
154    /// The shared `Arc` is cheap to clone, so callers can drop the lock
155    /// while they run Dijkstra. Previously `concept_path` rebuilt the
156    /// entire graph on every call — O(n² · k) per query.
157    fn knn_adjacency(&self, items: &[&EmbeddingItem], k: usize) -> Arc<Vec<Vec<(usize, f64)>>> {
158        {
159            let cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
160            if let Some(cached) = cache.as_ref()
161                && cached.k == k
162                && cached.adj.len() == items.len()
163            {
164                return Arc::clone(&cached.adj);
165            }
166        }
167
168        // Miss — build a fresh undirected adjacency. Symmetrize in one
169        // O(E) pass using a HashSet instead of the previous O(n · k²)
170        // linear scan.
171        let n = items.len();
172        let id_to_idx: HashMap<&str, usize> = items
173            .iter()
174            .enumerate()
175            .map(|(i, item)| (item.id.as_str(), i))
176            .collect();
177        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::with_capacity(k); n];
178        let mut seen: std::collections::HashSet<(usize, usize)> =
179            std::collections::HashSet::with_capacity(n * k);
180        for (i, item) in items.iter().enumerate() {
181            let nearest = self.index.nearest(item.position(), k + 1);
182            for result in &nearest {
183                let Some(&j) = id_to_idx.get(result.item.id.as_str()) else {
184                    continue;
185                };
186                if i == j {
187                    continue;
188                }
189                let key = if i < j { (i, j) } else { (j, i) };
190                if seen.insert(key) {
191                    adj[i].push((j, result.distance));
192                    adj[j].push((i, result.distance));
193                }
194            }
195        }
196
197        let adj = Arc::new(adj);
198        let mut cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
199        *cache = Some(KnnCache {
200            k,
201            adj: Arc::clone(&adj),
202        });
203        adj
204    }
205
206    /// Find the k embeddings whose projected directions are closest to the query.
207    pub fn search_nearest(&self, query: &Embedding, k: usize) -> Vec<NearestResult<EmbeddingItem>> {
208        let projected = self.projection.project(query);
209        self.index.nearest(&projected, k)
210    }
211
212    /// Find all embeddings whose projected cosine similarity to the query
213    /// is at least `min_cosine_similarity`.
214    ///
215    /// Internally maps cos(sim) → angular distance and uses `within_distance`.
216    pub fn search_similar(
217        &self,
218        query: &Embedding,
219        min_cosine_similarity: f64,
220    ) -> SpatialQueryResult<EmbeddingItem> {
221        let projected = self.projection.project(query);
222        let max_angle = min_cosine_similarity.clamp(-1.0, 1.0).acos();
223        self.index.within_distance(&projected, max_angle)
224    }
225
226    pub fn search_region(&self, region: &Region) -> SpatialQueryResult<EmbeddingItem> {
227        self.index.query_region(region)
228    }
229
230    pub fn remove(&mut self, id: &str) -> Option<EmbeddingItem> {
231        let removed = self.index.remove(&id.to_string());
232        if removed.is_some() {
233            self.invalidate_knn_cache();
234        }
235        removed
236    }
237
238    pub fn get(&self, id: &str) -> Option<&EmbeddingItem> {
239        self.index.get(&id.to_string())
240    }
241
242    pub fn len(&self) -> usize {
243        self.index.len()
244    }
245
246    pub fn is_empty(&self) -> bool {
247        self.index.is_empty()
248    }
249
250    pub fn projection(&self) -> &P {
251        &self.projection
252    }
253
254    pub fn all_items(&self) -> Vec<&EmbeddingItem> {
255        self.index.all_items()
256    }
257
258    /// Find the shortest semantic path between two items through a k-NN graph.
259    ///
260    /// Builds a k-nearest-neighbor graph over all indexed embeddings, then
261    /// runs Dijkstra's algorithm weighted by angular distance. The resulting
262    /// path traces the chain of closest intermediate concepts connecting
263    /// the source to the target.
264    ///
265    /// The k-NN graph is memoized per `k`: the first call at a given `k`
266    /// builds it in O(n · log n · k) (index-assisted) and every
267    /// subsequent call reuses the snapshot until the index mutates.
268    /// Dijkstra itself is O((n + E) · log n) via a binary heap.
269    pub fn concept_path(&self, source_id: &str, target_id: &str, k: usize) -> Option<ConceptPath> {
270        let items = self.index.all_items();
271        let n = items.len();
272        if n < 2 {
273            return None;
274        }
275
276        let id_to_idx: HashMap<&str, usize> = items
277            .iter()
278            .enumerate()
279            .map(|(i, item)| (item.id.as_str(), i))
280            .collect();
281
282        let source_idx = *id_to_idx.get(source_id)?;
283        let target_idx = *id_to_idx.get(target_id)?;
284
285        let adj = self.knn_adjacency(&items, k);
286
287        // Dijkstra (min-heap via reversed Ord)
288        let mut dist = vec![f64::INFINITY; n];
289        let mut prev: Vec<Option<usize>> = vec![None; n];
290        let mut heap = BinaryHeap::new();
291
292        dist[source_idx] = 0.0;
293        heap.push(DijkstraEntry {
294            dist: 0.0,
295            node: source_idx,
296        });
297
298        while let Some(entry) = heap.pop() {
299            let u = entry.node;
300            if entry.dist > dist[u] {
301                continue;
302            }
303            if u == target_idx {
304                break;
305            }
306            for &(v, w) in &adj[u] {
307                let nd = dist[u] + w;
308                if nd < dist[v] {
309                    dist[v] = nd;
310                    prev[v] = Some(u);
311                    heap.push(DijkstraEntry { dist: nd, node: v });
312                }
313            }
314        }
315
316        if dist[target_idx].is_infinite() {
317            return None;
318        }
319
320        // Reconstruct
321        let mut path = Vec::new();
322        let mut cur = target_idx;
323        loop {
324            let hop_distance = prev[cur]
325                .and_then(|p| adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, d)| d))
326                .unwrap_or(0.0);
327            path.push(PathStep {
328                id: items[cur].id.clone(),
329                cumulative_distance: dist[cur],
330                hop_distance,
331                category: None,
332                bridge_strength: None,
333            });
334            match prev[cur] {
335                Some(p) => cur = p,
336                None => break,
337            }
338        }
339        path.reverse();
340
341        Some(ConceptPath {
342            total_distance: dist[target_idx],
343            steps: path,
344        })
345    }
346
347    /// Find a semantic path that prefers hops with strong conceptual bridges.
348    ///
349    /// Like [`concept_path`](Self::concept_path), but when a hop crosses a
350    /// category boundary, the edge weight is penalized based on the bridge's
351    /// classification:
352    /// - [`BridgeClassification::Genuine`]: `angular_dist / (strength + 0.1)`
353    /// - [`BridgeClassification::Weak`]: `angular_dist / (strength + 0.01)`
354    /// - [`BridgeClassification::OverlapArtifact`]: `angular_dist * 2.0`
355    ///   (shared-territory bridges are actively discouraged — they aren't
356    ///   real connectors).
357    ///
358    /// - `categories`: maps item ID → category index.
359    /// - `bridge_strengths`: maps `(cat_a, cat_b) → (max_bridge_strength, classification)`.
360    ///   Missing entries are treated as a weak no-bridge (strength 0, Weak).
361    ///
362    /// Same-category hops use raw angular distance.
363    pub fn concept_path_bridged(
364        &self,
365        source_id: &str,
366        target_id: &str,
367        k: usize,
368        categories: &HashMap<&str, usize>,
369        bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
370    ) -> Option<ConceptPath> {
371        let items = self.index.all_items();
372        let n = items.len();
373        if n < 2 {
374            return None;
375        }
376
377        let id_to_idx: HashMap<&str, usize> = items
378            .iter()
379            .enumerate()
380            .map(|(i, item)| (item.id.as_str(), i))
381            .collect();
382
383        let source_idx = *id_to_idx.get(source_id)?;
384        let target_idx = *id_to_idx.get(target_id)?;
385
386        // Look up category for each item index
387        let item_cats: Vec<Option<usize>> = items
388            .iter()
389            .map(|item| categories.get(item.id.as_str()).copied())
390            .collect();
391
392        // Reuse the cached raw-angular k-NN adjacency; bridge-aware
393        // weights are derived per edge at Dijkstra time. The previous
394        // implementation materialized a second n-row Vec<Vec<...>>
395        // that duplicated the neighborhood structure — this version
396        // shares the angular graph with `concept_path`.
397        let adj = self.knn_adjacency(&items, k);
398
399        // Dijkstra on effective weights
400        let mut dist = vec![f64::INFINITY; n];
401        let mut prev: Vec<Option<usize>> = vec![None; n];
402        let mut heap = BinaryHeap::new();
403
404        dist[source_idx] = 0.0;
405        heap.push(DijkstraEntry {
406            dist: 0.0,
407            node: source_idx,
408        });
409
410        while let Some(entry) = heap.pop() {
411            let u = entry.node;
412            if entry.dist > dist[u] {
413                continue;
414            }
415            if u == target_idx {
416                break;
417            }
418            for &(v, raw_d) in &adj[u] {
419                let (w, _) = cross_category_weight(raw_d, &item_cats, u, v, bridge_strengths);
420                let nd = dist[u] + w;
421                if nd < dist[v] {
422                    dist[v] = nd;
423                    prev[v] = Some(u);
424                    heap.push(DijkstraEntry { dist: nd, node: v });
425                }
426            }
427        }
428
429        if dist[target_idx].is_infinite() {
430            return None;
431        }
432
433        // Reconstruct with bridge metadata
434        let mut path = Vec::new();
435        let mut cur = target_idx;
436        loop {
437            let edge_info = prev[cur].and_then(|p| {
438                adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, raw_d)| {
439                    let (_, bs) =
440                        cross_category_weight(raw_d, &item_cats, p, cur, bridge_strengths);
441                    (raw_d, bs)
442                })
443            });
444            let hop_distance = edge_info.map_or(0.0, |(d, _)| d);
445            let bridge_str = edge_info.and_then(|(_, bs)| bs);
446
447            path.push(PathStep {
448                id: items[cur].id.clone(),
449                cumulative_distance: dist[cur],
450                hop_distance,
451                category: item_cats[cur],
452                bridge_strength: bridge_str,
453            });
454            match prev[cur] {
455                Some(p) => cur = p,
456                None => break,
457            }
458        }
459        path.reverse();
460
461        Some(ConceptPath {
462            total_distance: dist[target_idx],
463            steps: path,
464        })
465    }
466}
467
468// --- Concept path types ---
469
470#[derive(Debug, Clone)]
471pub struct ConceptPath {
472    pub steps: Vec<PathStep>,
473    pub total_distance: f64,
474}
475
476#[derive(Debug, Clone)]
477pub struct PathStep {
478    pub id: String,
479    pub cumulative_distance: f64,
480    /// Angular distance of this hop (0.0 for the first step).
481    pub hop_distance: f64,
482    /// Category index of this item (None if no category info was provided).
483    pub category: Option<usize>,
484    /// Bridge strength used on the hop *to* this step (None for same-category
485    /// hops or the first step). Present only when `concept_path_bridged` is used.
486    pub bridge_strength: Option<f64>,
487}
488
489#[derive(PartialEq)]
490struct DijkstraEntry {
491    dist: f64,
492    node: usize,
493}
494
495// Safety: dist values come from cosine_proxy on unit vectors, never NaN in practice.
496// Ord impl uses unwrap_or(Equal) as a NaN guard.
497impl Eq for DijkstraEntry {}
498
499impl PartialOrd for DijkstraEntry {
500    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
501        Some(self.cmp(other))
502    }
503}
504
505impl Ord for DijkstraEntry {
506    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
507        // Reversed: BinaryHeap is a max-heap, so smaller dist = higher priority
508        other
509            .dist
510            .partial_cmp(&self.dist)
511            .unwrap_or(std::cmp::Ordering::Equal)
512    }
513}
514
515// --- Slicing manifold ---
516
517/// A 2D plane fitted through the 3D projected point cloud that captures
518/// the maximum variance. Found by PCA on the Cartesian coordinates of
519/// the projected embeddings.
520///
521/// The plane is defined by:
522/// - `centroid`: the mean of all 3D points
523/// - `basis_u`, `basis_v`: orthonormal vectors spanning the plane (directions of max variance)
524/// - `normal`: vector perpendicular to the plane (direction of minimum variance)
525#[derive(Debug, Clone)]
526pub struct SlicingManifold {
527    pub centroid: [f64; 3],
528    pub normal: [f64; 3],
529    pub basis_u: [f64; 3],
530    pub basis_v: [f64; 3],
531    pub variance_ratio: f64,
532}
533
534impl SlicingManifold {
535    /// Fit the optimal slicing plane to a set of 3D points.
536    /// Each point is (x, y, z) in Cartesian coordinates.
537    pub fn fit(points: &[[f64; 3]]) -> Self {
538        let n = points.len() as f64;
539        assert!(n >= 3.0, "need at least 3 points to fit a plane");
540
541        // Centroid
542        let mut c = [0.0; 3];
543        for p in points {
544            for i in 0..3 {
545                c[i] += p[i];
546            }
547        }
548        for ci in &mut c {
549            *ci /= n;
550        }
551
552        // 3×3 covariance matrix (symmetric)
553        let mut cov = [[0.0f64; 3]; 3];
554        for p in points {
555            let d = [p[0] - c[0], p[1] - c[1], p[2] - c[2]];
556            for i in 0..3 {
557                for j in 0..3 {
558                    cov[i][j] += d[i] * d[j];
559                }
560            }
561        }
562        for row in &mut cov {
563            for v in row.iter_mut() {
564                *v /= n;
565            }
566        }
567
568        // Eigendecomposition of 3×3 symmetric matrix via Jacobi iteration
569        let (eigenvalues, eigenvectors) = eigen_symmetric_3x3(&cov);
570
571        // eigenvalues are sorted descending: λ₀ ≥ λ₁ ≥ λ₂
572        // basis_u = eigenvector of λ₀, basis_v = eigenvector of λ₁, normal = eigenvector of λ₂
573        let total_var = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
574        let variance_ratio = if total_var > 0.0 {
575            (eigenvalues[0] + eigenvalues[1]) / total_var
576        } else {
577            1.0
578        };
579
580        Self {
581            centroid: c,
582            normal: eigenvectors[2],
583            basis_u: eigenvectors[0],
584            basis_v: eigenvectors[1],
585            variance_ratio,
586        }
587    }
588
589    /// Project a 3D point onto the plane, returning (u, v) coordinates.
590    pub fn project_2d(&self, point: &[f64; 3]) -> (f64, f64) {
591        let d = [
592            point[0] - self.centroid[0],
593            point[1] - self.centroid[1],
594            point[2] - self.centroid[2],
595        ];
596        let u = d[0] * self.basis_u[0] + d[1] * self.basis_u[1] + d[2] * self.basis_u[2];
597        let v = d[0] * self.basis_v[0] + d[1] * self.basis_v[1] + d[2] * self.basis_v[2];
598        (u, v)
599    }
600
601    /// Signed distance from the plane (positive = same side as normal).
602    pub fn distance(&self, point: &[f64; 3]) -> f64 {
603        let d = [
604            point[0] - self.centroid[0],
605            point[1] - self.centroid[1],
606            point[2] - self.centroid[2],
607        ];
608        d[0] * self.normal[0] + d[1] * self.normal[1] + d[2] * self.normal[2]
609    }
610
611    /// Fit a local manifold around a query point using its k nearest neighbors.
612    ///
613    /// The local plane captures the shape of the semantic neighborhood:
614    /// - If variance_ratio ≈ 1.0, the neighborhood is flat (concepts spread in a plane)
615    /// - If variance_ratio ≈ 0.67, concepts are uniformly distributed (spherical)
616    /// - The normal direction reveals which semantic axis is least relevant locally
617    ///
618    /// This enables directional search narrowing: once you know the local geometry,
619    /// you can restrict subsequent queries to the dominant plane, cutting the
620    /// effective search dimensionality from 3D to 2D in that region.
621    pub fn fit_local(query: &[f64; 3], all_points: &[[f64; 3]], k: usize) -> Self {
622        let mut dists: Vec<(usize, f64)> = all_points
623            .iter()
624            .enumerate()
625            .map(|(i, p)| (i, dist3(query, p)))
626            .collect();
627        // `total_cmp` gives a total order over all f64 including NaN
628        // (which sorts to the end). Previously `.partial_cmp().unwrap()`
629        // panicked on NaN — and NaN is reachable whenever `all_points`
630        // contains a degenerate entry from a lossy projection, making
631        // this one of the few query-path panic sites in the crate.
632        dists.sort_by(|a, b| a.1.total_cmp(&b.1));
633
634        let neighborhood: Vec<[f64; 3]> = dists
635            .iter()
636            .take(k.max(3))
637            .map(|&(i, _)| all_points[i])
638            .collect();
639
640        Self::fit(&neighborhood)
641    }
642}
643
644/// Eigendecomposition of a 3×3 symmetric matrix via Jacobi rotations.
645/// Returns (eigenvalues_desc, eigenvectors_desc) sorted by decreasing eigenvalue.
646fn eigen_symmetric_3x3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
647    let mut a = *m;
648    let mut v = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; // eigenvector matrix
649
650    #[allow(clippy::needless_range_loop)]
651    for _ in 0..50 {
652        // Find largest off-diagonal element
653        let mut p = 0;
654        let mut q = 1;
655        let mut max_val = a[0][1].abs();
656        for i in 0..3 {
657            for j in (i + 1)..3 {
658                if a[i][j].abs() > max_val {
659                    max_val = a[i][j].abs();
660                    p = i;
661                    q = j;
662                }
663            }
664        }
665        if max_val < 1e-15 {
666            break;
667        }
668
669        // Jacobi rotation to zero out a[p][q]
670        let theta = if (a[p][p] - a[q][q]).abs() < 1e-30 {
671            std::f64::consts::FRAC_PI_4
672        } else {
673            0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
674        };
675        let c = theta.cos();
676        let s = theta.sin();
677
678        // Rotate a ← GᵀaG
679        let mut new_a = a;
680        for i in 0..3 {
681            new_a[i][p] = c * a[i][p] + s * a[i][q];
682            new_a[i][q] = -s * a[i][p] + c * a[i][q];
683        }
684        let snapshot = new_a;
685        for j in 0..3 {
686            new_a[p][j] = c * snapshot[p][j] + s * snapshot[q][j];
687            new_a[q][j] = -s * snapshot[p][j] + c * snapshot[q][j];
688        }
689        new_a[p][q] = 0.0;
690        new_a[q][p] = 0.0;
691        a = new_a;
692
693        // Rotate eigenvectors: V ← VG
694        let mut new_v = v;
695        for i in 0..3 {
696            new_v[i][p] = c * v[i][p] + s * v[i][q];
697            new_v[i][q] = -s * v[i][p] + c * v[i][q];
698        }
699        v = new_v;
700    }
701
702    let eigenvalues = [a[0][0], a[1][1], a[2][2]];
703
704    // Sort by descending eigenvalue
705    let mut order = [0usize, 1, 2];
706    order.sort_by(|&a, &b| eigenvalues[b].partial_cmp(&eigenvalues[a]).unwrap());
707
708    let sorted_vals = [
709        eigenvalues[order[0]],
710        eigenvalues[order[1]],
711        eigenvalues[order[2]],
712    ];
713    // Eigenvectors are columns of v
714    let sorted_vecs = [
715        [v[0][order[0]], v[1][order[0]], v[2][order[0]]],
716        [v[0][order[1]], v[1][order[1]], v[2][order[1]]],
717        [v[0][order[2]], v[1][order[2]], v[2][order[2]]],
718    ];
719
720    (sorted_vals, sorted_vecs)
721}
722
723// --- Concept Globs (spherical k-means + silhouette auto-k) ---
724
725/// A cluster of semantically related embeddings in the projected 3D space.
726#[derive(Debug, Clone)]
727pub struct ConceptGlob {
728    pub id: usize,
729    pub centroid: [f64; 3],
730    pub member_ids: Vec<String>,
731    pub member_distances: Vec<f64>,
732    pub radius: f64,
733}
734
735/// Result of glob detection: the set of all globs plus quality metrics.
736#[derive(Debug, Clone)]
737pub struct GlobResult {
738    pub globs: Vec<ConceptGlob>,
739    pub k: usize,
740    pub silhouette: f64,
741}
742
743impl GlobResult {
744    /// Detect concept globs from 3D projected points.
745    ///
746    /// If `k` is `Some`, uses that many clusters.
747    /// If `None`, auto-selects k ∈ [2, max_k] by maximizing the silhouette score.
748    pub fn detect(points: &[[f64; 3]], ids: &[String], k: Option<usize>, max_k: usize) -> Self {
749        let n = points.len();
750        assert_eq!(n, ids.len());
751        assert!(n >= 2, "need at least 2 points for clustering");
752
753        let max_k = max_k.min(n);
754
755        if let Some(k) = k {
756            let k = k.clamp(2, max_k);
757            let (assignments, silhouette) = kmeans_3d(points, k);
758            let globs = build_globs(points, ids, &assignments, k);
759            return Self {
760                globs,
761                k,
762                silhouette,
763            };
764        }
765
766        // Auto-detect: try k = 2..=max_k, pick best silhouette
767        let mut best_k = 2;
768        let mut best_sil = f64::NEG_INFINITY;
769        let mut best_assignments = vec![0usize; n];
770
771        for trial_k in 2..=max_k {
772            let (assignments, sil) = kmeans_3d(points, trial_k);
773            if sil > best_sil {
774                best_sil = sil;
775                best_k = trial_k;
776                best_assignments = assignments;
777            }
778        }
779
780        let globs = build_globs(points, ids, &best_assignments, best_k);
781        Self {
782            globs,
783            k: best_k,
784            silhouette: best_sil,
785        }
786    }
787}
788
789fn kmeans_3d(points: &[[f64; 3]], k: usize) -> (Vec<usize>, f64) {
790    let n = points.len();
791    let max_iter = 50;
792
793    // Init: spread initial centers evenly across the point set
794    let mut centers: Vec<[f64; 3]> = (0..k).map(|i| points[i * n / k]).collect();
795
796    let mut assignments = vec![0usize; n];
797
798    for _ in 0..max_iter {
799        let mut changed = false;
800
801        // Assign by angular distance (direction, not position)
802        for (i, p) in points.iter().enumerate() {
803            let mut best = 0;
804            let mut best_d = f64::MAX;
805            for (j, c) in centers.iter().enumerate() {
806                let d = angular_dist3(p, c);
807                if d < best_d {
808                    best_d = d;
809                    best = j;
810                }
811            }
812            if assignments[i] != best {
813                assignments[i] = best;
814                changed = true;
815            }
816        }
817
818        if !changed {
819            break;
820        }
821
822        // Update centers: mean direction (Euclidean mean of unit vectors, then normalize).
823        // This is the Fréchet mean on S² for concentrated clusters.
824        let mut sums = vec![[0.0f64; 3]; k];
825        let mut counts = vec![0usize; k];
826        for (i, &a) in assignments.iter().enumerate() {
827            let norm_p = normalize3(&points[i]);
828            for (d, &np) in norm_p.iter().enumerate() {
829                sums[a][d] += np;
830            }
831            counts[a] += 1;
832        }
833        for j in 0..k {
834            if counts[j] > 0 {
835                centers[j] = normalize3(&sums[j]);
836            }
837        }
838    }
839
840    let sil = silhouette_score(points, &assignments, k);
841    (assignments, sil)
842}
843
844fn silhouette_score(points: &[[f64; 3]], assignments: &[usize], k: usize) -> f64 {
845    let n = points.len();
846    if n <= 1 || k <= 1 {
847        return 0.0;
848    }
849
850    let mut total = 0.0;
851    for i in 0..n {
852        let ci = assignments[i];
853
854        // a(i): mean angular dist to same-cluster members
855        let mut a_sum = 0.0;
856        let mut a_cnt = 0;
857        for j in 0..n {
858            if j != i && assignments[j] == ci {
859                a_sum += angular_dist3(&points[i], &points[j]);
860                a_cnt += 1;
861            }
862        }
863        let a = if a_cnt > 0 { a_sum / a_cnt as f64 } else { 0.0 };
864
865        // b(i): min mean angular dist to any other cluster
866        let mut b = f64::MAX;
867        for ck in 0..k {
868            if ck == ci {
869                continue;
870            }
871            let mut b_sum = 0.0;
872            let mut b_cnt = 0;
873            for j in 0..n {
874                if assignments[j] == ck {
875                    b_sum += angular_dist3(&points[i], &points[j]);
876                    b_cnt += 1;
877                }
878            }
879            if b_cnt > 0 {
880                b = b.min(b_sum / b_cnt as f64);
881            }
882        }
883        if b == f64::MAX {
884            b = 0.0;
885        }
886
887        let denom = a.max(b);
888        total += if denom > 0.0 { (b - a) / denom } else { 0.0 };
889    }
890
891    total / n as f64
892}
893
894fn build_globs(
895    points: &[[f64; 3]],
896    ids: &[String],
897    assignments: &[usize],
898    k: usize,
899) -> Vec<ConceptGlob> {
900    let mut globs = Vec::with_capacity(k);
901
902    for cluster_id in 0..k {
903        let member_indices: Vec<usize> = assignments
904            .iter()
905            .enumerate()
906            .filter(|&(_, &a)| a == cluster_id)
907            .map(|(i, _)| i)
908            .collect();
909
910        if member_indices.is_empty() {
911            continue;
912        }
913
914        // Centroid: mean direction (normalize to get angular centroid)
915        let mut centroid = [0.0; 3];
916        for &i in &member_indices {
917            let norm_p = normalize3(&points[i]);
918            for (d, c) in centroid.iter_mut().enumerate() {
919                *c += norm_p[d];
920            }
921        }
922        centroid = normalize3(&centroid);
923
924        // Member angular distances from centroid
925        let member_distances: Vec<f64> = member_indices
926            .iter()
927            .map(|&i| angular_dist3(&points[i], &centroid))
928            .collect();
929
930        let radius = member_distances.iter().cloned().fold(0.0f64, f64::max);
931
932        let member_ids: Vec<String> = member_indices.iter().map(|&i| ids[i].clone()).collect();
933
934        globs.push(ConceptGlob {
935            id: cluster_id,
936            centroid,
937            member_ids,
938            member_distances,
939            radius,
940        });
941    }
942
943    globs
944}
945
946fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
947    let dx = a[0] - b[0];
948    let dy = a[1] - b[1];
949    let dz = a[2] - b[2];
950    (dx * dx + dy * dy + dz * dz).sqrt()
951}
952
953/// Angular distance between two 3D points treated as direction vectors.
954/// Returns the angle in radians [0, π]. Ignores magnitude differences.
955fn angular_dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
956    let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
957    let ma = (a[0] * a[0] + a[1] * a[1] + a[2] * a[2]).sqrt();
958    let mb = (b[0] * b[0] + b[1] * b[1] + b[2] * b[2]).sqrt();
959    let denom = ma * mb;
960    if denom < f64::EPSILON {
961        return 0.0;
962    }
963    (dot / denom).clamp(-1.0, 1.0).acos()
964}
965
966/// Normalize a 3D vector to unit length. Returns zero vector if input is zero.
967fn normalize3(v: &[f64; 3]) -> [f64; 3] {
968    let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
969    if mag < f64::EPSILON {
970        return [0.0; 3];
971    }
972    [v[0] / mag, v[1] / mag, v[2] / mag]
973}
974
975/// Compute the effective edge weight for a hop between two items.
976///
977/// Same-category hops use raw angular distance. Cross-category hops are
978/// weighted by the bridge's quality classification:
979/// - `Genuine`:         `angular_dist / (strength + 0.1)`
980/// - `Weak`:            `angular_dist / (strength + 0.01)`  (harsher penalty)
981/// - `OverlapArtifact`: `angular_dist * 2.0`  (actively discouraged — the
982///   two categories share territory, so the "bridge" isn't real)
983///
984/// Unknown cross-category pairs (no entry in `bridge_strengths`) are treated
985/// as `Weak` with strength 0.
986///
987/// Returns (effective_weight, Option<bridge_strength>).
988/// The bridge_strength is None for same-category hops.
989fn cross_category_weight(
990    angular_dist: f64,
991    item_cats: &[Option<usize>],
992    i: usize,
993    j: usize,
994    bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
995) -> (f64, Option<f64>) {
996    match (item_cats[i], item_cats[j]) {
997        (Some(ci), Some(cj)) if ci != cj => {
998            let (strength, classification) = bridge_strengths
999                .get(&(ci, cj))
1000                .or_else(|| bridge_strengths.get(&(cj, ci)))
1001                .copied()
1002                .unwrap_or((0.0, BridgeClassification::Weak));
1003            let weight = match classification {
1004                BridgeClassification::Genuine => angular_dist / (strength + 0.1),
1005                BridgeClassification::OverlapArtifact => angular_dist * 2.0,
1006                BridgeClassification::Weak => angular_dist / (strength + 0.01),
1007            };
1008            (weight, Some(strength))
1009        }
1010        _ => (angular_dist, None),
1011    }
1012}
1013
1014/// Builds SphereQL [`Region`]s from semantic constraints on embeddings.
1015pub struct SemanticQuery;
1016
1017impl SemanticQuery {
1018    /// Spherical cap: all points within `max_angular_distance` radians of the query.
1019    pub fn within_angle<P: Projection>(
1020        query: &Embedding,
1021        projection: &P,
1022        max_angular_distance: f64,
1023    ) -> Region {
1024        let point = projection.project(query);
1025        let half_angle = max_angular_distance.clamp(1e-10, std::f64::consts::PI);
1026        Region::Cap(
1027            Cap::new(
1028                SphericalPoint::new_unchecked(1.0, point.theta, point.phi),
1029                half_angle,
1030            )
1031            .unwrap(),
1032        )
1033    }
1034
1035    /// Spherical cap from a cosine similarity threshold.
1036    /// cos_sim >= threshold ↔ angular_distance <= arccos(threshold).
1037    pub fn above_similarity<P: Projection>(
1038        query: &Embedding,
1039        projection: &P,
1040        min_similarity: f64,
1041    ) -> Region {
1042        let half_angle = min_similarity.clamp(-1.0, 1.0).acos();
1043        Self::within_angle(query, projection, half_angle)
1044    }
1045
1046    /// Radial shell: embeddings whose projected radius falls in [inner, outer].
1047    pub fn in_shell(inner: f64, outer: f64) -> Region {
1048        Region::Shell(Shell::new(inner, outer).expect("invalid shell bounds"))
1049    }
1050
1051    /// Intersection of a similarity cap with a radial shell.
1052    /// "Semantically similar AND within a magnitude/metadata range."
1053    pub fn similar_in_shell<P: Projection>(
1054        query: &Embedding,
1055        projection: &P,
1056        min_similarity: f64,
1057        shell_inner: f64,
1058        shell_outer: f64,
1059    ) -> Region {
1060        Region::intersection(vec![
1061            Self::above_similarity(query, projection, min_similarity),
1062            Self::in_shell(shell_inner, shell_outer),
1063        ])
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use crate::projection::{PcaProjection, RandomProjection};
1071    use crate::types::RadialStrategy;
1072    use sphereql_core::angular_distance;
1073
1074    fn emb(vals: &[f64]) -> Embedding {
1075        Embedding::new(vals.to_vec())
1076    }
1077
1078    fn test_corpus() -> Vec<Embedding> {
1079        vec![
1080            emb(&[1.0, 0.0, 0.0, 0.1, 0.0]),
1081            emb(&[0.0, 1.0, 0.0, 0.0, 0.1]),
1082            emb(&[0.0, 0.0, 1.0, 0.1, 0.0]),
1083            emb(&[1.0, 1.0, 0.0, 0.05, 0.05]),
1084            emb(&[-1.0, 0.0, 0.0, -0.1, 0.0]),
1085            emb(&[0.0, -1.0, 0.0, 0.0, -0.1]),
1086        ]
1087    }
1088
1089    // --- EmbeddingIndex ---
1090
1091    #[test]
1092    fn insert_and_get() {
1093        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1094        let mut idx = EmbeddingIndex::builder(rp)
1095            .theta_divisions(4)
1096            .phi_divisions(3)
1097            .build();
1098
1099        idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1100        idx.insert("b", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1101
1102        assert_eq!(idx.len(), 2);
1103        assert!(!idx.is_empty());
1104        assert!(idx.get("a").is_some());
1105        assert!(idx.get("b").is_some());
1106        assert!(idx.get("c").is_none());
1107    }
1108
1109    #[test]
1110    fn remove() {
1111        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1112        let mut idx = EmbeddingIndex::builder(rp).build();
1113
1114        idx.insert("a", &emb(&[1.0; 5]));
1115        assert_eq!(idx.len(), 1);
1116
1117        let removed = idx.remove("a");
1118        assert!(removed.is_some());
1119        assert_eq!(removed.unwrap().id, "a");
1120        assert_eq!(idx.len(), 0);
1121        assert!(idx.get("a").is_none());
1122    }
1123
1124    #[test]
1125    fn remove_nonexistent() {
1126        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1127        let mut idx = EmbeddingIndex::builder(rp).build();
1128        assert!(idx.remove("nope").is_none());
1129    }
1130
1131    #[test]
1132    fn search_nearest_returns_sorted() {
1133        let corpus = test_corpus();
1134        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1135        let mut idx = EmbeddingIndex::builder(pca)
1136            .theta_divisions(4)
1137            .phi_divisions(3)
1138            .build();
1139
1140        for (i, e) in corpus.iter().enumerate() {
1141            idx.insert(format!("item-{i}"), e);
1142        }
1143
1144        let query = emb(&[0.95, 0.1, 0.0, 0.05, 0.0]);
1145        let results = idx.search_nearest(&query, 3);
1146
1147        assert_eq!(results.len(), 3);
1148        assert!(results[0].distance <= results[1].distance);
1149        assert!(results[1].distance <= results[2].distance);
1150    }
1151
1152    #[test]
1153    fn search_similar_respects_threshold() {
1154        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1155        let mut idx = EmbeddingIndex::builder(rp)
1156            .theta_divisions(4)
1157            .phi_divisions(3)
1158            .build();
1159
1160        idx.insert("close_a", &emb(&[1.0, 0.1, 0.0, 0.0, 0.0]));
1161        idx.insert("close_b", &emb(&[0.9, 0.2, 0.0, 0.0, 0.0]));
1162        idx.insert("far", &emb(&[-1.0, 0.0, 0.0, 0.0, 0.0]));
1163
1164        let query = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
1165        let projected_query = idx.projection().project(&query);
1166        let result = idx.search_similar(&query, 0.5);
1167
1168        let max_angle = 0.5_f64.acos();
1169        for item in &result.items {
1170            let d = angular_distance(&projected_query, item.position());
1171            assert!(d <= max_angle + 1e-10, "item {} too far: {d}", item.id);
1172        }
1173    }
1174
1175    #[test]
1176    fn insert_with_radius_overrides() {
1177        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1178        let mut idx = EmbeddingIndex::builder(rp).build();
1179
1180        idx.insert_with_radius("custom", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), 42.0);
1181        let item = idx.get("custom").unwrap();
1182        assert!((item.position.r - 42.0).abs() < 1e-12);
1183    }
1184
1185    #[test]
1186    fn original_magnitude_stored() {
1187        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1188        let mut idx = EmbeddingIndex::builder(rp).build();
1189
1190        let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0]);
1191        idx.insert("vec", &e);
1192        let item = idx.get("vec").unwrap();
1193        assert!((item.original_magnitude - 5.0).abs() < 1e-10);
1194    }
1195
1196    #[test]
1197    fn magnitude_radial_with_shell_query() {
1198        let corpus = test_corpus();
1199        let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
1200        let mut idx = EmbeddingIndex::builder(pca)
1201            .uniform_shells(5, 10.0)
1202            .theta_divisions(4)
1203            .phi_divisions(3)
1204            .build();
1205
1206        idx.insert("small", &emb(&[0.1, 0.0, 0.0, 0.0, 0.0]));
1207        idx.insert("medium", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1208        idx.insert("large", &emb(&[5.0, 0.0, 0.0, 0.0, 0.0]));
1209
1210        let shell = Shell::new(0.5, 2.0).unwrap();
1211        let result = idx.search_region(&Region::Shell(shell));
1212
1213        let ids: Vec<&str> = result.items.iter().map(|i| i.id.as_str()).collect();
1214        assert!(
1215            ids.contains(&"medium"),
1216            "medium (mag=1.0) should be in [0.5, 2.0]"
1217        );
1218        assert!(
1219            !ids.contains(&"large"),
1220            "large (mag=5.0) should not be in [0.5, 2.0]"
1221        );
1222    }
1223
1224    #[test]
1225    fn empty_index() {
1226        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1227        let idx = EmbeddingIndex::builder(rp).build();
1228
1229        assert!(idx.is_empty());
1230        assert_eq!(idx.len(), 0);
1231        assert!(idx.get("x").is_none());
1232
1233        let results = idx.search_nearest(&emb(&[1.0; 5]), 5);
1234        assert!(results.is_empty());
1235    }
1236
1237    #[test]
1238    fn projection_accessor() {
1239        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1240        let idx = EmbeddingIndex::builder(rp).build();
1241        assert_eq!(idx.projection().dimensionality(), 5);
1242    }
1243
1244    // --- SemanticQuery ---
1245
1246    #[test]
1247    fn above_similarity_creates_cap() {
1248        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1249        let region = SemanticQuery::above_similarity(&emb(&[1.0; 5]), &rp, 0.8);
1250        assert!(matches!(region, Region::Cap(_)));
1251    }
1252
1253    #[test]
1254    fn within_angle_creates_cap() {
1255        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1256        let region = SemanticQuery::within_angle(&emb(&[1.0; 5]), &rp, 0.5);
1257        assert!(matches!(region, Region::Cap(_)));
1258    }
1259
1260    #[test]
1261    fn in_shell_creates_shell() {
1262        let region = SemanticQuery::in_shell(1.0, 5.0);
1263        assert!(matches!(region, Region::Shell(_)));
1264    }
1265
1266    #[test]
1267    fn similar_in_shell_creates_intersection() {
1268        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1269        let region = SemanticQuery::similar_in_shell(&emb(&[1.0; 5]), &rp, 0.7, 1.0, 5.0);
1270
1271        match region {
1272            Region::Intersection(parts) => {
1273                assert_eq!(parts.len(), 2);
1274                assert!(matches!(parts[0], Region::Cap(_)));
1275                assert!(matches!(parts[1], Region::Shell(_)));
1276            }
1277            other => panic!("expected Intersection, got {other:?}"),
1278        }
1279    }
1280
1281    #[test]
1282    fn semantic_query_region_used_in_index() {
1283        let corpus = test_corpus();
1284        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1285        let projection_clone = pca.clone();
1286        let mut idx = EmbeddingIndex::builder(pca)
1287            .theta_divisions(4)
1288            .phi_divisions(3)
1289            .build();
1290
1291        for (i, e) in corpus.iter().enumerate() {
1292            idx.insert(format!("item-{i}"), e);
1293        }
1294
1295        let query = emb(&[1.0, 0.0, 0.0, 0.05, 0.0]);
1296        let region = SemanticQuery::above_similarity(&query, &projection_clone, 0.5);
1297        let result = idx.search_region(&region);
1298
1299        for item in &result.items {
1300            assert!(region.contains(item.position()));
1301        }
1302    }
1303
1304    // --- concept_path PathStep fields ---
1305
1306    #[test]
1307    fn concept_path_populates_hop_distance() {
1308        let corpus = test_corpus();
1309        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1310        let mut idx = EmbeddingIndex::builder(pca)
1311            .theta_divisions(4)
1312            .phi_divisions(3)
1313            .build();
1314
1315        for (i, e) in corpus.iter().enumerate() {
1316            idx.insert(format!("item-{i}"), e);
1317        }
1318
1319        if let Some(path) = idx.concept_path("item-0", "item-4", 3) {
1320            assert!(path.steps[0].hop_distance == 0.0, "first step has no hop");
1321            for step in &path.steps[1..] {
1322                assert!(
1323                    step.hop_distance > 0.0,
1324                    "subsequent steps should have a hop distance"
1325                );
1326            }
1327            assert!(path.steps.iter().all(|s| s.category.is_none()));
1328            assert!(path.steps.iter().all(|s| s.bridge_strength.is_none()));
1329        }
1330    }
1331
1332    // --- concept_path_bridged ---
1333
1334    #[test]
1335    fn concept_path_bridged_same_category_equals_unbridged() {
1336        let corpus = test_corpus();
1337        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1338        let mut idx = EmbeddingIndex::builder(pca)
1339            .theta_divisions(4)
1340            .phi_divisions(3)
1341            .build();
1342
1343        for (i, e) in corpus.iter().enumerate() {
1344            idx.insert(format!("item-{i}"), e);
1345        }
1346
1347        // All items in the same category — bridged path should equal unbridged
1348        let categories: HashMap<&str, usize> = (0..6)
1349            .map(|i| {
1350                (
1351                    ["item-0", "item-1", "item-2", "item-3", "item-4", "item-5"][i],
1352                    0,
1353                )
1354            })
1355            .collect();
1356        let bridges = HashMap::new();
1357
1358        let unbridged = idx.concept_path("item-0", "item-3", 3);
1359        let bridged = idx.concept_path_bridged("item-0", "item-3", 3, &categories, &bridges);
1360
1361        match (unbridged, bridged) {
1362            (Some(u), Some(b)) => {
1363                assert_eq!(u.steps.len(), b.steps.len());
1364                assert!((u.total_distance - b.total_distance).abs() < 1e-10);
1365                for step in &b.steps {
1366                    assert_eq!(step.category, Some(0));
1367                    assert!(step.bridge_strength.is_none());
1368                }
1369            }
1370            (None, None) => {} // both unreachable is fine
1371            _ => panic!("bridged and unbridged should agree on reachability"),
1372        }
1373    }
1374
1375    #[test]
1376    fn concept_path_bridged_penalizes_weak_bridges() {
1377        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1378        let mut idx = EmbeddingIndex::builder(rp)
1379            .theta_divisions(4)
1380            .phi_divisions(3)
1381            .build();
1382
1383        // Create two clusters in different categories
1384        // Category 0: items close to [1, 0, 0, 0, 0]
1385        idx.insert("a0", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1386        idx.insert("a1", &emb(&[0.9, 0.1, 0.0, 0.0, 0.0]));
1387        // Category 1: items close to [0, 1, 0, 0, 0]
1388        idx.insert("b0", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1389        idx.insert("b1", &emb(&[0.1, 0.9, 0.0, 0.0, 0.0]));
1390
1391        let mut categories: HashMap<&str, usize> = HashMap::new();
1392        categories.insert("a0", 0);
1393        categories.insert("a1", 0);
1394        categories.insert("b0", 1);
1395        categories.insert("b1", 1);
1396
1397        // Weak bridge between categories
1398        let mut weak_bridges = HashMap::new();
1399        weak_bridges.insert((0, 1), (0.05, BridgeClassification::Weak));
1400
1401        // Strong bridge between categories
1402        let mut strong_bridges = HashMap::new();
1403        strong_bridges.insert((0, 1), (0.95, BridgeClassification::Genuine));
1404
1405        let weak_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &weak_bridges);
1406        let strong_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &strong_bridges);
1407
1408        // Both should find a path (same topology)
1409        // But weak bridge should have higher total_distance
1410        if let (Some(weak), Some(strong)) = (weak_path, strong_path) {
1411            assert!(
1412                weak.total_distance > strong.total_distance,
1413                "weak bridge ({:.4}) should produce higher cost than strong ({:.4})",
1414                weak.total_distance,
1415                strong.total_distance
1416            );
1417        }
1418    }
1419
1420    #[test]
1421    fn concept_path_bridged_populates_bridge_metadata() {
1422        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1423        let mut idx = EmbeddingIndex::builder(rp)
1424            .theta_divisions(4)
1425            .phi_divisions(3)
1426            .build();
1427
1428        idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1429        idx.insert("b", &emb(&[0.5, 0.5, 0.0, 0.0, 0.0]));
1430        idx.insert("c", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1431
1432        let mut categories: HashMap<&str, usize> = HashMap::new();
1433        categories.insert("a", 0);
1434        categories.insert("b", 0);
1435        categories.insert("c", 1);
1436
1437        let mut bridges = HashMap::new();
1438        bridges.insert((0, 1), (0.7, BridgeClassification::Genuine));
1439
1440        if let Some(path) = idx.concept_path_bridged("a", "c", 3, &categories, &bridges) {
1441            // Each step should have category metadata
1442            for step in &path.steps {
1443                assert!(step.category.is_some());
1444            }
1445            // At least one cross-category hop should have bridge_strength
1446            let has_bridge = path.steps.iter().any(|s| s.bridge_strength.is_some());
1447            assert!(
1448                has_bridge,
1449                "should record bridge strength on cross-category hop"
1450            );
1451        }
1452    }
1453
1454    // --- cross_category_weight ---
1455
1456    #[test]
1457    fn cross_category_weight_same_category() {
1458        let cats = vec![Some(0), Some(0)];
1459        let bridges = HashMap::new();
1460        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1461        assert!((weight - 0.5).abs() < 1e-10);
1462        assert!(bs.is_none());
1463    }
1464
1465    #[test]
1466    fn cross_category_weight_different_categories_no_bridge() {
1467        let cats = vec![Some(0), Some(1)];
1468        let bridges = HashMap::new();
1469        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1470        // Missing entry → treated as Weak with strength 0: 0.5 / (0 + 0.01) = 50.0
1471        assert!((weight - 50.0).abs() < 1e-10);
1472        assert_eq!(bs, Some(0.0));
1473    }
1474
1475    #[test]
1476    fn cross_category_weight_genuine_bridge() {
1477        let cats = vec![Some(0), Some(1)];
1478        let mut bridges = HashMap::new();
1479        bridges.insert((0, 1), (0.9, BridgeClassification::Genuine));
1480        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1481        // Genuine: 0.5 / (0.9 + 0.1) = 0.5
1482        assert!((weight - 0.5).abs() < 1e-10);
1483        assert_eq!(bs, Some(0.9));
1484    }
1485
1486    #[test]
1487    fn cross_category_weight_weak_bridge() {
1488        let cats = vec![Some(0), Some(1)];
1489        let mut bridges = HashMap::new();
1490        bridges.insert((0, 1), (0.3, BridgeClassification::Weak));
1491        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1492        // Weak: 0.5 / (0.3 + 0.01) ≈ 1.6129
1493        assert!((weight - 0.5 / 0.31).abs() < 1e-10);
1494        assert_eq!(bs, Some(0.3));
1495    }
1496
1497    #[test]
1498    fn cross_category_weight_overlap_artifact_discouraged() {
1499        let cats = vec![Some(0), Some(1)];
1500        let mut bridges = HashMap::new();
1501        bridges.insert((0, 1), (0.8, BridgeClassification::OverlapArtifact));
1502        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1503        // OverlapArtifact: 0.5 * 2.0 = 1.0 (penalty, not reward)
1504        assert!((weight - 1.0).abs() < 1e-10);
1505        assert_eq!(bs, Some(0.8));
1506    }
1507
1508    #[test]
1509    fn cross_category_weight_no_category_info() {
1510        let cats = vec![None, Some(1)];
1511        let bridges = HashMap::new();
1512        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1513        assert!((weight - 0.5).abs() < 1e-10);
1514        assert!(bs.is_none());
1515    }
1516}