1use crate::distance::{Distance, EuclideanDistance};
68use crate::error::{SpatialError, SpatialResult};
69use crate::safe_conversions::*;
70use scirs2_core::ndarray::Array2;
71use scirs2_core::numeric::Float;
72use std::cmp::Ordering;
73
74#[derive(Clone, Debug)]
79pub struct Rectangle<T: Float> {
80 mins: Vec<T>,
82 maxes: Vec<T>,
84}
85
86impl<T: Float> Rectangle<T> {
87 pub fn new(mins: Vec<T>, maxes: Vec<T>) -> Self {
103 assert_eq!(
104 mins.len(),
105 maxes.len(),
106 "mins and maxes must have the same length"
107 );
108
109 for i in 0..mins.len() {
110 assert!(
111 mins[i] <= maxes[i],
112 "min value must be less than or equal to max value"
113 );
114 }
115
116 Rectangle { mins, maxes }
117 }
118
119 pub fn mins(&self) -> &[T] {
135 &self.mins
136 }
137
138 pub fn maxes(&self) -> &[T] {
154 &self.maxes
155 }
156
157 pub fn split(&self, dim: usize, value: T) -> (Self, Self) {
168 let mut left_maxes = self.maxes.clone();
169 left_maxes[dim] = value;
170
171 let mut right_mins = self.mins.clone();
172 right_mins[dim] = value;
173
174 let left = Rectangle::new(self.mins.clone(), left_maxes);
175 let right = Rectangle::new(right_mins, self.maxes.clone());
176
177 (left, right)
178 }
179
180 pub fn contains(&self, point: &[T]) -> bool {
190 assert_eq!(
191 point.len(),
192 self.mins.len(),
193 "point must have the same dimension as the rectangle"
194 );
195
196 for (i, &p) in point.iter().enumerate() {
197 if p < self.mins[i] || p > self.maxes[i] {
198 return false;
199 }
200 }
201
202 true
203 }
204
205 pub fn min_distance<D: Distance<T>>(&self, point: &[T], metric: &D) -> T {
216 metric.min_distance_point_rectangle(point, &self.mins, &self.maxes)
217 }
218}
219
220#[derive(Debug, Clone)]
222struct KDNode<T: Float> {
223 idx: usize,
225 value: T,
227 axis: usize,
229 left: Option<usize>,
231 right: Option<usize>,
233}
234
235#[derive(Debug, Clone)]
242pub struct KDTree<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> {
243 points: Array2<T>,
245 nodes: Vec<KDNode<T>>,
247 ndim: usize,
249 root: Option<usize>,
251 metric: D,
253 leafsize: usize,
255 bounds: Rectangle<T>,
257}
258
259impl<T: Float + Send + Sync + 'static> KDTree<T, EuclideanDistance<T>> {
260 pub fn new(points: &Array2<T>) -> SpatialResult<Self> {
270 let metric = EuclideanDistance::new();
271 Self::with_metric(points, metric)
272 }
273
274 pub fn with_leaf_size(points: &Array2<T>, leafsize: usize) -> SpatialResult<Self> {
285 let metric = EuclideanDistance::new();
286 Self::with_options(points, metric, leafsize)
287 }
288}
289
290impl<T: Float + Send + Sync + 'static, D: Distance<T> + 'static> KDTree<T, D> {
291 pub fn with_metric(points: &Array2<T>, metric: D) -> SpatialResult<Self> {
302 Self::with_options(points, metric, 16) }
304
305 pub fn with_options(points: &Array2<T>, metric: D, leafsize: usize) -> SpatialResult<Self> {
317 let n = points.nrows();
318 let ndim = points.ncols();
319
320 if n == 0 {
321 return Err(SpatialError::ValueError("Empty point set".to_string()));
322 }
323
324 if leafsize == 0 {
325 return Err(SpatialError::ValueError(
326 "Leaf _size must be greater than 0".to_string(),
327 ));
328 }
329
330 let mut mins = vec![T::max_value(); ndim];
332 let mut maxes = vec![T::min_value(); ndim];
333
334 for i in 0..n {
335 for j in 0..ndim {
336 let val = points[[i, j]];
337 if val < mins[j] {
338 mins[j] = val;
339 }
340 if val > maxes[j] {
341 maxes[j] = val;
342 }
343 }
344 }
345
346 let bounds = Rectangle::new(mins, maxes);
347
348 let mut tree = KDTree {
349 points: points.clone(),
350 nodes: Vec::with_capacity(n),
351 ndim,
352 root: None,
353 metric,
354 leafsize,
355 bounds,
356 };
357
358 let mut indices: Vec<usize> = (0..n).collect();
360
361 if n > 0 {
363 let root = tree.build_tree(&mut indices, 0, 0, n)?;
364 tree.root = Some(root);
365 }
366
367 Ok(tree)
368 }
369
370 fn build_tree(
383 &mut self,
384 indices: &mut [usize],
385 depth: usize,
386 start: usize,
387 end: usize,
388 ) -> SpatialResult<usize> {
389 let n = end - start;
390
391 if n == 0 {
392 return Err(SpatialError::ValueError(
393 "Empty point set in build_tree".to_string(),
394 ));
395 }
396
397 let axis = depth % self.ndim;
399
400 let node_idx;
402 if n == 1 {
403 let idx = indices[start];
404 let value = self.points[[idx, axis]];
405
406 node_idx = self.nodes.len();
407 self.nodes.push(KDNode {
408 idx,
409 value,
410 axis,
411 left: None,
412 right: None,
413 });
414
415 return Ok(node_idx);
416 }
417
418 indices[start..end].sort_by(|&i, &j| {
420 let a = self.points[[i, axis]];
421 let b = self.points[[j, axis]];
422 a.partial_cmp(&b).unwrap_or(Ordering::Equal)
423 });
424
425 let mid = start + n / 2;
427 let idx = indices[mid];
428 let value = self.points[[idx, axis]];
429
430 node_idx = self.nodes.len();
432 self.nodes.push(KDNode {
433 idx,
434 value,
435 axis,
436 left: None,
437 right: None,
438 });
439
440 if mid > start {
442 let left_idx = self.build_tree(indices, depth + 1, start, mid)?;
443 self.nodes[node_idx].left = Some(left_idx);
444 }
445
446 if mid + 1 < end {
447 let right_idx = self.build_tree(indices, depth + 1, mid + 1, end)?;
448 self.nodes[node_idx].right = Some(right_idx);
449 }
450
451 Ok(node_idx)
452 }
453
454 pub fn query(&self, point: &[T], k: usize) -> SpatialResult<(Vec<usize>, Vec<T>)> {
481 if point.len() != self.ndim {
482 return Err(SpatialError::DimensionError(format!(
483 "Query point dimension ({}) does not match tree dimension ({})",
484 point.len(),
485 self.ndim
486 )));
487 }
488
489 if k == 0 {
490 return Ok((vec![], vec![]));
491 }
492
493 if self.points.nrows() == 0 {
494 return Ok((vec![], vec![]));
495 }
496
497 let mut neighbors: Vec<(T, usize)> = Vec::with_capacity(k + 1);
500
501 let mut max_dist = T::infinity();
503
504 if let Some(root) = self.root {
505 self.query_recursive(root, point, k, &mut neighbors, &mut max_dist);
507
508 neighbors.sort_by(|a, b| {
510 match safe_partial_cmp(&a.0, &b.0, "kdtree sort neighbors") {
511 Ok(std::cmp::Ordering::Equal) => a.1.cmp(&b.1), Ok(ord) => ord,
513 Err(_) => std::cmp::Ordering::Equal,
514 }
515 });
516
517 if neighbors.len() > k {
519 neighbors.truncate(k);
520 }
521
522 let mut indices = Vec::with_capacity(neighbors.len());
524 let mut distances = Vec::with_capacity(neighbors.len());
525
526 for (dist, idx) in neighbors {
527 indices.push(idx);
528 distances.push(dist);
529 }
530
531 Ok((indices, distances))
532 } else {
533 Err(SpatialError::ValueError("Empty tree".to_string()))
534 }
535 }
536
537 fn query_recursive(
539 &self,
540 node_idx: usize,
541 point: &[T],
542 k: usize,
543 neighbors: &mut Vec<(T, usize)>,
544 max_dist: &mut T,
545 ) {
546 let node = &self.nodes[node_idx];
547 let idx = node.idx;
548 let axis = node.axis;
549
550 let node_point = self.points.row(idx).to_vec();
552 let _dist = self.metric.distance(&node_point, point);
553
554 if neighbors.len() < k {
556 neighbors.push((_dist, idx));
557
558 if neighbors.len() == k {
560 neighbors.sort_by(|a, b| {
561 match safe_partial_cmp(&b.0, &a.0, "kdtree sort max-heap") {
562 Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), Ok(ord) => ord,
564 Err(_) => std::cmp::Ordering::Equal,
565 }
566 });
567 *max_dist = neighbors[0].0;
568 }
569 } else if &_dist < max_dist {
570 neighbors[0] = (_dist, idx);
572
573 neighbors.sort_by(|a, b| {
575 match safe_partial_cmp(&b.0, &a.0, "kdtree re-sort max-heap") {
576 Ok(std::cmp::Ordering::Equal) => b.1.cmp(&a.1), Ok(ord) => ord,
578 Err(_) => std::cmp::Ordering::Equal,
579 }
580 });
581 *max_dist = neighbors[0].0;
582 }
583
584 let diff = point[axis] - node.value;
586 let (first, second) = if diff < T::zero() {
587 (node.left, node.right)
588 } else {
589 (node.right, node.left)
590 };
591
592 if let Some(first_idx) = first {
594 self.query_recursive(first_idx, point, k, neighbors, max_dist);
595 }
596
597 let axis_dist = if diff < T::zero() {
599 T::zero() } else {
602 diff
604 };
605
606 if let Some(second_idx) = second {
607 if neighbors.len() < k || axis_dist < *max_dist {
609 self.query_recursive(second_idx, point, k, neighbors, max_dist);
610 }
611 }
612 }
613
614 pub fn query_radius(&self, point: &[T], radius: T) -> SpatialResult<(Vec<usize>, Vec<T>)> {
642 if point.len() != self.ndim {
643 return Err(SpatialError::DimensionError(format!(
644 "Query point dimension ({}) does not match tree dimension ({})",
645 point.len(),
646 self.ndim
647 )));
648 }
649
650 if radius < T::zero() {
651 return Err(SpatialError::ValueError(
652 "Radius must be non-negative".to_string(),
653 ));
654 }
655
656 let mut indices = Vec::new();
657 let mut distances = Vec::new();
658
659 if let Some(root) = self.root {
660 let bounds_dist = self.bounds.min_distance(point, &self.metric);
662 if bounds_dist > radius {
663 return Ok((indices, distances));
664 }
665
666 self.query_radius_recursive(root, point, radius, &mut indices, &mut distances);
668
669 if !indices.is_empty() {
671 let mut idx_dist: Vec<(usize, T)> = indices.into_iter().zip(distances).collect();
672 idx_dist.sort_by(|a, b| {
673 safe_partial_cmp(&a.1, &b.1, "kdtree sort radius results")
674 .unwrap_or(std::cmp::Ordering::Equal)
675 });
676
677 indices = idx_dist.iter().map(|(idx_, _)| *idx_).collect();
678 distances = idx_dist.iter().map(|(_, dist)| *dist).collect();
679 }
680 }
681
682 Ok((indices, distances))
683 }
684
685 fn query_radius_recursive(
687 &self,
688 node_idx: usize,
689 point: &[T],
690 radius: T,
691 indices: &mut Vec<usize>,
692 distances: &mut Vec<T>,
693 ) {
694 let node = &self.nodes[node_idx];
695 let idx = node.idx;
696 let axis = node.axis;
697
698 let node_point = self.points.row(idx).to_vec();
700 let dist = self.metric.distance(&node_point, point);
701
702 if dist <= radius {
704 indices.push(idx);
705 distances.push(dist);
706 }
707
708 let diff = point[axis] - node.value;
710
711 let (near, far) = if diff < T::zero() {
713 (node.left, node.right)
714 } else {
715 (node.right, node.left)
716 };
717
718 if let Some(near_idx) = near {
719 self.query_radius_recursive(near_idx, point, radius, indices, distances);
720 }
721
722 if diff.abs() <= radius {
724 if let Some(far_idx) = far {
725 self.query_radius_recursive(far_idx, point, radius, indices, distances);
726 }
727 }
728 }
729
730 pub fn count_neighbors(&self, point: &[T], radius: T) -> SpatialResult<usize> {
760 if point.len() != self.ndim {
761 return Err(SpatialError::DimensionError(format!(
762 "Query point dimension ({}) does not match tree dimension ({})",
763 point.len(),
764 self.ndim
765 )));
766 }
767
768 if radius < T::zero() {
769 return Err(SpatialError::ValueError(
770 "Radius must be non-negative".to_string(),
771 ));
772 }
773
774 let mut count = 0;
775
776 if let Some(root) = self.root {
777 let bounds_dist = self.bounds.min_distance(point, &self.metric);
779 if bounds_dist > radius {
780 return Ok(0);
781 }
782
783 self.count_neighbors_recursive(root, point, radius, &mut count);
785 }
786
787 Ok(count)
788 }
789
790 fn count_neighbors_recursive(
792 &self,
793 node_idx: usize,
794 point: &[T],
795 radius: T,
796 count: &mut usize,
797 ) {
798 let node = &self.nodes[node_idx];
799 let idx = node.idx;
800 let axis = node.axis;
801
802 let node_point = self.points.row(idx).to_vec();
804 let dist = self.metric.distance(&node_point, point);
805
806 if dist <= radius {
808 *count += 1;
809 }
810
811 let diff = point[axis] - node.value;
813
814 let (near, far) = if diff < T::zero() {
816 (node.left, node.right)
817 } else {
818 (node.right, node.left)
819 };
820
821 if let Some(near_idx) = near {
822 self.count_neighbors_recursive(near_idx, point, radius, count);
823 }
824
825 if diff.abs() <= radius {
827 if let Some(far_idx) = far {
828 self.count_neighbors_recursive(far_idx, point, radius, count);
829 }
830 }
831 }
832
833 pub fn shape(&self) -> (usize, usize) {
839 (self.points.nrows(), self.ndim)
840 }
841
842 pub fn npoints(&self) -> usize {
848 self.points.nrows()
849 }
850
851 pub fn ndim(&self) -> usize {
857 self.ndim
858 }
859
860 pub fn leafsize(&self) -> usize {
866 self.leafsize
867 }
868
869 pub fn bounds(&self) -> &Rectangle<T> {
875 &self.bounds
876 }
877}
878
879#[cfg(test)]
880mod tests {
881 use super::{KDTree, Rectangle};
882 use crate::distance::{
883 ChebyshevDistance, Distance, EuclideanDistance, ManhattanDistance, MinkowskiDistance,
884 };
885 use approx::assert_relative_eq;
886 use scirs2_core::ndarray::arr2;
887
888 #[test]
889 fn test_rectangle() {
890 let mins = vec![0.0, 0.0];
891 let maxes = vec![1.0, 1.0];
892 let rect = Rectangle::new(mins, maxes);
893
894 assert!(rect.contains(&[0.5, 0.5]));
896 assert!(rect.contains(&[0.0, 0.0]));
897 assert!(rect.contains(&[1.0, 1.0]));
898 assert!(!rect.contains(&[1.5, 0.5]));
899 assert!(!rect.contains(&[0.5, 1.5]));
900
901 let (left, right) = rect.split(0, 0.5);
903 assert!(left.contains(&[0.25, 0.5]));
904 assert!(!left.contains(&[0.75, 0.5]));
905 assert!(!right.contains(&[0.25, 0.5]));
906 assert!(right.contains(&[0.75, 0.5]));
907
908 let metric = EuclideanDistance::<f64>::new();
910 assert_relative_eq!(rect.min_distance(&[0.5, 0.5], &metric), 0.0, epsilon = 1e-6);
911 assert_relative_eq!(rect.min_distance(&[2.0, 0.5], &metric), 1.0, epsilon = 1e-6);
912 assert_relative_eq!(
913 rect.min_distance(&[2.0, 2.0], &metric),
914 std::f64::consts::SQRT_2,
915 epsilon = 1e-6
916 );
917 }
918
919 #[test]
920 fn test_kdtree_build() {
921 let points = arr2(&[
922 [2.0, 3.0],
923 [5.0, 4.0],
924 [9.0, 6.0],
925 [4.0, 7.0],
926 [8.0, 1.0],
927 [7.0, 2.0],
928 ]);
929
930 let kdtree = KDTree::new(&points).expect("Operation failed");
931
932 assert_eq!(kdtree.nodes.len(), points.nrows());
934
935 assert_eq!(kdtree.shape(), (6, 2));
937 assert_eq!(kdtree.npoints(), 6);
938 assert_eq!(kdtree.ndim(), 2);
939 assert_eq!(kdtree.leafsize(), 16);
940
941 assert_eq!(kdtree.bounds().mins(), &[2.0, 1.0]);
943 assert_eq!(kdtree.bounds().maxes(), &[9.0, 7.0]);
944 }
945
946 #[test]
947 fn test_kdtree_query() {
948 let points = arr2(&[
949 [2.0, 3.0],
950 [5.0, 4.0],
951 [9.0, 6.0],
952 [4.0, 7.0],
953 [8.0, 1.0],
954 [7.0, 2.0],
955 ]);
956
957 let kdtree = KDTree::new(&points).expect("Operation failed");
958
959 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
961 assert_eq!(indices.len(), 1);
962 assert_eq!(distances.len(), 1);
963
964 let query = [3.0, 5.0];
966 let mut expected_dists = vec![];
967 for i in 0..points.nrows() {
968 let p = points.row(i).to_vec();
969 let metric = EuclideanDistance::<f64>::new();
970 expected_dists.push((i, metric.distance(&p, &query)));
971 }
972 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
973
974 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
977
978 let min_dist = expected_dists[0].1;
980 let valid_indices: Vec<usize> = expected_dists
981 .iter()
982 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
983 .map(|(i, _)| *i)
984 .collect();
985 assert!(
986 valid_indices.contains(&indices[0]),
987 "Expected one of {:?} but got {}",
988 valid_indices,
989 indices[0]
990 );
991 }
992
993 #[test]
994 fn test_kdtree_query_k() {
995 let points = arr2(&[
996 [2.0, 3.0],
997 [5.0, 4.0],
998 [9.0, 6.0],
999 [4.0, 7.0],
1000 [8.0, 1.0],
1001 [7.0, 2.0],
1002 ]);
1003
1004 let kdtree = KDTree::new(&points).expect("Operation failed");
1005
1006 let (indices, distances) = kdtree.query(&[3.0, 5.0], 3).expect("Operation failed");
1008 assert_eq!(indices.len(), 3);
1009 assert_eq!(distances.len(), 3);
1010
1011 let query = [3.0, 5.0];
1013 let mut expected_dists = vec![];
1014 for i in 0..points.nrows() {
1015 let p = points.row(i).to_vec();
1016 let metric = EuclideanDistance::<f64>::new();
1017 expected_dists.push((i, metric.distance(&p, &query)));
1018 }
1019 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1020
1021 let expected_indices: Vec<usize> = expected_dists.iter().take(3).map(|&(i, _)| i).collect();
1023 let expected_distances: Vec<f64> = expected_dists.iter().take(3).map(|&(_, d)| d).collect();
1024
1025 for i in &indices {
1027 assert!(expected_indices.contains(i));
1028 }
1029
1030 assert!(distances[0] <= distances[1]);
1032 assert!(distances[1] <= distances[2]);
1033
1034 for i in 0..3 {
1036 assert_relative_eq!(distances[i], expected_distances[i], epsilon = 1e-6);
1037 }
1038 }
1039
1040 #[test]
1041 fn test_kdtree_query_radius() {
1042 let points = arr2(&[
1043 [2.0, 3.0],
1044 [5.0, 4.0],
1045 [9.0, 6.0],
1046 [4.0, 7.0],
1047 [8.0, 1.0],
1048 [7.0, 2.0],
1049 ]);
1050
1051 let kdtree = KDTree::new(&points).expect("Operation failed");
1052
1053 let (indices, distances) = kdtree
1055 .query_radius(&[3.0, 5.0], 3.0)
1056 .expect("Operation failed");
1057
1058 let query = [3.0, 5.0];
1060 let radius = 3.0;
1061 let mut expected_results = vec![];
1062 for i in 0..points.nrows() {
1063 let p = points.row(i).to_vec();
1064 let metric = EuclideanDistance::<f64>::new();
1065 let dist = metric.distance(&p, &query);
1066 if dist <= radius {
1067 expected_results.push((i, dist));
1068 }
1069 }
1070 expected_results.sort_by(|a, b| {
1071 match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1072 std::cmp::Ordering::Equal => a.0.cmp(&b.0), ord => ord,
1074 }
1075 });
1076
1077 assert_eq!(indices.len(), expected_results.len());
1079
1080 for i in 0..indices.len() {
1082 assert!(distances[i] <= radius + 1e-6);
1083 }
1084
1085 let mut idx_dist_pairs: Vec<(usize, f64)> = indices
1088 .iter()
1089 .zip(distances.iter())
1090 .map(|(&i, &d)| (i, d))
1091 .collect();
1092 idx_dist_pairs.sort_by(|a, b| {
1093 match a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) {
1094 std::cmp::Ordering::Equal => a.0.cmp(&b.0),
1095 ord => ord,
1096 }
1097 });
1098
1099 for (actual, expected) in idx_dist_pairs.iter().zip(expected_results.iter()) {
1100 assert_eq!(actual.0, expected.0);
1101 assert_relative_eq!(actual.1, expected.1, epsilon = 1e-6);
1102 }
1103 }
1104
1105 #[test]
1106 fn test_kdtree_count_neighbors() {
1107 let points = arr2(&[
1108 [2.0, 3.0],
1109 [5.0, 4.0],
1110 [9.0, 6.0],
1111 [4.0, 7.0],
1112 [8.0, 1.0],
1113 [7.0, 2.0],
1114 ]);
1115
1116 let kdtree = KDTree::new(&points).expect("Operation failed");
1117
1118 let count = kdtree
1120 .count_neighbors(&[3.0, 5.0], 3.0)
1121 .expect("Operation failed");
1122
1123 let query = [3.0, 5.0];
1125 let mut expected_count = 0;
1126 for i in 0..points.nrows() {
1127 let p = points.row(i).to_vec();
1128 let metric = EuclideanDistance::<f64>::new();
1129 let dist = metric.distance(&p, &query);
1130 if dist <= 3.0 {
1131 expected_count += 1;
1132 }
1133 }
1134
1135 assert_eq!(count, expected_count);
1136 }
1137
1138 #[test]
1139 fn test_kdtree_with_manhattan_distance() {
1140 let points = arr2(&[
1141 [2.0, 3.0],
1142 [5.0, 4.0],
1143 [9.0, 6.0],
1144 [4.0, 7.0],
1145 [8.0, 1.0],
1146 [7.0, 2.0],
1147 ]);
1148
1149 let metric = ManhattanDistance::new();
1150 let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
1151
1152 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1154
1155 let query = [3.0, 5.0];
1157 let mut expected_dists = vec![];
1158 for i in 0..points.nrows() {
1159 let p = points.row(i).to_vec();
1160 let m = ManhattanDistance::<f64>::new();
1161 expected_dists.push((i, m.distance(&p, &query)));
1162 }
1163 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1164
1165 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1167
1168 let min_dist = expected_dists[0].1;
1170 let valid_indices: Vec<usize> = expected_dists
1171 .iter()
1172 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1173 .map(|(i, _)| *i)
1174 .collect();
1175 assert!(
1176 valid_indices.contains(&indices[0]),
1177 "Expected one of {:?} but got {}",
1178 valid_indices,
1179 indices[0]
1180 );
1181 }
1182
1183 #[test]
1184 fn test_kdtree_with_chebyshev_distance() {
1185 let points = arr2(&[
1186 [2.0, 3.0],
1187 [5.0, 4.0],
1188 [9.0, 6.0],
1189 [4.0, 7.0],
1190 [8.0, 1.0],
1191 [7.0, 2.0],
1192 ]);
1193
1194 let metric = ChebyshevDistance::new();
1195 let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
1196
1197 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1199
1200 let query = [3.0, 5.0];
1202 let mut expected_dists = vec![];
1203 for i in 0..points.nrows() {
1204 let p = points.row(i).to_vec();
1205 let m = ChebyshevDistance::<f64>::new();
1206 expected_dists.push((i, m.distance(&p, &query)));
1207 }
1208 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1209
1210 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1212
1213 let min_dist = expected_dists[0].1;
1215 let valid_indices: Vec<usize> = expected_dists
1216 .iter()
1217 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1218 .map(|(i, _)| *i)
1219 .collect();
1220 assert!(
1221 valid_indices.contains(&indices[0]),
1222 "Expected one of {:?} but got {}",
1223 valid_indices,
1224 indices[0]
1225 );
1226 }
1227
1228 #[test]
1229 fn test_kdtree_with_minkowski_distance() {
1230 let points = arr2(&[
1231 [2.0, 3.0],
1232 [5.0, 4.0],
1233 [9.0, 6.0],
1234 [4.0, 7.0],
1235 [8.0, 1.0],
1236 [7.0, 2.0],
1237 ]);
1238
1239 let metric = MinkowskiDistance::new(3.0);
1240 let kdtree = KDTree::with_metric(&points, metric).expect("Operation failed");
1241
1242 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1244
1245 let query = [3.0, 5.0];
1247 let mut expected_dists = vec![];
1248 for i in 0..points.nrows() {
1249 let p = points.row(i).to_vec();
1250 let m = MinkowskiDistance::<f64>::new(3.0);
1251 expected_dists.push((i, m.distance(&p, &query)));
1252 }
1253 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1254
1255 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1257
1258 let min_dist = expected_dists[0].1;
1260 let valid_indices: Vec<usize> = expected_dists
1261 .iter()
1262 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1263 .map(|(i, _)| *i)
1264 .collect();
1265 assert!(
1266 valid_indices.contains(&indices[0]),
1267 "Expected one of {:?} but got {}",
1268 valid_indices,
1269 indices[0]
1270 );
1271 }
1272
1273 #[test]
1274 fn test_kdtree_with_custom_leaf_size() {
1275 let points = arr2(&[
1276 [2.0, 3.0],
1277 [5.0, 4.0],
1278 [9.0, 6.0],
1279 [4.0, 7.0],
1280 [8.0, 1.0],
1281 [7.0, 2.0],
1282 ]);
1283
1284 let leafsize = 1;
1286 let kdtree = KDTree::with_leaf_size(&points, leafsize).expect("Operation failed");
1287
1288 assert_eq!(kdtree.leafsize(), 1);
1289
1290 let (indices, distances) = kdtree.query(&[3.0, 5.0], 1).expect("Operation failed");
1292
1293 let query = [3.0, 5.0];
1295 let mut expected_dists = vec![];
1296 for i in 0..points.nrows() {
1297 let p = points.row(i).to_vec();
1298 let metric = EuclideanDistance::<f64>::new();
1299 expected_dists.push((i, metric.distance(&p, &query)));
1300 }
1301 expected_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1302
1303 assert_relative_eq!(distances[0], expected_dists[0].1, epsilon = 1e-6);
1306
1307 let min_dist = expected_dists[0].1;
1309 let valid_indices: Vec<usize> = expected_dists
1310 .iter()
1311 .filter(|(_, d)| (d - min_dist).abs() < 1e-6)
1312 .map(|(i, _)| *i)
1313 .collect();
1314 assert!(
1315 valid_indices.contains(&indices[0]),
1316 "Expected one of {:?} but got {}",
1317 valid_indices,
1318 indices[0]
1319 );
1320 }
1321}