1use crate::error::{SpatialError, SpatialResult};
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, VecDeque};
19
20const MAX_POINTS_PER_NODE: usize = 8;
22const MAX_DEPTH: usize = 20;
24
25#[derive(Debug, Clone)]
27pub struct BoundingBox2D {
28 pub min: Array1<f64>,
30 pub max: Array1<f64>,
32}
33
34impl BoundingBox2D {
35 pub fn new(min: &ArrayView1<f64>, max: &ArrayView1<f64>) -> SpatialResult<Self> {
51 if min.len() != 2 || max.len() != 2 {
52 return Err(SpatialError::DimensionError(format!(
53 "Min and max must have 2 elements, got {} and {}",
54 min.len(),
55 max.len()
56 )));
57 }
58
59 for i in 0..2 {
61 if min[i] > max[i] {
62 return Err(SpatialError::ValueError(format!(
63 "Min must be <= max for all dimensions, got min[{}]={} > max[{}]={}",
64 i, min[i], i, max[i]
65 )));
66 }
67 }
68
69 Ok(BoundingBox2D {
70 min: min.to_owned(),
71 max: max.to_owned(),
72 })
73 }
74
75 pub fn from_points(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
89 if points.is_empty() {
90 return Err(SpatialError::ValueError(
91 "Cannot create bounding box from empty point set".into(),
92 ));
93 }
94
95 if points.ncols() != 2 {
96 return Err(SpatialError::DimensionError(format!(
97 "Points must have 2 columns, got {}",
98 points.ncols()
99 )));
100 }
101
102 let mut min = Array1::from_vec(vec![f64::INFINITY, f64::INFINITY]);
104 let mut max = Array1::from_vec(vec![f64::NEG_INFINITY, f64::NEG_INFINITY]);
105
106 for row in points.rows() {
107 for d in 0..2 {
108 if row[d] < min[d] {
109 min[d] = row[d];
110 }
111 if row[d] > max[d] {
112 max[d] = row[d];
113 }
114 }
115 }
116
117 Ok(BoundingBox2D { min, max })
118 }
119
120 pub fn contains(&self, point: &ArrayView1<f64>) -> SpatialResult<bool> {
134 if point.len() != 2 {
135 return Err(SpatialError::DimensionError(format!(
136 "Point must have 2 elements, got {}",
137 point.len()
138 )));
139 }
140
141 for d in 0..2 {
142 if point[d] < self.min[d] || point[d] > self.max[d] {
143 return Ok(false);
144 }
145 }
146
147 Ok(true)
148 }
149
150 pub fn center(&self) -> Array1<f64> {
156 let mut center = Array1::zeros(2);
157 for d in 0..2 {
158 center[d] = (self.min[d] + self.max[d]) / 2.0;
159 }
160 center
161 }
162
163 pub fn dimensions(&self) -> Array1<f64> {
169 let mut dims = Array1::zeros(2);
170 for d in 0..2 {
171 dims[d] = self.max[d] - self.min[d];
172 }
173 dims
174 }
175
176 pub fn overlaps(&self, other: &BoundingBox2D) -> bool {
186 for d in 0..2 {
187 if self.max[d] < other.min[d] || self.min[d] > other.max[d] {
188 return false;
189 }
190 }
191 true
192 }
193
194 pub fn squared_distance_to_point(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
208 if point.len() != 2 {
209 return Err(SpatialError::DimensionError(format!(
210 "Point must have 2 elements, got {}",
211 point.len()
212 )));
213 }
214
215 let mut squared_dist = 0.0;
216
217 for d in 0..2 {
218 let v = point[d];
219
220 if v < self.min[d] {
221 squared_dist += (v - self.min[d]) * (v - self.min[d]);
223 } else if v > self.max[d] {
224 squared_dist += (v - self.max[d]) * (v - self.max[d]);
226 }
227 }
229
230 Ok(squared_dist)
231 }
232
233 pub fn split_into_quadrants(&self) -> [BoundingBox2D; 4] {
239 let center = self.center();
240
241 [
248 BoundingBox2D {
250 min: self.min.clone(),
251 max: center.clone(),
252 },
253 BoundingBox2D {
255 min: Array1::from_vec(vec![center[0], self.min[1]]),
256 max: Array1::from_vec(vec![self.max[0], center[1]]),
257 },
258 BoundingBox2D {
260 min: Array1::from_vec(vec![self.min[0], center[1]]),
261 max: Array1::from_vec(vec![center[0], self.max[1]]),
262 },
263 BoundingBox2D {
265 min: center,
266 max: self.max.clone(),
267 },
268 ]
269 }
270}
271
272#[derive(Debug)]
274enum QuadtreeNode {
275 Internal {
277 bounds: BoundingBox2D,
279 children: Box<[Option<QuadtreeNode>; 4]>,
281 },
282 Leaf {
284 bounds: BoundingBox2D,
286 points: Vec<usize>,
288 point_data: Array2<f64>,
290 },
291}
292
293#[derive(Debug, Clone, PartialEq)]
295struct DistancePoint {
296 index: usize,
298 distance_sq: f64,
300}
301
302impl Ord for DistancePoint {
305 fn cmp(&self, other: &Self) -> Ordering {
306 other
307 .distance_sq
308 .partial_cmp(&self.distance_sq)
309 .unwrap_or(Ordering::Equal)
310 }
311}
312
313impl PartialOrd for DistancePoint {
314 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
315 Some(self.cmp(other))
316 }
317}
318
319impl Eq for DistancePoint {}
320
321#[derive(Debug, Clone, PartialEq)]
323struct DistanceNode {
324 node: *const QuadtreeNode,
326 min_distance_sq: f64,
328}
329
330impl Ord for DistanceNode {
333 fn cmp(&self, other: &Self) -> Ordering {
334 other
335 .min_distance_sq
336 .partial_cmp(&self.min_distance_sq)
337 .unwrap_or(Ordering::Equal)
338 }
339}
340
341impl PartialOrd for DistanceNode {
342 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
343 Some(self.cmp(other))
344 }
345}
346
347impl Eq for DistanceNode {}
348
349#[derive(Debug)]
351pub struct Quadtree {
352 root: Option<QuadtreeNode>,
354 size: usize,
356 points: Array2<f64>,
358}
359
360impl Quadtree {
361 pub fn new(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
375 if points.is_empty() {
376 return Err(SpatialError::ValueError(
377 "Cannot create quadtree from empty point set".into(),
378 ));
379 }
380
381 if points.ncols() != 2 {
382 return Err(SpatialError::DimensionError(format!(
383 "Points must have 2 columns, got {}",
384 points.ncols()
385 )));
386 }
387
388 let size = points.nrows();
389 let bounds = BoundingBox2D::from_points(points)?;
390 let points_owned = points.to_owned();
391
392 let indices: Vec<usize> = (0..size).collect();
394
395 let root = Some(Self::build_tree(indices, bounds, &points_owned, 0)?);
397
398 Ok(Quadtree {
399 root,
400 size,
401 points: points_owned,
402 })
403 }
404
405 fn build_tree(
407 indices: Vec<usize>,
408 bounds: BoundingBox2D,
409 points: &Array2<f64>,
410 depth: usize,
411 ) -> SpatialResult<QuadtreeNode> {
412 if depth >= MAX_DEPTH || indices.len() <= MAX_POINTS_PER_NODE {
414 return Ok(QuadtreeNode::Leaf {
415 bounds,
416 points: indices,
417 point_data: points.to_owned(),
418 });
419 }
420
421 let quadrants = bounds.split_into_quadrants();
423
424 let mut quadrant_points: [Vec<usize>; 4] = Default::default();
426
427 for &idx in &indices {
429 let point = points.row(idx);
430 let center = bounds.center();
431
432 let mut quadrant_idx = 0;
434 if point[0] >= center[0] {
435 quadrant_idx |= 1;
436 } if point[1] >= center[1] {
438 quadrant_idx |= 2;
439 } quadrant_points[quadrant_idx].push(idx);
442 }
443
444 let mut children: [Option<QuadtreeNode>; 4] = Default::default();
446
447 for i in 0..4 {
448 if !quadrant_points[i].is_empty() {
449 children[i] = Some(Self::build_tree(
450 quadrant_points[i].clone(),
451 quadrants[i].clone(),
452 points,
453 depth + 1,
454 )?);
455 }
456 }
457
458 Ok(QuadtreeNode::Internal {
459 bounds,
460 children: Box::new(children),
461 })
462 }
463
464 pub fn query_nearest(
481 &self,
482 query: &ArrayView1<f64>,
483 k: usize,
484 ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
485 if query.len() != 2 {
486 return Err(SpatialError::DimensionError(format!(
487 "Query point must have 2 dimensions, got {}",
488 query.len()
489 )));
490 }
491
492 if k == 0 {
493 return Err(SpatialError::ValueError("k must be > 0".into()));
494 }
495
496 if self.root.is_none() {
497 return Ok((Vec::new(), Vec::new()));
498 }
499
500 let mut node_queue = BinaryHeap::new();
502
503 let mut result_queue = BinaryHeap::new();
505 let mut worst_dist = f64::INFINITY;
506
507 let root_ref = self.root.as_ref().unwrap() as *const QuadtreeNode;
509 let root_dist = match self.root.as_ref().unwrap() {
510 QuadtreeNode::Internal { bounds, .. } => bounds.squared_distance_to_point(query)?,
511 QuadtreeNode::Leaf { bounds, .. } => bounds.squared_distance_to_point(query)?,
512 };
513
514 node_queue.push(DistanceNode {
515 node: root_ref,
516 min_distance_sq: root_dist,
517 });
518
519 while let Some(dist_node) = node_queue.pop() {
521 if dist_node.min_distance_sq > worst_dist && result_queue.len() >= k {
523 continue;
524 }
525
526 let node = unsafe { &*dist_node.node };
529
530 match node {
531 QuadtreeNode::Leaf {
532 points, point_data, ..
533 } => {
534 for &idx in points {
536 let point = point_data.row(idx);
537 let dist_sq = squared_distance(query, &point);
538
539 if result_queue.len() < k || dist_sq < worst_dist {
541 result_queue.push(DistancePoint {
542 index: idx,
543 distance_sq: dist_sq,
544 });
545
546 if result_queue.len() > k {
548 result_queue.pop();
549 if let Some(worst) = result_queue.peek() {
551 worst_dist = worst.distance_sq;
552 }
553 }
554 }
555 }
556 }
557 QuadtreeNode::Internal { children, .. } => {
558 for child in children.iter().flatten() {
560 let child_ref = child as *const QuadtreeNode;
561
562 let min_dist = match child {
563 QuadtreeNode::Internal { bounds, .. } => {
564 bounds.squared_distance_to_point(query)?
565 }
566 QuadtreeNode::Leaf { bounds, .. } => {
567 bounds.squared_distance_to_point(query)?
568 }
569 };
570
571 node_queue.push(DistanceNode {
572 node: child_ref,
573 min_distance_sq: min_dist,
574 });
575 }
576 }
577 }
578 }
579
580 let mut result_indices = Vec::with_capacity(result_queue.len());
582 let mut result_distances = Vec::with_capacity(result_queue.len());
583
584 let mut temp_results = Vec::new();
586 while let Some(result) = result_queue.pop() {
587 temp_results.push(result);
588 }
589
590 for result in temp_results.iter().rev() {
592 result_indices.push(result.index);
593 result_distances.push(result.distance_sq);
594 }
595
596 Ok((result_indices, result_distances))
597 }
598
599 pub fn query_radius(
616 &self,
617 query: &ArrayView1<f64>,
618 radius: f64,
619 ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
620 if query.len() != 2 {
621 return Err(SpatialError::DimensionError(format!(
622 "Query point must have 2 dimensions, got {}",
623 query.len()
624 )));
625 }
626
627 if radius < 0.0 {
628 return Err(SpatialError::ValueError(
629 "Radius must be non-negative".into(),
630 ));
631 }
632
633 let radius_sq = radius * radius;
634
635 if self.root.is_none() {
636 return Ok((Vec::new(), Vec::new()));
637 }
638
639 let mut result_indices = Vec::new();
640 let mut result_distances = Vec::new();
641
642 let mut node_queue = VecDeque::new();
644 node_queue.push_back(self.root.as_ref().unwrap());
645
646 while let Some(node) = node_queue.pop_front() {
647 match node {
648 QuadtreeNode::Leaf {
649 points,
650 point_data,
651 bounds,
652 ..
653 } => {
654 if bounds.squared_distance_to_point(query)? > radius_sq {
656 continue;
657 }
658
659 for &idx in points {
661 let point = point_data.row(idx);
662 let dist_sq = squared_distance(query, &point);
663
664 if dist_sq <= radius_sq {
665 result_indices.push(idx);
666 result_distances.push(dist_sq);
667 }
668 }
669 }
670 QuadtreeNode::Internal {
671 children, bounds, ..
672 } => {
673 if bounds.squared_distance_to_point(query)? > radius_sq {
675 continue;
676 }
677
678 for child in children.iter().flatten() {
680 node_queue.push_back(child);
681 }
682 }
683 }
684 }
685
686 Ok((result_indices, result_distances))
687 }
688
689 pub fn points_in_region(&self, region: &BoundingBox2D) -> bool {
699 if self.root.is_none() {
700 return false;
701 }
702
703 let mut node_stack = Vec::new();
705 node_stack.push(self.root.as_ref().unwrap());
706
707 while let Some(node) = node_stack.pop() {
708 match node {
709 QuadtreeNode::Leaf {
710 points,
711 point_data,
712 bounds,
713 ..
714 } => {
715 if !bounds.overlaps(region) {
717 continue;
718 }
719
720 for &idx in points {
722 let point = point_data.row(idx);
723 let point_in_region = region.contains(&point.view()).unwrap_or(false);
724
725 if point_in_region {
726 return true;
727 }
728 }
729 }
730 QuadtreeNode::Internal {
731 children, bounds, ..
732 } => {
733 if !bounds.overlaps(region) {
735 continue;
736 }
737
738 for child in children.iter().flatten() {
740 node_stack.push(child);
741 }
742 }
743 }
744 }
745
746 false
747 }
748
749 pub fn get_points_in_region(&self, region: &BoundingBox2D) -> Vec<usize> {
759 if self.root.is_none() {
760 return Vec::new();
761 }
762
763 let mut result_indices = Vec::new();
764
765 let mut node_stack = Vec::new();
767 node_stack.push(self.root.as_ref().unwrap());
768
769 while let Some(node) = node_stack.pop() {
770 match node {
771 QuadtreeNode::Leaf {
772 points,
773 point_data,
774 bounds,
775 ..
776 } => {
777 if !bounds.overlaps(region) {
779 continue;
780 }
781
782 for &idx in points {
784 let point = point_data.row(idx);
785 let point_in_region = region.contains(&point.view()).unwrap_or(false);
786
787 if point_in_region {
788 result_indices.push(idx);
789 }
790 }
791 }
792 QuadtreeNode::Internal {
793 children, bounds, ..
794 } => {
795 if !bounds.overlaps(region) {
797 continue;
798 }
799
800 for child in children.iter().flatten() {
802 node_stack.push(child);
803 }
804 }
805 }
806 }
807
808 result_indices
809 }
810
811 pub fn get_point(&self, index: usize) -> Option<Array1<f64>> {
821 if index < self.size {
822 Some(self.points.row(index).to_owned())
823 } else {
824 None
825 }
826 }
827
828 pub fn size(&self) -> usize {
834 self.size
835 }
836
837 pub fn bounds(&self) -> Option<BoundingBox2D> {
843 match &self.root {
844 Some(QuadtreeNode::Internal { bounds, .. }) => Some(bounds.clone()),
845 Some(QuadtreeNode::Leaf { bounds, .. }) => Some(bounds.clone()),
846 None => None,
847 }
848 }
849
850 pub fn max_depth(&self) -> usize {
856 Quadtree::compute_max_depth(self.root.as_ref())
857 }
858
859 #[allow(clippy::only_used_in_recursion)]
861 fn compute_max_depth(node: Option<&QuadtreeNode>) -> usize {
862 match node {
863 None => 0,
864 Some(QuadtreeNode::Leaf { .. }) => 1,
865 Some(QuadtreeNode::Internal { children, .. }) => {
866 let mut max_child_depth = 0;
867 for child in children.iter().flatten() {
868 let child_depth = Self::compute_max_depth(Some(child));
869 max_child_depth = max_child_depth.max(child_depth);
870 }
871 1 + max_child_depth
872 }
873 }
874 }
875}
876
877#[allow(dead_code)]
888fn squared_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
889 let mut sum_sq = 0.0;
890 for i in 0..p1.len().min(p2.len()) {
891 let diff = p1[i] - p2[i];
892 sum_sq += diff * diff;
893 }
894 sum_sq
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900 use scirs2_core::ndarray::array;
901
902 #[test]
903 fn test_bounding_box_creation() {
904 let min = array![0.0, 0.0];
906 let max = array![1.0, 1.0];
907 let bbox = BoundingBox2D::new(&min.view(), &max.view()).unwrap();
908
909 assert_eq!(bbox.min, min);
910 assert_eq!(bbox.max, max);
911
912 let points = array![[0.0, 0.0], [1.0, 1.0], [0.5, 0.5],];
914 let bbox = BoundingBox2D::from_points(&points.view()).unwrap();
915
916 assert_eq!(bbox.min, min);
917 assert_eq!(bbox.max, max);
918
919 let bad_min = array![0.0];
921 let result = BoundingBox2D::new(&bad_min.view(), &max.view());
922 assert!(result.is_err());
923
924 let bad_minmax = array![2.0, 0.0];
925 let result = BoundingBox2D::new(&bad_minmax.view(), &max.view());
926 assert!(result.is_err());
927 }
928
929 #[test]
930 fn test_bounding_box_operations() {
931 let min = array![0.0, 0.0];
932 let max = array![2.0, 4.0];
933 let bbox = BoundingBox2D::new(&min.view(), &max.view()).unwrap();
934
935 let center = bbox.center();
937 assert_eq!(center, array![1.0, 2.0]);
938
939 let dims = bbox.dimensions();
941 assert_eq!(dims, array![2.0, 4.0]);
942
943 let inside_point = array![1.0, 1.0];
945 assert!(bbox.contains(&inside_point.view()).unwrap());
946
947 let outside_point = array![3.0, 3.0];
948 assert!(!bbox.contains(&outside_point.view()).unwrap());
949
950 let edge_point = array![0.0, 4.0];
951 assert!(bbox.contains(&edge_point.view()).unwrap());
952
953 let overlapping_box =
955 BoundingBox2D::new(&array![1.0, 1.0].view(), &array![3.0, 3.0].view()).unwrap();
956 assert!(bbox.overlaps(&overlapping_box));
957
958 let non_overlapping_box =
959 BoundingBox2D::new(&array![3.0, 5.0].view(), &array![4.0, 6.0].view()).unwrap();
960 assert!(!bbox.overlaps(&non_overlapping_box));
961
962 let inside_dist = bbox
964 .squared_distance_to_point(&inside_point.view())
965 .unwrap();
966 assert_eq!(inside_dist, 0.0);
967
968 let outside_dist = bbox
969 .squared_distance_to_point(&array![3.0, 5.0].view())
970 .unwrap();
971 assert_eq!(outside_dist, 1.0 + 1.0); }
973
974 #[test]
975 fn test_quadtree_creation() {
976 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
978
979 let quadtree = Quadtree::new(&points.view()).unwrap();
980
981 assert_eq!(quadtree.size(), 5);
983
984 let bounds = quadtree.bounds().unwrap();
985 assert_eq!(bounds.min, array![0.0, 0.0]);
986 assert_eq!(bounds.max, array![1.0, 1.0]);
987
988 assert!(quadtree.max_depth() > 0);
990 }
991
992 #[test]
993 fn test_nearest_neighbor_search() {
994 let points = array![
996 [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5], [2.0, 2.0], ];
1003
1004 let quadtree = Quadtree::new(&points.view()).unwrap();
1005
1006 let query = array![0.1, 0.1];
1008 let (indices, distances) = quadtree.query_nearest(&query.view(), 1).unwrap();
1009
1010 assert_eq!(indices.len(), 1);
1011 assert!(indices[0] < points.shape()[0]);
1014 assert!(distances[0] >= 0.0);
1015
1016 let (indices, distances) = quadtree.query_nearest(&query.view(), 3).unwrap();
1018
1019 assert!(!indices.is_empty());
1021
1022 for d in distances.iter() {
1024 assert!(*d >= 0.0);
1025 }
1026
1027 let (indices, distances) = quadtree.query_nearest(&query.view(), 10).unwrap();
1029
1030 assert_eq!(indices.len(), 6); assert_eq!(distances.len(), 6);
1032 }
1033
1034 #[test]
1035 fn test_radius_search() {
1036 let points = array![
1038 [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5], [2.0, 2.0], ];
1045
1046 let quadtree = Quadtree::new(&points.view()).unwrap();
1047
1048 let query = array![0.0, 0.0];
1050 let radius = 0.5;
1051 let (indices, distances) = quadtree.query_radius(&query.view(), radius).unwrap();
1052
1053 assert_eq!(indices.len(), 1);
1054 assert_eq!(indices[0], 0); let radius = 1.5;
1058 let (indices, distances) = quadtree.query_radius(&query.view(), radius).unwrap();
1059
1060 assert!(indices.len() >= 4); for &dist in &distances {
1064 assert!(dist <= radius * radius);
1065 }
1066
1067 let radius = 4.0;
1069 let (indices, distances) = quadtree.query_radius(&query.view(), radius).unwrap();
1070
1071 assert_eq!(indices.len(), 6); }
1073
1074 #[test]
1075 fn test_region_queries() {
1076 let points = array![
1078 [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5], [2.0, 2.0], ];
1085
1086 let quadtree = Quadtree::new(&points.view()).unwrap();
1087
1088 let region =
1090 BoundingBox2D::new(&array![0.25, 0.25].view(), &array![0.75, 0.75].view()).unwrap();
1091
1092 assert!(quadtree.points_in_region(®ion));
1094
1095 let indices = quadtree.get_points_in_region(®ion);
1097 assert_eq!(indices.len(), 1);
1098 assert_eq!(indices[0], 4); let large_region =
1102 BoundingBox2D::new(&array![0.0, 0.0].view(), &array![1.0, 1.0].view()).unwrap();
1103
1104 let indices = quadtree.get_points_in_region(&large_region);
1105 assert_eq!(indices.len(), 5); let empty_region =
1109 BoundingBox2D::new(&array![1.5, 1.5].view(), &array![1.9, 1.9].view()).unwrap();
1110
1111 assert!(!quadtree.points_in_region(&empty_region));
1112 let indices = quadtree.get_points_in_region(&empty_region);
1113 assert_eq!(indices.len(), 0);
1114 }
1115}