Skip to main content

sphereql_embed/
ann.rs

1//! Approximate nearest neighbors via random projection forest.
2//!
3//! Builds `n_trees` binary trees by recursively splitting the data at
4//! random hyperplanes. At query time, each tree nominates a candidate
5//! leaf; the union of candidates is re-scored exactly to produce the
6//! final k-NN list.
7//!
8//! Deterministic for a given seed (uses [`SplitMix64`]).
9//!
10//! Complexity:
11//! - Build: O(N · d · trees · log N)
12//! - Query: O(trees · log N · d + |candidates| · d)
13//!
14//! Designed for cosine similarity (all vectors are L2-normalized
15//! internally). Reusable across UMAP graph construction,
16//! `GraphModularity` scoring, and downstream consumers (globetrot
17//! similarity lookup).
18
19use crate::projection::{SplitMix64, dot, normalize_vec};
20
21/// Default seed for the RP-forest. Deterministic by design.
22const DEFAULT_ANN_SEED: u64 = 0xA00F_0E57;
23
24/// Configuration for the RP-forest index.
25#[derive(Debug, Clone)]
26pub struct AnnConfig {
27    /// Number of random projection trees. More trees = better recall,
28    /// slower build + query. 8 is a good default for N < 1M.
29    pub n_trees: usize,
30    /// Maximum leaf size. Nodes with fewer than this many items are not
31    /// split further. Smaller = deeper trees = more precise but slower.
32    pub max_leaf_size: usize,
33    /// PRNG seed for reproducibility.
34    pub seed: u64,
35}
36
37impl Default for AnnConfig {
38    fn default() -> Self {
39        Self {
40            n_trees: 8,
41            max_leaf_size: 40,
42            seed: DEFAULT_ANN_SEED,
43        }
44    }
45}
46
47/// A built RP-forest index over L2-normalized vectors.
48pub struct AnnIndex {
49    trees: Vec<RpTree>,
50    /// L2-normalized copies of the input vectors. Stored so that exact
51    /// re-scoring at query time uses the same normalization as the
52    /// tree splits.
53    normalized: Vec<Vec<f64>>,
54    dim: usize,
55}
56
57/// One node in an RP-tree. Either a split (hyperplane + children) or a
58/// leaf (list of item indices).
59enum RpNode {
60    Split {
61        /// Unit normal of the splitting hyperplane.
62        normal: Vec<f64>,
63        /// Items with `dot(x, normal) >= offset` go right.
64        offset: f64,
65        left: Box<RpNode>,
66        right: Box<RpNode>,
67    },
68    Leaf {
69        indices: Vec<usize>,
70    },
71}
72
73struct RpTree {
74    root: RpNode,
75}
76
77impl AnnIndex {
78    /// Build the index from raw (un-normalized) vectors. Each vector is
79    /// L2-normalized internally. All vectors must have the same
80    /// dimensionality. Panics if `data` is empty or dimensions
81    /// disagree.
82    pub fn build(data: &[Vec<f64>], config: &AnnConfig) -> Self {
83        assert!(
84            !data.is_empty(),
85            "AnnIndex::build requires at least one vector"
86        );
87        let dim = data[0].len();
88        for (i, v) in data.iter().enumerate() {
89            assert_eq!(
90                v.len(),
91                dim,
92                "AnnIndex::build: vector {i} has dim {}, expected {dim}",
93                v.len()
94            );
95        }
96
97        let normalized: Vec<Vec<f64>> = data
98            .iter()
99            .map(|v| {
100                let mut n = v.clone();
101                normalize_vec(&mut n);
102                n
103            })
104            .collect();
105
106        Self::build_from_normalized(normalized, dim, config)
107    }
108
109    /// Build from pre-normalized vectors (avoids a redundant
110    /// normalization pass when the caller already has unit vectors).
111    /// Enforces the same contract as [`Self::build`]: panics if
112    /// `normalized` is empty or dimensions disagree.
113    pub fn build_normalized(normalized: Vec<Vec<f64>>, config: &AnnConfig) -> Self {
114        assert!(
115            !normalized.is_empty(),
116            "AnnIndex::build_normalized requires at least one vector"
117        );
118        let dim = normalized[0].len();
119        for (i, v) in normalized.iter().enumerate() {
120            assert_eq!(
121                v.len(),
122                dim,
123                "AnnIndex::build_normalized: vector {i} has dim {}, expected {dim}",
124                v.len()
125            );
126        }
127        Self::build_from_normalized(normalized, dim, config)
128    }
129
130    fn build_from_normalized(normalized: Vec<Vec<f64>>, dim: usize, config: &AnnConfig) -> Self {
131        assert!(
132            config.n_trees > 0,
133            "AnnConfig.n_trees must be > 0 (zero trees yields an index that returns no neighbors)"
134        );
135        assert!(
136            config.max_leaf_size > 0,
137            "AnnConfig.max_leaf_size must be > 0 (zero recurses forever on singleton partitions)"
138        );
139
140        let all_indices: Vec<usize> = (0..normalized.len()).collect();
141        let mut rng = SplitMix64::new(config.seed);
142
143        let trees: Vec<RpTree> = (0..config.n_trees)
144            .map(|_| {
145                let root = build_tree(
146                    &normalized,
147                    &all_indices,
148                    dim,
149                    config.max_leaf_size,
150                    &mut rng,
151                );
152                RpTree { root }
153            })
154            .collect();
155
156        Self {
157            trees,
158            normalized,
159            dim,
160        }
161    }
162
163    /// Find the `k` approximate nearest neighbors of `query` by cosine
164    /// similarity. Returns `(index, similarity)` pairs sorted by
165    /// descending similarity. `query` is L2-normalized internally.
166    pub fn query(&self, query: &[f64], k: usize) -> Vec<(usize, f64)> {
167        assert_eq!(query.len(), self.dim);
168        let mut q = query.to_vec();
169        normalize_vec(&mut q);
170
171        let mut candidates = Vec::new();
172        for tree in &self.trees {
173            collect_leaf(&tree.root, &q, &mut candidates);
174        }
175        candidates.sort_unstable();
176        candidates.dedup();
177
178        let mut scored: Vec<(usize, f64)> = candidates
179            .iter()
180            .map(|&i| (i, dot(&q, &self.normalized[i])))
181            .collect();
182        scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
183        scored.truncate(k);
184        scored
185    }
186
187    /// Find the `k` approximate nearest neighbors of the item at
188    /// `index` (excludes self from results).
189    pub fn query_by_index(&self, index: usize, k: usize) -> Vec<(usize, f64)> {
190        let q = &self.normalized[index];
191        let mut candidates = Vec::new();
192        for tree in &self.trees {
193            collect_leaf(&tree.root, q, &mut candidates);
194        }
195        candidates.sort_unstable();
196        candidates.dedup();
197
198        let mut scored: Vec<(usize, f64)> = candidates
199            .iter()
200            .filter(|&&i| i != index)
201            .map(|&i| (i, dot(q, &self.normalized[i])))
202            .collect();
203        scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
204        scored.truncate(k);
205        scored
206    }
207
208    /// Build a full k-NN adjacency list for all items. Returns
209    /// `knn[i]` = indices of the k nearest neighbors of item `i`
210    /// (excluding self), sorted by descending similarity.
211    pub fn knn_graph(&self, k: usize) -> Vec<Vec<usize>> {
212        self.knn_graph_with_sims(k)
213            .into_iter()
214            .map(|row| row.into_iter().map(|(j, _)| j).collect())
215            .collect()
216    }
217
218    /// Like [`Self::knn_graph`], but keeps each neighbor's cosine
219    /// similarity: `graph[i]` = `(index, similarity)` pairs for the k
220    /// nearest neighbors of item `i` (excluding self), sorted by
221    /// descending similarity.
222    pub fn knn_graph_with_sims(&self, k: usize) -> Vec<Vec<(usize, f64)>> {
223        (0..self.normalized.len())
224            .map(|i| self.query_by_index(i, k))
225            .collect()
226    }
227
228    /// Number of indexed items.
229    pub fn len(&self) -> usize {
230        self.normalized.len()
231    }
232
233    /// True if the index contains no items.
234    pub fn is_empty(&self) -> bool {
235        self.normalized.is_empty()
236    }
237}
238
239// ── Tree construction ──────────────────────────────────────────────────
240
241fn build_tree(
242    data: &[Vec<f64>],
243    indices: &[usize],
244    dim: usize,
245    max_leaf: usize,
246    rng: &mut SplitMix64,
247) -> RpNode {
248    if indices.len() <= max_leaf {
249        return RpNode::Leaf {
250            indices: indices.to_vec(),
251        };
252    }
253
254    // Annoy-style RP split: pick two items, take their difference as the
255    // hyperplane normal. Falls back to a random Gaussian normal if the
256    // two items collide.
257    let a = indices[(rng.next_u64() as usize) % indices.len()];
258    let mut b = indices[(rng.next_u64() as usize) % indices.len()];
259    let mut attempts = 0;
260    while b == a && attempts < 10 {
261        b = indices[(rng.next_u64() as usize) % indices.len()];
262        attempts += 1;
263    }
264
265    let mut normal: Vec<f64> = data[a]
266        .iter()
267        .zip(data[b].iter())
268        .map(|(&ai, &bi)| ai - bi)
269        .collect();
270    let mag = normalize_vec(&mut normal);
271    if mag < f64::EPSILON {
272        normal = (0..dim).map(|_| rng.normal()).collect();
273        normalize_vec(&mut normal);
274    }
275
276    // Median projection gives a balanced split.
277    let mut projections: Vec<f64> = indices.iter().map(|&i| dot(&data[i], &normal)).collect();
278    projections.sort_unstable_by(|a, b| a.total_cmp(b));
279    let offset = projections[projections.len() / 2];
280
281    let mut left_idx = Vec::new();
282    let mut right_idx = Vec::new();
283    for &i in indices {
284        if dot(&data[i], &normal) < offset {
285            left_idx.push(i);
286        } else {
287            right_idx.push(i);
288        }
289    }
290
291    // Guard against degenerate splits where every item lands on one side.
292    if left_idx.is_empty() || right_idx.is_empty() {
293        let mid = indices.len() / 2;
294        left_idx = indices[..mid].to_vec();
295        right_idx = indices[mid..].to_vec();
296    }
297
298    let left = build_tree(data, &left_idx, dim, max_leaf, rng);
299    let right = build_tree(data, &right_idx, dim, max_leaf, rng);
300
301    RpNode::Split {
302        normal,
303        offset,
304        left: Box::new(left),
305        right: Box::new(right),
306    }
307}
308
309fn collect_leaf(node: &RpNode, query: &[f64], out: &mut Vec<usize>) {
310    match node {
311        RpNode::Leaf { indices } => {
312            out.extend_from_slice(indices);
313        }
314        RpNode::Split {
315            normal,
316            offset,
317            left,
318            right,
319        } => {
320            if dot(query, normal) < *offset {
321                collect_leaf(left, query, out);
322            } else {
323                collect_leaf(right, query, out);
324            }
325        }
326    }
327}
328
329// ── Tests ────────────────────────────────────────────────────────────
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
336        let mut rng = SplitMix64::new(seed);
337        (0..n)
338            .map(|_| (0..dim).map(|_| rng.normal()).collect())
339            .collect()
340    }
341
342    #[test]
343    fn build_and_query_smoke() {
344        let data = random_vectors(200, 32, 42);
345        let index = AnnIndex::build(&data, &AnnConfig::default());
346        assert_eq!(index.len(), 200);
347        assert!(!index.is_empty());
348
349        let results = index.query(&data[0], 5);
350        assert_eq!(results.len(), 5);
351        for w in results.windows(2) {
352            assert!(w[0].1 >= w[1].1);
353        }
354        assert_eq!(results[0].0, 0);
355    }
356
357    #[test]
358    fn query_by_index_excludes_self() {
359        let data = random_vectors(100, 16, 7);
360        let index = AnnIndex::build(&data, &AnnConfig::default());
361        let results = index.query_by_index(0, 5);
362        assert!(results.iter().all(|(i, _)| *i != 0));
363    }
364
365    #[test]
366    fn knn_graph_shape() {
367        let data = random_vectors(50, 16, 99);
368        let index = AnnIndex::build(&data, &AnnConfig::default());
369        let knn = index.knn_graph(5);
370        assert_eq!(knn.len(), 50);
371        for neighbors in &knn {
372            assert_eq!(neighbors.len(), 5);
373        }
374    }
375
376    #[test]
377    fn knn_graph_with_sims_matches_knn_graph() {
378        let data = random_vectors(50, 16, 99);
379        let index = AnnIndex::build(&data, &AnnConfig::default());
380        let plain = index.knn_graph(5);
381        let with_sims = index.knn_graph_with_sims(5);
382        assert_eq!(plain.len(), with_sims.len());
383        for (row, srow) in plain.iter().zip(&with_sims) {
384            assert_eq!(row.len(), srow.len());
385            for (j, (sj, sim)) in row.iter().zip(srow) {
386                assert_eq!(j, sj);
387                assert!(sim.is_finite() && *sim <= 1.0 + 1e-12);
388            }
389            for w in srow.windows(2) {
390                assert!(w[0].1 >= w[1].1);
391            }
392        }
393    }
394
395    #[test]
396    fn deterministic_with_same_seed() {
397        let data = random_vectors(100, 16, 42);
398        let cfg = AnnConfig {
399            seed: 0xBEEF,
400            ..AnnConfig::default()
401        };
402        let index1 = AnnIndex::build(&data, &cfg);
403        let index2 = AnnIndex::build(&data, &cfg);
404        let r1 = index1.query(&data[5], 10);
405        let r2 = index2.query(&data[5], 10);
406        assert_eq!(r1.len(), r2.len());
407        for (a, b) in r1.iter().zip(r2.iter()) {
408            assert_eq!(a.0, b.0);
409            assert!((a.1 - b.1).abs() < 1e-12);
410        }
411    }
412
413    #[test]
414    fn finds_true_nearest_in_top_results() {
415        // Two tight clusters; query a member of cluster A. True nearest
416        // neighbors should all be from cluster A.
417        let mut data = Vec::new();
418        let mut rng = SplitMix64::new(42);
419        for _ in 0..50 {
420            let mut v = vec![0.0; 16];
421            v[0] = 1.0 + rng.normal() * 0.05;
422            v[1] = 0.0 + rng.normal() * 0.05;
423            data.push(v);
424        }
425        for _ in 0..50 {
426            let mut v = vec![0.0; 16];
427            v[0] = 0.0 + rng.normal() * 0.05;
428            v[1] = 1.0 + rng.normal() * 0.05;
429            data.push(v);
430        }
431
432        let index = AnnIndex::build(&data, &AnnConfig::default());
433        let results = index.query_by_index(0, 10);
434        for (idx, _) in &results {
435            assert!(*idx < 50, "expected cluster A member, got index {idx}");
436        }
437    }
438
439    #[test]
440    fn empty_panics() {
441        let result = std::panic::catch_unwind(|| {
442            AnnIndex::build(&[], &AnnConfig::default());
443        });
444        assert!(result.is_err());
445    }
446
447    #[test]
448    fn build_normalized_ragged_input_panics() {
449        let result = std::panic::catch_unwind(|| {
450            AnnIndex::build_normalized(
451                vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0]],
452                &AnnConfig::default(),
453            );
454        });
455        assert!(result.is_err(), "ragged input must be rejected");
456    }
457}