Skip to main content

scivex_core/
spatial.rs

1//! Spatial data structures: KD-tree, ball tree.
2//!
3//! Provides efficient nearest-neighbor and range queries in multi-dimensional
4//! space.
5
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8
9use crate::dtype::Float;
10use crate::error::{CoreError, Result};
11use crate::tensor::Tensor;
12
13// ---------------------------------------------------------------------------
14// Internal helpers
15// ---------------------------------------------------------------------------
16
17/// Max-heap entry for KNN search. Stores squared distance and original index.
18struct HeapEntry<T> {
19    sq_dist: T,
20    index: usize,
21}
22
23impl<T: Float> PartialEq for HeapEntry<T> {
24    fn eq(&self, other: &Self) -> bool {
25        self.sq_dist.to_f64() == other.sq_dist.to_f64()
26    }
27}
28
29impl<T: Float> Eq for HeapEntry<T> {}
30
31impl<T: Float> PartialOrd for HeapEntry<T> {
32    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
33        Some(self.cmp(other))
34    }
35}
36
37impl<T: Float> Ord for HeapEntry<T> {
38    fn cmp(&self, other: &Self) -> Ordering {
39        self.sq_dist
40            .to_f64()
41            .partial_cmp(&other.sq_dist.to_f64())
42            .unwrap_or(Ordering::Equal)
43    }
44}
45
46// ---------------------------------------------------------------------------
47// KD-tree node
48// ---------------------------------------------------------------------------
49
50/// Leaf size threshold — leaves hold up to this many points.
51const LEAF_SIZE: usize = 10;
52
53#[cfg_attr(
54    feature = "serde-support",
55    derive(serde::Serialize, serde::Deserialize)
56)]
57#[derive(Debug, Clone)]
58enum KdNode {
59    Leaf {
60        indices: Vec<usize>,
61    },
62    Internal {
63        split_dim: usize,
64        split_value: f64,
65        left: Box<KdNode>,
66        right: Box<KdNode>,
67    },
68}
69
70// ---------------------------------------------------------------------------
71// KdTree
72// ---------------------------------------------------------------------------
73
74/// A KD-tree for efficient spatial queries.
75///
76/// Points are stored as a flat `Vec<T>` with known dimensionality. The tree
77/// structure is a recursive enum of internal split nodes and leaf nodes.
78///
79/// # Examples
80///
81/// ```
82/// use scivex_core::spatial::KdTree;
83///
84/// let points: Vec<Vec<f64>> = vec![
85///     vec![0.0, 0.0],
86///     vec![1.0, 0.0],
87///     vec![0.0, 1.0],
88///     vec![1.0, 1.0],
89/// ];
90/// let refs: Vec<&[f64]> = points.iter().map(|p| p.as_slice()).collect();
91/// let tree = KdTree::build(&refs).unwrap();
92///
93/// let (indices, dists) = tree.query(&[0.1, 0.1], 1).unwrap();
94/// assert_eq!(indices[0], 0);
95/// ```
96#[cfg_attr(
97    feature = "serde-support",
98    derive(serde::Serialize, serde::Deserialize)
99)]
100#[derive(Debug, Clone)]
101pub struct KdTree<T: Float> {
102    /// Flat storage: point `i` occupies `data[i*dim .. (i+1)*dim]`.
103    data: Vec<T>,
104    /// Dimensionality of each point.
105    dim: usize,
106    /// Number of points.
107    n_points: usize,
108    /// Root of the tree.
109    root: KdNode,
110}
111
112impl<T: Float> KdTree<T> {
113    /// Build a KD-tree from a set of points.
114    ///
115    /// `points` is a slice of point slices, each of length `dim`.
116    ///
117    /// # Errors
118    ///
119    /// Returns `CoreError::InvalidArgument` if `points` is empty or if the
120    /// point slices have inconsistent lengths.
121    pub fn build(points: &[&[T]]) -> Result<Self> {
122        if points.is_empty() {
123            return Err(CoreError::InvalidArgument {
124                reason: "cannot build KD-tree from empty point set",
125            });
126        }
127        let dim = points[0].len();
128        if dim == 0 {
129            return Err(CoreError::InvalidArgument {
130                reason: "point dimensionality must be at least 1",
131            });
132        }
133        for (i, p) in points.iter().enumerate() {
134            if p.len() != dim {
135                return Err(CoreError::InvalidArgument {
136                    reason: "all points must have the same dimensionality",
137                });
138            }
139            let _ = i; // suppress unused warning
140        }
141
142        let n_points = points.len();
143        let mut data = Vec::with_capacity(n_points * dim);
144        for p in points {
145            data.extend_from_slice(p);
146        }
147
148        let indices: Vec<usize> = (0..n_points).collect();
149        let root = Self::build_recursive(&data, dim, indices);
150
151        Ok(Self {
152            data,
153            dim,
154            n_points,
155            root,
156        })
157    }
158
159    /// Build from a 2D tensor where each row is a point.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if the tensor is not 2-dimensional or is empty.
164    pub fn from_tensor(tensor: &Tensor<T>) -> Result<Self> {
165        let shape = tensor.shape();
166        if shape.len() != 2 {
167            return Err(CoreError::InvalidArgument {
168                reason: "tensor must be 2-dimensional (rows = points, cols = dims)",
169            });
170        }
171        let n = shape[0];
172        let dim = shape[1];
173        if n == 0 {
174            return Err(CoreError::InvalidArgument {
175                reason: "cannot build KD-tree from empty point set",
176            });
177        }
178
179        let slice = tensor.as_slice();
180        let refs: Vec<&[T]> = (0..n).map(|i| &slice[i * dim..(i + 1) * dim]).collect();
181        Self::build(&refs)
182    }
183
184    /// Find the k nearest neighbors to `query`.
185    ///
186    /// Returns `(indices, distances)` sorted by distance (ascending).
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if `k == 0` or the query dimension does not match.
191    pub fn query(&self, query: &[T], k: usize) -> Result<(Vec<usize>, Vec<T>)> {
192        if k == 0 {
193            return Err(CoreError::InvalidArgument {
194                reason: "k must be at least 1",
195            });
196        }
197        if query.len() != self.dim {
198            return Err(CoreError::InvalidArgument {
199                reason: "query dimensionality does not match tree",
200            });
201        }
202        let k = k.min(self.n_points);
203
204        let mut heap: BinaryHeap<HeapEntry<T>> = BinaryHeap::new();
205        self.knn_recursive(&self.root, query, k, &mut heap);
206
207        // Drain the max-heap into a vec and reverse so smallest distance first.
208        let mut results: Vec<(usize, T)> = heap
209            .into_sorted_vec()
210            .into_iter()
211            .map(|e| (e.index, e.sq_dist.sqrt()))
212            .collect();
213        results.sort_by(|a, b| {
214            a.1.to_f64()
215                .partial_cmp(&b.1.to_f64())
216                .unwrap_or(Ordering::Equal)
217        });
218        let indices = results.iter().map(|(i, _)| *i).collect();
219        let dists = results.iter().map(|(_, d)| *d).collect();
220        Ok((indices, dists))
221    }
222
223    /// Find all points within distance `radius` of `query`.
224    ///
225    /// Returns `(indices, distances)`.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if the query dimension does not match.
230    pub fn query_radius(&self, query: &[T], radius: T) -> Result<(Vec<usize>, Vec<T>)> {
231        if query.len() != self.dim {
232            return Err(CoreError::InvalidArgument {
233                reason: "query dimensionality does not match tree",
234            });
235        }
236        let sq_radius = radius * radius;
237        let mut results: Vec<(usize, T)> = Vec::new();
238        self.range_recursive(&self.root, query, sq_radius, &mut results);
239
240        // Sort by distance.
241        results.sort_by(|a, b| {
242            a.1.to_f64()
243                .partial_cmp(&b.1.to_f64())
244                .unwrap_or(Ordering::Equal)
245        });
246        let indices = results.iter().map(|(i, _)| *i).collect();
247        let dists = results.into_iter().map(|(_, d)| d.sqrt()).collect();
248        Ok((indices, dists))
249    }
250
251    /// Find all pairs of points within distance `r` of each other.
252    ///
253    /// Returns pairs as `(i, j)` with `i < j`.
254    pub fn query_pairs(&self, r: T) -> Vec<(usize, usize)> {
255        let sq_r = r * r;
256        let mut pairs = Vec::new();
257        for i in 0..self.n_points {
258            let point = &self.data[i * self.dim..(i + 1) * self.dim];
259            let mut neighbors: Vec<(usize, T)> = Vec::new();
260            self.range_recursive(&self.root, point, sq_r, &mut neighbors);
261            for (j, _) in neighbors {
262                if i < j {
263                    pairs.push((i, j));
264                }
265            }
266        }
267        pairs.sort_unstable();
268        pairs.dedup();
269        pairs
270    }
271
272    /// Return the number of points.
273    #[inline]
274    pub fn len(&self) -> usize {
275        self.n_points
276    }
277
278    /// Whether the tree contains no points.
279    #[inline]
280    pub fn is_empty(&self) -> bool {
281        self.n_points == 0
282    }
283
284    /// Return the dimensionality.
285    #[inline]
286    pub fn dim(&self) -> usize {
287        self.dim
288    }
289
290    // -----------------------------------------------------------------------
291    // Internal build
292    // -----------------------------------------------------------------------
293
294    fn build_recursive(data: &[T], dim: usize, mut indices: Vec<usize>) -> KdNode {
295        if indices.len() <= LEAF_SIZE {
296            return KdNode::Leaf { indices };
297        }
298
299        // Find dimension with widest spread.
300        let split_dim = Self::widest_spread_dim(data, dim, &indices);
301
302        // Sort indices by the split dimension.
303        indices.sort_by(|&a, &b| {
304            let va = data[a * dim + split_dim].to_f64();
305            let vb = data[b * dim + split_dim].to_f64();
306            va.partial_cmp(&vb).unwrap_or(Ordering::Equal)
307        });
308
309        let median_idx = indices.len() / 2;
310        let split_value = data[indices[median_idx] * dim + split_dim].to_f64();
311
312        let right_indices = indices.split_off(median_idx);
313        let left_indices = indices;
314
315        let left = Box::new(Self::build_recursive(data, dim, left_indices));
316        let right = Box::new(Self::build_recursive(data, dim, right_indices));
317
318        KdNode::Internal {
319            split_dim,
320            split_value,
321            left,
322            right,
323        }
324    }
325
326    fn widest_spread_dim(data: &[T], dim: usize, indices: &[usize]) -> usize {
327        let mut best_dim = 0;
328        let mut best_spread = f64::NEG_INFINITY;
329        for d in 0..dim {
330            let mut lo = f64::INFINITY;
331            let mut hi = f64::NEG_INFINITY;
332            for &idx in indices {
333                let v = data[idx * dim + d].to_f64();
334                if v < lo {
335                    lo = v;
336                }
337                if v > hi {
338                    hi = v;
339                }
340            }
341            let spread = hi - lo;
342            if spread > best_spread {
343                best_spread = spread;
344                best_dim = d;
345            }
346        }
347        best_dim
348    }
349
350    // -----------------------------------------------------------------------
351    // KNN search
352    // -----------------------------------------------------------------------
353
354    fn squared_distance(&self, a: &[T], b_idx: usize) -> T {
355        let mut sum = T::zero();
356        let offset = b_idx * self.dim;
357        for (d, a_val) in a.iter().enumerate().take(self.dim) {
358            let diff = *a_val - self.data[offset + d];
359            sum += diff * diff;
360        }
361        sum
362    }
363
364    fn knn_recursive(
365        &self,
366        node: &KdNode,
367        query: &[T],
368        k: usize,
369        heap: &mut BinaryHeap<HeapEntry<T>>,
370    ) {
371        match node {
372            KdNode::Leaf { indices } => {
373                for &idx in indices {
374                    let sq_dist = self.squared_distance(query, idx);
375                    if heap.len() < k {
376                        heap.push(HeapEntry {
377                            sq_dist,
378                            index: idx,
379                        });
380                    } else if heap
381                        .peek()
382                        .is_some_and(|worst| sq_dist.to_f64() < worst.sq_dist.to_f64())
383                    {
384                        heap.pop();
385                        heap.push(HeapEntry {
386                            sq_dist,
387                            index: idx,
388                        });
389                    }
390                }
391            }
392            KdNode::Internal {
393                split_dim,
394                split_value,
395                left,
396                right,
397            } => {
398                let query_val = query[*split_dim].to_f64();
399                let diff = query_val - split_value;
400
401                let (first, second) = if diff <= 0.0 {
402                    (left, right)
403                } else {
404                    (right, left)
405                };
406
407                self.knn_recursive(first, query, k, heap);
408
409                // Prune: only visit second child if the split plane is closer
410                // than the worst candidate.
411                let should_visit =
412                    heap.len() < k || diff * diff < heap.peek().unwrap().sq_dist.to_f64();
413                if should_visit {
414                    self.knn_recursive(second, query, k, heap);
415                }
416            }
417        }
418    }
419
420    // -----------------------------------------------------------------------
421    // Range search
422    // -----------------------------------------------------------------------
423
424    fn range_recursive(
425        &self,
426        node: &KdNode,
427        query: &[T],
428        sq_radius: T,
429        results: &mut Vec<(usize, T)>,
430    ) {
431        match node {
432            KdNode::Leaf { indices } => {
433                for &idx in indices {
434                    let sq_dist = self.squared_distance(query, idx);
435                    if sq_dist.to_f64() <= sq_radius.to_f64() {
436                        results.push((idx, sq_dist));
437                    }
438                }
439            }
440            KdNode::Internal {
441                split_dim,
442                split_value,
443                left,
444                right,
445            } => {
446                let query_val = query[*split_dim].to_f64();
447                let diff = query_val - split_value;
448                let sq_diff = diff * diff;
449
450                let (first, second) = if diff <= 0.0 {
451                    (left, right)
452                } else {
453                    (right, left)
454                };
455
456                self.range_recursive(first, query, sq_radius, results);
457
458                if sq_diff <= sq_radius.to_f64() {
459                    self.range_recursive(second, query, sq_radius, results);
460                }
461            }
462        }
463    }
464}
465
466// ===========================================================================
467// Tests
468// ===========================================================================
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_kd_tree_knn_exact_match() {
476        let pts: [[f64; 2]; 4] = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
477        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
478        let tree = KdTree::build(&refs).unwrap();
479
480        let (indices, dists) = tree.query(&[0.0, 0.0], 1).unwrap();
481        assert_eq!(indices.len(), 1);
482        assert_eq!(indices[0], 0);
483        assert!(dists[0].abs() < 1e-12);
484    }
485
486    #[test]
487    fn test_kd_tree_knn_k3_sorted() {
488        let pts: [[f64; 2]; 5] = [[0.0, 0.0], [1.0, 0.0], [3.0, 0.0], [5.0, 0.0], [10.0, 0.0]];
489        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
490        let tree = KdTree::build(&refs).unwrap();
491
492        let (indices, dists) = tree.query(&[0.5, 0.0], 3).unwrap();
493        assert_eq!(indices.len(), 3);
494        // Closest: pt0 (0.5), pt1 (0.5), pt2 (2.5)
495        assert!(dists[0] <= dists[1]);
496        assert!(dists[1] <= dists[2]);
497        // The two closest are pts 0 and 1 (both at distance 0.5)
498        assert!((dists[0] - 0.5).abs() < 1e-12);
499        assert!((dists[1] - 0.5).abs() < 1e-12);
500        assert!((dists[2] - 2.5).abs() < 1e-12);
501    }
502
503    #[test]
504    fn test_kd_tree_range_query() {
505        let pts: [[f64; 2]; 4] = [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [10.0, 0.0]];
506        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
507        let tree = KdTree::build(&refs).unwrap();
508
509        let (indices, _dists) = tree.query_radius(&[0.0, 0.0], 1.5).unwrap();
510        // Should find pts 0 and 1 (distances 0 and 1)
511        assert_eq!(indices.len(), 2);
512        assert!(indices.contains(&0));
513        assert!(indices.contains(&1));
514    }
515
516    #[test]
517    fn test_kd_tree_query_pairs() {
518        let pts: [[f64; 2]; 3] = [[0.0, 0.0], [0.5, 0.0], [10.0, 0.0]];
519        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
520        let tree = KdTree::build(&refs).unwrap();
521
522        let pairs = tree.query_pairs(1.0);
523        assert_eq!(pairs.len(), 1);
524        assert_eq!(pairs[0], (0, 1));
525    }
526
527    #[test]
528    fn test_kd_tree_from_tensor() {
529        let data = vec![0.0_f64, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
530        let tensor = Tensor::from_vec(data, vec![4, 2]).unwrap();
531        let tree = KdTree::from_tensor(&tensor).unwrap();
532        assert_eq!(tree.len(), 4);
533        assert_eq!(tree.dim(), 2);
534
535        let (indices, _) = tree.query(&[0.0, 0.0], 1).unwrap();
536        assert_eq!(indices[0], 0);
537    }
538
539    #[test]
540    fn test_kd_tree_high_dimensional() {
541        // 5-dimensional points
542        let pts: [[f64; 5]; 3] = [
543            [0.0, 0.0, 0.0, 0.0, 0.0],
544            [1.0, 1.0, 1.0, 1.0, 1.0],
545            [2.0, 2.0, 2.0, 2.0, 2.0],
546        ];
547        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 5]>::as_slice).collect();
548        let tree = KdTree::build(&refs).unwrap();
549
550        let (indices, dists) = tree.query(&[0.0, 0.0, 0.0, 0.0, 0.0], 1).unwrap();
551        assert_eq!(indices[0], 0);
552        assert!(dists[0].abs() < 1e-12);
553
554        // Distance from origin to (1,1,1,1,1) = sqrt(5)
555        let (indices, dists) = tree.query(&[0.0, 0.0, 0.0, 0.0, 0.0], 2).unwrap();
556        assert_eq!(indices.len(), 2);
557        assert!((dists[1] - 5.0_f64.sqrt()).abs() < 1e-12);
558    }
559
560    #[test]
561    fn test_kd_tree_single_point() {
562        let pts: [[f64; 2]; 1] = [[42.0_f64, 7.0]];
563        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
564        let tree = KdTree::build(&refs).unwrap();
565        assert_eq!(tree.len(), 1);
566
567        let (indices, dists) = tree.query(&[42.0, 7.0], 1).unwrap();
568        assert_eq!(indices[0], 0);
569        assert!(dists[0].abs() < 1e-12);
570    }
571
572    #[test]
573    fn test_kd_tree_error_empty() {
574        let refs: Vec<&[f64]> = vec![];
575        let result = KdTree::build(&refs);
576        assert!(result.is_err());
577    }
578
579    #[test]
580    fn test_kd_tree_error_k_zero() {
581        let pts: [[f64; 2]; 1] = [[0.0_f64, 0.0]];
582        let refs: Vec<&[f64]> = pts.iter().map(<[f64; 2]>::as_slice).collect();
583        let tree = KdTree::build(&refs).unwrap();
584
585        let result = tree.query(&[0.0, 0.0], 0);
586        assert!(result.is_err());
587    }
588}