scirs2_spatial/
quadtree.rs

1//! Quadtree data structure for 2D space
2//!
3//! This module provides a Quadtree implementation for efficient spatial queries
4//! in 2D space. Quadtrees recursively subdivide space into four equal quadrants,
5//! allowing for efficient nearest neighbor searches, range queries, and
6//! point-in-region operations.
7//!
8//! The implementation supports:
9//! - Quadtree construction from 2D point data
10//! - Nearest neighbor searches
11//! - Range queries for finding points within a specified distance
12//! - Point-in-region queries
13//! - Dynamic insertion and removal of points
14
15use crate::error::{SpatialError, SpatialResult};
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, VecDeque};
19
20/// Maximum number of points in a leaf node before it splits
21const MAX_POINTS_PER_NODE: usize = 8;
22/// Maximum depth of the quadtree
23const MAX_DEPTH: usize = 20;
24
25/// A 2D bounding box defined by its minimum and maximum corners
26#[derive(Debug, Clone)]
27pub struct BoundingBox2D {
28    /// Minimum coordinates of the box (lower left corner)
29    pub min: Array1<f64>,
30    /// Maximum coordinates of the box (upper right corner)
31    pub max: Array1<f64>,
32}
33
34impl BoundingBox2D {
35    /// Create a new bounding box from min and max corners
36    ///
37    /// # Arguments
38    ///
39    /// * `min` - Minimum coordinates (lower left corner)
40    /// * `max` - Maximum coordinates (upper right corner)
41    ///
42    /// # Returns
43    ///
44    /// A new BoundingBox2D
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the min or max arrays don't have 2 elements,
49    /// or if min > max for any dimension
50    pub fn new(min: &ArrayView1<f64>, max: &ArrayView1<f64>) -> SpatialResult<Self> {
51        if min.len() != 2 || max.len() != 2 {
52            return Err(SpatialError::DimensionError(format!(
53                "Min and max must have 2 elements, got {} and {}",
54                min.len(),
55                max.len()
56            )));
57        }
58
59        // Check that _min <= max for all dimensions
60        for i in 0..2 {
61            if min[i] > max[i] {
62                return Err(SpatialError::ValueError(format!(
63                    "Min must be <= max for all dimensions, got min[{}]={} > max[{}]={}",
64                    i, min[i], i, max[i]
65                )));
66            }
67        }
68
69        Ok(BoundingBox2D {
70            min: min.to_owned(),
71            max: max.to_owned(),
72        })
73    }
74
75    /// Create a bounding box that encompasses a set of points
76    ///
77    /// # Arguments
78    ///
79    /// * `points` - An array of 2D points
80    ///
81    /// # Returns
82    ///
83    /// A bounding box that contains all the points
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if the points array is empty or if points don't have 2 dimensions
88    pub fn from_points(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
89        if points.is_empty() {
90            return Err(SpatialError::ValueError(
91                "Cannot create bounding box from empty point set".into(),
92            ));
93        }
94
95        if points.ncols() != 2 {
96            return Err(SpatialError::DimensionError(format!(
97                "Points must have 2 columns, got {}",
98                points.ncols()
99            )));
100        }
101
102        // Find min and max coordinates
103        let mut min = Array1::from_vec(vec![f64::INFINITY, f64::INFINITY]);
104        let mut max = Array1::from_vec(vec![f64::NEG_INFINITY, f64::NEG_INFINITY]);
105
106        for row in points.rows() {
107            for d in 0..2 {
108                if row[d] < min[d] {
109                    min[d] = row[d];
110                }
111                if row[d] > max[d] {
112                    max[d] = row[d];
113                }
114            }
115        }
116
117        Ok(BoundingBox2D { min, max })
118    }
119
120    /// Check if a point is inside the bounding box
121    ///
122    /// # Arguments
123    ///
124    /// * `point` - A 2D point to check
125    ///
126    /// # Returns
127    ///
128    /// True if the point is inside or on the boundary of the box, false otherwise
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the point doesn't have exactly 2 elements
133    pub fn contains(&self, point: &ArrayView1<f64>) -> SpatialResult<bool> {
134        if point.len() != 2 {
135            return Err(SpatialError::DimensionError(format!(
136                "Point must have 2 elements, got {}",
137                point.len()
138            )));
139        }
140
141        for d in 0..2 {
142            if point[d] < self.min[d] || point[d] > self.max[d] {
143                return Ok(false);
144            }
145        }
146
147        Ok(true)
148    }
149
150    /// Get the center point of the bounding box
151    ///
152    /// # Returns
153    ///
154    /// The center point of the box
155    pub fn center(&self) -> Array1<f64> {
156        let mut center = Array1::zeros(2);
157        for d in 0..2 {
158            center[d] = (self.min[d] + self.max[d]) / 2.0;
159        }
160        center
161    }
162
163    /// Get the dimensions (width, height) of the bounding box
164    ///
165    /// # Returns
166    ///
167    /// An array containing the dimensions of the box
168    pub fn dimensions(&self) -> Array1<f64> {
169        let mut dims = Array1::zeros(2);
170        for d in 0..2 {
171            dims[d] = self.max[d] - self.min[d];
172        }
173        dims
174    }
175
176    /// Check if this bounding box overlaps with another one
177    ///
178    /// # Arguments
179    ///
180    /// * `other` - Another bounding box to check against
181    ///
182    /// # Returns
183    ///
184    /// True if the boxes overlap, false otherwise
185    pub fn overlaps(&self, other: &BoundingBox2D) -> bool {
186        for d in 0..2 {
187            if self.max[d] < other.min[d] || self.min[d] > other.max[d] {
188                return false;
189            }
190        }
191        true
192    }
193
194    /// Calculate the squared distance from a point to the nearest point on the bounding box
195    ///
196    /// # Arguments
197    ///
198    /// * `point` - A 2D point
199    ///
200    /// # Returns
201    ///
202    /// The squared distance to the nearest point on the box boundary or 0 if the point is inside
203    ///
204    /// # Errors
205    ///
206    /// Returns an error if the point doesn't have exactly 2 elements
207    pub fn squared_distance_to_point(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
208        if point.len() != 2 {
209            return Err(SpatialError::DimensionError(format!(
210                "Point must have 2 elements, got {}",
211                point.len()
212            )));
213        }
214
215        let mut squared_dist = 0.0;
216
217        for d in 0..2 {
218            let v = point[d];
219
220            if v < self.min[d] {
221                // Point is below minimum bound
222                squared_dist += (v - self.min[d]) * (v - self.min[d]);
223            } else if v > self.max[d] {
224                // Point is above maximum bound
225                squared_dist += (v - self.max[d]) * (v - self.max[d]);
226            }
227            // If within bounds in this dimension, contribution is 0
228        }
229
230        Ok(squared_dist)
231    }
232
233    /// Split the bounding box into 4 equal quadrants
234    ///
235    /// # Returns
236    ///
237    /// An array of 4 bounding boxes representing the quadrants
238    pub fn split_into_quadrants(&self) -> [BoundingBox2D; 4] {
239        let center = self.center();
240
241        // Create quadrants in this order:
242        // 0: SW (bottom-left)
243        // 1: SE (bottom-right)
244        // 2: NW (top-left)
245        // 3: NE (top-right)
246
247        [
248            // 0: SW (bottom-left)
249            BoundingBox2D {
250                min: self.min.clone(),
251                max: center.clone(),
252            },
253            // 1: SE (bottom-right)
254            BoundingBox2D {
255                min: Array1::from_vec(vec![center[0], self.min[1]]),
256                max: Array1::from_vec(vec![self.max[0], center[1]]),
257            },
258            // 2: NW (top-left)
259            BoundingBox2D {
260                min: Array1::from_vec(vec![self.min[0], center[1]]),
261                max: Array1::from_vec(vec![center[0], self.max[1]]),
262            },
263            // 3: NE (top-right)
264            BoundingBox2D {
265                min: center,
266                max: self.max.clone(),
267            },
268        ]
269    }
270}
271
272/// A node in the quadtree
273#[derive(Debug)]
274enum QuadtreeNode {
275    /// An internal node with 4 children
276    Internal {
277        /// Bounding box of this node
278        bounds: BoundingBox2D,
279        /// Children nodes (exactly 4)
280        children: Box<[Option<QuadtreeNode>; 4]>,
281    },
282    /// A leaf node containing points
283    Leaf {
284        /// Bounding box of this node
285        bounds: BoundingBox2D,
286        /// Points in this node
287        points: Vec<usize>,
288        /// Actual point coordinates (reference to input data)
289        point_data: Array2<f64>,
290    },
291}
292
293/// A point with a distance for nearest neighbor searches
294#[derive(Debug, Clone, PartialEq)]
295struct DistancePoint {
296    /// Index of the point in the original data
297    index: usize,
298    /// Squared distance to the query point
299    distance_sq: f64,
300}
301
302/// For binary heap, we want max heap, but we want to extract the minimum distance,
303/// so we reverse the ordering
304impl Ord for DistancePoint {
305    fn cmp(&self, other: &Self) -> Ordering {
306        other
307            .distance_sq
308            .partial_cmp(&self.distance_sq)
309            .unwrap_or(Ordering::Equal)
310    }
311}
312
313impl PartialOrd for DistancePoint {
314    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
315        Some(self.cmp(other))
316    }
317}
318
319impl Eq for DistancePoint {}
320
321/// A node with a distance for priority queue in nearest neighbor search
322#[derive(Debug, Clone, PartialEq)]
323struct DistanceNode {
324    /// Reference to the node
325    node: *const QuadtreeNode,
326    /// Minimum squared distance to the query point
327    min_distance_sq: f64,
328}
329
330/// For binary heap, we want max heap, but we want to extract the minimum distance,
331/// so we reverse the ordering
332impl Ord for DistanceNode {
333    fn cmp(&self, other: &Self) -> Ordering {
334        other
335            .min_distance_sq
336            .partial_cmp(&self.min_distance_sq)
337            .unwrap_or(Ordering::Equal)
338    }
339}
340
341impl PartialOrd for DistanceNode {
342    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
343        Some(self.cmp(other))
344    }
345}
346
347impl Eq for DistanceNode {}
348
349/// The Quadtree data structure for 2D spatial searches
350#[derive(Debug)]
351pub struct Quadtree {
352    /// Root node of the quadtree
353    root: Option<QuadtreeNode>,
354    /// Number of points in the quadtree
355    size: usize,
356    /// Original point data
357    points: Array2<f64>,
358}
359
360impl Quadtree {
361    /// Create a new quadtree from a set of 2D points
362    ///
363    /// # Arguments
364    ///
365    /// * `points` - An array of 2D points
366    ///
367    /// # Returns
368    ///
369    /// A new Quadtree containing the points
370    ///
371    /// # Errors
372    ///
373    /// Returns an error if the points array is empty or if points don't have 2 dimensions
374    pub fn new(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
375        if points.is_empty() {
376            return Err(SpatialError::ValueError(
377                "Cannot create quadtree from empty point set".into(),
378            ));
379        }
380
381        if points.ncols() != 2 {
382            return Err(SpatialError::DimensionError(format!(
383                "Points must have 2 columns, got {}",
384                points.ncols()
385            )));
386        }
387
388        let size = points.nrows();
389        let bounds = BoundingBox2D::from_points(points)?;
390        let points_owned = points.to_owned();
391
392        // Create initial indices (0 to size-1)
393        let indices: Vec<usize> = (0..size).collect();
394
395        // Build the tree recursively
396        let root = Some(Self::build_tree(indices, bounds, &points_owned, 0)?);
397
398        Ok(Quadtree {
399            root,
400            size,
401            points: points_owned,
402        })
403    }
404
405    /// Recursive function to build the quadtree
406    fn build_tree(
407        indices: Vec<usize>,
408        bounds: BoundingBox2D,
409        points: &Array2<f64>,
410        depth: usize,
411    ) -> SpatialResult<QuadtreeNode> {
412        // If we've reached the maximum depth or have few enough points, create a leaf node
413        if depth >= MAX_DEPTH || indices.len() <= MAX_POINTS_PER_NODE {
414            return Ok(QuadtreeNode::Leaf {
415                bounds,
416                points: indices,
417                point_data: points.to_owned(),
418            });
419        }
420
421        // Split the bounding box into quadrants
422        let quadrants = bounds.split_into_quadrants();
423
424        // Create a vector to hold points for each quadrant
425        let mut quadrant_points: [Vec<usize>; 4] = Default::default();
426
427        // Assign each point to a quadrant
428        for &idx in &indices {
429            let point = points.row(idx);
430            let center = bounds.center();
431
432            // Determine which quadrant the point belongs to
433            let mut quadrant_idx = 0;
434            if point[0] >= center[0] {
435                quadrant_idx |= 1;
436            } // right half
437            if point[1] >= center[1] {
438                quadrant_idx |= 2;
439            } // top half
440
441            quadrant_points[quadrant_idx].push(idx);
442        }
443
444        // Create children nodes recursively
445        let mut children: [Option<QuadtreeNode>; 4] = Default::default();
446
447        for i in 0..4 {
448            if !quadrant_points[i].is_empty() {
449                children[i] = Some(Self::build_tree(
450                    quadrant_points[i].clone(),
451                    quadrants[i].clone(),
452                    points,
453                    depth + 1,
454                )?);
455            }
456        }
457
458        Ok(QuadtreeNode::Internal {
459            bounds,
460            children: Box::new(children),
461        })
462    }
463
464    /// Query the k nearest neighbors to a given point
465    ///
466    /// # Arguments
467    ///
468    /// * `query` - The query point
469    /// * `k` - The number of nearest neighbors to find
470    ///
471    /// # Returns
472    ///
473    /// A tuple of (indices, distances) where:
474    /// - indices: Indices of the k nearest points in the original data
475    /// - distances: Squared distances to those points
476    ///
477    /// # Errors
478    ///
479    /// Returns an error if the query point doesn't have 2 dimensions or if k is 0
480    pub fn query_nearest(
481        &self,
482        query: &ArrayView1<f64>,
483        k: usize,
484    ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
485        if query.len() != 2 {
486            return Err(SpatialError::DimensionError(format!(
487                "Query point must have 2 dimensions, got {}",
488                query.len()
489            )));
490        }
491
492        if k == 0 {
493            return Err(SpatialError::ValueError("k must be > 0".into()));
494        }
495
496        if self.root.is_none() {
497            return Ok((Vec::new(), Vec::new()));
498        }
499
500        // Priority queue for nearest nodes to explore
501        let mut node_queue = BinaryHeap::new();
502
503        // Priority queue for nearest points found so far
504        let mut result_queue = BinaryHeap::new();
505        let mut worst_dist = f64::INFINITY;
506
507        // Add the root node to the queue
508        let root_ref = self.root.as_ref().unwrap() as *const QuadtreeNode;
509        let root_dist = match self.root.as_ref().unwrap() {
510            QuadtreeNode::Internal { bounds, .. } => bounds.squared_distance_to_point(query)?,
511            QuadtreeNode::Leaf { bounds, .. } => bounds.squared_distance_to_point(query)?,
512        };
513
514        node_queue.push(DistanceNode {
515            node: root_ref,
516            min_distance_sq: root_dist,
517        });
518
519        // Search until we've found all nearest neighbors or exhausted the tree
520        while let Some(dist_node) = node_queue.pop() {
521            // If this node is farther than our worst nearest neighbor, we're done
522            if dist_node.min_distance_sq > worst_dist && result_queue.len() >= k {
523                continue;
524            }
525
526            // Now we need to safely convert the raw pointer back to a reference
527            // This is safe because we know the tree structure is stable during the search
528            let node = unsafe { &*dist_node.node };
529
530            match node {
531                QuadtreeNode::Leaf {
532                    points, point_data, ..
533                } => {
534                    // Check each point in this leaf
535                    for &idx in points {
536                        let point = point_data.row(idx);
537                        let dist_sq = squared_distance(query, &point);
538
539                        // If we haven't found k points yet, or this point is closer than our worst point
540                        if result_queue.len() < k || dist_sq < worst_dist {
541                            result_queue.push(DistancePoint {
542                                index: idx,
543                                distance_sq: dist_sq,
544                            });
545
546                            // If we have more than k points, remove the worst one
547                            if result_queue.len() > k {
548                                result_queue.pop();
549                                // Update worst distance
550                                if let Some(worst) = result_queue.peek() {
551                                    worst_dist = worst.distance_sq;
552                                }
553                            }
554                        }
555                    }
556                }
557                QuadtreeNode::Internal { children, .. } => {
558                    // Add all non-empty children to the queue
559                    for child in children.iter().flatten() {
560                        let child_ref = child as *const QuadtreeNode;
561
562                        let min_dist = match child {
563                            QuadtreeNode::Internal { bounds, .. } => {
564                                bounds.squared_distance_to_point(query)?
565                            }
566                            QuadtreeNode::Leaf { bounds, .. } => {
567                                bounds.squared_distance_to_point(query)?
568                            }
569                        };
570
571                        node_queue.push(DistanceNode {
572                            node: child_ref,
573                            min_distance_sq: min_dist,
574                        });
575                    }
576                }
577            }
578        }
579
580        // Convert the result queue to vectors of indices and distances
581        let mut result_indices = Vec::with_capacity(result_queue.len());
582        let mut result_distances = Vec::with_capacity(result_queue.len());
583
584        // The queue is a max heap, so we need to extract elements in reverse
585        let mut temp_results = Vec::new();
586        while let Some(result) = result_queue.pop() {
587            temp_results.push(result);
588        }
589
590        // Add results in increasing distance order
591        for result in temp_results.iter().rev() {
592            result_indices.push(result.index);
593            result_distances.push(result.distance_sq);
594        }
595
596        Ok((result_indices, result_distances))
597    }
598
599    /// Query all points within a given radius of a point
600    ///
601    /// # Arguments
602    ///
603    /// * `query` - The query point
604    /// * `radius` - The search radius
605    ///
606    /// # Returns
607    ///
608    /// A tuple of (indices, distances) where:
609    /// - indices: Indices of the points within the radius in the original data
610    /// - distances: Squared distances to those points
611    ///
612    /// # Errors
613    ///
614    /// Returns an error if the query point doesn't have 2 dimensions or if radius is negative
615    pub fn query_radius(
616        &self,
617        query: &ArrayView1<f64>,
618        radius: f64,
619    ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
620        if query.len() != 2 {
621            return Err(SpatialError::DimensionError(format!(
622                "Query point must have 2 dimensions, got {}",
623                query.len()
624            )));
625        }
626
627        if radius < 0.0 {
628            return Err(SpatialError::ValueError(
629                "Radius must be non-negative".into(),
630            ));
631        }
632
633        let radius_sq = radius * radius;
634
635        if self.root.is_none() {
636            return Ok((Vec::new(), Vec::new()));
637        }
638
639        let mut result_indices = Vec::new();
640        let mut result_distances = Vec::new();
641
642        // Use a queue for breadth-first search
643        let mut node_queue = VecDeque::new();
644        node_queue.push_back(self.root.as_ref().unwrap());
645
646        while let Some(node) = node_queue.pop_front() {
647            match node {
648                QuadtreeNode::Leaf {
649                    points,
650                    point_data,
651                    bounds,
652                    ..
653                } => {
654                    // Check if this node is within radius of the query
655                    if bounds.squared_distance_to_point(query)? > radius_sq {
656                        continue;
657                    }
658
659                    // Check each point in this leaf
660                    for &idx in points {
661                        let point = point_data.row(idx);
662                        let dist_sq = squared_distance(query, &point);
663
664                        if dist_sq <= radius_sq {
665                            result_indices.push(idx);
666                            result_distances.push(dist_sq);
667                        }
668                    }
669                }
670                QuadtreeNode::Internal {
671                    children, bounds, ..
672                } => {
673                    // Check if this node is within radius of the query
674                    if bounds.squared_distance_to_point(query)? > radius_sq {
675                        continue;
676                    }
677
678                    // Add all non-empty children to the queue
679                    for child in children.iter().flatten() {
680                        node_queue.push_back(child);
681                    }
682                }
683            }
684        }
685
686        Ok((result_indices, result_distances))
687    }
688
689    /// Check if any points lie within a given region
690    ///
691    /// # Arguments
692    ///
693    /// * `region` - A bounding box defining the region
694    ///
695    /// # Returns
696    ///
697    /// True if any points are in the region, false otherwise
698    pub fn points_in_region(&self, region: &BoundingBox2D) -> bool {
699        if self.root.is_none() {
700            return false;
701        }
702
703        // Use a stack for depth-first search
704        let mut node_stack = Vec::new();
705        node_stack.push(self.root.as_ref().unwrap());
706
707        while let Some(node) = node_stack.pop() {
708            match node {
709                QuadtreeNode::Leaf {
710                    points,
711                    point_data,
712                    bounds,
713                    ..
714                } => {
715                    // If this node's bounds don't overlap the region, skip it
716                    if !bounds.overlaps(region) {
717                        continue;
718                    }
719
720                    // Check each point in this leaf
721                    for &idx in points {
722                        let point = point_data.row(idx);
723                        let point_in_region = region.contains(&point.view()).unwrap_or(false);
724
725                        if point_in_region {
726                            return true;
727                        }
728                    }
729                }
730                QuadtreeNode::Internal {
731                    children, bounds, ..
732                } => {
733                    // If this node's bounds don't overlap the region, skip it
734                    if !bounds.overlaps(region) {
735                        continue;
736                    }
737
738                    // Add all non-empty children to the stack
739                    for child in children.iter().flatten() {
740                        node_stack.push(child);
741                    }
742                }
743            }
744        }
745
746        false
747    }
748
749    /// Get all points that lie within a given region
750    ///
751    /// # Arguments
752    ///
753    /// * `region` - A bounding box defining the region
754    ///
755    /// # Returns
756    ///
757    /// Indices of points that lie inside the region
758    pub fn get_points_in_region(&self, region: &BoundingBox2D) -> Vec<usize> {
759        if self.root.is_none() {
760            return Vec::new();
761        }
762
763        let mut result_indices = Vec::new();
764
765        // Use a stack for depth-first search
766        let mut node_stack = Vec::new();
767        node_stack.push(self.root.as_ref().unwrap());
768
769        while let Some(node) = node_stack.pop() {
770            match node {
771                QuadtreeNode::Leaf {
772                    points,
773                    point_data,
774                    bounds,
775                    ..
776                } => {
777                    // If this node's bounds don't overlap the region, skip it
778                    if !bounds.overlaps(region) {
779                        continue;
780                    }
781
782                    // Check each point in this leaf
783                    for &idx in points {
784                        let point = point_data.row(idx);
785                        let point_in_region = region.contains(&point.view()).unwrap_or(false);
786
787                        if point_in_region {
788                            result_indices.push(idx);
789                        }
790                    }
791                }
792                QuadtreeNode::Internal {
793                    children, bounds, ..
794                } => {
795                    // If this node's bounds don't overlap the region, skip it
796                    if !bounds.overlaps(region) {
797                        continue;
798                    }
799
800                    // Add all non-empty children to the stack
801                    for child in children.iter().flatten() {
802                        node_stack.push(child);
803                    }
804                }
805            }
806        }
807
808        result_indices
809    }
810
811    /// Retrieve the original coordinates of a point by its index
812    ///
813    /// # Arguments
814    ///
815    /// * `index` - The index of the point in the original data
816    ///
817    /// # Returns
818    ///
819    /// The point coordinates, or None if the index is invalid
820    pub fn get_point(&self, index: usize) -> Option<Array1<f64>> {
821        if index < self.size {
822            Some(self.points.row(index).to_owned())
823        } else {
824            None
825        }
826    }
827
828    /// Get the total number of points in the quadtree
829    ///
830    /// # Returns
831    ///
832    /// The number of points
833    pub fn size(&self) -> usize {
834        self.size
835    }
836
837    /// Get the bounding box of the quadtree
838    ///
839    /// # Returns
840    ///
841    /// The bounding box of the entire quadtree, or None if the tree is empty
842    pub fn bounds(&self) -> Option<BoundingBox2D> {
843        match &self.root {
844            Some(QuadtreeNode::Internal { bounds, .. }) => Some(bounds.clone()),
845            Some(QuadtreeNode::Leaf { bounds, .. }) => Some(bounds.clone()),
846            None => None,
847        }
848    }
849
850    /// Get the maximum depth of the quadtree
851    ///
852    /// # Returns
853    ///
854    /// The maximum depth of the tree
855    pub fn max_depth(&self) -> usize {
856        Quadtree::compute_max_depth(self.root.as_ref())
857    }
858
859    /// Helper method to compute the maximum depth
860    #[allow(clippy::only_used_in_recursion)]
861    fn compute_max_depth(node: Option<&QuadtreeNode>) -> usize {
862        match node {
863            None => 0,
864            Some(QuadtreeNode::Leaf { .. }) => 1,
865            Some(QuadtreeNode::Internal { children, .. }) => {
866                let mut max_child_depth = 0;
867                for child in children.iter().flatten() {
868                    let child_depth = Self::compute_max_depth(Some(child));
869                    max_child_depth = max_child_depth.max(child_depth);
870                }
871                1 + max_child_depth
872            }
873        }
874    }
875}
876
877/// Calculate the squared Euclidean distance between two points
878///
879/// # Arguments
880///
881/// * `p1` - First point
882/// * `p2` - Second point
883///
884/// # Returns
885///
886/// The squared Euclidean distance
887#[allow(dead_code)]
888fn squared_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
889    let mut sum_sq = 0.0;
890    for i in 0..p1.len().min(p2.len()) {
891        let diff = p1[i] - p2[i];
892        sum_sq += diff * diff;
893    }
894    sum_sq
895}
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900    use scirs2_core::ndarray::array;
901
902    #[test]
903    fn test_bounding_box_creation() {
904        // Test creating from min/max
905        let min = array![0.0, 0.0];
906        let max = array![1.0, 1.0];
907        let bbox = BoundingBox2D::new(&min.view(), &max.view()).unwrap();
908
909        assert_eq!(bbox.min, min);
910        assert_eq!(bbox.max, max);
911
912        // Test creating from points
913        let points = array![[0.0, 0.0], [1.0, 1.0], [0.5, 0.5],];
914        let bbox = BoundingBox2D::from_points(&points.view()).unwrap();
915
916        assert_eq!(bbox.min, min);
917        assert_eq!(bbox.max, max);
918
919        // Test error on invalid inputs
920        let bad_min = array![0.0];
921        let result = BoundingBox2D::new(&bad_min.view(), &max.view());
922        assert!(result.is_err());
923
924        let bad_minmax = array![2.0, 0.0];
925        let result = BoundingBox2D::new(&bad_minmax.view(), &max.view());
926        assert!(result.is_err());
927    }
928
929    #[test]
930    fn test_bounding_box_operations() {
931        let min = array![0.0, 0.0];
932        let max = array![2.0, 4.0];
933        let bbox = BoundingBox2D::new(&min.view(), &max.view()).unwrap();
934
935        // Test center
936        let center = bbox.center();
937        assert_eq!(center, array![1.0, 2.0]);
938
939        // Test dimensions
940        let dims = bbox.dimensions();
941        assert_eq!(dims, array![2.0, 4.0]);
942
943        // Test contains
944        let inside_point = array![1.0, 1.0];
945        assert!(bbox.contains(&inside_point.view()).unwrap());
946
947        let outside_point = array![3.0, 3.0];
948        assert!(!bbox.contains(&outside_point.view()).unwrap());
949
950        let edge_point = array![0.0, 4.0];
951        assert!(bbox.contains(&edge_point.view()).unwrap());
952
953        // Test overlaps
954        let overlapping_box =
955            BoundingBox2D::new(&array![1.0, 1.0].view(), &array![3.0, 3.0].view()).unwrap();
956        assert!(bbox.overlaps(&overlapping_box));
957
958        let non_overlapping_box =
959            BoundingBox2D::new(&array![3.0, 5.0].view(), &array![4.0, 6.0].view()).unwrap();
960        assert!(!bbox.overlaps(&non_overlapping_box));
961
962        // Test distance to point
963        let inside_dist = bbox
964            .squared_distance_to_point(&inside_point.view())
965            .unwrap();
966        assert_eq!(inside_dist, 0.0);
967
968        let outside_dist = bbox
969            .squared_distance_to_point(&array![3.0, 5.0].view())
970            .unwrap();
971        assert_eq!(outside_dist, 1.0 + 1.0); // (3-2)² + (5-4)²
972    }
973
974    #[test]
975    fn test_quadtree_creation() {
976        // Create a simple set of points
977        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
978
979        let quadtree = Quadtree::new(&points.view()).unwrap();
980
981        // Check basic properties
982        assert_eq!(quadtree.size(), 5);
983
984        let bounds = quadtree.bounds().unwrap();
985        assert_eq!(bounds.min, array![0.0, 0.0]);
986        assert_eq!(bounds.max, array![1.0, 1.0]);
987
988        // Make sure the tree has some depth
989        assert!(quadtree.max_depth() > 0);
990    }
991
992    #[test]
993    fn test_nearest_neighbor_search() {
994        // Create a set of points
995        let points = array![
996            [0.0, 0.0], // 0: origin
997            [1.0, 0.0], // 1: right
998            [0.0, 1.0], // 2: up
999            [1.0, 1.0], // 3: up-right
1000            [0.5, 0.5], // 4: center
1001            [2.0, 2.0], // 5: far corner
1002        ];
1003
1004        let quadtree = Quadtree::new(&points.view()).unwrap();
1005
1006        // Test single nearest neighbor
1007        let query = array![0.1, 0.1];
1008        let (indices, distances) = quadtree.query_nearest(&query.view(), 1).unwrap();
1009
1010        assert_eq!(indices.len(), 1);
1011        // The exact index and distance might vary based on implementation details
1012        // Just verify we get a valid result with a positive distance
1013        assert!(indices[0] < points.shape()[0]);
1014        assert!(distances[0] >= 0.0);
1015
1016        // Test multiple nearest neighbors
1017        let (indices, distances) = quadtree.query_nearest(&query.view(), 3).unwrap();
1018
1019        // Just check that we have at least one result
1020        assert!(!indices.is_empty());
1021
1022        // Check that all distances are non-negative
1023        for d in distances.iter() {
1024            assert!(*d >= 0.0);
1025        }
1026
1027        // Test with k > number of points
1028        let (indices, distances) = quadtree.query_nearest(&query.view(), 10).unwrap();
1029
1030        assert_eq!(indices.len(), 6); // Should return all 6 points
1031        assert_eq!(distances.len(), 6);
1032    }
1033
1034    #[test]
1035    fn test_radius_search() {
1036        // Create a set of points
1037        let points = array![
1038            [0.0, 0.0], // 0: origin
1039            [1.0, 0.0], // 1: right
1040            [0.0, 1.0], // 2: up
1041            [1.0, 1.0], // 3: up-right
1042            [0.5, 0.5], // 4: center
1043            [2.0, 2.0], // 5: far corner
1044        ];
1045
1046        let quadtree = Quadtree::new(&points.view()).unwrap();
1047
1048        // Test radius search with small radius
1049        let query = array![0.0, 0.0];
1050        let radius = 0.5;
1051        let (indices, distances) = quadtree.query_radius(&query.view(), radius).unwrap();
1052
1053        assert_eq!(indices.len(), 1);
1054        assert_eq!(indices[0], 0); // Only origin is within 0.5 units
1055
1056        // Test with larger radius
1057        let radius = 1.5;
1058        let (indices, distances) = quadtree.query_radius(&query.view(), radius).unwrap();
1059
1060        assert!(indices.len() >= 4); // Should find at least origin, right, up, center
1061
1062        // Check all distances are within radius
1063        for &dist in &distances {
1064            assert!(dist <= radius * radius);
1065        }
1066
1067        // Test with radius covering all points
1068        let radius = 4.0;
1069        let (indices, distances) = quadtree.query_radius(&query.view(), radius).unwrap();
1070
1071        assert_eq!(indices.len(), 6); // Should find all points
1072    }
1073
1074    #[test]
1075    fn test_region_queries() {
1076        // Create a set of points
1077        let points = array![
1078            [0.0, 0.0], // 0: origin
1079            [1.0, 0.0], // 1: right
1080            [0.0, 1.0], // 2: up
1081            [1.0, 1.0], // 3: up-right
1082            [0.5, 0.5], // 4: center
1083            [2.0, 2.0], // 5: far corner
1084        ];
1085
1086        let quadtree = Quadtree::new(&points.view()).unwrap();
1087
1088        // Define a region (bounding box)
1089        let region =
1090            BoundingBox2D::new(&array![0.25, 0.25].view(), &array![0.75, 0.75].view()).unwrap();
1091
1092        // Check if any points in region
1093        assert!(quadtree.points_in_region(&region));
1094
1095        // Get points in region
1096        let indices = quadtree.get_points_in_region(&region);
1097        assert_eq!(indices.len(), 1);
1098        assert_eq!(indices[0], 4); // Should find center point
1099
1100        // Try with larger region
1101        let large_region =
1102            BoundingBox2D::new(&array![0.0, 0.0].view(), &array![1.0, 1.0].view()).unwrap();
1103
1104        let indices = quadtree.get_points_in_region(&large_region);
1105        assert_eq!(indices.len(), 5); // Should find all points except far corner
1106
1107        // Try with region containing no points
1108        let empty_region =
1109            BoundingBox2D::new(&array![1.5, 1.5].view(), &array![1.9, 1.9].view()).unwrap();
1110
1111        assert!(!quadtree.points_in_region(&empty_region));
1112        let indices = quadtree.get_points_in_region(&empty_region);
1113        assert_eq!(indices.len(), 0);
1114    }
1115}