1use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8
9use crate::dtype::Float;
10use crate::error::{CoreError, Result};
11use crate::tensor::Tensor;
12
13struct HeapEntry<T> {
19 sq_dist: T,
20 index: usize,
21}
22
23impl<T: Float> PartialEq for HeapEntry<T> {
24 fn eq(&self, other: &Self) -> bool {
25 self.sq_dist.to_f64() == other.sq_dist.to_f64()
26 }
27}
28
29impl<T: Float> Eq for HeapEntry<T> {}
30
31impl<T: Float> PartialOrd for HeapEntry<T> {
32 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33 Some(self.cmp(other))
34 }
35}
36
37impl<T: Float> Ord for HeapEntry<T> {
38 fn cmp(&self, other: &Self) -> Ordering {
39 self.sq_dist
40 .to_f64()
41 .partial_cmp(&other.sq_dist.to_f64())
42 .unwrap_or(Ordering::Equal)
43 }
44}
45
46const LEAF_SIZE: usize = 10;
52
53#[cfg_attr(
54 feature = "serde-support",
55 derive(serde::Serialize, serde::Deserialize)
56)]
57#[derive(Debug, Clone)]
58enum KdNode {
59 Leaf {
60 indices: Vec<usize>,
61 },
62 Internal {
63 split_dim: usize,
64 split_value: f64,
65 left: Box<KdNode>,
66 right: Box<KdNode>,
67 },
68}
69
70#[cfg_attr(
97 feature = "serde-support",
98 derive(serde::Serialize, serde::Deserialize)
99)]
100#[derive(Debug, Clone)]
101pub struct KdTree<T: Float> {
102 data: Vec<T>,
104 dim: usize,
106 n_points: usize,
108 root: KdNode,
110}
111
112impl<T: Float> KdTree<T> {
113 pub fn build(points: &[&[T]]) -> Result<Self> {
122 if points.is_empty() {
123 return Err(CoreError::InvalidArgument {
124 reason: "cannot build KD-tree from empty point set",
125 });
126 }
127 let dim = points[0].len();
128 if dim == 0 {
129 return Err(CoreError::InvalidArgument {
130 reason: "point dimensionality must be at least 1",
131 });
132 }
133 for (i, p) in points.iter().enumerate() {
134 if p.len() != dim {
135 return Err(CoreError::InvalidArgument {
136 reason: "all points must have the same dimensionality",
137 });
138 }
139 let _ = i; }
141
142 let n_points = points.len();
143 let mut data = Vec::with_capacity(n_points * dim);
144 for p in points {
145 data.extend_from_slice(p);
146 }
147
148 let indices: Vec<usize> = (0..n_points).collect();
149 let root = Self::build_recursive(&data, dim, indices);
150
151 Ok(Self {
152 data,
153 dim,
154 n_points,
155 root,
156 })
157 }
158
159 pub fn from_tensor(tensor: &Tensor<T>) -> Result<Self> {
165 let shape = tensor.shape();
166 if shape.len() != 2 {
167 return Err(CoreError::InvalidArgument {
168 reason: "tensor must be 2-dimensional (rows = points, cols = dims)",
169 });
170 }
171 let n = shape[0];
172 let dim = shape[1];
173 if n == 0 {
174 return Err(CoreError::InvalidArgument {
175 reason: "cannot build KD-tree from empty point set",
176 });
177 }
178
179 let slice = tensor.as_slice();
180 let refs: Vec<&[T]> = (0..n).map(|i| &slice[i * dim..(i + 1) * dim]).collect();
181 Self::build(&refs)
182 }
183
184 pub fn query(&self, query: &[T], k: usize) -> Result<(Vec<usize>, Vec<T>)> {
192 if k == 0 {
193 return Err(CoreError::InvalidArgument {
194 reason: "k must be at least 1",
195 });
196 }
197 if query.len() != self.dim {
198 return Err(CoreError::InvalidArgument {
199 reason: "query dimensionality does not match tree",
200 });
201 }
202 let k = k.min(self.n_points);
203
204 let mut heap: BinaryHeap<HeapEntry<T>> = BinaryHeap::new();
205 self.knn_recursive(&self.root, query, k, &mut heap);
206
207 let mut results: Vec<(usize, T)> = heap
209 .into_sorted_vec()
210 .into_iter()
211 .map(|e| (e.index, e.sq_dist.sqrt()))
212 .collect();
213 results.sort_by(|a, b| {
214 a.1.to_f64()
215 .partial_cmp(&b.1.to_f64())
216 .unwrap_or(Ordering::Equal)
217 });
218 let indices = results.iter().map(|(i, _)| *i).collect();
219 let dists = results.iter().map(|(_, d)| *d).collect();
220 Ok((indices, dists))
221 }
222
223 pub fn query_radius(&self, query: &[T], radius: T) -> Result<(Vec<usize>, Vec<T>)> {
231 if query.len() != self.dim {
232 return Err(CoreError::InvalidArgument {
233 reason: "query dimensionality does not match tree",
234 });
235 }
236 let sq_radius = radius * radius;
237 let mut results: Vec<(usize, T)> = Vec::new();
238 self.range_recursive(&self.root, query, sq_radius, &mut results);
239
240 results.sort_by(|a, b| {
242 a.1.to_f64()
243 .partial_cmp(&b.1.to_f64())
244 .unwrap_or(Ordering::Equal)
245 });
246 let indices = results.iter().map(|(i, _)| *i).collect();
247 let dists = results.into_iter().map(|(_, d)| d.sqrt()).collect();
248 Ok((indices, dists))
249 }
250
251 pub fn query_pairs(&self, r: T) -> Vec<(usize, usize)> {
255 let sq_r = r * r;
256 let mut pairs = Vec::new();
257 for i in 0..self.n_points {
258 let point = &self.data[i * self.dim..(i + 1) * self.dim];
259 let mut neighbors: Vec<(usize, T)> = Vec::new();
260 self.range_recursive(&self.root, point, sq_r, &mut neighbors);
261 for (j, _) in neighbors {
262 if i < j {
263 pairs.push((i, j));
264 }
265 }
266 }
267 pairs.sort_unstable();
268 pairs.dedup();
269 pairs
270 }
271
272 #[inline]
274 pub fn len(&self) -> usize {
275 self.n_points
276 }
277
278 #[inline]
280 pub fn is_empty(&self) -> bool {
281 self.n_points == 0
282 }
283
284 #[inline]
286 pub fn dim(&self) -> usize {
287 self.dim
288 }
289
290 fn build_recursive(data: &[T], dim: usize, mut indices: Vec<usize>) -> KdNode {
295 if indices.len() <= LEAF_SIZE {
296 return KdNode::Leaf { indices };
297 }
298
299 let split_dim = Self::widest_spread_dim(data, dim, &indices);
301
302 indices.sort_by(|&a, &b| {
304 let va = data[a * dim + split_dim].to_f64();
305 let vb = data[b * dim + split_dim].to_f64();
306 va.partial_cmp(&vb).unwrap_or(Ordering::Equal)
307 });
308
309 let median_idx = indices.len() / 2;
310 let split_value = data[indices[median_idx] * dim + split_dim].to_f64();
311
312 let right_indices = indices.split_off(median_idx);
313 let left_indices = indices;
314
315 let left = Box::new(Self::build_recursive(data, dim, left_indices));
316 let right = Box::new(Self::build_recursive(data, dim, right_indices));
317
318 KdNode::Internal {
319 split_dim,
320 split_value,
321 left,
322 right,
323 }
324 }
325
326 fn widest_spread_dim(data: &[T], dim: usize, indices: &[usize]) -> usize {
327 let mut best_dim = 0;
328 let mut best_spread = f64::NEG_INFINITY;
329 for d in 0..dim {
330 let mut lo = f64::INFINITY;
331 let mut hi = f64::NEG_INFINITY;
332 for &idx in indices {
333 let v = data[idx * dim + d].to_f64();
334 if v < lo {
335 lo = v;
336 }
337 if v > hi {
338 hi = v;
339 }
340 }
341 let spread = hi - lo;
342 if spread > best_spread {
343 best_spread = spread;
344 best_dim = d;
345 }
346 }
347 best_dim
348 }
349
350 fn squared_distance(&self, a: &[T], b_idx: usize) -> T {
355 let mut sum = T::zero();
356 let offset = b_idx * self.dim;
357 for (d, a_val) in a.iter().enumerate().take(self.dim) {
358 let diff = *a_val - self.data[offset + d];
359 sum += diff * diff;
360 }
361 sum
362 }
363
364 fn knn_recursive(
365 &self,
366 node: &KdNode,
367 query: &[T],
368 k: usize,
369 heap: &mut BinaryHeap<HeapEntry<T>>,
370 ) {
371 match node {
372 KdNode::Leaf { indices } => {
373 for &idx in indices {
374 let sq_dist = self.squared_distance(query, idx);
375 if heap.len() < k {
376 heap.push(HeapEntry {
377 sq_dist,
378 index: idx,
379 });
380 } else if heap
381 .peek()
382 .is_some_and(|worst| sq_dist.to_f64() < worst.sq_dist.to_f64())
383 {
384 heap.pop();
385 heap.push(HeapEntry {
386 sq_dist,
387 index: idx,
388 });
389 }
390 }
391 }
392 KdNode::Internal {
393 split_dim,
394 split_value,
395 left,
396 right,
397 } => {
398 let query_val = query[*split_dim].to_f64();
399 let diff = query_val - split_value;
400
401 let (first, second) = if diff <= 0.0 {
402 (left, right)
403 } else {
404 (right, left)
405 };
406
407 self.knn_recursive(first, query, k, heap);
408
409 let should_visit =
412 heap.len() < k || diff * diff < heap.peek().unwrap().sq_dist.to_f64();
413 if should_visit {
414 self.knn_recursive(second, query, k, heap);
415 }
416 }
417 }
418 }
419
420 fn range_recursive(
425 &self,
426 node: &KdNode,
427 query: &[T],
428 sq_radius: T,
429 results: &mut Vec<(usize, T)>,
430 ) {
431 match node {
432 KdNode::Leaf { indices } => {
433 for &idx in indices {
434 let sq_dist = self.squared_distance(query, idx);
435 if sq_dist.to_f64() <= sq_radius.to_f64() {
436 results.push((idx, sq_dist));
437 }
438 }
439 }
440 KdNode::Internal {
441 split_dim,
442 split_value,
443 left,
444 right,
445 } => {
446 let query_val = query[*split_dim].to_f64();
447 let diff = query_val - split_value;
448 let sq_diff = diff * diff;
449
450 let (first, second) = if diff <= 0.0 {
451 (left, right)
452 } else {
453 (right, left)
454 };
455
456 self.range_recursive(first, query, sq_radius, results);
457
458 if sq_diff <= sq_radius.to_f64() {
459 self.range_recursive(second, query, sq_radius, results);
460 }
461 }
462 }
463 }
464}
465
466#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_kd_tree_knn_exact_match() {
476 let pts: [[f64; 2]; 4] = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
477 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
478 let tree = KdTree::build(&refs).unwrap();
479
480 let (indices, dists) = tree.query(&[0.0, 0.0], 1).unwrap();
481 assert_eq!(indices.len(), 1);
482 assert_eq!(indices[0], 0);
483 assert!(dists[0].abs() < 1e-12);
484 }
485
486 #[test]
487 fn test_kd_tree_knn_k3_sorted() {
488 let pts: [[f64; 2]; 5] = [[0.0, 0.0], [1.0, 0.0], [3.0, 0.0], [5.0, 0.0], [10.0, 0.0]];
489 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
490 let tree = KdTree::build(&refs).unwrap();
491
492 let (indices, dists) = tree.query(&[0.5, 0.0], 3).unwrap();
493 assert_eq!(indices.len(), 3);
494 assert!(dists[0] <= dists[1]);
496 assert!(dists[1] <= dists[2]);
497 assert!((dists[0] - 0.5).abs() < 1e-12);
499 assert!((dists[1] - 0.5).abs() < 1e-12);
500 assert!((dists[2] - 2.5).abs() < 1e-12);
501 }
502
503 #[test]
504 fn test_kd_tree_range_query() {
505 let pts: [[f64; 2]; 4] = [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [10.0, 0.0]];
506 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
507 let tree = KdTree::build(&refs).unwrap();
508
509 let (indices, _dists) = tree.query_radius(&[0.0, 0.0], 1.5).unwrap();
510 assert_eq!(indices.len(), 2);
512 assert!(indices.contains(&0));
513 assert!(indices.contains(&1));
514 }
515
516 #[test]
517 fn test_kd_tree_query_pairs() {
518 let pts: [[f64; 2]; 3] = [[0.0, 0.0], [0.5, 0.0], [10.0, 0.0]];
519 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
520 let tree = KdTree::build(&refs).unwrap();
521
522 let pairs = tree.query_pairs(1.0);
523 assert_eq!(pairs.len(), 1);
524 assert_eq!(pairs[0], (0, 1));
525 }
526
527 #[test]
528 fn test_kd_tree_from_tensor() {
529 let data = vec![0.0_f64, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
530 let tensor = Tensor::from_vec(data, vec![4, 2]).unwrap();
531 let tree = KdTree::from_tensor(&tensor).unwrap();
532 assert_eq!(tree.len(), 4);
533 assert_eq!(tree.dim(), 2);
534
535 let (indices, _) = tree.query(&[0.0, 0.0], 1).unwrap();
536 assert_eq!(indices[0], 0);
537 }
538
539 #[test]
540 fn test_kd_tree_high_dimensional() {
541 let pts: [[f64; 5]; 3] = [
543 [0.0, 0.0, 0.0, 0.0, 0.0],
544 [1.0, 1.0, 1.0, 1.0, 1.0],
545 [2.0, 2.0, 2.0, 2.0, 2.0],
546 ];
547 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 5]>::as_slice).collect();
548 let tree = KdTree::build(&refs).unwrap();
549
550 let (indices, dists) = tree.query(&[0.0, 0.0, 0.0, 0.0, 0.0], 1).unwrap();
551 assert_eq!(indices[0], 0);
552 assert!(dists[0].abs() < 1e-12);
553
554 let (indices, dists) = tree.query(&[0.0, 0.0, 0.0, 0.0, 0.0], 2).unwrap();
556 assert_eq!(indices.len(), 2);
557 assert!((dists[1] - 5.0_f64.sqrt()).abs() < 1e-12);
558 }
559
560 #[test]
561 fn test_kd_tree_single_point() {
562 let pts: [[f64; 2]; 1] = [[42.0_f64, 7.0]];
563 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
564 let tree = KdTree::build(&refs).unwrap();
565 assert_eq!(tree.len(), 1);
566
567 let (indices, dists) = tree.query(&[42.0, 7.0], 1).unwrap();
568 assert_eq!(indices[0], 0);
569 assert!(dists[0].abs() < 1e-12);
570 }
571
572 #[test]
573 fn test_kd_tree_error_empty() {
574 let refs: Vec<&[f64]> = vec![];
575 let result = KdTree::build(&refs);
576 assert!(result.is_err());
577 }
578
579 #[test]
580 fn test_kd_tree_error_k_zero() {
581 let pts: [[f64; 2]; 1] = [[0.0_f64, 0.0]];
582 let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
583 let tree = KdTree::build(&refs).unwrap();
584
585 let result = tree.query(&[0.0, 0.0], 0);
586 assert!(result.is_err());
587 }
588}