scirs2_spatial/
kdtree.rs

1//! KD-Tree for efficient nearest neighbor searches
2//!
3//! This module provides a KD-Tree implementation for efficient
4//! nearest neighbor and range searches in multidimensional spaces.
5//!
6//! The KD-Tree (k-dimensional tree) is a space-partitioning data structure
7//! that organizes points in a k-dimensional space. It enables efficient range searches
8//! and nearest neighbor searches.
9//!
10//! # Features
11//!
12//! * Fast nearest neighbor queries with customizable `k`
13//! * Range queries with distance threshold
14//! * Support for different distance metrics (Euclidean, Manhattan, Chebyshev, etc.)
15//! * Parallel query processing (when using the `parallel` feature)
16//! * Customizable leaf size for performance tuning
17//!
18//! # Examples
19//!
20//! ```
21//! use scirs2_spatial::KDTree;
22//! use scirs2_core::ndarray::array;
23//!
24//! // Create a KD-Tree with points in 2D space
25//! let points = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
26//! let kdtree = KDTree::new(&points).unwrap();
27//!
28//! // Find the nearest neighbor to [4.0, 5.0]
29//! let (idx, dist) = kdtree.query(&[4.0, 5.0], 1).unwrap();
30//! assert_eq!(idx.len(), 1); // Should return exactly one neighbor
31//!
32//! // Find all points within radius 3.0 of [4.0, 5.0]
33//! let (indices, distances) = kdtree.query_radius(&[4.0, 5.0], 3.0).unwrap();
34//! ```
35//!
36//! # Advanced Usage
37//!
38//! Using custom distance metrics:
39//!
40//! ```
41//! use scirs2_spatial::KDTree;
42//! use scirs2_spatial::distance::ManhattanDistance;
43//! use scirs2_core::ndarray::array;
44//!
45//! // Create a KD-Tree with Manhattan distance metric
46//! let points = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
47//! let metric = ManhattanDistance::new();
48//! let kdtree = KDTree::with_metric(&points, metric).unwrap();
49//!
50//! // Find the nearest neighbor to [4.0, 5.0] using Manhattan distance
51//! let (idx, dist) = kdtree.query(&[4.0, 5.0], 1).unwrap();
52//! ```
53//!
54//! Using custom leaf size for performance tuning:
55//!
56//! ```
57//! use scirs2_spatial::KDTree;
58//! use scirs2_core::ndarray::array;
59//!
60//! // Create a KD-Tree with custom leaf size
61//! let points = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],
62//!                      [9.0, 10.0], [11.0, 12.0], [13.0, 14.0], [15.0, 16.0]];
63//! let leafsize = 2; // Default is 16
64//! let kdtree = KDTree::with_leaf_size(&points, leafsize).unwrap();
65//! ```
66
67use crate::distance::{Distance, EuclideanDistance};
68use crate::error::{SpatialError, SpatialResult};
69use crate::safe_conversions::*;
70use scirs2_core::ndarray::Array2;
71use scirs2_core::numeric::Float;
72use std::cmp::Ordering;
73
74// Rayon parallel processing currently not used in this module
75#[cfg(feature = "parallel")]
76#[allow(unused_imports)]
77/// A rectangle representing a hyperrectangle in k-dimensional space
78///
79/// Used for efficient nearest-neighbor and range queries in KD-trees.
80#[derive(Clone, Debug)]
81pub struct Rectangle<T: Float> {
82    /// Minimum coordinates for each dimension
83    mins: Vec<T>,
84    /// Maximum coordinates for each dimension
85    maxes: Vec<T>,
86}
87
88impl<T: Float> Rectangle<T> {
89    /// Create a new hyperrectangle
90    ///
91    /// # Arguments
92    ///
93    /// * `mins` - Minimum coordinates for each dimension
94    /// * `maxes` - Maximum coordinates for each dimension
95    ///
96    /// # Returns
97    ///
98    /// * A new rectangle
99    ///
100    /// # Panics
101    ///
102    /// * If mins and maxes have different lengths
103    /// * If any min value is greater than the corresponding max value
104    pub fn new(mins: Vec<T>, maxes: Vec<T>) -> Self {
105        assert_eq!(
106            mins.len(),
107            maxes.len(),
108            "mins and maxes must have the same length"
109        );
110
111        for i in 0..mins.len() {
112            assert!(
113                mins[i] <= maxes[i],
114                "min value must be less than or equal to max value"
115            );
116        }
117
118        Rectangle { mins, maxes }
119    }
120
121    /// Get the minimum coordinates of the rectangle
122    ///
123    /// # Returns
124    ///
125    /// * A slice containing the minimum coordinate values for each dimension
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use scirs2_spatial::kdtree::Rectangle;
131    ///
132    /// let rect = Rectangle::new(vec![0.0, 0.0], vec![1.0, 1.0]);
133    /// let mins = rect.mins();
134    /// assert_eq!(mins, &[0.0, 0.0]);
135    /// ```
136    pub fn mins(&self) -> &[T] {
137        &self.mins
138    }
139
140    /// Get the maximum coordinates of the rectangle
141    ///
142    /// # Returns
143    ///
144    /// * A slice containing the maximum coordinate values for each dimension
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// use scirs2_spatial::kdtree::Rectangle;
150    ///
151    /// let rect = Rectangle::new(vec![0.0, 0.0], vec![1.0, 1.0]);
152    /// let maxes = rect.maxes();
153    /// assert_eq!(maxes, &[1.0, 1.0]);
154    /// ```
155    pub fn maxes(&self) -> &[T] {
156        &self.maxes
157    }
158
159    /// Split the rectangle along a given dimension at a given value
160    ///
161    /// # Arguments
162    ///
163    /// * `dim` - The dimension to split on
164    /// * `value` - The value to split at
165    ///
166    /// # Returns
167    ///
168    /// * A tuple of (left, right) rectangles
169    pub fn split(&self, dim: usize, value: T) -> (Self, Self) {
170        let mut left_maxes = self.maxes.clone();
171        left_maxes[dim] = value;
172
173        let mut right_mins = self.mins.clone();
174        right_mins[dim] = value;
175
176        let left = Rectangle::new(self.mins.clone(), left_maxes);
177        let right = Rectangle::new(right_mins, self.maxes.clone());
178
179        (left, right)
180    }
181
182    /// Check if the rectangle contains a point
183    ///
184    /// # Arguments
185    ///
186    /// * `point` - The point to check
187    ///
188    /// # Returns
189    ///
190    /// * true if the rectangle contains the point, false otherwise
191    pub fn contains(&self, point: &[T]) -> bool {
192        assert_eq!(
193            point.len(),
194            self.mins.len(),
195            "point must have the same dimension as the rectangle"
196        );
197
198        for (i, &p) in point.iter().enumerate() {
199            if p < self.mins[i] || p > self.maxes[i] {
200                return false;
201            }
202        }
203
204        true
205    }
206
207    /// Calculate the minimum distance from a point to the rectangle
208    ///
209    /// # Arguments
210    ///
211    /// * `point` - The point to calculate distance to
212    /// * `metric` - The distance metric to use
213    ///
214    /// # Returns
215    ///
216    /// * The minimum distance from the point to any point in the rectangle
217    pub fn min_distance<D: Distance<T>>(&self, point: &[T], metric: &D) -> T {
218        metric.min_distance_point_rectangle(point, &self.mins, &self.maxes)
219    }
220}
221
222/// A node in the KD-Tree
223#[derive(Debug, Clone)]
224struct KDNode<T: Float> {
225    /// Index of the point in the original data array
226    idx: usize,
227    /// The value of the point along the splitting dimension
228    value: T,
229    /// The dimension used for splitting
230    axis: usize,
231    /// Left child node (values < median along splitting axis)
232    left: Option<usize>,
233    /// Right child node (values >= median along splitting axis)
234    right: Option<usize>,
235}
236
237/// A KD-Tree for efficient nearest neighbor searches
238///
239/// # Type Parameters
240///
241/// * `T` - The floating point type for coordinates
242/// * `D` - The distance metric type
243#[derive(Debug, Clone)]
244pub struct KDTree<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> {
245    /// The points stored in the KD-Tree
246    points: Array2<T>,
247    /// The nodes of the KD-Tree
248    nodes: Vec<KDNode<T>>,
249    /// The dimensionality of the points
250    ndim: usize,
251    /// The root node index
252    root: Option<usize>,
253    /// The distance metric
254    metric: D,
255    /// The leaf size (maximum number of points in a leaf node)
256    leafsize: usize,
257    /// Minimum bounding rectangle of the entire dataset
258    bounds: Rectangle<T>,
259}
260
261impl<T: Float + Send + Sync + 'static> KDTree<T, EuclideanDistance<T>> {
262    /// Create a new KD-Tree with default Euclidean distance metric
263    ///
264    /// # Arguments
265    ///
266    /// * `points` - Array of points, each row is a point
267    ///
268    /// # Returns
269    ///
270    /// * A new KD-Tree
271    pub fn new(points: &Array2<T>) -> SpatialResult<Self> {
272        let metric = EuclideanDistance::new();
273        Self::with_metric(points, metric)
274    }
275
276    /// Create a new KD-Tree with custom leaf size (using Euclidean distance)
277    ///
278    /// # Arguments
279    ///
280    /// * `points` - Array of points, each row is a point
281    /// * `leafsize` - The maximum number of points in a leaf node
282    ///
283    /// # Returns
284    ///
285    /// * A new KD-Tree
286    pub fn with_leaf_size(points: &Array2<T>, leafsize: usize) -> SpatialResult<Self> {
287        let metric = EuclideanDistance::new();
288        Self::with_options(points, metric, leafsize)
289    }
290}
291
292impl<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> KDTree<T, D> {
293    /// Create a new KD-Tree with custom distance metric
294    ///
295    /// # Arguments
296    ///
297    /// * `points` - Array of points, each row is a point
298    /// * `metric` - The distance metric to use
299    ///
300    /// # Returns
301    ///
302    /// * A new KD-Tree
303    pub fn with_metric(points: &Array2<T>, metric: D) -> SpatialResult<Self> {
304        Self::with_options(points, metric, 16) // Default leaf size is 16
305    }
306
307    /// Create a new KD-Tree with custom distance metric and leaf size
308    ///
309    /// # Arguments
310    ///
311    /// * `points` - Array of points, each row is a point
312    /// * `metric` - The distance metric to use
313    /// * `leafsize` - The maximum number of points in a leaf node
314    ///
315    /// # Returns
316    ///
317    /// * A new KD-Tree
318    pub fn with_options(points: &Array2<T>, metric: D, leafsize: usize) -> SpatialResult<Self> {
319        let n = points.nrows();
320        let ndim = points.ncols();
321
322        if n == 0 {
323            return Err(SpatialError::ValueError("Empty point set".to_string()));
324        }
325
326        if leafsize == 0 {
327            return Err(SpatialError::ValueError(
328                "Leaf _size must be greater than 0".to_string(),
329            ));
330        }
331
332        // Calculate the bounds of the dataset
333        let mut mins = vec![T::max_value(); ndim];
334        let mut maxes = vec![T::min_value(); ndim];
335
336        for i in 0..n {
337            for j in 0..ndim {
338                let val = points[[i, j]];
339                if val < mins[j] {
340                    mins[j] = val;
341                }
342                if val > maxes[j] {
343                    maxes[j] = val;
344                }
345            }
346        }
347
348        let bounds = Rectangle::new(mins, maxes);
349
350        let mut tree = KDTree {
351            points: points.clone(),
352            nodes: Vec::with_capacity(n),
353            ndim,
354            root: None,
355            metric,
356            leafsize,
357            bounds,
358        };
359
360        // Create indices for the points
361        let mut indices: Vec<usize> = (0..n).collect();
362
363        // Build the tree recursively
364        if n > 0 {
365            let root = tree.build_tree(&mut indices, 0, 0, n)?;
366            tree.root = Some(root);
367        }
368
369        Ok(tree)
370    }
371
372    /// Build the KD-Tree recursively
373    ///
374    /// # Arguments
375    ///
376    /// * `indices` - Indices of the points to consider
377    /// * `depth` - Current depth in the tree
378    /// * `start` - Start index in the indices array
379    /// * `end` - End index in the indices array
380    ///
381    /// # Returns
382    ///
383    /// * Index of the root node of the subtree
384    fn build_tree(
385        &mut self,
386        indices: &mut [usize],
387        depth: usize,
388        start: usize,
389        end: usize,
390    ) -> SpatialResult<usize> {
391        let n = end - start;
392
393        if n == 0 {
394            return Err(SpatialError::ValueError(
395                "Empty point set in build_tree".to_string(),
396            ));
397        }
398
399        // Choose axis based on depth (cycle through axes)
400        let axis = depth % self.ndim;
401
402        // If we have only one point, create a leaf node
403        let node_idx;
404        if n == 1 {
405            let idx = indices[start];
406            let value = self.points[[idx, axis]];
407
408            node_idx = self.nodes.len();
409            self.nodes.push(KDNode {
410                idx,
411                value,
412                axis,
413                left: None,
414                right: None,
415            });
416
417            return Ok(node_idx);
418        }
419
420        // Sort indices based on the axis
421        indices[start..end].sort_by(|&i, &j| {
422            let a = self.points[[i, axis]];
423            let b = self.points[[j, axis]];
424            a.partial_cmp(&b).unwrap_or(Ordering::Equal)
425        });
426
427        // Get the median index
428        let mid = start + n / 2;
429        let idx = indices[mid];
430        let value = self.points[[idx, axis]];
431
432        // Create node
433        node_idx = self.nodes.len();
434        self.nodes.push(KDNode {
435            idx,
436            value,
437            axis,
438            left: None,
439            right: None,
440        });
441
442        // Build left and right subtrees
443        if mid > start {
444            let left_idx = self.build_tree(indices, depth + 1, start, mid)?;
445            self.nodes[node_idx].left = Some(left_idx);
446        }
447
448        if mid + 1 < end {
449            let right_idx = self.build_tree(indices, depth + 1, mid + 1, end)?;
450            self.nodes[node_idx].right = Some(right_idx);
451        }
452
453        Ok(node_idx)
454    }
455
456    /// Find the k nearest neighbors to a query point
457    ///
458    /// # Arguments
459    ///
460    /// * `point` - Query point
461    /// * `k` - Number of nearest neighbors to find
462    ///
463    /// # Returns
464    ///
465    /// * (indices, distances) of the k nearest neighbors
466    ///
467    /// # Examples
468    ///
469    /// ```
470    /// use scirs2_spatial::KDTree;
471    /// use scirs2_core::ndarray::array;
472    ///
473    /// // Create points for the KDTree - use the exact same points from test_kdtree_with_custom_leaf_size
474    /// let points = array![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0], [8.0, 1.0], [7.0, 2.0]];
475    /// let kdtree = KDTree::new(&points).unwrap();
476    ///
477    /// // Find the 2 nearest neighbors to [0.5, 0.5]
478    /// let (indices, distances) = kdtree.query(&[0.5, 0.5], 2).unwrap();
479    /// assert_eq!(indices.len(), 2);
480    /// assert_eq!(distances.len(), 2);
481    /// ```
482    pub fn query(&self, point: &[T], k: usize) -> SpatialResult<(Vec<usize>, Vec<T>)> {
483        if point.len() != self.ndim {
484            return Err(SpatialError::DimensionError(format!(
485                "Query point dimension ({}) does not match tree dimension ({})",
486                point.len(),
487                self.ndim
488            )));
489        }
490
491        if k == 0 {
492            return Ok((vec![], vec![]));
493        }
494
495        if self.points.nrows() == 0 {
496            return Ok((vec![], vec![]));
497        }
498
499        // Initialize priority queue for k nearest neighbors
500        // We use a max-heap so we can efficiently replace the furthest point when we find a closer one
501        let mut neighbors: Vec<(T, usize)> = Vec::with_capacity(k + 1);
502
503        // Keep track of the maximum distance in the heap, for early termination
504        let mut max_dist = T::infinity();
505
506        if let Some(root) = self.root {
507            // Search recursively
508            self.query_recursive(root, point, k, &mut neighbors, &mut max_dist);
509
510            // Sort by distance (ascending), with index as tiebreaker
511            neighbors.sort_by(|a, b| {
512                match safe_partial_cmp(&a.0, &b.0, "kdtree sort neighbors") {
513                    Ok(std::cmp::Ordering::Equal) => a.1.cmp(&b.1), // Use index as tiebreaker
514                    Ok(ord) => ord,
515                    Err(_) => std::cmp::Ordering::Equal,
516                }
517            });
518
519            // Trim to k elements if needed
520            if neighbors.len() > k {
521                neighbors.truncate(k);
522            }
523
524            // Convert to sorted lists of indices and distances
525            let mut indices = Vec::with_capacity(neighbors.len());
526            let mut distances = Vec::with_capacity(neighbors.len());
527
528            for (dist, idx) in neighbors {
529                indices.push(idx);
530                distances.push(dist);
531            }
532
533            Ok((indices, distances))
534        } else {
535            Err(SpatialError::ValueError("Empty tree".to_string()))
536        }
537    }
538
539    /// Recursive helper for query
540    fn query_recursive(
541        &self,
542        node_idx: usize,
543        point: &[T],
544        k: usize,
545        neighbors: &mut Vec<(T, usize)>,
546        max_dist: &mut T,
547    ) {
548        let node = &self.nodes[node_idx];
549        let idx = node.idx;
550        let axis = node.axis;
551
552        // Calculate distance to current point
553        let node_point = self.points.row(idx).to_vec();
554        let _dist = self.metric.distance(&node_point, point);
555
556        // Update neighbors if needed
557        if neighbors.len() < k {
558            neighbors.push((_dist, idx));
559
560            // Sort if we just filled to capacity to establish max-heap
561            if neighbors.len() == k {
562                neighbors.sort_by(|a, b| {
563                    match safe_partial_cmp(&b.0, &a.0, "kdtree sort max-heap") {
564                        Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), // Use index as tiebreaker
565                        Ok(ord) => ord,
566                        Err(_) => std::cmp::Ordering::Equal,
567                    }
568                });
569                *max_dist = neighbors[0].0;
570            }
571        } else if &_dist < max_dist {
572            // Replace the worst neighbor with this one
573            neighbors[0] = (_dist, idx);
574
575            // Re-sort to maintain max-heap property
576            neighbors.sort_by(|a, b| {
577                match safe_partial_cmp(&b.0, &a.0, "kdtree re-sort max-heap") {
578                    Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), // Use index as tiebreaker
579                    Ok(ord) => ord,
580                    Err(_) => std::cmp::Ordering::Equal,
581                }
582            });
583            *max_dist = neighbors[0].0;
584        }
585
586        // Determine which subtree to search first
587        let diff = point[axis] - node.value;
588        let (first, second) = if diff < T::zero() {
589            (node.left, node.right)
590        } else {
591            (node.right, node.left)
592        };
593
594        // Search the near subtree
595        if let Some(first_idx) = first {
596            self.query_recursive(first_idx, point, k, neighbors, max_dist);
597        }
598
599        // Only search the far subtree if it could contain closer points
600        let axis_dist = if diff < T::zero() {
601            // Point is to the left of the splitting hyperplane
602            T::zero() // No need to calculate distance if we're considering the left subtree next
603        } else {
604            // Point is to the right of the splitting hyperplane
605            diff
606        };
607
608        if let Some(second_idx) = second {
609            // Only search the second subtree if necessary
610            if neighbors.len() < k || axis_dist < *max_dist {
611                self.query_recursive(second_idx, point, k, neighbors, max_dist);
612            }
613        }
614    }
615
616    /// Find all points within a radius of a query point
617    ///
618    /// # Arguments
619    ///
620    /// * `point` - Query point
621    /// * `radius` - Search radius
622    ///
623    /// # Returns
624    ///
625    /// * (indices, distances) of points within the radius
626    ///
627    /// # Examples
628    ///
629    /// ```
630    /// use scirs2_spatial::KDTree;
631    /// use scirs2_core::ndarray::array;
632    ///
633    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
634    /// let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
635    /// let kdtree = KDTree::new(&points)?;
636    ///
637    /// // Find all points within radius 0.7 of [0.5, 0.5]
638    /// let (indices, distances) = kdtree.query_radius(&[0.5, 0.5], 0.7)?;
639    /// assert_eq!(indices.len(), 4); // All points are within 0.7 units of [0.5, 0.5]
640    /// # Ok(())
641    /// # }
642    /// ```
643    pub fn query_radius(&self, point: &[T], radius: T) -> SpatialResult<(Vec<usize>, Vec<T>)> {
644        if point.len() != self.ndim {
645            return Err(SpatialError::DimensionError(format!(
646                "Query point dimension ({}) does not match tree dimension ({})",
647                point.len(),
648                self.ndim
649            )));
650        }
651
652        if radius < T::zero() {
653            return Err(SpatialError::ValueError(
654                "Radius must be non-negative".to_string(),
655            ));
656        }
657
658        let mut indices = Vec::new();
659        let mut distances = Vec::new();
660
661        if let Some(root) = self.root {
662            // If the radius is outside the bounds of the entire dataset, just return an empty result
663            let bounds_dist = self.bounds.min_distance(point, &self.metric);
664            if bounds_dist > radius {
665                return Ok((indices, distances));
666            }
667
668            // Search recursively
669            self.query_radius_recursive(root, point, radius, &mut indices, &mut distances);
670
671            // Sort by distance
672            if !indices.is_empty() {
673                let mut idx_dist: Vec<(usize, T)> = indices.into_iter().zip(distances).collect();
674                idx_dist.sort_by(|a, b| {
675                    safe_partial_cmp(&a.1, &b.1, "kdtree sort radius results")
676                        .unwrap_or(std::cmp::Ordering::Equal)
677                });
678
679                indices = idx_dist.iter().map(|(idx_, _)| *idx_).collect();
680                distances = idx_dist.iter().map(|(_, dist)| *dist).collect();
681            }
682        }
683
684        Ok((indices, distances))
685    }
686
687    /// Recursive helper for query_radius
688    fn query_radius_recursive(
689        &self,
690        node_idx: usize,
691        point: &[T],
692        radius: T,
693        indices: &mut Vec<usize>,
694        distances: &mut Vec<T>,
695    ) {
696        let node = &self.nodes[node_idx];
697        let idx = node.idx;
698        let axis = node.axis;
699
700        // Calculate distance to current point
701        let node_point = self.points.row(idx).to_vec();
702        let dist = self.metric.distance(&node_point, point);
703
704        // If point is within radius, add it to results
705        if dist <= radius {
706            indices.push(idx);
707            distances.push(dist);
708        }
709
710        // Determine which subtrees need to be searched
711        let diff = point[axis] - node.value;
712
713        // Always search the near subtree
714        let (near, far) = if diff < T::zero() {
715            (node.left, node.right)
716        } else {
717            (node.right, node.left)
718        };
719
720        if let Some(near_idx) = near {
721            self.query_radius_recursive(near_idx, point, radius, indices, distances);
722        }
723
724        // Only search the far subtree if it could contain points within radius
725        if diff.abs() <= radius {
726            if let Some(far_idx) = far {
727                self.query_radius_recursive(far_idx, point, radius, indices, distances);
728            }
729        }
730    }
731
732    /// Count the number of points within a radius of a query point
733    ///
734    /// This method is more efficient than query_radius when only the count is needed.
735    ///
736    /// # Arguments
737    ///
738    /// * `point` - Query point
739    /// * `radius` - Search radius
740    ///
741    /// # Returns
742    ///
743    /// * Number of points within the radius
744    ///
745    /// # Examples
746    ///
747    /// ```
748    /// use scirs2_spatial::KDTree;
749    /// use scirs2_core::ndarray::array;
750    ///
751    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
752    /// let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
753    /// let kdtree = KDTree::new(&points)?;
754    ///
755    /// // Count points within radius 0.7 of [0.5, 0.5]
756    /// let count = kdtree.count_neighbors(&[0.5, 0.5], 0.7)?;
757    /// assert_eq!(count, 4); // All points are within 0.7 units of [0.5, 0.5]
758    /// # Ok(())
759    /// # }
760    /// ```
761    pub fn count_neighbors(&self, point: &[T], radius: T) -> SpatialResult<usize> {
762        if point.len() != self.ndim {
763            return Err(SpatialError::DimensionError(format!(
764                "Query point dimension ({}) does not match tree dimension ({})",
765                point.len(),
766                self.ndim
767            )));
768        }
769
770        if radius < T::zero() {
771            return Err(SpatialError::ValueError(
772                "Radius must be non-negative".to_string(),
773            ));
774        }
775
776        let mut count = 0;
777
778        if let Some(root) = self.root {
779            // If the radius is outside the bounds of the entire dataset, just return 0
780            let bounds_dist = self.bounds.min_distance(point, &self.metric);
781            if bounds_dist > radius {
782                return Ok(0);
783            }
784
785            // Search recursively
786            self.count_neighbors_recursive(root, point, radius, &mut count);
787        }
788
789        Ok(count)
790    }
791
792    /// Recursive helper for count_neighbors
793    fn count_neighbors_recursive(
794        &self,
795        node_idx: usize,
796        point: &[T],
797        radius: T,
798        count: &mut usize,
799    ) {
800        let node = &self.nodes[node_idx];
801        let idx = node.idx;
802        let axis = node.axis;
803
804        // Calculate distance to current point
805        let node_point = self.points.row(idx).to_vec();
806        let dist = self.metric.distance(&node_point, point);
807
808        // If point is within radius, increment count
809        if dist <= radius {
810            *count += 1;
811        }
812
813        // Determine which subtrees need to be searched
814        let diff = point[axis] - node.value;
815
816        // Always search the near subtree
817        let (near, far) = if diff < T::zero() {
818            (node.left, node.right)
819        } else {
820            (node.right, node.left)
821        };
822
823        if let Some(near_idx) = near {
824            self.count_neighbors_recursive(near_idx, point, radius, count);
825        }
826
827        // Only search the far subtree if it could contain points within radius
828        if diff.abs() <= radius {
829            if let Some(far_idx) = far {
830                self.count_neighbors_recursive(far_idx, point, radius, count);
831            }
832        }
833    }
834
835    /// Get the shape of the KD-Tree's point set
836    ///
837    /// # Returns
838    ///
839    /// * A tuple of (n_points, n_dimensions)
840    pub fn shape(&self) -> (usize, usize) {
841        (self.points.nrows(), self.ndim)
842    }
843
844    /// Get the number of points in the KD-Tree
845    ///
846    /// # Returns
847    ///
848    /// * The number of points
849    pub fn npoints(&self) -> usize {
850        self.points.nrows()
851    }
852
853    /// Get the dimensionality of the points in the KD-Tree
854    ///
855    /// # Returns
856    ///
857    /// * The dimensionality of the points
858    pub fn ndim(&self) -> usize {
859        self.ndim
860    }
861
862    /// Get the leaf size of the KD-Tree
863    ///
864    /// # Returns
865    ///
866    /// * The leaf size
867    pub fn leafsize(&self) -> usize {
868        self.leafsize
869    }
870
871    /// Get the bounds of the KD-Tree
872    ///
873    /// # Returns
874    ///
875    /// * The bounding rectangle of the entire dataset
876    pub fn bounds(&self) -> &Rectangle<T> {
877        &self.bounds
878    }
879}
880
881#[cfg(test)]
882mod tests {
883    use super::{KDTree, Rectangle};
884    use crate::distance::{
885        ChebyshevDistance, Distance, EuclideanDistance, ManhattanDistance, MinkowskiDistance,
886    };
887    use approx::assert_relative_eq;
888    use scirs2_core::ndarray::arr2;
889
890    #[test]
891    fn test_rectangle() {
892        let mins = vec![0.0, 0.0];
893        let maxes = vec![1.0, 1.0];
894        let rect = Rectangle::new(mins, maxes);
895
896        // Test contains
897        assert!(rect.contains(&[0.5, 0.5]));
898        assert!(rect.contains(&[0.0, 0.0]));
899        assert!(rect.contains(&[1.0, 1.0]));
900        assert!(!rect.contains(&[1.5, 0.5]));
901        assert!(!rect.contains(&[0.5, 1.5]));
902
903        // Test split
904        let (left, right) = rect.split(0, 0.5);
905        assert!(left.contains(&[0.25, 0.5]));
906        assert!(!left.contains(&[0.75, 0.5]));
907        assert!(!right.contains(&[0.25, 0.5]));
908        assert!(right.contains(&[0.75, 0.5]));
909
910        // Test min_distance
911        let metric = EuclideanDistance::<f64>::new();
912        assert_relative_eq!(rect.min_distance(&[0.5, 0.5], &metric), 0.0, epsilon = 1e-6);
913        assert_relative_eq!(rect.min_distance(&[2.0, 0.5], &metric), 1.0, epsilon = 1e-6);
914        assert_relative_eq!(
915            rect.min_distance(&[2.0, 2.0], &metric),
916            std::f64::consts::SQRT_2,
917            epsilon = 1e-6
918        );
919    }
920
921    #[test]
922    fn test_kdtree_build() {
923        let points = arr2(&[
924            [2.0, 3.0],
925            [5.0, 4.0],
926            [9.0, 6.0],
927            [4.0, 7.0],
928            [8.0, 1.0],
929            [7.0, 2.0],
930        ]);
931
932        let kdtree = KDTree::new(&points).unwrap();
933
934        // Check that the tree has the correct number of nodes
935        assert_eq!(kdtree.nodes.len(), points.nrows());
936
937        // Check tree properties
938        assert_eq!(kdtree.shape(), (6, 2));
939        assert_eq!(kdtree.npoints(), 6);
940        assert_eq!(kdtree.ndim(), 2);
941        assert_eq!(kdtree.leafsize(), 16);
942
943        // Check bounds
944        assert_eq!(kdtree.bounds().mins(), &[2.0, 1.0]);
945        assert_eq!(kdtree.bounds().maxes(), &[9.0, 7.0]);
946    }
947
948    #[test]
949    fn test_kdtree_query() {
950        let points = arr2(&[
951            [2.0, 3.0],
952            [5.0, 4.0],
953            [9.0, 6.0],
954            [4.0, 7.0],
955            [8.0, 1.0],
956            [7.0, 2.0],
957        ]);
958
959        let kdtree = KDTree::new(&points).unwrap();
960
961        // Query for nearest neighbor to [3.0, 5.0]
962        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
963        assert_eq!(indices.len(), 1);
964        assert_eq!(distances.len(), 1);
965
966        // Calculate actual distances
967        let query = [3.0, 5.0];
968        let mut expected_dists = vec![];
969        for i in 0..points.nrows() {
970            let p = points.row(i).to_vec();
971            let metric = EuclideanDistance::<f64>::new();
972            expected_dists.push((i, metric.distance(&p, &query)));
973        }
974        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
975
976        // Verify we got one of the actual nearest neighbors (there might be ties)
977        // Check that the distance matches the expected minimum distance
978        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
979
980        // Verify the returned index is one of the points with minimum distance
981        let min_dist = expected_dists[0].1;
982        let valid_indices: Vec<usize> = expected_dists
983            .iter()
984            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
985            .map(|(i, _)| *i)
986            .collect();
987        assert!(
988            valid_indices.contains(&indices[0]),
989            "Expected one of {:?} but got {}",
990            valid_indices,
991            indices[0]
992        );
993    }
994
995    #[test]
996    fn test_kdtree_query_k() {
997        let points = arr2(&[
998            [2.0, 3.0],
999            [5.0, 4.0],
1000            [9.0, 6.0],
1001            [4.0, 7.0],
1002            [8.0, 1.0],
1003            [7.0, 2.0],
1004        ]);
1005
1006        let kdtree = KDTree::new(&points).unwrap();
1007
1008        // Query for 3 nearest neighbors to [3.0, 5.0]
1009        let (indices, distances) = kdtree.query(&[3.0, 5.0], 3).unwrap();
1010        assert_eq!(indices.len(), 3);
1011        assert_eq!(distances.len(), 3);
1012
1013        // Calculate actual distances
1014        let query = [3.0, 5.0];
1015        let mut expected_dists = vec![];
1016        for i in 0..points.nrows() {
1017            let p = points.row(i).to_vec();
1018            let metric = EuclideanDistance::<f64>::new();
1019            expected_dists.push((i, metric.distance(&p, &query)));
1020        }
1021        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1022
1023        // Verify we got the 3 actual nearest neighbors (for now, just check distances)
1024        let expected_indices: Vec<usize> = expected_dists.iter().take(3).map(|&(i, _)| i).collect();
1025        let expected_distances: Vec<f64> = expected_dists.iter().take(3).map(|&(_, d)| d).collect();
1026
1027        // Check each returned index is in the expected set
1028        for i in &indices {
1029            assert!(expected_indices.contains(i));
1030        }
1031
1032        // Check that distances are sorted
1033        assert!(distances[0] <= distances[1]);
1034        assert!(distances[1] <= distances[2]);
1035
1036        // Check distance values match expected
1037        for i in 0..3 {
1038            assert_relative_eq!(distances[i], expected_distances[i], epsilon = 1e-6);
1039        }
1040    }
1041
1042    #[test]
1043    fn test_kdtree_query_radius() {
1044        let points = arr2(&[
1045            [2.0, 3.0],
1046            [5.0, 4.0],
1047            [9.0, 6.0],
1048            [4.0, 7.0],
1049            [8.0, 1.0],
1050            [7.0, 2.0],
1051        ]);
1052
1053        let kdtree = KDTree::new(&points).unwrap();
1054
1055        // Query for points within radius 3.0 of [3.0, 5.0]
1056        let (indices, distances) = kdtree.query_radius(&[3.0, 5.0], 3.0).unwrap();
1057
1058        // Calculate expected results
1059        let query = [3.0, 5.0];
1060        let radius = 3.0;
1061        let mut expected_results = vec![];
1062        for i in 0..points.nrows() {
1063            let p = points.row(i).to_vec();
1064            let metric = EuclideanDistance::<f64>::new();
1065            let dist = metric.distance(&p, &query);
1066            if dist <= radius {
1067                expected_results.push((i, dist));
1068            }
1069        }
1070        expected_results.sort_by(|a, b| {
1071            match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1072                std::cmp::Ordering::Equal => a.0.cmp(&b.0), // Use index as tiebreaker
1073                ord => ord,
1074            }
1075        });
1076
1077        // Check that we got the expected number of points
1078        assert_eq!(indices.len(), expected_results.len());
1079
1080        // Check that all returned points are within radius
1081        for i in 0..indices.len() {
1082            assert!(distances[i] <= radius + 1e-6);
1083        }
1084
1085        // Check that the indices/distances pairs match expected results
1086        // Note: order might differ for equal distances
1087        let mut idx_dist_pairs: Vec<(usize, f64)> = indices
1088            .iter()
1089            .zip(distances.iter())
1090            .map(|(&i, &d)| (i, d))
1091            .collect();
1092        idx_dist_pairs.sort_by(|a, b| {
1093            match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1094                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
1095                ord => ord,
1096            }
1097        });
1098
1099        for (actual, expected) in idx_dist_pairs.iter().zip(expected_results.iter()) {
1100            assert_eq!(actual.0, expected.0);
1101            assert_relative_eq!(actual.1, expected.1, epsilon = 1e-6);
1102        }
1103    }
1104
1105    #[test]
1106    fn test_kdtree_count_neighbors() {
1107        let points = arr2(&[
1108            [2.0, 3.0],
1109            [5.0, 4.0],
1110            [9.0, 6.0],
1111            [4.0, 7.0],
1112            [8.0, 1.0],
1113            [7.0, 2.0],
1114        ]);
1115
1116        let kdtree = KDTree::new(&points).unwrap();
1117
1118        // Count points within radius 3.0 of [3.0, 5.0]
1119        let count = kdtree.count_neighbors(&[3.0, 5.0], 3.0).unwrap();
1120
1121        // Calculate actual count
1122        let query = [3.0, 5.0];
1123        let mut expected_count = 0;
1124        for i in 0..points.nrows() {
1125            let p = points.row(i).to_vec();
1126            let metric = EuclideanDistance::<f64>::new();
1127            let dist = metric.distance(&p, &query);
1128            if dist <= 3.0 {
1129                expected_count += 1;
1130            }
1131        }
1132
1133        assert_eq!(count, expected_count);
1134    }
1135
1136    #[test]
1137    fn test_kdtree_with_manhattan_distance() {
1138        let points = arr2(&[
1139            [2.0, 3.0],
1140            [5.0, 4.0],
1141            [9.0, 6.0],
1142            [4.0, 7.0],
1143            [8.0, 1.0],
1144            [7.0, 2.0],
1145        ]);
1146
1147        let metric = ManhattanDistance::new();
1148        let kdtree = KDTree::with_metric(&points, metric).unwrap();
1149
1150        // Query for nearest neighbor to [3.0, 5.0] using Manhattan distance
1151        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1152
1153        // Calculate actual distances using Manhattan distance
1154        let query = [3.0, 5.0];
1155        let mut expected_dists = vec![];
1156        for i in 0..points.nrows() {
1157            let p = points.row(i).to_vec();
1158            let m = ManhattanDistance::<f64>::new();
1159            expected_dists.push((i, m.distance(&p, &query)));
1160        }
1161        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1162
1163        // Check that the distance matches the expected minimum distance
1164        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1165
1166        // Verify the returned index is one of the points with minimum distance
1167        let min_dist = expected_dists[0].1;
1168        let valid_indices: Vec<usize> = expected_dists
1169            .iter()
1170            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1171            .map(|(i, _)| *i)
1172            .collect();
1173        assert!(
1174            valid_indices.contains(&indices[0]),
1175            "Expected one of {:?} but got {}",
1176            valid_indices,
1177            indices[0]
1178        );
1179    }
1180
1181    #[test]
1182    fn test_kdtree_with_chebyshev_distance() {
1183        let points = arr2(&[
1184            [2.0, 3.0],
1185            [5.0, 4.0],
1186            [9.0, 6.0],
1187            [4.0, 7.0],
1188            [8.0, 1.0],
1189            [7.0, 2.0],
1190        ]);
1191
1192        let metric = ChebyshevDistance::new();
1193        let kdtree = KDTree::with_metric(&points, metric).unwrap();
1194
1195        // Query for nearest neighbor to [3.0, 5.0] using Chebyshev distance
1196        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1197
1198        // Calculate actual distances using Chebyshev distance
1199        let query = [3.0, 5.0];
1200        let mut expected_dists = vec![];
1201        for i in 0..points.nrows() {
1202            let p = points.row(i).to_vec();
1203            let m = ChebyshevDistance::<f64>::new();
1204            expected_dists.push((i, m.distance(&p, &query)));
1205        }
1206        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1207
1208        // Check that the distance matches the expected minimum distance
1209        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1210
1211        // Verify the returned index is one of the points with minimum distance
1212        let min_dist = expected_dists[0].1;
1213        let valid_indices: Vec<usize> = expected_dists
1214            .iter()
1215            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1216            .map(|(i, _)| *i)
1217            .collect();
1218        assert!(
1219            valid_indices.contains(&indices[0]),
1220            "Expected one of {:?} but got {}",
1221            valid_indices,
1222            indices[0]
1223        );
1224    }
1225
1226    #[test]
1227    fn test_kdtree_with_minkowski_distance() {
1228        let points = arr2(&[
1229            [2.0, 3.0],
1230            [5.0, 4.0],
1231            [9.0, 6.0],
1232            [4.0, 7.0],
1233            [8.0, 1.0],
1234            [7.0, 2.0],
1235        ]);
1236
1237        let metric = MinkowskiDistance::new(3.0);
1238        let kdtree = KDTree::with_metric(&points, metric).unwrap();
1239
1240        // Query for nearest neighbor to [3.0, 5.0] using Minkowski distance (p=3)
1241        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1242
1243        // Calculate actual distances using Minkowski distance
1244        let query = [3.0, 5.0];
1245        let mut expected_dists = vec![];
1246        for i in 0..points.nrows() {
1247            let p = points.row(i).to_vec();
1248            let m = MinkowskiDistance::<f64>::new(3.0);
1249            expected_dists.push((i, m.distance(&p, &query)));
1250        }
1251        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1252
1253        // Check that the distance matches the expected minimum distance
1254        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1255
1256        // Verify the returned index is one of the points with minimum distance
1257        let min_dist = expected_dists[0].1;
1258        let valid_indices: Vec<usize> = expected_dists
1259            .iter()
1260            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1261            .map(|(i, _)| *i)
1262            .collect();
1263        assert!(
1264            valid_indices.contains(&indices[0]),
1265            "Expected one of {:?} but got {}",
1266            valid_indices,
1267            indices[0]
1268        );
1269    }
1270
1271    #[test]
1272    fn test_kdtree_with_custom_leaf_size() {
1273        let points = arr2(&[
1274            [2.0, 3.0],
1275            [5.0, 4.0],
1276            [9.0, 6.0],
1277            [4.0, 7.0],
1278            [8.0, 1.0],
1279            [7.0, 2.0],
1280        ]);
1281
1282        // Use a very small leaf size to test that it works
1283        let leafsize = 1;
1284        let kdtree = KDTree::with_leaf_size(&points, leafsize).unwrap();
1285
1286        assert_eq!(kdtree.leafsize(), 1);
1287
1288        // Query for nearest neighbor to [3.0, 5.0]
1289        let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1290
1291        // Calculate actual distances
1292        let query = [3.0, 5.0];
1293        let mut expected_dists = vec![];
1294        for i in 0..points.nrows() {
1295            let p = points.row(i).to_vec();
1296            let metric = EuclideanDistance::<f64>::new();
1297            expected_dists.push((i, metric.distance(&p, &query)));
1298        }
1299        expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1300
1301        // Verify we got one of the actual nearest neighbors (there might be ties)
1302        // Check that the distance matches the expected minimum distance
1303        assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1304
1305        // Verify the returned index is one of the points with minimum distance
1306        let min_dist = expected_dists[0].1;
1307        let valid_indices: Vec<usize> = expected_dists
1308            .iter()
1309            .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1310            .map(|(i, _)| *i)
1311            .collect();
1312        assert!(
1313            valid_indices.contains(&indices[0]),
1314            "Expected one of {:?} but got {}",
1315            valid_indices,
1316            indices[0]
1317        );
1318    }
1319}