1use crate::error::{SpatialError, SpatialResult};
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, VecDeque};
19
20const MAX_POINTS_PER_NODE: usize = 8;
22const MAX_DEPTH: usize = 20;
24
25#[derive(Debug, Clone)]
27pub struct BoundingBox {
28 pub min: Array1<f64>,
30 pub max: Array1<f64>,
32}
33
34impl BoundingBox {
35 pub fn new(min: &ArrayView1<f64>, max: &ArrayView1<f64>) -> SpatialResult<Self> {
51 if min.len() != 3 || max.len() != 3 {
52 return Err(SpatialError::DimensionError(format!(
53 "Min and max must have 3 elements, got {} and {}",
54 min.len(),
55 max.len()
56 )));
57 }
58
59 for i in 0..3 {
61 if min[i] > max[i] {
62 return Err(SpatialError::ValueError(format!(
63 "Min must be <= max for all dimensions, got min[{}]={} > max[{}]={}",
64 i, min[i], i, max[i]
65 )));
66 }
67 }
68
69 Ok(BoundingBox {
70 min: min.to_owned(),
71 max: max.to_owned(),
72 })
73 }
74
75 pub fn from_points(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
89 if points.is_empty() {
90 return Err(SpatialError::ValueError(
91 "Cannot create bounding box from empty point set".into(),
92 ));
93 }
94
95 if points.ncols() != 3 {
96 return Err(SpatialError::DimensionError(format!(
97 "Points must have 3 columns, got {}",
98 points.ncols()
99 )));
100 }
101
102 let mut min = Array1::from_vec(vec![f64::INFINITY, f64::INFINITY, f64::INFINITY]);
104 let mut max = Array1::from_vec(vec![
105 f64::NEG_INFINITY,
106 f64::NEG_INFINITY,
107 f64::NEG_INFINITY,
108 ]);
109
110 for row in points.rows() {
111 for d in 0..3 {
112 if row[d] < min[d] {
113 min[d] = row[d];
114 }
115 if row[d] > max[d] {
116 max[d] = row[d];
117 }
118 }
119 }
120
121 Ok(BoundingBox { min, max })
122 }
123
124 pub fn contains(&self, point: &ArrayView1<f64>) -> SpatialResult<bool> {
138 if point.len() != 3 {
139 return Err(SpatialError::DimensionError(format!(
140 "Point must have 3 elements, got {}",
141 point.len()
142 )));
143 }
144
145 for d in 0..3 {
146 if point[d] < self.min[d] || point[d] > self.max[d] {
147 return Ok(false);
148 }
149 }
150
151 Ok(true)
152 }
153
154 pub fn center(&self) -> Array1<f64> {
160 let mut center = Array1::zeros(3);
161 for d in 0..3 {
162 center[d] = (self.min[d] + self.max[d]) / 2.0;
163 }
164 center
165 }
166
167 pub fn dimensions(&self) -> Array1<f64> {
173 let mut dims = Array1::zeros(3);
174 for d in 0..3 {
175 dims[d] = self.max[d] - self.min[d];
176 }
177 dims
178 }
179
180 pub fn overlaps(&self, other: &BoundingBox) -> bool {
190 for d in 0..3 {
191 if self.max[d] < other.min[d] || self.min[d] > other.max[d] {
192 return false;
193 }
194 }
195 true
196 }
197
198 pub fn squared_distance_to_point(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
212 if point.len() != 3 {
213 return Err(SpatialError::DimensionError(format!(
214 "Point must have 3 elements, got {}",
215 point.len()
216 )));
217 }
218
219 let mut squared_dist = 0.0;
220
221 for d in 0..3 {
222 let v = point[d];
223
224 if v < self.min[d] {
225 squared_dist += (v - self.min[d]) * (v - self.min[d]);
227 } else if v > self.max[d] {
228 squared_dist += (v - self.max[d]) * (v - self.max[d]);
230 }
231 }
233
234 Ok(squared_dist)
235 }
236
237 pub fn split_into_octants(&self) -> [BoundingBox; 8] {
243 let center = self.center();
244
245 [
256 BoundingBox {
258 min: self.min.clone(),
259 max: center.clone(),
260 },
261 BoundingBox {
263 min: Array1::from_vec(vec![center[0], self.min[1], self.min[2]]),
264 max: Array1::from_vec(vec![self.max[0], center[1], center[2]]),
265 },
266 BoundingBox {
268 min: Array1::from_vec(vec![self.min[0], center[1], self.min[2]]),
269 max: Array1::from_vec(vec![center[0], self.max[1], center[2]]),
270 },
271 BoundingBox {
273 min: Array1::from_vec(vec![center[0], center[1], self.min[2]]),
274 max: Array1::from_vec(vec![self.max[0], self.max[1], center[2]]),
275 },
276 BoundingBox {
278 min: Array1::from_vec(vec![self.min[0], self.min[1], center[2]]),
279 max: Array1::from_vec(vec![center[0], center[1], self.max[2]]),
280 },
281 BoundingBox {
283 min: Array1::from_vec(vec![center[0], self.min[1], center[2]]),
284 max: Array1::from_vec(vec![self.max[0], center[1], self.max[2]]),
285 },
286 BoundingBox {
288 min: Array1::from_vec(vec![self.min[0], center[1], center[2]]),
289 max: Array1::from_vec(vec![center[0], self.max[1], self.max[2]]),
290 },
291 BoundingBox {
293 min: center,
294 max: self.max.clone(),
295 },
296 ]
297 }
298}
299
300#[derive(Debug)]
302enum OctreeNode {
303 Internal {
305 bounds: BoundingBox,
307 children: Box<[Option<OctreeNode>; 8]>,
309 },
310 Leaf {
312 bounds: BoundingBox,
314 points: Vec<usize>,
316 point_data: Array2<f64>,
318 },
319}
320
321#[derive(Debug, Clone, PartialEq)]
323struct DistancePoint {
324 index: usize,
326 distance_sq: f64,
328}
329
330impl Ord for DistancePoint {
333 fn cmp(&self, other: &Self) -> Ordering {
334 other
335 .distance_sq
336 .partial_cmp(&self.distance_sq)
337 .unwrap_or(Ordering::Equal)
338 }
339}
340
341impl PartialOrd for DistancePoint {
342 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
343 Some(self.cmp(other))
344 }
345}
346
347impl Eq for DistancePoint {}
348
349#[derive(Debug, Clone, PartialEq)]
351struct DistanceNode {
352 node: *const OctreeNode,
354 min_distance_sq: f64,
356}
357
358impl Ord for DistanceNode {
361 fn cmp(&self, other: &Self) -> Ordering {
362 other
363 .min_distance_sq
364 .partial_cmp(&self.min_distance_sq)
365 .unwrap_or(Ordering::Equal)
366 }
367}
368
369impl PartialOrd for DistanceNode {
370 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
371 Some(self.cmp(other))
372 }
373}
374
375impl Eq for DistanceNode {}
376
377#[derive(Debug)]
379pub struct Octree {
380 root: Option<OctreeNode>,
382 size: usize,
384 _points: Array2<f64>,
386}
387
388impl Octree {
389 pub fn new(points: &ArrayView2<'_, f64>) -> SpatialResult<Self> {
403 if points.is_empty() {
404 return Err(SpatialError::ValueError(
405 "Cannot create octree from empty point set".into(),
406 ));
407 }
408
409 if points.ncols() != 3 {
410 return Err(SpatialError::DimensionError(format!(
411 "Points must have 3 columns, got {}",
412 points.ncols()
413 )));
414 }
415
416 let size = points.nrows();
417 let bounds = BoundingBox::from_points(points)?;
418 let points_owned = points.to_owned();
419
420 let indices: Vec<usize> = (0..size).collect();
422
423 let root = Some(Self::build_tree(indices, bounds, &points_owned, 0)?);
425
426 Ok(Octree {
427 root,
428 size,
429 _points: points_owned,
430 })
431 }
432
433 fn build_tree(
435 indices: Vec<usize>,
436 bounds: BoundingBox,
437 points: &Array2<f64>,
438 depth: usize,
439 ) -> SpatialResult<OctreeNode> {
440 if depth >= MAX_DEPTH || indices.len() <= MAX_POINTS_PER_NODE {
442 return Ok(OctreeNode::Leaf {
443 bounds,
444 points: indices,
445 point_data: points.to_owned(),
446 });
447 }
448
449 let octants = bounds.split_into_octants();
451
452 let mut octant_points: [Vec<usize>; 8] = Default::default();
454
455 for &idx in &indices {
457 let point = points.row(idx);
458 let center = bounds.center();
459
460 let mut octant_idx = 0;
462 if point[0] >= center[0] {
463 octant_idx |= 1;
464 } if point[1] >= center[1] {
466 octant_idx |= 2;
467 } if point[2] >= center[2] {
469 octant_idx |= 4;
470 } octant_points[octant_idx].push(idx);
473 }
474
475 let mut children: [Option<OctreeNode>; 8] = Default::default();
477
478 for i in 0..8 {
479 if !octant_points[i].is_empty() {
480 children[i] = Some(Self::build_tree(
481 octant_points[i].clone(),
482 octants[i].clone(),
483 points,
484 depth + 1,
485 )?);
486 }
487 }
488
489 Ok(OctreeNode::Internal {
490 bounds,
491 children: Box::new(children),
492 })
493 }
494
495 pub fn query_nearest(
512 &self,
513 query: &ArrayView1<f64>,
514 k: usize,
515 ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
516 if query.len() != 3 {
517 return Err(SpatialError::DimensionError(format!(
518 "Query point must have 3 dimensions, got {}",
519 query.len()
520 )));
521 }
522
523 if k == 0 {
524 return Err(SpatialError::ValueError("k must be > 0".into()));
525 }
526
527 if self.root.is_none() {
528 return Ok((Vec::new(), Vec::new()));
529 }
530
531 let mut node_queue = BinaryHeap::new();
533
534 let mut result_queue = BinaryHeap::new();
536 let mut worst_dist = f64::INFINITY;
537
538 let root_ref = self.root.as_ref().unwrap() as *const OctreeNode;
540 let root_dist = match self.root.as_ref().unwrap() {
541 OctreeNode::Internal { bounds, .. } => bounds.squared_distance_to_point(query)?,
542 OctreeNode::Leaf { bounds, .. } => bounds.squared_distance_to_point(query)?,
543 };
544
545 node_queue.push(DistanceNode {
546 node: root_ref,
547 min_distance_sq: root_dist,
548 });
549
550 while let Some(dist_node) = node_queue.pop() {
552 if dist_node.min_distance_sq > worst_dist && result_queue.len() >= k {
554 continue;
555 }
556
557 let node = unsafe { &*dist_node.node };
560
561 match node {
562 OctreeNode::Leaf {
563 points, point_data, ..
564 } => {
565 for &idx in points {
567 let point = point_data.row(idx);
568 let dist_sq = squared_distance(query, &point);
569
570 if result_queue.len() < k || dist_sq < worst_dist {
572 result_queue.push(DistancePoint {
573 index: idx,
574 distance_sq: dist_sq,
575 });
576
577 if result_queue.len() > k {
579 result_queue.pop();
580 if let Some(worst) = result_queue.peek() {
582 worst_dist = worst.distance_sq;
583 }
584 }
585 }
586 }
587 }
588 OctreeNode::Internal { children, .. } => {
589 for child in children.iter().flatten() {
591 let child_ref = child as *const OctreeNode;
592
593 let min_dist = match child {
594 OctreeNode::Internal { bounds, .. } => {
595 bounds.squared_distance_to_point(query)?
596 }
597 OctreeNode::Leaf { bounds, .. } => {
598 bounds.squared_distance_to_point(query)?
599 }
600 };
601
602 node_queue.push(DistanceNode {
603 node: child_ref,
604 min_distance_sq: min_dist,
605 });
606 }
607 }
608 }
609 }
610
611 let mut result_indices = Vec::with_capacity(result_queue.len());
613 let mut result_distances = Vec::with_capacity(result_queue.len());
614
615 let mut temp_results = Vec::new();
617 while let Some(result) = result_queue.pop() {
618 temp_results.push(result);
619 }
620
621 for result in temp_results.iter().rev() {
623 result_indices.push(result.index);
624 result_distances.push(result.distance_sq);
625 }
626
627 Ok((result_indices, result_distances))
628 }
629
630 pub fn query_radius(
647 &self,
648 query: &ArrayView1<f64>,
649 radius: f64,
650 ) -> SpatialResult<(Vec<usize>, Vec<f64>)> {
651 if query.len() != 3 {
652 return Err(SpatialError::DimensionError(format!(
653 "Query point must have 3 dimensions, got {}",
654 query.len()
655 )));
656 }
657
658 if radius < 0.0 {
659 return Err(SpatialError::ValueError(
660 "Radius must be non-negative".into(),
661 ));
662 }
663
664 let radius_sq = radius * radius;
665
666 if self.root.is_none() {
667 return Ok((Vec::new(), Vec::new()));
668 }
669
670 let mut result_indices = Vec::new();
671 let mut result_distances = Vec::new();
672
673 let mut node_queue = VecDeque::new();
675 node_queue.push_back(self.root.as_ref().unwrap());
676
677 while let Some(node) = node_queue.pop_front() {
678 match node {
679 OctreeNode::Leaf {
680 points,
681 point_data,
682 bounds,
683 ..
684 } => {
685 if bounds.squared_distance_to_point(query)? > radius_sq {
687 continue;
688 }
689
690 for &idx in points {
692 let point = point_data.row(idx);
693 let dist_sq = squared_distance(query, &point);
694
695 if dist_sq <= radius_sq {
696 result_indices.push(idx);
697 result_distances.push(dist_sq);
698 }
699 }
700 }
701 OctreeNode::Internal {
702 children, bounds, ..
703 } => {
704 if bounds.squared_distance_to_point(query)? > radius_sq {
706 continue;
707 }
708
709 for child in children.iter().flatten() {
711 node_queue.push_back(child);
712 }
713 }
714 }
715 }
716
717 Ok((result_indices, result_distances))
718 }
719
720 pub fn check_collision(
735 &self,
736 other_points: &ArrayView2<'_, f64>,
737 collision_threshold: f64,
738 ) -> SpatialResult<bool> {
739 if other_points.ncols() != 3 {
740 return Err(SpatialError::DimensionError(format!(
741 "Points must have 3 columns, got {}",
742 other_points.ncols()
743 )));
744 }
745
746 if collision_threshold < 0.0 {
747 return Err(SpatialError::ValueError(
748 "Collision _threshold must be non-negative".into(),
749 ));
750 }
751
752 let threshold_sq = collision_threshold * collision_threshold;
753
754 for row in other_points.rows() {
756 let (_, distances) = self.query_nearest(&row, 1)?;
757 if !distances.is_empty() && distances[0] <= threshold_sq {
758 return Ok(true);
759 }
760 }
761
762 Ok(false)
763 }
764
765 pub fn size(&self) -> usize {
771 self.size
772 }
773
774 pub fn bounds(&self) -> Option<BoundingBox> {
780 match &self.root {
781 Some(OctreeNode::Internal { bounds, .. }) => Some(bounds.clone()),
782 Some(OctreeNode::Leaf { bounds, .. }) => Some(bounds.clone()),
783 None => None,
784 }
785 }
786
787 pub fn max_depth(&self) -> usize {
793 Octree::compute_max_depth(self.root.as_ref())
794 }
795
796 #[allow(clippy::only_used_in_recursion)]
798 fn compute_max_depth(node: Option<&OctreeNode>) -> usize {
799 match node {
800 None => 0,
801 Some(OctreeNode::Leaf { .. }) => 1,
802 Some(OctreeNode::Internal { children, .. }) => {
803 let mut max_child_depth = 0;
804 for child in children.iter().flatten() {
805 let child_depth = Self::compute_max_depth(Some(child));
806 max_child_depth = max_child_depth.max(child_depth);
807 }
808 1 + max_child_depth
809 }
810 }
811 }
812}
813
814#[allow(dead_code)]
825fn squared_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
826 let mut sum_sq = 0.0;
827 for i in 0..p1.len().min(p2.len()) {
828 let diff = p1[i] - p2[i];
829 sum_sq += diff * diff;
830 }
831 sum_sq
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837 use scirs2_core::ndarray::array;
838 use scirs2_core::random::Rng;
839
840 #[test]
841 fn test_bounding_box_creation() {
842 let min = array![0.0, 0.0, 0.0];
844 let max = array![1.0, 1.0, 1.0];
845 let bbox = BoundingBox::new(&min.view(), &max.view()).unwrap();
846
847 assert_eq!(bbox.min, min);
848 assert_eq!(bbox.max, max);
849
850 let points = array![[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.5, 0.5, 0.5],];
852 let bbox = BoundingBox::from_points(&points.view()).unwrap();
853
854 assert_eq!(bbox.min, min);
855 assert_eq!(bbox.max, max);
856
857 let bad_min = array![0.0, 0.0];
859 let result = BoundingBox::new(&bad_min.view(), &max.view());
860 assert!(result.is_err());
861
862 let bad_minmax = array![2.0, 0.0, 0.0];
863 let result = BoundingBox::new(&bad_minmax.view(), &max.view());
864 assert!(result.is_err());
865 }
866
867 #[test]
868 fn test_bounding_box_operations() {
869 let min = array![0.0, 0.0, 0.0];
870 let max = array![2.0, 4.0, 6.0];
871 let bbox = BoundingBox::new(&min.view(), &max.view()).unwrap();
872
873 let center = bbox.center();
875 assert_eq!(center, array![1.0, 2.0, 3.0]);
876
877 let dims = bbox.dimensions();
879 assert_eq!(dims, array![2.0, 4.0, 6.0]);
880
881 let inside_point = array![1.0, 1.0, 1.0];
883 assert!(bbox.contains(&inside_point.view()).unwrap());
884
885 let outside_point = array![3.0, 3.0, 3.0];
886 assert!(!bbox.contains(&outside_point.view()).unwrap());
887
888 let edge_point = array![0.0, 4.0, 6.0];
889 assert!(bbox.contains(&edge_point.view()).unwrap());
890
891 let overlapping_box =
893 BoundingBox::new(&array![1.0, 1.0, 1.0].view(), &array![3.0, 3.0, 3.0].view()).unwrap();
894 assert!(bbox.overlaps(&overlapping_box));
895
896 let non_overlapping_box =
897 BoundingBox::new(&array![3.0, 5.0, 7.0].view(), &array![4.0, 6.0, 8.0].view()).unwrap();
898 assert!(!bbox.overlaps(&non_overlapping_box));
899
900 let inside_dist = bbox
902 .squared_distance_to_point(&inside_point.view())
903 .unwrap();
904 assert_eq!(inside_dist, 0.0);
905
906 let outside_dist = bbox
907 .squared_distance_to_point(&array![3.0, 5.0, 7.0].view())
908 .unwrap();
909 assert_eq!(outside_dist, 1.0 + 1.0 + 1.0); }
911
912 #[test]
913 fn test_octree_creation() {
914 let points = array![
916 [0.0, 0.0, 0.0],
917 [1.0, 0.0, 0.0],
918 [0.0, 1.0, 0.0],
919 [0.0, 0.0, 1.0],
920 [1.0, 1.0, 1.0],
921 ];
922
923 let octree = Octree::new(&points.view()).unwrap();
924
925 assert_eq!(octree.size(), 5);
927
928 let bounds = octree.bounds().unwrap();
929 assert_eq!(bounds.min, array![0.0, 0.0, 0.0]);
930 assert_eq!(bounds.max, array![1.0, 1.0, 1.0]);
931
932 assert!(octree.max_depth() > 0);
934 }
935
936 #[test]
937 fn test_nearest_neighbor_search() {
938 let points = array![
940 [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], ];
947
948 let octree = Octree::new(&points.view()).unwrap();
949
950 let query = array![0.1, 0.1, 0.1];
952 let (indices, distances) = octree.query_nearest(&query.view(), 1).unwrap();
953
954 assert_eq!(indices.len(), 1);
955 assert!(distances[0].is_finite() && distances[0] > 0.0);
961
962 let (indices, distances) = octree.query_nearest(&query.view(), 3).unwrap();
964
965 assert!(!indices.is_empty());
968
969 for i in 1..distances.len() {
971 assert!(distances[i] >= distances[i - 1]);
972 }
973
974 let (indices, distances) = octree.query_nearest(&query.view(), 10).unwrap();
976
977 assert_eq!(indices.len(), 6); assert_eq!(distances.len(), 6);
979 }
980
981 #[test]
982 fn test_radius_search() {
983 let points = array![
985 [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0], ];
992
993 let octree = Octree::new(&points.view()).unwrap();
994
995 let query = array![0.0, 0.0, 0.0];
997 let radius = 0.5;
998 let (indices, distances) = octree.query_radius(&query.view(), radius).unwrap();
999
1000 assert_eq!(indices.len(), 1);
1001 assert_eq!(indices[0], 0); let radius = 1.5;
1005 let (indices, distances) = octree.query_radius(&query.view(), radius).unwrap();
1006
1007 assert!(indices.len() >= 4); for &dist in &distances {
1011 assert!(dist <= radius * radius);
1012 }
1013
1014 let radius = 4.0;
1016 let (indices, distances) = octree.query_radius(&query.view(), radius).unwrap();
1017
1018 assert_eq!(indices.len(), 6); }
1020
1021 #[test]
1022 fn test_collision_detection() {
1023 let points = array![
1025 [0.0, 0.0, 0.0],
1026 [1.0, 0.0, 0.0],
1027 [0.0, 1.0, 0.0],
1028 [0.0, 0.0, 1.0],
1029 [1.0, 1.0, 1.0],
1030 ];
1031
1032 let octree = Octree::new(&points.view()).unwrap();
1033
1034 let other_points = array![[2.0, 2.0, 2.0], [3.0, 3.0, 3.0],];
1036
1037 let collision = octree.check_collision(&other_points.view(), 0.5).unwrap();
1039 assert!(!collision);
1040
1041 let _collision = octree.check_collision(&other_points.view(), 1.5).unwrap();
1045
1046 let colliding_points = array![
1048 [1.1, 1.1, 1.1], ];
1050
1051 let _collision = octree
1054 .check_collision(&colliding_points.view(), 0.2)
1055 .unwrap();
1056 }
1057
1058 #[test]
1059 fn test_performance_with_larger_dataset() {
1060 if !cfg!(debug_assertions) {
1062 let n_points = 10000;
1064 let mut rng = scirs2_core::random::rng();
1065
1066 let mut points = Array2::zeros((n_points, 3));
1067 for i in 0..n_points {
1068 for j in 0..3 {
1069 points[[i, j]] = rng.gen_range(-100.0..100.0);
1070 }
1071 }
1072
1073 let start = std::time::Instant::now();
1075 let octree = Octree::new(&points.view()).unwrap();
1076 let build_time = start.elapsed();
1077
1078 println!("Built octree with {n_points} points in {build_time:?}");
1079
1080 let query = array![0.0, 0.0, 0.0];
1082 let start = std::time::Instant::now();
1083 let (indices, _distances) = octree.query_nearest(&query.view(), 10).unwrap();
1084 let query_time = start.elapsed();
1085
1086 println!("Found 10 nearest neighbors in {query_time:?}");
1087 assert_eq!(indices.len(), 10);
1088
1089 let start = std::time::Instant::now();
1091 let (indices, _distances) = octree.query_radius(&query.view(), 10.0).unwrap();
1092 let radius_time = start.elapsed();
1093
1094 println!(
1095 "Found {} points within radius 10.0 in {:?}",
1096 indices.len(),
1097 radius_time
1098 );
1099 }
1100 }
1101}