1use crate::projection::{SplitMix64, dot, normalize_vec};
20
21const DEFAULT_ANN_SEED: u64 = 0xA00F_0E57;
23
24#[derive(Debug, Clone)]
26pub struct AnnConfig {
27 pub n_trees: usize,
30 pub max_leaf_size: usize,
33 pub seed: u64,
35}
36
37impl Default for AnnConfig {
38 fn default() -> Self {
39 Self {
40 n_trees: 8,
41 max_leaf_size: 40,
42 seed: DEFAULT_ANN_SEED,
43 }
44 }
45}
46
47pub struct AnnIndex {
49 trees: Vec<RpTree>,
50 normalized: Vec<Vec<f64>>,
54 dim: usize,
55}
56
57enum RpNode {
60 Split {
61 normal: Vec<f64>,
63 offset: f64,
65 left: Box<RpNode>,
66 right: Box<RpNode>,
67 },
68 Leaf {
69 indices: Vec<usize>,
70 },
71}
72
73struct RpTree {
74 root: RpNode,
75}
76
77impl AnnIndex {
78 pub fn build(data: &[Vec<f64>], config: &AnnConfig) -> Self {
83 assert!(
84 !data.is_empty(),
85 "AnnIndex::build requires at least one vector"
86 );
87 let dim = data[0].len();
88 for (i, v) in data.iter().enumerate() {
89 assert_eq!(
90 v.len(),
91 dim,
92 "AnnIndex::build: vector {i} has dim {}, expected {dim}",
93 v.len()
94 );
95 }
96
97 let normalized: Vec<Vec<f64>> = data
98 .iter()
99 .map(|v| {
100 let mut n = v.clone();
101 normalize_vec(&mut n);
102 n
103 })
104 .collect();
105
106 Self::build_from_normalized(normalized, dim, config)
107 }
108
109 pub fn build_normalized(normalized: Vec<Vec<f64>>, config: &AnnConfig) -> Self {
114 assert!(
115 !normalized.is_empty(),
116 "AnnIndex::build_normalized requires at least one vector"
117 );
118 let dim = normalized[0].len();
119 for (i, v) in normalized.iter().enumerate() {
120 assert_eq!(
121 v.len(),
122 dim,
123 "AnnIndex::build_normalized: vector {i} has dim {}, expected {dim}",
124 v.len()
125 );
126 }
127 Self::build_from_normalized(normalized, dim, config)
128 }
129
130 fn build_from_normalized(normalized: Vec<Vec<f64>>, dim: usize, config: &AnnConfig) -> Self {
131 assert!(
132 config.n_trees > 0,
133 "AnnConfig.n_trees must be > 0 (zero trees yields an index that returns no neighbors)"
134 );
135 assert!(
136 config.max_leaf_size > 0,
137 "AnnConfig.max_leaf_size must be > 0 (zero recurses forever on singleton partitions)"
138 );
139
140 let all_indices: Vec<usize> = (0..normalized.len()).collect();
141 let mut rng = SplitMix64::new(config.seed);
142
143 let trees: Vec<RpTree> = (0..config.n_trees)
144 .map(|_| {
145 let root = build_tree(
146 &normalized,
147 &all_indices,
148 dim,
149 config.max_leaf_size,
150 &mut rng,
151 );
152 RpTree { root }
153 })
154 .collect();
155
156 Self {
157 trees,
158 normalized,
159 dim,
160 }
161 }
162
163 pub fn query(&self, query: &[f64], k: usize) -> Vec<(usize, f64)> {
167 assert_eq!(query.len(), self.dim);
168 let mut q = query.to_vec();
169 normalize_vec(&mut q);
170
171 let mut candidates = Vec::new();
172 for tree in &self.trees {
173 collect_leaf(&tree.root, &q, &mut candidates);
174 }
175 candidates.sort_unstable();
176 candidates.dedup();
177
178 let mut scored: Vec<(usize, f64)> = candidates
179 .iter()
180 .map(|&i| (i, dot(&q, &self.normalized[i])))
181 .collect();
182 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
183 scored.truncate(k);
184 scored
185 }
186
187 pub fn query_by_index(&self, index: usize, k: usize) -> Vec<(usize, f64)> {
190 let q = &self.normalized[index];
191 let mut candidates = Vec::new();
192 for tree in &self.trees {
193 collect_leaf(&tree.root, q, &mut candidates);
194 }
195 candidates.sort_unstable();
196 candidates.dedup();
197
198 let mut scored: Vec<(usize, f64)> = candidates
199 .iter()
200 .filter(|&&i| i != index)
201 .map(|&i| (i, dot(q, &self.normalized[i])))
202 .collect();
203 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
204 scored.truncate(k);
205 scored
206 }
207
208 pub fn knn_graph(&self, k: usize) -> Vec<Vec<usize>> {
212 self.knn_graph_with_sims(k)
213 .into_iter()
214 .map(|row| row.into_iter().map(|(j, _)| j).collect())
215 .collect()
216 }
217
218 pub fn knn_graph_with_sims(&self, k: usize) -> Vec<Vec<(usize, f64)>> {
223 (0..self.normalized.len())
224 .map(|i| self.query_by_index(i, k))
225 .collect()
226 }
227
228 pub fn len(&self) -> usize {
230 self.normalized.len()
231 }
232
233 pub fn is_empty(&self) -> bool {
235 self.normalized.is_empty()
236 }
237}
238
239fn build_tree(
242 data: &[Vec<f64>],
243 indices: &[usize],
244 dim: usize,
245 max_leaf: usize,
246 rng: &mut SplitMix64,
247) -> RpNode {
248 if indices.len() <= max_leaf {
249 return RpNode::Leaf {
250 indices: indices.to_vec(),
251 };
252 }
253
254 let a = indices[(rng.next_u64() as usize) % indices.len()];
258 let mut b = indices[(rng.next_u64() as usize) % indices.len()];
259 let mut attempts = 0;
260 while b == a && attempts < 10 {
261 b = indices[(rng.next_u64() as usize) % indices.len()];
262 attempts += 1;
263 }
264
265 let mut normal: Vec<f64> = data[a]
266 .iter()
267 .zip(data[b].iter())
268 .map(|(&ai, &bi)| ai - bi)
269 .collect();
270 let mag = normalize_vec(&mut normal);
271 if mag < f64::EPSILON {
272 normal = (0..dim).map(|_| rng.normal()).collect();
273 normalize_vec(&mut normal);
274 }
275
276 let mut projections: Vec<f64> = indices.iter().map(|&i| dot(&data[i], &normal)).collect();
278 projections.sort_unstable_by(|a, b| a.total_cmp(b));
279 let offset = projections[projections.len() / 2];
280
281 let mut left_idx = Vec::new();
282 let mut right_idx = Vec::new();
283 for &i in indices {
284 if dot(&data[i], &normal) < offset {
285 left_idx.push(i);
286 } else {
287 right_idx.push(i);
288 }
289 }
290
291 if left_idx.is_empty() || right_idx.is_empty() {
293 let mid = indices.len() / 2;
294 left_idx = indices[..mid].to_vec();
295 right_idx = indices[mid..].to_vec();
296 }
297
298 let left = build_tree(data, &left_idx, dim, max_leaf, rng);
299 let right = build_tree(data, &right_idx, dim, max_leaf, rng);
300
301 RpNode::Split {
302 normal,
303 offset,
304 left: Box::new(left),
305 right: Box::new(right),
306 }
307}
308
309fn collect_leaf(node: &RpNode, query: &[f64], out: &mut Vec<usize>) {
310 match node {
311 RpNode::Leaf { indices } => {
312 out.extend_from_slice(indices);
313 }
314 RpNode::Split {
315 normal,
316 offset,
317 left,
318 right,
319 } => {
320 if dot(query, normal) < *offset {
321 collect_leaf(left, query, out);
322 } else {
323 collect_leaf(right, query, out);
324 }
325 }
326 }
327}
328
329#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
336 let mut rng = SplitMix64::new(seed);
337 (0..n)
338 .map(|_| (0..dim).map(|_| rng.normal()).collect())
339 .collect()
340 }
341
342 #[test]
343 fn build_and_query_smoke() {
344 let data = random_vectors(200, 32, 42);
345 let index = AnnIndex::build(&data, &AnnConfig::default());
346 assert_eq!(index.len(), 200);
347 assert!(!index.is_empty());
348
349 let results = index.query(&data[0], 5);
350 assert_eq!(results.len(), 5);
351 for w in results.windows(2) {
352 assert!(w[0].1 >= w[1].1);
353 }
354 assert_eq!(results[0].0, 0);
355 }
356
357 #[test]
358 fn query_by_index_excludes_self() {
359 let data = random_vectors(100, 16, 7);
360 let index = AnnIndex::build(&data, &AnnConfig::default());
361 let results = index.query_by_index(0, 5);
362 assert!(results.iter().all(|(i, _)| *i != 0));
363 }
364
365 #[test]
366 fn knn_graph_shape() {
367 let data = random_vectors(50, 16, 99);
368 let index = AnnIndex::build(&data, &AnnConfig::default());
369 let knn = index.knn_graph(5);
370 assert_eq!(knn.len(), 50);
371 for neighbors in &knn {
372 assert_eq!(neighbors.len(), 5);
373 }
374 }
375
376 #[test]
377 fn knn_graph_with_sims_matches_knn_graph() {
378 let data = random_vectors(50, 16, 99);
379 let index = AnnIndex::build(&data, &AnnConfig::default());
380 let plain = index.knn_graph(5);
381 let with_sims = index.knn_graph_with_sims(5);
382 assert_eq!(plain.len(), with_sims.len());
383 for (row, srow) in plain.iter().zip(&with_sims) {
384 assert_eq!(row.len(), srow.len());
385 for (j, (sj, sim)) in row.iter().zip(srow) {
386 assert_eq!(j, sj);
387 assert!(sim.is_finite() && *sim <= 1.0 + 1e-12);
388 }
389 for w in srow.windows(2) {
390 assert!(w[0].1 >= w[1].1);
391 }
392 }
393 }
394
395 #[test]
396 fn deterministic_with_same_seed() {
397 let data = random_vectors(100, 16, 42);
398 let cfg = AnnConfig {
399 seed: 0xBEEF,
400 ..AnnConfig::default()
401 };
402 let index1 = AnnIndex::build(&data, &cfg);
403 let index2 = AnnIndex::build(&data, &cfg);
404 let r1 = index1.query(&data[5], 10);
405 let r2 = index2.query(&data[5], 10);
406 assert_eq!(r1.len(), r2.len());
407 for (a, b) in r1.iter().zip(r2.iter()) {
408 assert_eq!(a.0, b.0);
409 assert!((a.1 - b.1).abs() < 1e-12);
410 }
411 }
412
413 #[test]
414 fn finds_true_nearest_in_top_results() {
415 let mut data = Vec::new();
418 let mut rng = SplitMix64::new(42);
419 for _ in 0..50 {
420 let mut v = vec![0.0; 16];
421 v[0] = 1.0 + rng.normal() * 0.05;
422 v[1] = 0.0 + rng.normal() * 0.05;
423 data.push(v);
424 }
425 for _ in 0..50 {
426 let mut v = vec![0.0; 16];
427 v[0] = 0.0 + rng.normal() * 0.05;
428 v[1] = 1.0 + rng.normal() * 0.05;
429 data.push(v);
430 }
431
432 let index = AnnIndex::build(&data, &AnnConfig::default());
433 let results = index.query_by_index(0, 10);
434 for (idx, _) in &results {
435 assert!(*idx < 50, "expected cluster A member, got index {idx}");
436 }
437 }
438
439 #[test]
440 fn empty_panics() {
441 let result = std::panic::catch_unwind(|| {
442 AnnIndex::build(&[], &AnnConfig::default());
443 });
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn build_normalized_ragged_input_panics() {
449 let result = std::panic::catch_unwind(|| {
450 AnnIndex::build_normalized(
451 vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0]],
452 &AnnConfig::default(),
453 );
454 });
455 assert!(result.is_err(), "ragged input must be rejected");
456 }
457}