Skip to main content

scirs2_spatial/
kdtree.rs

1//! KD-Tree for efficient nearest neighbor searches
2//!
3//! This module provides a KD-Tree implementation for efficient
4//! nearest neighbor and range searches in multidimensional spaces.
5//!
6//! The KD-Tree (k-dimensional tree) is a space-partitioning data structure
7//! that organizes points in a k-dimensional space. It enables efficient range searches
8//! and nearest neighbor searches.
9//!
10//! # Features
11//!
12//! * Fast nearest neighbor queries with customizable `k`
13//! * Range queries with distance threshold
14//! * Support for different distance metrics (Euclidean, Manhattan, Chebyshev, etc.)
15//! * Parallel query processing (when using the `parallel` feature)
16//! * Customizable leaf size for performance tuning
17//!
18//! # Examples
19//!
20//! ```
21//! use scirs2_spatial::KDTree;
22//! use scirs2_core::ndarray::array;
23//!
24//! // Create a KD-Tree with points in 2D space
25//! let points = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
26//! let kdtree = KDTree::new(&points).expect("Operation failed");
27//!
28//! // Find the nearest neighbor to [4.0, 5.0]
29//! let (idx, dist) = kdtree.query(&[4.0, 5.0], 1).expect("Operation failed");
30//! assert_eq!(idx.len(), 1); // Should return exactly one neighbor
31//!
32//! // Find all points within radius 3.0 of [4.0, 5.0]
33//! let (indices, distances) = kdtree.query_radius(&[4.0, 5.0], 3.0).expect("Operation failed");
34//! ```
35//!
36//! # Advanced Usage
37//!
38//! Using custom distance metrics:
39//!
40//! ```
41//! use scirs2_spatial::KDTree;
42//! use scirs2_spatial::distance::ManhattanDistance;
43//! use scirs2_core::ndarray::array;
44//!
45//! // Create a KD-Tree with Manhattan distance metric
46//! let points = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
47//! let metric = ManhattanDistance::new();
48//! let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
49//!
50//! // Find the nearest neighbor to [4.0, 5.0] using Manhattan distance
51//! let (idx, dist) = kdtree.query(&[4.0, 5.0], 1).expect("Operation failed");
52//! ```
53//!
54//! Using custom leaf size for performance tuning:
55//!
56//! ```
57//! use scirs2_spatial::KDTree;
58//! use scirs2_core::ndarray::array;
59//!
60//! // Create a KD-Tree with custom leaf size
61//! let points = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],
62//!                      [9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]];
63//! let leafsize = 2; // Default is 16
64//! let kdtree = KDTree::with_leaf_size(&points, leafsize).expect("Operation failed");
65//! ```
66
67use crate::distance::{Distance, EuclideanDistance};
68use crate::error::{SpatialError, SpatialResult};
69use crate::safe_conversions::*;
70use scirs2_core::ndarray::Array2;
71use scirs2_core::numeric::Float;
72use std::cmp::Ordering;
73
74// Rayon parallel processing currently not used in this module
75/// A rectangle representing a hyperrectangle in k-dimensional space
76///
77/// Used for efficient nearest-neighbor and range queries in KD-trees.
78#[derive(Clone, Debug)]
79pub struct Rectangle<T: Float> {
80    /// Minimum coordinates for each dimension
81    mins: Vec<T>,
82    /// Maximum coordinates for each dimension
83    maxes: Vec<T>,
84}
85
86impl<T: Float> Rectangle<T> {
87    /// Create a new hyperrectangle
88    ///
89    /// # Arguments
90    ///
91    /// * `mins` - Minimum coordinates for each dimension
92    /// * `maxes` - Maximum coordinates for each dimension
93    ///
94    /// # Returns
95    ///
96    /// * A new rectangle
97    ///
98    /// # Panics
99    ///
100    /// * If mins and maxes have different lengths
101    /// * If any min value is greater than the corresponding max value
102    pub fn new(mins: Vec<T>, maxes: Vec<T>) -> Self {
103        assert_eq!(
104            mins.len(),
105            maxes.len(),
106            "mins and maxes must have the same length"
107        );
108
109        for i in 0..mins.len() {
110            assert!(
111                mins[i] <= maxes[i],
112                "min value must be less than or equal to max value"
113            );
114        }
115
116        Rectangle { mins, maxes }
117    }
118
119    /// Get the minimum coordinates of the rectangle
120    ///
121    /// # Returns
122    ///
123    /// * A slice containing the minimum coordinate values for each dimension
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// use scirs2_spatial::kdtree::Rectangle;
129    ///
130    /// let rect = Rectangle::new(vec![0.0, 0.0], vec![1.0, 1.0]);
131    /// let mins = rect.mins();
132    /// assert_eq!(mins, &[0.0, 0.0]);
133    /// ```
134    pub fn mins(&self) -> &[T] {
135        &self.mins
136    }
137
138    /// Get the maximum coordinates of the rectangle
139    ///
140    /// # Returns
141    ///
142    /// * A slice containing the maximum coordinate values for each dimension
143    ///
144    /// # Examples
145    ///
146    /// ```
147    /// use scirs2_spatial::kdtree::Rectangle;
148    ///
149    /// let rect = Rectangle::new(vec![0.0, 0.0], vec![1.0, 1.0]);
150    /// let maxes = rect.maxes();
151    /// assert_eq!(maxes, &[1.0, 1.0]);
152    /// ```
153    pub fn maxes(&self) -> &[T] {
154        &self.maxes
155    }
156
157    /// Split the rectangle along a given dimension at a given value
158    ///
159    /// # Arguments
160    ///
161    /// * `dim` - The dimension to split on
162    /// * `value` - The value to split at
163    ///
164    /// # Returns
165    ///
166    /// * A tuple of (left, right) rectangles
167    pub fn split(&self, dim: usize, value: T) -> (Self, Self) {
168        let mut left_maxes = self.maxes.clone();
169        left_maxes[dim] = value;
170
171        let mut right_mins = self.mins.clone();
172        right_mins[dim] = value;
173
174        let left = Rectangle::new(self.mins.clone(), left_maxes);
175        let right = Rectangle::new(right_mins, self.maxes.clone());
176
177        (left, right)
178    }
179
180    /// Check if the rectangle contains a point
181    ///
182    /// # Arguments
183    ///
184    /// * `point` - The point to check
185    ///
186    /// # Returns
187    ///
188    /// * true if the rectangle contains the point, false otherwise
189    pub fn contains(&self, point: &[T]) -> bool {
190        assert_eq!(
191            point.len(),
192            self.mins.len(),
193            "point must have the same dimension as the rectangle"
194        );
195
196        for (i, &p) in point.iter().enumerate() {
197            if p < self.mins[i] || p > self.maxes[i] {
198                return false;
199            }
200        }
201
202        true
203    }
204
205    /// Calculate the minimum distance from a point to the rectangle
206    ///
207    /// # Arguments
208    ///
209    /// * `point` - The point to calculate distance to
210    /// * `metric` - The distance metric to use
211    ///
212    /// # Returns
213    ///
214    /// * The minimum distance from the point to any point in the rectangle
215    pub fn min_distance<D: Distance<T>>(&self, point: &[T], metric: &D) -> T {
216        metric.min_distance_point_rectangle(point, &self.mins, &self.maxes)
217    }
218}
219
220/// A node in the KD-Tree
221#[derive(Debug, Clone)]
222struct KDNode<T: Float> {
223    /// Index of the point in the original data array
224    idx: usize,
225    /// The value of the point along the splitting dimension
226    value: T,
227    /// The dimension used for splitting
228    axis: usize,
229    /// Left child node (values < median along splitting axis)
230    left: Option<usize>,
231    /// Right child node (values >= median along splitting axis)
232    right: Option<usize>,
233}
234
235/// A KD-Tree for efficient nearest neighbor searches
236///
237/// # Type Parameters
238///
239/// * `T` - The floating point type for coordinates
240/// * `D` - The distance metric type
241#[derive(Debug, Clone)]
242pub struct KDTree<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> {
243    /// The points stored in the KD-Tree
244    points: Array2<T>,
245    /// The nodes of the KD-Tree
246    nodes: Vec<KDNode<T>>,
247    /// The dimensionality of the points
248    ndim: usize,
249    /// The root node index
250    root: Option<usize>,
251    /// The distance metric
252    metric: D,
253    /// The leaf size (maximum number of points in a leaf node)
254    leafsize: usize,
255    /// Minimum bounding rectangle of the entire dataset
256    bounds: Rectangle<T>,
257}
258
259impl<T: Float + Send + Sync + 'static> KDTree<T, EuclideanDistance<T>> {
260    /// Create a new KD-Tree with default Euclidean distance metric
261    ///
262    /// # Arguments
263    ///
264    /// * `points` - Array of points, each row is a point
265    ///
266    /// # Returns
267    ///
268    /// * A new KD-Tree
269    pub fn new(points: &Array2<T>) -> SpatialResult<Self> {
270        let metric = EuclideanDistance::new();
271        Self::with_metric(points, metric)
272    }
273
274    /// Create a new KD-Tree with custom leaf size (using Euclidean distance)
275    ///
276    /// # Arguments
277    ///
278    /// * `points` - Array of points, each row is a point
279    /// * `leafsize` - The maximum number of points in a leaf node
280    ///
281    /// # Returns
282    ///
283    /// * A new KD-Tree
284    pub fn with_leaf_size(points: &Array2<T>, leafsize: usize) -> SpatialResult<Self> {
285        let metric = EuclideanDistance::new();
286        Self::with_options(points, metric, leafsize)
287    }
288}
289
290impl<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> KDTree<T, D> {
291    /// Create a new KD-Tree with custom distance metric
292    ///
293    /// # Arguments
294    ///
295    /// * `points` - Array of points, each row is a point
296    /// * `metric` - The distance metric to use
297    ///
298    /// # Returns
299    ///
300    /// * A new KD-Tree
301    pub fn with_metric(points: &Array2<T>, metric: D) -> SpatialResult<Self> {
302        Self::with_options(points, metric, 16) // Default leaf size is 16
303    }
304
305    /// Create a new KD-Tree with custom distance metric and leaf size
306    ///
307    /// # Arguments
308    ///
309    /// * `points` - Array of points, each row is a point
310    /// * `metric` - The distance metric to use
311    /// * `leafsize` - The maximum number of points in a leaf node
312    ///
313    /// # Returns
314    ///
315    /// * A new KD-Tree
316    pub fn with_options(points: &Array2<T>, metric: D, leafsize: usize) -> SpatialResult<Self> {
317        let n = points.nrows();
318        let ndim = points.ncols();
319
320        if n == 0 {
321            return Err(SpatialError::ValueError("Empty point set".to_string()));
322        }
323
324        if leafsize == 0 {
325            return Err(SpatialError::ValueError(
326                "Leaf _size must be greater than 0".to_string(),
327            ));
328        }
329
330        // Calculate the bounds of the dataset
331        let mut mins = vec![T::max_value(); ndim];
332        let mut maxes = vec![T::min_value(); ndim];
333
334        for i in 0..n {
335            for j in 0..ndim {
336                let val = points[[i, j]];
337                if val < mins[j] {
338                    mins[j] = val;
339                }
340                if val > maxes[j] {
341                    maxes[j] = val;
342                }
343            }
344        }
345
346        let bounds = Rectangle::new(mins, maxes);
347
348        let mut tree = KDTree {
349            points: points.clone(),
350            nodes: Vec::with_capacity(n),
351            ndim,
352            root: None,
353            metric,
354            leafsize,
355            bounds,
356        };
357
358        // Create indices for the points
359        let mut indices: Vec<usize> = (0..n).collect();
360
361        // Build the tree recursively
362        if n > 0 {
363            let root = tree.build_tree(&mut indices, 0, 0, n)?;
364            tree.root = Some(root);
365        }
366
367        Ok(tree)
368    }
369
370    /// Build the KD-Tree recursively
371    ///
372    /// # Arguments
373    ///
374    /// * `indices` - Indices of the points to consider
375    /// * `depth` - Current depth in the tree
376    /// * `start` - Start index in the indices array
377    /// * `end` - End index in the indices array
378    ///
379    /// # Returns
380    ///
381    /// * Index of the root node of the subtree
382    fn build_tree(
383        &mut self,
384        indices: &mut [usize],
385        depth: usize,
386        start: usize,
387        end: usize,
388    ) -> SpatialResult<usize> {
389        let n = end - start;
390
391        if n == 0 {
392            return Err(SpatialError::ValueError(
393                "Empty point set in build_tree".to_string(),
394            ));
395        }
396
397        // Choose axis based on depth (cycle through axes)
398        let axis = depth % self.ndim;
399
400        // If we have only one point, create a leaf node
401        let node_idx;
402        if n == 1 {
403            let idx = indices[start];
404            let value = self.points[[idx, axis]];
405
406            node_idx = self.nodes.len();
407            self.nodes.push(KDNode {
408                idx,
409                value,
410                axis,
411                left: None,
412                right: None,
413            });
414
415            return Ok(node_idx);
416        }
417
418        // Sort indices based on the axis
419        indices[start..end].sort_by(|&i, &j| {
420            let a = self.points[[i, axis]];
421            let b = self.points[[j, axis]];
422            a.partial_cmp(&b).unwrap_or(Ordering::Equal)
423        });
424
425        // Get the median index
426        let mid = start + n / 2;
427        let idx = indices[mid];
428        let value = self.points[[idx, axis]];
429
430        // Create node
431        node_idx = self.nodes.len();
432        self.nodes.push(KDNode {
433            idx,
434            value,
435            axis,
436            left: None,
437            right: None,
438        });
439
440        // Build left and right subtrees
441        if mid > start {
442            let left_idx = self.build_tree(indices, depth + 1, start, mid)?;
443            self.nodes[node_idx].left = Some(left_idx);
444        }
445
446        if mid + 1 < end {
447            let right_idx = self.build_tree(indices, depth + 1, mid + 1, end)?;
448            self.nodes[node_idx].right = Some(right_idx);
449        }
450
451        Ok(node_idx)
452    }
453
454    /// Find the k nearest neighbors to a query point
455    ///
456    /// # Arguments
457    ///
458    /// * `point` - Query point
459    /// * `k` - Number of nearest neighbors to find
460    ///
461    /// # Returns
462    ///
463    /// * (indices, distances) of the k nearest neighbors
464    ///
465    /// # Examples
466    ///
467    /// ```
468    /// use scirs2_spatial::KDTree;
469    /// use scirs2_core::ndarray::array;
470    ///
471    /// // Create points for the KDTree - use the exact same points from test_kdtree_with_custom_leaf_size
472    /// let points = array![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0], [8.0, 1.0], [7.0, 2.0]];
473    /// let kdtree = KDTree::new(&points).expect("Operation failed");
474    ///
475    /// // Find the 2 nearest neighbors to [0.5, 0.5]
476    /// let (indices, distances) = kdtree.query(&[0.5, 0.5], 2).expect("Operation failed");
477    /// assert_eq!(indices.len(), 2);
478    /// assert_eq!(distances.len(), 2);
479    /// ```
480    pub fn query(&self, point: &[T], k: usize) -> SpatialResult<(Vec<usize>, Vec<T>)> {
481        if point.len() != self.ndim {
482            return Err(SpatialError::DimensionError(format!(
483                "Query point dimension ({}) does not match tree dimension ({})",
484                point.len(),
485                self.ndim
486            )));
487        }
488
489        if k == 0 {
490            return Ok((vec![], vec![]));
491        }
492
493        if self.points.nrows() == 0 {
494            return Ok((vec![], vec![]));
495        }
496
497        // Initialize priority queue for k nearest neighbors
498        // We use a max-heap so we can efficiently replace the furthest point when we find a closer one
499        let mut neighbors: Vec<(T, usize)> = Vec::with_capacity(k + 1);
500
501        // Keep track of the maximum distance in the heap, for early termination
502        let mut max_dist = T::infinity();
503
504        if let Some(root) = self.root {
505            // Search recursively
506            self.query_recursive(root, point, k, &mut neighbors, &mut max_dist);
507
508            // Sort by distance (ascending), with index as tiebreaker
509            neighbors.sort_by(|a, b| {
510                match safe_partial_cmp(&a.0, &b.0, "kdtree sort neighbors") {
511                    Ok(std::cmp::Ordering::Equal) => a.1.cmp(&b.1), // Use index as tiebreaker
512                    Ok(ord) => ord,
513                    Err(_) => std::cmp::Ordering::Equal,
514                }
515            });
516
517            // Trim to k elements if needed
518            if neighbors.len() > k {
519                neighbors.truncate(k);
520            }
521
522            // Convert to sorted lists of indices and distances
523            let mut indices = Vec::with_capacity(neighbors.len());
524            let mut distances = Vec::with_capacity(neighbors.len());
525
526            for (dist, idx) in neighbors {
527                indices.push(idx);
528                distances.push(dist);
529            }
530
531            Ok((indices, distances))
532        } else {
533            Err(SpatialError::ValueError("Empty tree".to_string()))
534        }
535    }
536
537    /// Recursive helper for query
538    fn query_recursive(
539        &self,
540        node_idx: usize,
541        point: &[T],
542        k: usize,
543        neighbors: &mut Vec<(T, usize)>,
544        max_dist: &mut T,
545    ) {
546        let node = &self.nodes[node_idx];
547        let idx = node.idx;
548        let axis = node.axis;
549
550        // Calculate distance to current point
551        let node_point = self.points.row(idx).to_vec();
552        let _dist = self.metric.distance(&node_point, point);
553
554        // Update neighbors if needed
555        if neighbors.len() < k {
556            neighbors.push((_dist, idx));
557
558            // Sort if we just filled to capacity to establish max-heap
559            if neighbors.len() == k {
560                neighbors.sort_by(|a, b| {
561                    match safe_partial_cmp(&b.0, &a.0, "kdtree sort max-heap") {
562                        Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), // Use index as tiebreaker
563                        Ok(ord) => ord,
564                        Err(_) => std::cmp::Ordering::Equal,
565                    }
566                });
567                *max_dist = neighbors[0].0;
568            }
569        } else if &_dist < max_dist {
570            // Replace the worst neighbor with this one
571            neighbors[0] = (_dist, idx);
572
573            // Re-sort to maintain max-heap property
574            neighbors.sort_by(|a, b| {
575                match safe_partial_cmp(&b.0, &a.0, "kdtree re-sort max-heap") {
576                    Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), // Use index as tiebreaker
577                    Ok(ord) => ord,
578                    Err(_) => std::cmp::Ordering::Equal,
579                }
580            });
581            *max_dist = neighbors[0].0;
582        }
583
584        // Determine which subtree to search first
585        let diff = point[axis] - node.value;
586        let (first, second) = if diff < T::zero() {
587            (node.left, node.right)
588        } else {
589            (node.right, node.left)
590        };
591
592        // Search the near subtree
593        if let Some(first_idx) = first {
594            self.query_recursive(first_idx, point, k, neighbors, max_dist);
595        }
596
597        // Only search the far subtree if it could contain closer points
598        let axis_dist = if diff < T::zero() {
599            // Point is to the left of the splitting hyperplane
600            T::zero() // No need to calculate distance if we're considering the left subtree next
601        } else {
602            // Point is to the right of the splitting hyperplane
603            diff
604        };
605
606        if let Some(second_idx) = second {
607            // Only search the second subtree if necessary
608            if neighbors.len() < k || axis_dist < *max_dist {
609                self.query_recursive(second_idx, point, k, neighbors, max_dist);
610            }
611        }
612    }
613
614    /// Find all points within a radius of a query point
615    ///
616    /// # Arguments
617    ///
618    /// * `point` - Query point
619    /// * `radius` - Search radius
620    ///
621    /// # Returns
622    ///
623    /// * (indices, distances) of points within the radius
624    ///
625    /// # Examples
626    ///
627    /// ```
628    /// use scirs2_spatial::KDTree;
629    /// use scirs2_core::ndarray::array;
630    ///
631    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
632    /// let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
633    /// let kdtree = KDTree::new(&points)?;
634    ///
635    /// // Find all points within radius 0.7 of [0.5, 0.5]
636    /// let (indices, distances) = kdtree.query_radius(&[0.5, 0.5], 0.7)?;
637    /// assert_eq!(indices.len(), 4); // All points are within 0.7 units of [0.5, 0.5]
638    /// # Ok(())
639    /// # }
640    /// ```
641    pub fn query_radius(&self, point: &[T], radius: T) -> SpatialResult<(Vec<usize>, Vec<T>)> {
642        if point.len() != self.ndim {
643            return Err(SpatialError::DimensionError(format!(
644                "Query point dimension ({}) does not match tree dimension ({})",
645                point.len(),
646                self.ndim
647            )));
648        }
649
650        if radius < T::zero() {
651            return Err(SpatialError::ValueError(
652                "Radius must be non-negative".to_string(),
653            ));
654        }
655
656        let mut indices = Vec::new();
657        let mut distances = Vec::new();
658
659        if let Some(root) = self.root {
660            // If the radius is outside the bounds of the entire dataset, just return an empty result
661            let bounds_dist = self.bounds.min_distance(point, &self.metric);
662            if bounds_dist > radius {
663                return Ok((indices, distances));
664            }
665
666            // Search recursively
667            self.query_radius_recursive(root, point, radius, &mut indices, &mut distances);
668
669            // Sort by distance
670            if !indices.is_empty() {
671                let mut idx_dist: Vec<(usize, T)> = indices.into_iter().zip(distances).collect();
672                idx_dist.sort_by(|a, b| {
673                    safe_partial_cmp(&a.1, &b.1, "kdtree sort radius results")
674                        .unwrap_or(std::cmp::Ordering::Equal)
675                });
676
677                indices = idx_dist.iter().map(|(idx_, _)| *idx_).collect();
678                distances = idx_dist.iter().map(|(_, dist)| *dist).collect();
679            }
680        }
681
682        Ok((indices, distances))
683    }
684
685    /// Recursive helper for query_radius
686    fn query_radius_recursive(
687        &self,
688        node_idx: usize,
689        point: &[T],
690        radius: T,
691        indices: &mut Vec<usize>,
692        distances: &mut Vec<T>,
693    ) {
694        let node = &self.nodes[node_idx];
695        let idx = node.idx;
696        let axis = node.axis;
697
698        // Calculate distance to current point
699        let node_point = self.points.row(idx).to_vec();
700        let dist = self.metric.distance(&node_point, point);
701
702        // If point is within radius, add it to results
703        if dist <= radius {
704            indices.push(idx);
705            distances.push(dist);
706        }
707
708        // Determine which subtrees need to be searched
709        let diff = point[axis] - node.value;
710
711        // Always search the near subtree
712        let (near, far) = if diff < T::zero() {
713            (node.left, node.right)
714        } else {
715            (node.right, node.left)
716        };
717
718        if let Some(near_idx) = near {
719            self.query_radius_recursive(near_idx, point, radius, indices, distances);
720        }
721
722        // Only search the far subtree if it could contain points within radius
723        if diff.abs() <= radius {
724            if let Some(far_idx) = far {
725                self.query_radius_recursive(far_idx, point, radius, indices, distances);
726            }
727        }
728    }
729
730    /// Count the number of points within a radius of a query point
731    ///
732    /// This method is more efficient than query_radius when only the count is needed.
733    ///
734    /// # Arguments
735    ///
736    /// * `point` - Query point
737    /// * `radius` - Search radius
738    ///
739    /// # Returns
740    ///
741    /// * Number of points within the radius
742    ///
743    /// # Examples
744    ///
745    /// ```
746    /// use scirs2_spatial::KDTree;
747    /// use scirs2_core::ndarray::array;
748    ///
749    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
750    /// let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
751    /// let kdtree = KDTree::new(&points)?;
752    ///
753    /// // Count points within radius 0.7 of [0.5, 0.5]
754    /// let count = kdtree.count_neighbors(&[0.5, 0.5], 0.7)?;
755    /// assert_eq!(count, 4); // All points are within 0.7 units of [0.5, 0.5]
756    /// # Ok(())
757    /// # }
758    /// ```
759    pub fn count_neighbors(&self, point: &[T], radius: T) -> SpatialResult<usize> {
760        if point.len() != self.ndim {
761            return Err(SpatialError::DimensionError(format!(
762                "Query point dimension ({}) does not match tree dimension ({})",
763                point.len(),
764                self.ndim
765            )));
766        }
767
768        if radius < T::zero() {
769            return Err(SpatialError::ValueError(
770                "Radius must be non-negative".to_string(),
771            ));
772        }
773
774        let mut count = 0;
775
776        if let Some(root) = self.root {
777            // If the radius is outside the bounds of the entire dataset, just return 0
778            let bounds_dist = self.bounds.min_distance(point, &self.metric);
779            if bounds_dist > radius {
780                return Ok(0);
781            }
782
783            // Search recursively
784            self.count_neighbors_recursive(root, point, radius, &mut count);
785        }
786
787        Ok(count)
788    }
789
790    /// Recursive helper for count_neighbors
791    fn count_neighbors_recursive(
792        &self,
793        node_idx: usize,
794        point: &[T],
795        radius: T,
796        count: &mut usize,
797    ) {
798        let node = &self.nodes[node_idx];
799        let idx = node.idx;
800        let axis = node.axis;
801
802        // Calculate distance to current point
803        let node_point = self.points.row(idx).to_vec();
804        let dist = self.metric.distance(&node_point, point);
805
806        // If point is within radius, increment count
807        if dist <= radius {
808            *count += 1;
809        }
810
811        // Determine which subtrees need to be searched
812        let diff = point[axis] - node.value;
813
814        // Always search the near subtree
815        let (near, far) = if diff < T::zero() {
816            (node.left, node.right)
817        } else {
818            (node.right, node.left)
819        };
820
821        if let Some(near_idx) = near {
822            self.count_neighbors_recursive(near_idx, point, radius, count);
823        }
824
825        // Only search the far subtree if it could contain points within radius
826        if diff.abs() <= radius {
827            if let Some(far_idx) = far {
828                self.count_neighbors_recursive(far_idx, point, radius, count);
829            }
830        }
831    }
832
833    /// Get the shape of the KD-Tree's point set
834    ///
835    /// # Returns
836    ///
837    /// * A tuple of (n_points, n_dimensions)
838    pub fn shape(&self) -> (usize, usize) {
839        (self.points.nrows(), self.ndim)
840    }
841
842    /// Get the number of points in the KD-Tree
843    ///
844    /// # Returns
845    ///
846    /// * The number of points
847    pub fn npoints(&self) -> usize {
848        self.points.nrows()
849    }
850
851    /// Get the dimensionality of the points in the KD-Tree
852    ///
853    /// # Returns
854    ///
855    /// * The dimensionality of the points
856    pub fn ndim(&self) -> usize {
857        self.ndim
858    }
859
860    /// Get the leaf size of the KD-Tree
861    ///
862    /// # Returns
863    ///
864    /// * The leaf size
865    pub fn leafsize(&self) -> usize {
866        self.leafsize
867    }
868
869    /// Get the bounds of the KD-Tree
870    ///
871    /// # Returns
872    ///
873    /// * The bounding rectangle of the entire dataset
874    pub fn bounds(&self) -> &Rectangle<T> {
875        &self.bounds
876    }
877}
878
879#[cfg(test)]
880mod tests {
881    use super::{KDTree, Rectangle};
882    use crate::distance::{
883        ChebyshevDistance, Distance, EuclideanDistance, ManhattanDistance, MinkowskiDistance,
884    };
885    use approx::assert_relative_eq;
886    use scirs2_core::ndarray::arr2;
887
888    #[test]
889    fn test_rectangle() {
890        let mins = vec![0.0, 0.0];
891        let maxes = vec![1.0, 1.0];
892        let rect = Rectangle::new(mins, maxes);
893
894        // Test contains
895        assert!(rect.contains(&[0.5, 0.5]));
896        assert!(rect.contains(&[0.0, 0.0]));
897        assert!(rect.contains(&[1.0, 1.0]));
898        assert!(!rect.contains(&[1.5, 0.5]));
899        assert!(!rect.contains(&[0.5, 1.5]));
900
901        // Test split
902        let (left, right) = rect.split(0, 0.5);
903        assert!(left.contains(&[0.25, 0.5]));
904        assert!(!left.contains(&[0.75, 0.5]));
905        assert!(!right.contains(&[0.25, 0.5]));
906        assert!(right.contains(&[0.75, 0.5]));
907
908        // Test min_distance
909        let metric = EuclideanDistance::<f64>::new();
910        assert_relative_eq!(rect.min_distance(&[0.5, 0.5], &metric), 0.0, epsilon = 1e-6);
911        assert_relative_eq!(rect.min_distance(&[2.0, 0.5], &metric), 1.0, epsilon = 1e-6);
912        assert_relative_eq!(
913            rect.min_distance(&[2.0, 2.0], &metric),
914            std::f64::consts::SQRT_2,
915            epsilon = 1e-6
916        );
917    }
918
919    #[test]
920    fn test_kdtree_build() {
921        let points = arr2(&[
922            [2.0, 3.0],
923            [5.0, 4.0],
924            [9.0, 6.0],
925            [4.0, 7.0],
926            [8.0, 1.0],
927            [7.0, 2.0],
928        ]);
929
930        let kdtree = KDTree::new(&points).expect("Operation failed");
931
932        // Check that the tree has the correct number of nodes
933        assert_eq!(kdtree.nodes.len(), points.nrows());
934
935        // Check tree properties
936        assert_eq!(kdtree.shape(), (6, 2));
937        assert_eq!(kdtree.npoints(), 6);
938        assert_eq!(kdtree.ndim(), 2);
939        assert_eq!(kdtree.leafsize(), 16);
940
941        // Check bounds
942        assert_eq!(kdtree.bounds().mins(), &[2.0, 1.0]);
943        assert_eq!(kdtree.bounds().maxes(), &[9.0, 7.0]);
944    }
945
946    #[test]
947    fn test_kdtree_query() {
948        let points = arr2(&[
949            [2.0, 3.0],
950            [5.0, 4.0],
951            [9.0, 6.0],
952            [4.0, 7.0],
953            [8.0, 1.0],
954            [7.0, 2.0],
955        ]);
956
957        let kdtree = KDTree::new(&points).expect("Operation failed");
958
959        // Query for nearest neighbor to [3.0, 5.0]
960        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
961        assert_eq!(indices.len(), 1);
962        assert_eq!(distances.len(), 1);
963
964        // Calculate actual distances
965        let query = [3.0, 5.0];
966        let mut expected_dists = vec![];
967        for i in 0..points.nrows() {
968            let p = points.row(i).to_vec();
969            let metric = EuclideanDistance::<f64>::new();
970            expected_dists.push((i, metric.distance(&p, &query)));
971        }
972        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
973
974        // Verify we got one of the actual nearest neighbors (there might be ties)
975        // Check that the distance matches the expected minimum distance
976        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
977
978        // Verify the returned index is one of the points with minimum distance
979        let min_dist = expected_dists[0].1;
980        let valid_indices: Vec<usize> = expected_dists
981            .iter()
982            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
983            .map(|(i, _)| *i)
984            .collect();
985        assert!(
986            valid_indices.contains(&indices[0]),
987            "Expected one of {:?} but got {}",
988            valid_indices,
989            indices[0]
990        );
991    }
992
993    #[test]
994    fn test_kdtree_query_k() {
995        let points = arr2(&[
996            [2.0, 3.0],
997            [5.0, 4.0],
998            [9.0, 6.0],
999            [4.0, 7.0],
1000            [8.0, 1.0],
1001            [7.0, 2.0],
1002        ]);
1003
1004        let kdtree = KDTree::new(&points).expect("Operation failed");
1005
1006        // Query for 3 nearest neighbors to [3.0, 5.0]
1007        let (indices, distances) = kdtree.query(&[3.0, 5.0], 3).expect("Operation failed");
1008        assert_eq!(indices.len(), 3);
1009        assert_eq!(distances.len(), 3);
1010
1011        // Calculate actual distances
1012        let query = [3.0, 5.0];
1013        let mut expected_dists = vec![];
1014        for i in 0..points.nrows() {
1015            let p = points.row(i).to_vec();
1016            let metric = EuclideanDistance::<f64>::new();
1017            expected_dists.push((i, metric.distance(&p, &query)));
1018        }
1019        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1020
1021        // Verify we got the 3 actual nearest neighbors (for now, just check distances)
1022        let expected_indices: Vec<usize> = expected_dists.iter().take(3).map(|&(i, _)| i).collect();
1023        let expected_distances: Vec<f64> = expected_dists.iter().take(3).map(|&(_, d)| d).collect();
1024
1025        // Check each returned index is in the expected set
1026        for i in &indices {
1027            assert!(expected_indices.contains(i));
1028        }
1029
1030        // Check that distances are sorted
1031        assert!(distances[0] <= distances[1]);
1032        assert!(distances[1] <= distances[2]);
1033
1034        // Check distance values match expected
1035        for i in 0..3 {
1036            assert_relative_eq!(distances[i], expected_distances[i], epsilon = 1e-6);
1037        }
1038    }
1039
1040    #[test]
1041    fn test_kdtree_query_radius() {
1042        let points = arr2(&[
1043            [2.0, 3.0],
1044            [5.0, 4.0],
1045            [9.0, 6.0],
1046            [4.0, 7.0],
1047            [8.0, 1.0],
1048            [7.0, 2.0],
1049        ]);
1050
1051        let kdtree = KDTree::new(&points).expect("Operation failed");
1052
1053        // Query for points within radius 3.0 of [3.0, 5.0]
1054        let (indices, distances) = kdtree
1055            .query_radius(&[3.0, 5.0], 3.0)
1056            .expect("Operation failed");
1057
1058        // Calculate expected results
1059        let query = [3.0, 5.0];
1060        let radius = 3.0;
1061        let mut expected_results = vec![];
1062        for i in 0..points.nrows() {
1063            let p = points.row(i).to_vec();
1064            let metric = EuclideanDistance::<f64>::new();
1065            let dist = metric.distance(&p, &query);
1066            if dist <= radius {
1067                expected_results.push((i, dist));
1068            }
1069        }
1070        expected_results.sort_by(|a, b| {
1071            match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1072                std::cmp::Ordering::Equal => a.0.cmp(&b.0), // Use index as tiebreaker
1073                ord => ord,
1074            }
1075        });
1076
1077        // Check that we got the expected number of points
1078        assert_eq!(indices.len(), expected_results.len());
1079
1080        // Check that all returned points are within radius
1081        for i in 0..indices.len() {
1082            assert!(distances[i] <= radius + 1e-6);
1083        }
1084
1085        // Check that the indices/distances pairs match expected results
1086        // Note: order might differ for equal distances
1087        let mut idx_dist_pairs: Vec<(usize, f64)> = indices
1088            .iter()
1089            .zip(distances.iter())
1090            .map(|(&i, &d)| (i, d))
1091            .collect();
1092        idx_dist_pairs.sort_by(|a, b| {
1093            match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1094                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
1095                ord => ord,
1096            }
1097        });
1098
1099        for (actual, expected) in idx_dist_pairs.iter().zip(expected_results.iter()) {
1100            assert_eq!(actual.0, expected.0);
1101            assert_relative_eq!(actual.1, expected.1, epsilon = 1e-6);
1102        }
1103    }
1104
1105    #[test]
1106    fn test_kdtree_count_neighbors() {
1107        let points = arr2(&[
1108            [2.0, 3.0],
1109            [5.0, 4.0],
1110            [9.0, 6.0],
1111            [4.0, 7.0],
1112            [8.0, 1.0],
1113            [7.0, 2.0],
1114        ]);
1115
1116        let kdtree = KDTree::new(&points).expect("Operation failed");
1117
1118        // Count points within radius 3.0 of [3.0, 5.0]
1119        let count = kdtree
1120            .count_neighbors(&[3.0, 5.0], 3.0)
1121            .expect("Operation failed");
1122
1123        // Calculate actual count
1124        let query = [3.0, 5.0];
1125        let mut expected_count = 0;
1126        for i in 0..points.nrows() {
1127            let p = points.row(i).to_vec();
1128            let metric = EuclideanDistance::<f64>::new();
1129            let dist = metric.distance(&p, &query);
1130            if dist <= 3.0 {
1131                expected_count += 1;
1132            }
1133        }
1134
1135        assert_eq!(count, expected_count);
1136    }
1137
1138    #[test]
1139    fn test_kdtree_with_manhattan_distance() {
1140        let points = arr2(&[
1141            [2.0, 3.0],
1142            [5.0, 4.0],
1143            [9.0, 6.0],
1144            [4.0, 7.0],
1145            [8.0, 1.0],
1146            [7.0, 2.0],
1147        ]);
1148
1149        let metric = ManhattanDistance::new();
1150        let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
1151
1152        // Query for nearest neighbor to [3.0, 5.0] using Manhattan distance
1153        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1154
1155        // Calculate actual distances using Manhattan distance
1156        let query = [3.0, 5.0];
1157        let mut expected_dists = vec![];
1158        for i in 0..points.nrows() {
1159            let p = points.row(i).to_vec();
1160            let m = ManhattanDistance::<f64>::new();
1161            expected_dists.push((i, m.distance(&p, &query)));
1162        }
1163        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1164
1165        // Check that the distance matches the expected minimum distance
1166        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1167
1168        // Verify the returned index is one of the points with minimum distance
1169        let min_dist = expected_dists[0].1;
1170        let valid_indices: Vec<usize> = expected_dists
1171            .iter()
1172            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1173            .map(|(i, _)| *i)
1174            .collect();
1175        assert!(
1176            valid_indices.contains(&indices[0]),
1177            "Expected one of {:?} but got {}",
1178            valid_indices,
1179            indices[0]
1180        );
1181    }
1182
1183    #[test]
1184    fn test_kdtree_with_chebyshev_distance() {
1185        let points = arr2(&[
1186            [2.0, 3.0],
1187            [5.0, 4.0],
1188            [9.0, 6.0],
1189            [4.0, 7.0],
1190            [8.0, 1.0],
1191            [7.0, 2.0],
1192        ]);
1193
1194        let metric = ChebyshevDistance::new();
1195        let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
1196
1197        // Query for nearest neighbor to [3.0, 5.0] using Chebyshev distance
1198        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1199
1200        // Calculate actual distances using Chebyshev distance
1201        let query = [3.0, 5.0];
1202        let mut expected_dists = vec![];
1203        for i in 0..points.nrows() {
1204            let p = points.row(i).to_vec();
1205            let m = ChebyshevDistance::<f64>::new();
1206            expected_dists.push((i, m.distance(&p, &query)));
1207        }
1208        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1209
1210        // Check that the distance matches the expected minimum distance
1211        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1212
1213        // Verify the returned index is one of the points with minimum distance
1214        let min_dist = expected_dists[0].1;
1215        let valid_indices: Vec<usize> = expected_dists
1216            .iter()
1217            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1218            .map(|(i, _)| *i)
1219            .collect();
1220        assert!(
1221            valid_indices.contains(&indices[0]),
1222            "Expected one of {:?} but got {}",
1223            valid_indices,
1224            indices[0]
1225        );
1226    }
1227
1228    #[test]
1229    fn test_kdtree_with_minkowski_distance() {
1230        let points = arr2(&[
1231            [2.0, 3.0],
1232            [5.0, 4.0],
1233            [9.0, 6.0],
1234            [4.0, 7.0],
1235            [8.0, 1.0],
1236            [7.0, 2.0],
1237        ]);
1238
1239        let metric = MinkowskiDistance::new(3.0);
1240        let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
1241
1242        // Query for nearest neighbor to [3.0, 5.0] using Minkowski distance (p=3)
1243        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1244
1245        // Calculate actual distances using Minkowski distance
1246        let query = [3.0, 5.0];
1247        let mut expected_dists = vec![];
1248        for i in 0..points.nrows() {
1249            let p = points.row(i).to_vec();
1250            let m = MinkowskiDistance::<f64>::new(3.0);
1251            expected_dists.push((i, m.distance(&p, &query)));
1252        }
1253        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1254
1255        // Check that the distance matches the expected minimum distance
1256        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1257
1258        // Verify the returned index is one of the points with minimum distance
1259        let min_dist = expected_dists[0].1;
1260        let valid_indices: Vec<usize> = expected_dists
1261            .iter()
1262            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1263            .map(|(i, _)| *i)
1264            .collect();
1265        assert!(
1266            valid_indices.contains(&indices[0]),
1267            "Expected one of {:?} but got {}",
1268            valid_indices,
1269            indices[0]
1270        );
1271    }
1272
1273    #[test]
1274    fn test_kdtree_with_custom_leaf_size() {
1275        let points = arr2(&[
1276            [2.0, 3.0],
1277            [5.0, 4.0],
1278            [9.0, 6.0],
1279            [4.0, 7.0],
1280            [8.0, 1.0],
1281            [7.0, 2.0],
1282        ]);
1283
1284        // Use a very small leaf size to test that it works
1285        let leafsize = 1;
1286        let kdtree = KDTree::with_leaf_size(&points, leafsize).expect("Operation failed");
1287
1288        assert_eq!(kdtree.leafsize(), 1);
1289
1290        // Query for nearest neighbor to [3.0, 5.0]
1291        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1292
1293        // Calculate actual distances
1294        let query = [3.0, 5.0];
1295        let mut expected_dists = vec![];
1296        for i in 0..points.nrows() {
1297            let p = points.row(i).to_vec();
1298            let metric = EuclideanDistance::<f64>::new();
1299            expected_dists.push((i, metric.distance(&p, &query)));
1300        }
1301        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1302
1303        // Verify we got one of the actual nearest neighbors (there might be ties)
1304        // Check that the distance matches the expected minimum distance
1305        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1306
1307        // Verify the returned index is one of the points with minimum distance
1308        let min_dist = expected_dists[0].1;
1309        let valid_indices: Vec<usize> = expected_dists
1310            .iter()
1311            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1312            .map(|(i, _)| *i)
1313            .collect();
1314        assert!(
1315            valid_indices.contains(&indices[0]),
1316            "Expected one of {:?} but got {}",
1317            valid_indices,
1318            indices[0]
1319        );
1320    }
1321}