1use ordered_float::OrderedFloat;
15use scirs2_core::ndarray::{Array2, ArrayBase, ArrayView1, Data, Ix2};
16use scirs2_core::numeric::{Float, FromPrimitive};
17use std::cmp::Ordering;
18use std::fmt::Debug;
19use std::marker::PhantomData;
20
21use crate::error::{InterpolateError, InterpolateResult};
22
23#[derive(Debug, Clone)]
25struct KdNode<F: Float + ordered_float::FloatCore> {
26 idx: usize,
28
29 dim: usize,
31
32 value: F,
34
35 left: Option<usize>,
37
38 right: Option<usize>,
40}
41
42#[derive(Debug, Clone)]
73pub struct KdTree<F>
74where
75 F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
76{
77 points: Array2<F>,
79
80 nodes: Vec<KdNode<F>>,
82
83 root: Option<usize>,
85
86 dim: usize,
88
89 leaf_size: usize,
91
92 _phantom: PhantomData<F>,
94}
95
96impl<F> KdTree<F>
97where
98 F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
99{
100 pub fn new<S>(points: ArrayBase<S, Ix2>) -> InterpolateResult<Self>
110 where
111 S: Data<Elem = F>,
112 {
113 Self::with_leaf_size(points, 10)
114 }
115
116 pub fn with_leaf_size<S>(
127 _points: ArrayBase<S, Ix2>,
128 leaf_size: usize,
129 ) -> InterpolateResult<Self>
130 where
131 S: Data<Elem = F>,
132 {
133 let points = _points.to_owned();
135 if points.is_empty() {
136 return Err(InterpolateError::InvalidValue(
137 "Points array cannot be empty".to_string(),
138 ));
139 }
140
141 let n_points = points.shape()[0];
142 let dim = points.shape()[1];
143
144 if n_points <= leaf_size {
146 let mut tree = Self {
147 points,
148 nodes: Vec::new(),
149 root: None,
150 dim,
151 leaf_size,
152 _phantom: PhantomData,
153 };
154
155 if n_points > 0 {
156 tree.nodes.push(KdNode {
158 idx: 0,
159 dim: 0,
160 value: F::zero(), left: None,
162 right: None,
163 });
164 tree.root = Some(0);
165 }
166
167 return Ok(tree);
168 }
169
170 let est_nodes = (2 * n_points / leaf_size).max(16);
172
173 let mut tree = Self {
174 points,
175 nodes: Vec::with_capacity(est_nodes),
176 root: None,
177 dim,
178 leaf_size,
179 _phantom: PhantomData,
180 };
181
182 let mut indices: Vec<usize> = (0..n_points).collect();
184 tree.root = tree.build_subtree(&mut indices, 0);
185
186 Ok(tree)
187 }
188
189 fn build_subtree(&mut self, indices: &mut [usize], depth: usize) -> Option<usize> {
191 let n_points = indices.len();
192
193 if n_points == 0 {
194 return None;
195 }
196
197 if n_points <= self.leaf_size {
199 let node_idx = self.nodes.len();
200 self.nodes.push(KdNode {
201 idx: indices[0], dim: 0, value: F::zero(), left: None,
205 right: None,
206 });
207 return Some(node_idx);
208 }
209
210 let dim = depth % self.dim;
212
213 self.find_median(indices, dim);
215 let median_idx = n_points / 2;
216
217 let split_point_idx = indices[median_idx];
219 let split_value = self.points[[split_point_idx, dim]];
220
221 let node_idx = self.nodes.len();
222 self.nodes.push(KdNode {
223 idx: split_point_idx,
224 dim,
225 value: split_value,
226 left: None,
227 right: None,
228 });
229
230 let (left_indices, right_indices) = indices.split_at_mut(median_idx);
232 let right_indices = &mut right_indices[1..]; let left_child = self.build_subtree(left_indices, depth + 1);
235 let right_child = self.build_subtree(right_indices, depth + 1);
236
237 self.nodes[node_idx].left = left_child;
239 self.nodes[node_idx].right = right_child;
240
241 Some(node_idx)
242 }
243
244 fn find_median(&self, indices: &mut [usize], dim: usize) {
247 let n = indices.len();
248 if n <= 1 {
249 return;
250 }
251
252 let median_idx = n / 2;
253 quickselect_by_key(indices, median_idx, |&idx| self.points[[idx, dim]]);
254 }
255
256 pub fn nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
266 if query.len() != self.dim {
268 return Err(InterpolateError::DimensionMismatch(format!(
269 "Query dimension {} doesn't match KD-tree dimension {}",
270 query.len(),
271 self.dim
272 )));
273 }
274
275 if self.root.is_none() {
277 return Err(InterpolateError::InvalidState(
278 "KD-tree is empty".to_string(),
279 ));
280 }
281
282 if self.points.shape()[0] <= self.leaf_size {
284 return self.linear_nearest_neighbor(query);
285 }
286
287 let mut best_dist = <F as scirs2_core::numeric::Float>::infinity();
289 let mut best_idx = 0;
290
291 self.search_nearest(self.root.unwrap(), query, &mut best_dist, &mut best_idx);
293
294 Ok((best_idx, best_dist))
295 }
296
297 pub fn k_nearest_neighbors(&self, query: &[F], k: usize) -> InterpolateResult<Vec<(usize, F)>> {
308 if query.len() != self.dim {
310 return Err(InterpolateError::DimensionMismatch(format!(
311 "Query dimension {} doesn't match KD-tree dimension {}",
312 query.len(),
313 self.dim
314 )));
315 }
316
317 if self.root.is_none() {
319 return Err(InterpolateError::InvalidState(
320 "KD-tree is empty".to_string(),
321 ));
322 }
323
324 let k = k.min(self.points.shape()[0]);
326
327 if k == 0 {
328 return Ok(Vec::new());
329 }
330
331 if self.points.shape()[0] <= self.leaf_size {
333 return self.linear_k_nearest_neighbors(query, k);
334 }
335
336 use ordered_float::OrderedFloat;
340 use std::collections::BinaryHeap;
341
342 let mut heap: BinaryHeap<(OrderedFloat<F>, usize)> = BinaryHeap::with_capacity(k + 1);
343
344 self.search_k_nearest(self.root.unwrap(), query, k, &mut heap);
346
347 let mut results: Vec<(usize, F)> = heap
349 .into_iter()
350 .map(|(dist, idx)| (idx, dist.into_inner()))
351 .collect();
352
353 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
355
356 Ok(results)
357 }
358
359 pub fn points_within_radius(
370 &self,
371 query: &[F],
372 radius: F,
373 ) -> InterpolateResult<Vec<(usize, F)>> {
374 if query.len() != self.dim {
376 return Err(InterpolateError::DimensionMismatch(format!(
377 "Query dimension {} doesn't match KD-tree dimension {}",
378 query.len(),
379 self.dim
380 )));
381 }
382
383 if self.root.is_none() {
385 return Err(InterpolateError::InvalidState(
386 "KD-tree is empty".to_string(),
387 ));
388 }
389
390 if radius <= F::zero() {
391 return Err(InterpolateError::InvalidValue(
392 "Radius must be positive".to_string(),
393 ));
394 }
395
396 if self.points.shape()[0] <= self.leaf_size {
398 return self.linear_points_within_radius(query, radius);
399 }
400
401 let mut results = Vec::new();
403
404 self.search_radius(self.root.unwrap(), query, radius, &mut results);
406
407 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
409
410 Ok(results)
411 }
412
413 fn search_nearest(
415 &self,
416 node_idx: usize,
417 query: &[F],
418 best_dist: &mut F,
419 best_idx: &mut usize,
420 ) {
421 let node = &self.nodes[node_idx];
422
423 let point_idx = node.idx;
425 let point = self.points.row(point_idx);
426 let _dist = self.distance(&point.to_vec(), query);
427
428 if _dist < *best_dist {
430 *best_dist = _dist;
431 *best_idx = point_idx;
432 }
433
434 if node.left.is_none() && node.right.is_none() {
436 return;
437 }
438
439 let dim = node.dim;
441 let query_val = query[dim];
442 let node_val = node.value;
443
444 let (first, second) = if query_val < node_val {
445 (node.left, node.right)
446 } else {
447 (node.right, node.left)
448 };
449
450 if let Some(first_idx) = first {
452 self.search_nearest(first_idx, query, best_dist, best_idx);
453 }
454
455 let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
457
458 if plane_dist < *best_dist {
460 if let Some(second_idx) = second {
461 self.search_nearest(second_idx, query, best_dist, best_idx);
462 }
463 }
464 }
465
466 #[allow(clippy::type_complexity)]
468 fn search_k_nearest(
469 &self,
470 node_idx: usize,
471 query: &[F],
472 k: usize,
473 heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
474 ) {
475 let node = &self.nodes[node_idx];
476
477 let point_idx = node.idx;
479 let point = self.points.row(point_idx);
480 let dist = self.distance(&point.to_vec(), query);
481
482 heap.push((OrderedFloat(dist), point_idx));
484
485 if heap.len() > k {
487 heap.pop();
488 }
489
490 if node.left.is_none() && node.right.is_none() {
492 return;
493 }
494
495 let farthest_dist = match heap.peek() {
497 Some(&(dist_, _)) => dist_.into_inner(),
498 None => <F as scirs2_core::numeric::Float>::infinity(),
499 };
500
501 let dim = node.dim;
503 let query_val = query[dim];
504 let node_val = node.value;
505
506 let (first, second) = if query_val < node_val {
507 (node.left, node.right)
508 } else {
509 (node.right, node.left)
510 };
511
512 if let Some(first_idx) = first {
514 self.search_k_nearest(first_idx, query, k, heap);
515 }
516
517 let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
519
520 if plane_dist < farthest_dist || heap.len() < k {
522 if let Some(second_idx) = second {
523 self.search_k_nearest(second_idx, query, k, heap);
524 }
525 }
526 }
527
528 fn search_radius(
530 &self,
531 node_idx: usize,
532 query: &[F],
533 radius: F,
534 results: &mut Vec<(usize, F)>,
535 ) {
536 let node = &self.nodes[node_idx];
537
538 let point_idx = node.idx;
540 let point = self.points.row(point_idx);
541 let dist = self.distance(&point.to_vec(), query);
542
543 if dist <= radius {
545 results.push((point_idx, dist));
546 }
547
548 if node.left.is_none() && node.right.is_none() {
550 return;
551 }
552
553 let dim = node.dim;
555 let query_val = query[dim];
556 let node_val = node.value;
557
558 let (first, second) = if query_val < node_val {
559 (node.left, node.right)
560 } else {
561 (node.right, node.left)
562 };
563
564 if let Some(first_idx) = first {
566 self.search_radius(first_idx, query, radius, results);
567 }
568
569 let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
571
572 if plane_dist <= radius {
574 if let Some(second_idx) = second {
575 self.search_radius(second_idx, query, radius, results);
576 }
577 }
578 }
579
580 fn linear_nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
582 let n_points = self.points.shape()[0];
583
584 let mut min_dist = <F as scirs2_core::numeric::Float>::infinity();
585 let mut min_idx = 0;
586
587 for i in 0..n_points {
588 let point = self.points.row(i);
589 let dist = self.distance(&point.to_vec(), query);
590
591 if dist < min_dist {
592 min_dist = dist;
593 min_idx = i;
594 }
595 }
596
597 Ok((min_idx, min_dist))
598 }
599
600 fn linear_k_nearest_neighbors(
602 &self,
603 query: &[F],
604 k: usize,
605 ) -> InterpolateResult<Vec<(usize, F)>> {
606 let n_points = self.points.shape()[0];
607 let k = k.min(n_points); let mut distances: Vec<(usize, F)> = (0..n_points)
611 .map(|i| {
612 let point = self.points.row(i);
613 let dist = self.distance(&point.to_vec(), query);
614 (i, dist)
615 })
616 .collect();
617
618 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
620
621 distances.truncate(k);
623 Ok(distances)
624 }
625
626 fn linear_points_within_radius(
628 &self,
629 query: &[F],
630 radius: F,
631 ) -> InterpolateResult<Vec<(usize, F)>> {
632 let n_points = self.points.shape()[0];
633
634 let mut results: Vec<(usize, F)> = (0..n_points)
636 .filter_map(|i| {
637 let point = self.points.row(i);
638 let dist = self.distance(&point.to_vec(), query);
639 if dist <= radius {
640 Some((i, dist))
641 } else {
642 None
643 }
644 })
645 .collect();
646
647 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
649
650 Ok(results)
651 }
652
653 fn distance(&self, a: &[F], b: &[F]) -> F {
655 let mut sum_sq = F::zero();
656
657 for i in 0..self.dim {
658 let diff = a[i] - b[i];
659 sum_sq = sum_sq + diff * diff;
660 }
661
662 sum_sq.sqrt()
663 }
664
665 pub fn len(&self) -> usize {
667 self.points.shape()[0]
668 }
669
670 pub fn is_empty(&self) -> bool {
672 self.len() == 0
673 }
674
675 pub fn dim(&self) -> usize {
677 self.dim
678 }
679
680 pub fn points(&self) -> &Array2<F> {
682 &self.points
683 }
684
685 pub fn radius_neighbors(&self, query: &[F], radius: F) -> InterpolateResult<Vec<(usize, F)>> {
696 self.points_within_radius(query, radius)
697 }
698
699 pub fn radius_neighbors_view(
710 &self,
711 query: &ArrayView1<F>,
712 radius: F,
713 ) -> InterpolateResult<Vec<(usize, F)>> {
714 let query_slice = query.as_slice().ok_or_else(|| {
715 InterpolateError::InvalidValue("Query must be contiguous".to_string())
716 })?;
717 self.points_within_radius(query_slice, radius)
718 }
719
720 pub fn k_nearest_neighbors_optimized(
735 &self,
736 query: &[F],
737 k: usize,
738 max_distance: Option<F>,
739 ) -> InterpolateResult<Vec<(usize, F)>> {
740 if query.len() != self.dim {
742 return Err(InterpolateError::DimensionMismatch(format!(
743 "Query dimension {} doesn't match KD-tree dimension {}",
744 query.len(),
745 self.dim
746 )));
747 }
748
749 if self.root.is_none() {
751 return Err(InterpolateError::InvalidState(
752 "KD-tree is empty".to_string(),
753 ));
754 }
755
756 let k = k.min(self.points.shape()[0]);
758
759 if k == 0 {
760 return Ok(Vec::new());
761 }
762
763 if self.points.shape()[0] <= self.leaf_size {
765 return self.linear_k_nearest_neighbors_optimized(query, k, max_distance);
766 }
767
768 use ordered_float::OrderedFloat;
769 use std::collections::BinaryHeap;
770
771 let mut heap: BinaryHeap<(OrderedFloat<F>, usize)> = BinaryHeap::with_capacity(k + 1);
772 let mut search_radius =
773 max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
774
775 self.search_k_nearest_optimized(
777 self.root.unwrap(),
778 query,
779 k,
780 &mut heap,
781 &mut search_radius,
782 );
783
784 let mut results: Vec<(usize, F)> = heap
786 .into_iter()
787 .map(|(dist, idx)| (idx, dist.into_inner()))
788 .collect();
789
790 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
792
793 Ok(results)
794 }
795
796 fn linear_k_nearest_neighbors_optimized(
798 &self,
799 query: &[F],
800 k: usize,
801 max_distance: Option<F>,
802 ) -> InterpolateResult<Vec<(usize, F)>> {
803 let n_points = self.points.shape()[0];
804 let k = k.min(n_points);
805 let max_dist = max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
806
807 let mut distances: Vec<(usize, F)> = Vec::with_capacity(n_points);
808
809 for i in 0..n_points {
810 let point = self.points.row(i);
811 let dist = self.distance(&point.to_vec(), query);
812
813 if dist <= max_dist {
815 distances.push((i, dist));
816 }
817 }
818
819 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
821
822 distances.truncate(k);
824 Ok(distances)
825 }
826
827 #[allow(clippy::type_complexity)]
829 fn search_k_nearest_optimized(
830 &self,
831 node_idx: usize,
832 query: &[F],
833 k: usize,
834 heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
835 search_radius: &mut F,
836 ) {
837 let node = &self.nodes[node_idx];
838
839 let point_idx = node.idx;
841 let point = self.points.row(point_idx);
842 let dist = self.distance(&point.to_vec(), query);
843
844 if dist <= *search_radius {
846 heap.push((OrderedFloat(dist), point_idx));
847
848 if heap.len() > k {
850 heap.pop();
851 }
852
853 if heap.len() == k {
855 if let Some(&(max_dist_, _)) = heap.peek() {
856 *search_radius = max_dist_.into_inner();
857 }
858 }
859 }
860
861 if node.left.is_none() && node.right.is_none() {
863 return;
864 }
865
866 let kth_dist = if heap.len() < k {
868 *search_radius
869 } else {
870 match heap.peek() {
871 Some(&(dist_, _)) => dist_.into_inner(),
872 None => *search_radius,
873 }
874 };
875
876 let dim = node.dim;
878 let query_val = query[dim];
879 let node_val = node.value;
880
881 let (first, second) = if query_val < node_val {
882 (node.left, node.right)
883 } else {
884 (node.right, node.left)
885 };
886
887 if let Some(first_idx) = first {
889 self.search_k_nearest_optimized(first_idx, query, k, heap, search_radius);
890 }
891
892 let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
894
895 if plane_dist <= kth_dist {
897 if let Some(second_idx) = second {
898 self.search_k_nearest_optimized(second_idx, query, k, heap, search_radius);
899 }
900 }
901 }
902
903 pub fn query_nearest(
914 &self,
915 query: &scirs2_core::ndarray::ArrayView1<F>,
916 k: usize,
917 ) -> InterpolateResult<scirs2_core::ndarray::Array1<usize>> {
918 use scirs2_core::ndarray::Array1;
919
920 let query_slice = query.as_slice().ok_or_else(|| {
922 InterpolateError::InvalidValue("Query must be contiguous".to_string())
923 })?;
924
925 let neighbors = self.k_nearest_neighbors(query_slice, k)?;
927
928 let indices = neighbors.iter().map(|(idx_, _)| *idx_).collect::<Vec<_>>();
930 Ok(Array1::from(indices))
931 }
932}
933
934#[allow(dead_code)]
937fn quickselect_by_key<T, F, K>(items: &mut [T], k: usize, keyfn: F)
938where
939 F: Fn(&T) -> K,
940 K: PartialOrd,
941{
942 if items.len() <= 1 {
943 return;
944 }
945
946 let len = items.len();
947
948 let pivot_idx = len / 2;
950 items.swap(pivot_idx, len - 1);
951
952 let mut store_idx = 0;
954 for i in 0..len - 1 {
955 if keyfn(&items[i]) <= keyfn(&items[len - 1]) {
956 items.swap(i, store_idx);
957 store_idx += 1;
958 }
959 }
960
961 items.swap(store_idx, len - 1);
963
964 match k.cmp(&store_idx) {
966 Ordering::Less => quickselect_by_key(&mut items[0..store_idx], k, keyfn),
967 Ordering::Greater => {
968 quickselect_by_key(&mut items[store_idx + 1..], k - store_idx - 1, keyfn)
969 }
970 Ordering::Equal => (), }
972}
973
974#[cfg(test)]
975mod tests {
976 use super::*;
977 use scirs2_core::ndarray::arr2;
978
979 #[test]
980 fn test_kdtree_creation() {
981 let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
983
984 let kdtree = KdTree::new(points).unwrap();
985
986 assert_eq!(kdtree.len(), 5);
988 assert_eq!(kdtree.dim(), 2);
989 assert!(!kdtree.is_empty());
990 }
991
992 #[test]
993 fn test_nearest_neighbor() {
994 let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
996
997 let kdtree = KdTree::new(points).unwrap();
998
999 for i in 0..5 {
1001 let point = kdtree.points().row(i).to_vec();
1002 let (idx, dist) = kdtree.nearest_neighbor(&point).unwrap();
1003 assert_eq!(idx, i);
1004 assert!(dist < 1e-10);
1005 }
1006
1007 let query = vec![0.6, 0.6];
1009 let (idx, _) = kdtree.nearest_neighbor(&query).unwrap();
1010 assert_eq!(idx, 4); let query = vec![0.9, 0.1];
1013 let (idx, _) = kdtree.nearest_neighbor(&query).unwrap();
1014 assert_eq!(idx, 1); }
1016
1017 #[test]
1018 fn test_k_nearest_neighbors() {
1019 let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
1021
1022 let kdtree = KdTree::new(points).unwrap();
1023
1024 let query = vec![0.6, 0.6];
1026
1027 let neighbors = kdtree.k_nearest_neighbors(&query, 3).unwrap();
1029
1030 assert_eq!(neighbors.len(), 3);
1032 assert_eq!(neighbors[0].0, 4); }
1034
1035 #[test]
1036 fn test_points_within_radius() {
1037 let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
1039
1040 let kdtree = KdTree::new(points).unwrap();
1041
1042 let query = vec![0.0, 0.0];
1044 let radius = 0.7;
1045
1046 let results = kdtree.points_within_radius(&query, radius).unwrap();
1047
1048 assert!(!results.is_empty());
1051
1052 assert_eq!(results[0].0, 0);
1054 assert!(results[0].1 < 1e-10);
1055
1056 println!("Points within radius:");
1059 for (idx, dist) in &results {
1060 println!("Point index: {idx}, distance: {dist}");
1061 }
1062 }
1063}