1use crate::distance::{Distance, EuclideanDistance};
68use crate::error::{SpatialError, SpatialResult};
69use crate::safe_conversions::*;
70use scirs2_core::ndarray::Array2;
71use scirs2_core::numeric::Float;
72use std::cmp::Ordering;
73
74#[cfg(feature = "parallel")]
76#[allow(unused_imports)]
77#[derive(Clone, Debug)]
81pub struct Rectangle<T: Float> {
82 mins: Vec<T>,
84 maxes: Vec<T>,
86}
87
88impl<T: Float> Rectangle<T> {
89 pub fn new(mins: Vec<T>, maxes: Vec<T>) -> Self {
105 assert_eq!(
106 mins.len(),
107 maxes.len(),
108 "mins and maxes must have the same length"
109 );
110
111 for i in 0..mins.len() {
112 assert!(
113 mins[i] <= maxes[i],
114 "min value must be less than or equal to max value"
115 );
116 }
117
118 Rectangle { mins, maxes }
119 }
120
121 pub fn mins(&self) -> &[T] {
137 &self.mins
138 }
139
140 pub fn maxes(&self) -> &[T] {
156 &self.maxes
157 }
158
159 pub fn split(&self, dim: usize, value: T) -> (Self, Self) {
170 let mut left_maxes = self.maxes.clone();
171 left_maxes[dim] = value;
172
173 let mut right_mins = self.mins.clone();
174 right_mins[dim] = value;
175
176 let left = Rectangle::new(self.mins.clone(), left_maxes);
177 let right = Rectangle::new(right_mins, self.maxes.clone());
178
179 (left, right)
180 }
181
182 pub fn contains(&self, point: &[T]) -> bool {
192 assert_eq!(
193 point.len(),
194 self.mins.len(),
195 "point must have the same dimension as the rectangle"
196 );
197
198 for (i, &p) in point.iter().enumerate() {
199 if p < self.mins[i] || p > self.maxes[i] {
200 return false;
201 }
202 }
203
204 true
205 }
206
207 pub fn min_distance<D: Distance<T>>(&self, point: &[T], metric: &D) -> T {
218 metric.min_distance_point_rectangle(point, &self.mins, &self.maxes)
219 }
220}
221
222#[derive(Debug, Clone)]
224struct KDNode<T: Float> {
225 idx: usize,
227 value: T,
229 axis: usize,
231 left: Option<usize>,
233 right: Option<usize>,
235}
236
237#[derive(Debug, Clone)]
244pub struct KDTree<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> {
245 points: Array2<T>,
247 nodes: Vec<KDNode<T>>,
249 ndim: usize,
251 root: Option<usize>,
253 metric: D,
255 leafsize: usize,
257 bounds: Rectangle<T>,
259}
260
261impl<T: Float + Send + Sync + 'static> KDTree<T, EuclideanDistance<T>> {
262 pub fn new(points: &Array2<T>) -> SpatialResult<Self> {
272 let metric = EuclideanDistance::new();
273 Self::with_metric(points, metric)
274 }
275
276 pub fn with_leaf_size(points: &Array2<T>, leafsize: usize) -> SpatialResult<Self> {
287 let metric = EuclideanDistance::new();
288 Self::with_options(points, metric, leafsize)
289 }
290}
291
292impl<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> KDTree<T, D> {
293 pub fn with_metric(points: &Array2<T>, metric: D) -> SpatialResult<Self> {
304 Self::with_options(points, metric, 16) }
306
307 pub fn with_options(points: &Array2<T>, metric: D, leafsize: usize) -> SpatialResult<Self> {
319 let n = points.nrows();
320 let ndim = points.ncols();
321
322 if n == 0 {
323 return Err(SpatialError::ValueError("Empty point set".to_string()));
324 }
325
326 if leafsize == 0 {
327 return Err(SpatialError::ValueError(
328 "Leaf _size must be greater than 0".to_string(),
329 ));
330 }
331
332 let mut mins = vec![T::max_value(); ndim];
334 let mut maxes = vec![T::min_value(); ndim];
335
336 for i in 0..n {
337 for j in 0..ndim {
338 let val = points[[i, j]];
339 if val < mins[j] {
340 mins[j] = val;
341 }
342 if val > maxes[j] {
343 maxes[j] = val;
344 }
345 }
346 }
347
348 let bounds = Rectangle::new(mins, maxes);
349
350 let mut tree = KDTree {
351 points: points.clone(),
352 nodes: Vec::with_capacity(n),
353 ndim,
354 root: None,
355 metric,
356 leafsize,
357 bounds,
358 };
359
360 let mut indices: Vec<usize> = (0..n).collect();
362
363 if n > 0 {
365 let root = tree.build_tree(&mut indices, 0, 0, n)?;
366 tree.root = Some(root);
367 }
368
369 Ok(tree)
370 }
371
372 fn build_tree(
385 &mut self,
386 indices: &mut [usize],
387 depth: usize,
388 start: usize,
389 end: usize,
390 ) -> SpatialResult<usize> {
391 let n = end - start;
392
393 if n == 0 {
394 return Err(SpatialError::ValueError(
395 "Empty point set in build_tree".to_string(),
396 ));
397 }
398
399 let axis = depth % self.ndim;
401
402 let node_idx;
404 if n == 1 {
405 let idx = indices[start];
406 let value = self.points[[idx, axis]];
407
408 node_idx = self.nodes.len();
409 self.nodes.push(KDNode {
410 idx,
411 value,
412 axis,
413 left: None,
414 right: None,
415 });
416
417 return Ok(node_idx);
418 }
419
420 indices[start..end].sort_by(|&i, &j| {
422 let a = self.points[[i, axis]];
423 let b = self.points[[j, axis]];
424 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
425 });
426
427 let mid = start + n / 2;
429 let idx = indices[mid];
430 let value = self.points[[idx, axis]];
431
432 node_idx = self.nodes.len();
434 self.nodes.push(KDNode {
435 idx,
436 value,
437 axis,
438 left: None,
439 right: None,
440 });
441
442 if mid > start {
444 let left_idx = self.build_tree(indices, depth + 1, start, mid)?;
445 self.nodes[node_idx].left = Some(left_idx);
446 }
447
448 if mid + 1 < end {
449 let right_idx = self.build_tree(indices, depth + 1, mid + 1, end)?;
450 self.nodes[node_idx].right = Some(right_idx);
451 }
452
453 Ok(node_idx)
454 }
455
456 pub fn query(&self, point: &[T], k: usize) -> SpatialResult<(Vec<usize>, Vec<T>)> {
483 if point.len() != self.ndim {
484 return Err(SpatialError::DimensionError(format!(
485 "Query point dimension ({}) does not match tree dimension ({})",
486 point.len(),
487 self.ndim
488 )));
489 }
490
491 if k == 0 {
492 return Ok((vec![], vec![]));
493 }
494
495 if self.points.nrows() == 0 {
496 return Ok((vec![], vec![]));
497 }
498
499 let mut neighbors: Vec<(T, usize)> = Vec::with_capacity(k + 1);
502
503 let mut max_dist = T::infinity();
505
506 if let Some(root) = self.root {
507 self.query_recursive(root, point, k, &mut neighbors, &mut max_dist);
509
510 neighbors.sort_by(|a, b| {
512 match safe_partial_cmp(&a.0, &b.0, "kdtree sort neighbors") {
513 Ok(std::cmp::Ordering::Equal) => a.1.cmp(&b.1), Ok(ord) => ord,
515 Err(_) => std::cmp::Ordering::Equal,
516 }
517 });
518
519 if neighbors.len() > k {
521 neighbors.truncate(k);
522 }
523
524 let mut indices = Vec::with_capacity(neighbors.len());
526 let mut distances = Vec::with_capacity(neighbors.len());
527
528 for (dist, idx) in neighbors {
529 indices.push(idx);
530 distances.push(dist);
531 }
532
533 Ok((indices, distances))
534 } else {
535 Err(SpatialError::ValueError("Empty tree".to_string()))
536 }
537 }
538
539 fn query_recursive(
541 &self,
542 node_idx: usize,
543 point: &[T],
544 k: usize,
545 neighbors: &mut Vec<(T, usize)>,
546 max_dist: &mut T,
547 ) {
548 let node = &self.nodes[node_idx];
549 let idx = node.idx;
550 let axis = node.axis;
551
552 let node_point = self.points.row(idx).to_vec();
554 let _dist = self.metric.distance(&node_point, point);
555
556 if neighbors.len() < k {
558 neighbors.push((_dist, idx));
559
560 if neighbors.len() == k {
562 neighbors.sort_by(|a, b| {
563 match safe_partial_cmp(&b.0, &a.0, "kdtree sort max-heap") {
564 Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), Ok(ord) => ord,
566 Err(_) => std::cmp::Ordering::Equal,
567 }
568 });
569 *max_dist = neighbors[0].0;
570 }
571 } else if &_dist < max_dist {
572 neighbors[0] = (_dist, idx);
574
575 neighbors.sort_by(|a, b| {
577 match safe_partial_cmp(&b.0, &a.0, "kdtree re-sort max-heap") {
578 Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), Ok(ord) => ord,
580 Err(_) => std::cmp::Ordering::Equal,
581 }
582 });
583 *max_dist = neighbors[0].0;
584 }
585
586 let diff = point[axis] - node.value;
588 let (first, second) = if diff < T::zero() {
589 (node.left, node.right)
590 } else {
591 (node.right, node.left)
592 };
593
594 if let Some(first_idx) = first {
596 self.query_recursive(first_idx, point, k, neighbors, max_dist);
597 }
598
599 let axis_dist = if diff < T::zero() {
601 T::zero() } else {
604 diff
606 };
607
608 if let Some(second_idx) = second {
609 if neighbors.len() < k || axis_dist < *max_dist {
611 self.query_recursive(second_idx, point, k, neighbors, max_dist);
612 }
613 }
614 }
615
616 pub fn query_radius(&self, point: &[T], radius: T) -> SpatialResult<(Vec<usize>, Vec<T>)> {
644 if point.len() != self.ndim {
645 return Err(SpatialError::DimensionError(format!(
646 "Query point dimension ({}) does not match tree dimension ({})",
647 point.len(),
648 self.ndim
649 )));
650 }
651
652 if radius < T::zero() {
653 return Err(SpatialError::ValueError(
654 "Radius must be non-negative".to_string(),
655 ));
656 }
657
658 let mut indices = Vec::new();
659 let mut distances = Vec::new();
660
661 if let Some(root) = self.root {
662 let bounds_dist = self.bounds.min_distance(point, &self.metric);
664 if bounds_dist > radius {
665 return Ok((indices, distances));
666 }
667
668 self.query_radius_recursive(root, point, radius, &mut indices, &mut distances);
670
671 if !indices.is_empty() {
673 let mut idx_dist: Vec<(usize, T)> = indices.into_iter().zip(distances).collect();
674 idx_dist.sort_by(|a, b| {
675 safe_partial_cmp(&a.1, &b.1, "kdtree sort radius results")
676 .unwrap_or(std::cmp::Ordering::Equal)
677 });
678
679 indices = idx_dist.iter().map(|(idx_, _)| *idx_).collect();
680 distances = idx_dist.iter().map(|(_, dist)| *dist).collect();
681 }
682 }
683
684 Ok((indices, distances))
685 }
686
687 fn query_radius_recursive(
689 &self,
690 node_idx: usize,
691 point: &[T],
692 radius: T,
693 indices: &mut Vec<usize>,
694 distances: &mut Vec<T>,
695 ) {
696 let node = &self.nodes[node_idx];
697 let idx = node.idx;
698 let axis = node.axis;
699
700 let node_point = self.points.row(idx).to_vec();
702 let dist = self.metric.distance(&node_point, point);
703
704 if dist <= radius {
706 indices.push(idx);
707 distances.push(dist);
708 }
709
710 let diff = point[axis] - node.value;
712
713 let (near, far) = if diff < T::zero() {
715 (node.left, node.right)
716 } else {
717 (node.right, node.left)
718 };
719
720 if let Some(near_idx) = near {
721 self.query_radius_recursive(near_idx, point, radius, indices, distances);
722 }
723
724 if diff.abs() <= radius {
726 if let Some(far_idx) = far {
727 self.query_radius_recursive(far_idx, point, radius, indices, distances);
728 }
729 }
730 }
731
732 pub fn count_neighbors(&self, point: &[T], radius: T) -> SpatialResult<usize> {
762 if point.len() != self.ndim {
763 return Err(SpatialError::DimensionError(format!(
764 "Query point dimension ({}) does not match tree dimension ({})",
765 point.len(),
766 self.ndim
767 )));
768 }
769
770 if radius < T::zero() {
771 return Err(SpatialError::ValueError(
772 "Radius must be non-negative".to_string(),
773 ));
774 }
775
776 let mut count = 0;
777
778 if let Some(root) = self.root {
779 let bounds_dist = self.bounds.min_distance(point, &self.metric);
781 if bounds_dist > radius {
782 return Ok(0);
783 }
784
785 self.count_neighbors_recursive(root, point, radius, &mut count);
787 }
788
789 Ok(count)
790 }
791
792 fn count_neighbors_recursive(
794 &self,
795 node_idx: usize,
796 point: &[T],
797 radius: T,
798 count: &mut usize,
799 ) {
800 let node = &self.nodes[node_idx];
801 let idx = node.idx;
802 let axis = node.axis;
803
804 let node_point = self.points.row(idx).to_vec();
806 let dist = self.metric.distance(&node_point, point);
807
808 if dist <= radius {
810 *count += 1;
811 }
812
813 let diff = point[axis] - node.value;
815
816 let (near, far) = if diff < T::zero() {
818 (node.left, node.right)
819 } else {
820 (node.right, node.left)
821 };
822
823 if let Some(near_idx) = near {
824 self.count_neighbors_recursive(near_idx, point, radius, count);
825 }
826
827 if diff.abs() <= radius {
829 if let Some(far_idx) = far {
830 self.count_neighbors_recursive(far_idx, point, radius, count);
831 }
832 }
833 }
834
835 pub fn shape(&self) -> (usize, usize) {
841 (self.points.nrows(), self.ndim)
842 }
843
844 pub fn npoints(&self) -> usize {
850 self.points.nrows()
851 }
852
853 pub fn ndim(&self) -> usize {
859 self.ndim
860 }
861
862 pub fn leafsize(&self) -> usize {
868 self.leafsize
869 }
870
871 pub fn bounds(&self) -> &Rectangle<T> {
877 &self.bounds
878 }
879}
880
881#[cfg(test)]
882mod tests {
883 use super::{KDTree, Rectangle};
884 use crate::distance::{
885 ChebyshevDistance, Distance, EuclideanDistance, ManhattanDistance, MinkowskiDistance,
886 };
887 use approx::assert_relative_eq;
888 use scirs2_core::ndarray::arr2;
889
890 #[test]
891 fn test_rectangle() {
892 let mins = vec![0.0, 0.0];
893 let maxes = vec![1.0, 1.0];
894 let rect = Rectangle::new(mins, maxes);
895
896 assert!(rect.contains(&[0.5, 0.5]));
898 assert!(rect.contains(&[0.0, 0.0]));
899 assert!(rect.contains(&[1.0, 1.0]));
900 assert!(!rect.contains(&[1.5, 0.5]));
901 assert!(!rect.contains(&[0.5, 1.5]));
902
903 let (left, right) = rect.split(0, 0.5);
905 assert!(left.contains(&[0.25, 0.5]));
906 assert!(!left.contains(&[0.75, 0.5]));
907 assert!(!right.contains(&[0.25, 0.5]));
908 assert!(right.contains(&[0.75, 0.5]));
909
910 let metric = EuclideanDistance::<f64>::new();
912 assert_relative_eq!(rect.min_distance(&[0.5, 0.5], &metric), 0.0, epsilon = 1e-6);
913 assert_relative_eq!(rect.min_distance(&[2.0, 0.5], &metric), 1.0, epsilon = 1e-6);
914 assert_relative_eq!(
915 rect.min_distance(&[2.0, 2.0], &metric),
916 std::f64::consts::SQRT_2,
917 epsilon = 1e-6
918 );
919 }
920
921 #[test]
922 fn test_kdtree_build() {
923 let points = arr2(&[
924 [2.0, 3.0],
925 [5.0, 4.0],
926 [9.0, 6.0],
927 [4.0, 7.0],
928 [8.0, 1.0],
929 [7.0, 2.0],
930 ]);
931
932 let kdtree = KDTree::new(&points).unwrap();
933
934 assert_eq!(kdtree.nodes.len(), points.nrows());
936
937 assert_eq!(kdtree.shape(), (6, 2));
939 assert_eq!(kdtree.npoints(), 6);
940 assert_eq!(kdtree.ndim(), 2);
941 assert_eq!(kdtree.leafsize(), 16);
942
943 assert_eq!(kdtree.bounds().mins(), &[2.0, 1.0]);
945 assert_eq!(kdtree.bounds().maxes(), &[9.0, 7.0]);
946 }
947
948 #[test]
949 fn test_kdtree_query() {
950 let points = arr2(&[
951 [2.0, 3.0],
952 [5.0, 4.0],
953 [9.0, 6.0],
954 [4.0, 7.0],
955 [8.0, 1.0],
956 [7.0, 2.0],
957 ]);
958
959 let kdtree = KDTree::new(&points).unwrap();
960
961 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
963 assert_eq!(indices.len(), 1);
964 assert_eq!(distances.len(), 1);
965
966 let query = [3.0, 5.0];
968 let mut expected_dists = vec![];
969 for i in 0..points.nrows() {
970 let p = points.row(i).to_vec();
971 let metric = EuclideanDistance::<f64>::new();
972 expected_dists.push((i, metric.distance(&p, &query)));
973 }
974 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
975
976 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
979
980 let min_dist = expected_dists[0].1;
982 let valid_indices: Vec<usize> = expected_dists
983 .iter()
984 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
985 .map(|(i, _)| *i)
986 .collect();
987 assert!(
988 valid_indices.contains(&indices[0]),
989 "Expected one of {:?} but got {}",
990 valid_indices,
991 indices[0]
992 );
993 }
994
995 #[test]
996 fn test_kdtree_query_k() {
997 let points = arr2(&[
998 [2.0, 3.0],
999 [5.0, 4.0],
1000 [9.0, 6.0],
1001 [4.0, 7.0],
1002 [8.0, 1.0],
1003 [7.0, 2.0],
1004 ]);
1005
1006 let kdtree = KDTree::new(&points).unwrap();
1007
1008 let (indices, distances) = kdtree.query(&[3.0, 5.0], 3).unwrap();
1010 assert_eq!(indices.len(), 3);
1011 assert_eq!(distances.len(), 3);
1012
1013 let query = [3.0, 5.0];
1015 let mut expected_dists = vec![];
1016 for i in 0..points.nrows() {
1017 let p = points.row(i).to_vec();
1018 let metric = EuclideanDistance::<f64>::new();
1019 expected_dists.push((i, metric.distance(&p, &query)));
1020 }
1021 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1022
1023 let expected_indices: Vec<usize> = expected_dists.iter().take(3).map(|&(i, _)| i).collect();
1025 let expected_distances: Vec<f64> = expected_dists.iter().take(3).map(|&(_, d)| d).collect();
1026
1027 for i in &indices {
1029 assert!(expected_indices.contains(i));
1030 }
1031
1032 assert!(distances[0] <= distances[1]);
1034 assert!(distances[1] <= distances[2]);
1035
1036 for i in 0..3 {
1038 assert_relative_eq!(distances[i], expected_distances[i], epsilon = 1e-6);
1039 }
1040 }
1041
1042 #[test]
1043 fn test_kdtree_query_radius() {
1044 let points = arr2(&[
1045 [2.0, 3.0],
1046 [5.0, 4.0],
1047 [9.0, 6.0],
1048 [4.0, 7.0],
1049 [8.0, 1.0],
1050 [7.0, 2.0],
1051 ]);
1052
1053 let kdtree = KDTree::new(&points).unwrap();
1054
1055 let (indices, distances) = kdtree.query_radius(&[3.0, 5.0], 3.0).unwrap();
1057
1058 let query = [3.0, 5.0];
1060 let radius = 3.0;
1061 let mut expected_results = vec![];
1062 for i in 0..points.nrows() {
1063 let p = points.row(i).to_vec();
1064 let metric = EuclideanDistance::<f64>::new();
1065 let dist = metric.distance(&p, &query);
1066 if dist <= radius {
1067 expected_results.push((i, dist));
1068 }
1069 }
1070 expected_results.sort_by(|a, b| {
1071 match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1072 std::cmp::Ordering::Equal => a.0.cmp(&b.0), ord => ord,
1074 }
1075 });
1076
1077 assert_eq!(indices.len(), expected_results.len());
1079
1080 for i in 0..indices.len() {
1082 assert!(distances[i] <= radius + 1e-6);
1083 }
1084
1085 let mut idx_dist_pairs: Vec<(usize, f64)> = indices
1088 .iter()
1089 .zip(distances.iter())
1090 .map(|(&i, &d)| (i, d))
1091 .collect();
1092 idx_dist_pairs.sort_by(|a, b| {
1093 match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1094 std::cmp::Ordering::Equal => a.0.cmp(&b.0),
1095 ord => ord,
1096 }
1097 });
1098
1099 for (actual, expected) in idx_dist_pairs.iter().zip(expected_results.iter()) {
1100 assert_eq!(actual.0, expected.0);
1101 assert_relative_eq!(actual.1, expected.1, epsilon = 1e-6);
1102 }
1103 }
1104
1105 #[test]
1106 fn test_kdtree_count_neighbors() {
1107 let points = arr2(&[
1108 [2.0, 3.0],
1109 [5.0, 4.0],
1110 [9.0, 6.0],
1111 [4.0, 7.0],
1112 [8.0, 1.0],
1113 [7.0, 2.0],
1114 ]);
1115
1116 let kdtree = KDTree::new(&points).unwrap();
1117
1118 let count = kdtree.count_neighbors(&[3.0, 5.0], 3.0).unwrap();
1120
1121 let query = [3.0, 5.0];
1123 let mut expected_count = 0;
1124 for i in 0..points.nrows() {
1125 let p = points.row(i).to_vec();
1126 let metric = EuclideanDistance::<f64>::new();
1127 let dist = metric.distance(&p, &query);
1128 if dist <= 3.0 {
1129 expected_count += 1;
1130 }
1131 }
1132
1133 assert_eq!(count, expected_count);
1134 }
1135
1136 #[test]
1137 fn test_kdtree_with_manhattan_distance() {
1138 let points = arr2(&[
1139 [2.0, 3.0],
1140 [5.0, 4.0],
1141 [9.0, 6.0],
1142 [4.0, 7.0],
1143 [8.0, 1.0],
1144 [7.0, 2.0],
1145 ]);
1146
1147 let metric = ManhattanDistance::new();
1148 let kdtree = KDTree::with_metric(&points, metric).unwrap();
1149
1150 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1152
1153 let query = [3.0, 5.0];
1155 let mut expected_dists = vec![];
1156 for i in 0..points.nrows() {
1157 let p = points.row(i).to_vec();
1158 let m = ManhattanDistance::<f64>::new();
1159 expected_dists.push((i, m.distance(&p, &query)));
1160 }
1161 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1162
1163 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1165
1166 let min_dist = expected_dists[0].1;
1168 let valid_indices: Vec<usize> = expected_dists
1169 .iter()
1170 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1171 .map(|(i, _)| *i)
1172 .collect();
1173 assert!(
1174 valid_indices.contains(&indices[0]),
1175 "Expected one of {:?} but got {}",
1176 valid_indices,
1177 indices[0]
1178 );
1179 }
1180
1181 #[test]
1182 fn test_kdtree_with_chebyshev_distance() {
1183 let points = arr2(&[
1184 [2.0, 3.0],
1185 [5.0, 4.0],
1186 [9.0, 6.0],
1187 [4.0, 7.0],
1188 [8.0, 1.0],
1189 [7.0, 2.0],
1190 ]);
1191
1192 let metric = ChebyshevDistance::new();
1193 let kdtree = KDTree::with_metric(&points, metric).unwrap();
1194
1195 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1197
1198 let query = [3.0, 5.0];
1200 let mut expected_dists = vec![];
1201 for i in 0..points.nrows() {
1202 let p = points.row(i).to_vec();
1203 let m = ChebyshevDistance::<f64>::new();
1204 expected_dists.push((i, m.distance(&p, &query)));
1205 }
1206 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1207
1208 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1210
1211 let min_dist = expected_dists[0].1;
1213 let valid_indices: Vec<usize> = expected_dists
1214 .iter()
1215 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1216 .map(|(i, _)| *i)
1217 .collect();
1218 assert!(
1219 valid_indices.contains(&indices[0]),
1220 "Expected one of {:?} but got {}",
1221 valid_indices,
1222 indices[0]
1223 );
1224 }
1225
1226 #[test]
1227 fn test_kdtree_with_minkowski_distance() {
1228 let points = arr2(&[
1229 [2.0, 3.0],
1230 [5.0, 4.0],
1231 [9.0, 6.0],
1232 [4.0, 7.0],
1233 [8.0, 1.0],
1234 [7.0, 2.0],
1235 ]);
1236
1237 let metric = MinkowskiDistance::new(3.0);
1238 let kdtree = KDTree::with_metric(&points, metric).unwrap();
1239
1240 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1242
1243 let query = [3.0, 5.0];
1245 let mut expected_dists = vec![];
1246 for i in 0..points.nrows() {
1247 let p = points.row(i).to_vec();
1248 let m = MinkowskiDistance::<f64>::new(3.0);
1249 expected_dists.push((i, m.distance(&p, &query)));
1250 }
1251 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1252
1253 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1255
1256 let min_dist = expected_dists[0].1;
1258 let valid_indices: Vec<usize> = expected_dists
1259 .iter()
1260 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1261 .map(|(i, _)| *i)
1262 .collect();
1263 assert!(
1264 valid_indices.contains(&indices[0]),
1265 "Expected one of {:?} but got {}",
1266 valid_indices,
1267 indices[0]
1268 );
1269 }
1270
1271 #[test]
1272 fn test_kdtree_with_custom_leaf_size() {
1273 let points = arr2(&[
1274 [2.0, 3.0],
1275 [5.0, 4.0],
1276 [9.0, 6.0],
1277 [4.0, 7.0],
1278 [8.0, 1.0],
1279 [7.0, 2.0],
1280 ]);
1281
1282 let leafsize = 1;
1284 let kdtree = KDTree::with_leaf_size(&points, leafsize).unwrap();
1285
1286 assert_eq!(kdtree.leafsize(), 1);
1287
1288 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).unwrap();
1290
1291 let query = [3.0, 5.0];
1293 let mut expected_dists = vec![];
1294 for i in 0..points.nrows() {
1295 let p = points.row(i).to_vec();
1296 let metric = EuclideanDistance::<f64>::new();
1297 expected_dists.push((i, metric.distance(&p, &query)));
1298 }
1299 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1300
1301 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1304
1305 let min_dist = expected_dists[0].1;
1307 let valid_indices: Vec<usize> = expected_dists
1308 .iter()
1309 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1310 .map(|(i, _)| *i)
1311 .collect();
1312 assert!(
1313 valid_indices.contains(&indices[0]),
1314 "Expected one of {:?} but got {}",
1315 valid_indices,
1316 indices[0]
1317 );
1318 }
1319}