1use ordered_float::OrderedFloat;
19use scirs2_core::ndarray::Array2;
20use scirs2_core::numeric::{Float, FromPrimitive};
21use std::cmp::Ordering;
22use std::fmt::Debug;
23use std::marker::PhantomData;
24
25use crate::error::{InterpolateError, InterpolateResult};
26use scirs2_core::ndarray::ArrayView1;
27
28#[derive(Debug, Clone)]
30struct BallNode<F: Float + ordered_float::FloatCore> {
31 indices: Vec<usize>,
33
34 center: Vec<F>,
36
37 radius: F,
39
40 left: Option<usize>,
42
43 right: Option<usize>,
45}
46
47#[derive(Debug, Clone)]
75pub struct BallTree<F>
76where
77 F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
78{
79 points: Array2<F>,
81
82 nodes: Vec<BallNode<F>>,
84
85 root: Option<usize>,
87
88 dim: usize,
90
91 leafsize: usize,
93
94 _phantom: PhantomData<F>,
96}
97
98impl<F> BallTree<F>
99where
100 F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
101{
102 pub fn new(points: Array2<F>) -> InterpolateResult<Self> {
112 Self::with_leafsize(points, 10)
113 }
114
115 pub fn with_leafsize(points: Array2<F>, leafsize: usize) -> InterpolateResult<Self> {
126 if points.is_empty() {
127 return Err(InterpolateError::InvalidValue(
128 "Points array cannot be empty".to_string(),
129 ));
130 }
131
132 let n_points = points.shape()[0];
133 let dim = points.shape()[1];
134
135 if n_points <= leafsize {
137 let indices: Vec<usize> = (0..n_points).collect();
138 let center = compute_centroid(&points, &indices);
139 let radius = compute_radius(&points, &indices, ¢er);
140
141 let mut tree = Self {
142 points,
143 nodes: Vec::new(),
144 root: None,
145 dim,
146 leafsize,
147 _phantom: PhantomData,
148 };
149
150 if n_points > 0 {
151 tree.nodes.push(BallNode {
153 indices,
154 center,
155 radius,
156 left: None,
157 right: None,
158 });
159 tree.root = Some(0);
160 }
161
162 return Ok(tree);
163 }
164
165 let est_nodes = (2 * n_points / leafsize).max(16);
167
168 let mut tree = Self {
169 points,
170 nodes: Vec::with_capacity(est_nodes),
171 root: None,
172 dim,
173 leafsize,
174 _phantom: PhantomData,
175 };
176
177 let indices: Vec<usize> = (0..n_points).collect();
179 tree.root = Some(tree.build_subtree(&indices));
180
181 Ok(tree)
182 }
183
184 fn build_subtree(&mut self, indices: &[usize]) -> usize {
186 let n_points = indices.len();
187
188 let center = compute_centroid(&self.points, indices);
190 let radius = compute_radius(&self.points, indices, ¢er);
191
192 if n_points <= self.leafsize {
194 let node_idx = self.nodes.len();
195 self.nodes.push(BallNode {
196 indices: indices.to_vec(),
197 center,
198 radius,
199 left: None,
200 right: None,
201 });
202 return node_idx;
203 }
204
205 let (split_dim, _) = find_max_spread_dimension(&self.points, indices);
207
208 let (seed1, seed2) = find_distant_points(&self.points, indices, split_dim);
210
211 let (left_indices, right_indices) = partition_by_seeds(&self.points, indices, seed1, seed2);
213
214 let node_idx = self.nodes.len();
216 self.nodes.push(BallNode {
217 indices: indices.to_vec(),
218 center,
219 radius,
220 left: None,
221 right: None,
222 });
223
224 let left_idx = self.build_subtree(&left_indices);
226 let right_idx = self.build_subtree(&right_indices);
227
228 self.nodes[node_idx].left = Some(left_idx);
230 self.nodes[node_idx].right = Some(right_idx);
231
232 node_idx
233 }
234
235 pub fn nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
245 if query.len() != self.dim {
247 return Err(InterpolateError::DimensionMismatch(format!(
248 "Query dimension {} doesn't match Ball Tree dimension {}",
249 query.len(),
250 self.dim
251 )));
252 }
253
254 if self.root.is_none() {
256 return Err(InterpolateError::InvalidState(
257 "Ball Tree is empty".to_string(),
258 ));
259 }
260
261 if self.points.shape()[0] <= self.leafsize {
263 return self.linear_nearest_neighbor(query);
264 }
265
266 let mut best_dist = <F as scirs2_core::numeric::Float>::infinity();
268 let mut best_idx = 0;
269
270 self.search_nearest(self.root.unwrap(), query, &mut best_dist, &mut best_idx);
272
273 Ok((best_idx, best_dist))
274 }
275
276 pub fn k_nearest_neighbors(&self, query: &[F], k: usize) -> InterpolateResult<Vec<(usize, F)>> {
287 if query.len() != self.dim {
289 return Err(InterpolateError::DimensionMismatch(format!(
290 "Query dimension {} doesn't match Ball Tree dimension {}",
291 query.len(),
292 self.dim
293 )));
294 }
295
296 if self.root.is_none() {
298 return Err(InterpolateError::InvalidState(
299 "Ball Tree is empty".to_string(),
300 ));
301 }
302
303 let k = k.min(self.points.shape()[0]);
305
306 if k == 0 {
307 return Ok(Vec::new());
308 }
309
310 if self.points.shape()[0] <= self.leafsize {
312 return self.linear_k_nearest_neighbors(query, k);
313 }
314
315 use std::collections::BinaryHeap;
317
318 let mut heap = BinaryHeap::with_capacity(k + 1);
319
320 self.search_k_nearest(self.root.unwrap(), query, k, &mut heap);
322
323 let mut results: Vec<(usize, F)> = heap
325 .into_iter()
326 .map(|(dist, idx)| (idx, dist.into_inner()))
327 .collect();
328
329 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
331
332 Ok(results)
333 }
334
335 pub fn points_within_radius(
346 &self,
347 query: &[F],
348 radius: F,
349 ) -> InterpolateResult<Vec<(usize, F)>> {
350 if query.len() != self.dim {
352 return Err(InterpolateError::DimensionMismatch(format!(
353 "Query dimension {} doesn't match Ball Tree dimension {}",
354 query.len(),
355 self.dim
356 )));
357 }
358
359 if self.root.is_none() {
361 return Err(InterpolateError::InvalidState(
362 "Ball Tree is empty".to_string(),
363 ));
364 }
365
366 if radius <= F::zero() {
367 return Err(InterpolateError::InvalidValue(
368 "Radius must be positive".to_string(),
369 ));
370 }
371
372 if self.points.shape()[0] <= self.leafsize {
374 return self.linear_points_within_radius(query, radius);
375 }
376
377 let mut results = Vec::new();
379
380 self.search_radius(self.root.unwrap(), query, radius, &mut results);
382
383 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
385
386 Ok(results)
387 }
388
389 fn search_nearest(
391 &self,
392 node_idx: usize,
393 query: &[F],
394 best_dist: &mut F,
395 best_idx: &mut usize,
396 ) {
397 let node = &self.nodes[node_idx];
398
399 let center_dist = euclidean_distance(query, &node.center);
401
402 if center_dist > node.radius + *best_dist {
404 return;
405 }
406
407 if node.left.is_none() && node.right.is_none() {
409 for &idx in &node.indices {
410 let point = self.points.row(idx);
411 let dist = euclidean_distance(query, &point.to_vec());
412
413 if dist < *best_dist {
414 *best_dist = dist;
415 *best_idx = idx;
416 }
417 }
418 return;
419 }
420
421 let left_idx = node.left.unwrap();
424 let right_idx = node.right.unwrap();
425
426 let left_node = &self.nodes[left_idx];
427 let right_node = &self.nodes[right_idx];
428
429 let left_dist = euclidean_distance(query, &left_node.center);
430 let right_dist = euclidean_distance(query, &right_node.center);
431
432 if left_dist < right_dist {
433 self.search_nearest(left_idx, query, best_dist, best_idx);
435 self.search_nearest(right_idx, query, best_dist, best_idx);
436 } else {
437 self.search_nearest(right_idx, query, best_dist, best_idx);
439 self.search_nearest(left_idx, query, best_dist, best_idx);
440 }
441 }
442
443 #[allow(clippy::type_complexity)]
445 fn search_k_nearest(
446 &self,
447 node_idx: usize,
448 query: &[F],
449 k: usize,
450 heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
451 ) {
452 let node = &self.nodes[node_idx];
453
454 let center_dist = euclidean_distance(query, &node.center);
456
457 let kth_dist = if heap.len() < k {
459 <F as scirs2_core::numeric::Float>::infinity()
460 } else {
461 match heap.peek() {
463 Some(&(dist_, _)) => dist_.into_inner(),
464 None => <F as scirs2_core::numeric::Float>::infinity(),
465 }
466 };
467
468 if center_dist > node.radius + kth_dist {
470 return;
471 }
472
473 if node.left.is_none() && node.right.is_none() {
475 for &idx in &node.indices {
476 let point = self.points.row(idx);
477 let dist = euclidean_distance(query, &point.to_vec());
478
479 heap.push((OrderedFloat(dist), idx));
481
482 if heap.len() > k {
484 heap.pop();
485 }
486 }
487 return;
488 }
489
490 let left_idx = node.left.unwrap();
493 let right_idx = node.right.unwrap();
494
495 let left_node = &self.nodes[left_idx];
496 let right_node = &self.nodes[right_idx];
497
498 let left_dist = euclidean_distance(query, &left_node.center);
499 let right_dist = euclidean_distance(query, &right_node.center);
500
501 if left_dist < right_dist {
502 self.search_k_nearest(left_idx, query, k, heap);
504 self.search_k_nearest(right_idx, query, k, heap);
505 } else {
506 self.search_k_nearest(right_idx, query, k, heap);
508 self.search_k_nearest(left_idx, query, k, heap);
509 }
510 }
511
512 fn search_radius(
514 &self,
515 node_idx: usize,
516 query: &[F],
517 radius: F,
518 results: &mut Vec<(usize, F)>,
519 ) {
520 let node = &self.nodes[node_idx];
521
522 let center_dist = euclidean_distance(query, &node.center);
524
525 if center_dist > node.radius + radius {
527 return;
528 }
529
530 if node.left.is_none() && node.right.is_none() {
532 for &_idx in &node.indices {
533 let point = self.points.row(_idx);
534 let dist = euclidean_distance(query, &point.to_vec());
535
536 if dist <= radius {
537 results.push((_idx, dist));
538 }
539 }
540 return;
541 }
542
543 if let Some(left_idx) = node.left {
545 self.search_radius(left_idx, query, radius, results);
546 }
547
548 if let Some(right_idx) = node.right {
549 self.search_radius(right_idx, query, radius, results);
550 }
551 }
552
553 fn linear_nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
555 let n_points = self.points.shape()[0];
556
557 let mut min_dist = <F as scirs2_core::numeric::Float>::infinity();
558 let mut min_idx = 0;
559
560 for i in 0..n_points {
561 let point = self.points.row(i);
562 let dist = euclidean_distance(query, &point.to_vec());
563
564 if dist < min_dist {
565 min_dist = dist;
566 min_idx = i;
567 }
568 }
569
570 Ok((min_idx, min_dist))
571 }
572
573 fn linear_k_nearest_neighbors(
575 &self,
576 query: &[F],
577 k: usize,
578 ) -> InterpolateResult<Vec<(usize, F)>> {
579 let n_points = self.points.shape()[0];
580 let k = k.min(n_points); let mut distances: Vec<(usize, F)> = (0..n_points)
584 .map(|i| {
585 let point = self.points.row(i);
586 let dist = euclidean_distance(query, &point.to_vec());
587 (i, dist)
588 })
589 .collect();
590
591 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
593
594 distances.truncate(k);
596 Ok(distances)
597 }
598
599 fn linear_points_within_radius(
601 &self,
602 query: &[F],
603 radius: F,
604 ) -> InterpolateResult<Vec<(usize, F)>> {
605 let n_points = self.points.shape()[0];
606
607 let mut results: Vec<(usize, F)> = (0..n_points)
609 .filter_map(|i| {
610 let point = self.points.row(i);
611 let dist = euclidean_distance(query, &point.to_vec());
612 if dist <= radius {
613 Some((i, dist))
614 } else {
615 None
616 }
617 })
618 .collect();
619
620 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
622
623 Ok(results)
624 }
625
626 pub fn len(&self) -> usize {
628 self.points.shape()[0]
629 }
630
631 pub fn is_empty(&self) -> bool {
633 self.len() == 0
634 }
635
636 pub fn dim(&self) -> usize {
638 self.dim
639 }
640
641 pub fn points(&self) -> &Array2<F> {
643 &self.points
644 }
645
646 pub fn radius_neighbors(&self, query: &[F], radius: F) -> InterpolateResult<Vec<(usize, F)>> {
657 self.points_within_radius(query, radius)
658 }
659
660 pub fn radius_neighbors_view(
671 &self,
672 query: &scirs2_core::ndarray::ArrayView1<F>,
673 radius: F,
674 ) -> InterpolateResult<Vec<(usize, F)>> {
675 let query_slice = query.as_slice().ok_or_else(|| {
676 InterpolateError::InvalidValue("Query must be contiguous".to_string())
677 })?;
678 self.points_within_radius(query_slice, radius)
679 }
680
681 pub fn k_nearest_neighbors_optimized(
696 &self,
697 query: &[F],
698 k: usize,
699 max_distance: Option<F>,
700 ) -> InterpolateResult<Vec<(usize, F)>> {
701 if query.len() != self.dim {
703 return Err(InterpolateError::DimensionMismatch(format!(
704 "Query dimension {} doesn't match Ball Tree dimension {}",
705 query.len(),
706 self.dim
707 )));
708 }
709
710 if self.root.is_none() {
712 return Err(InterpolateError::InvalidState(
713 "Ball Tree is empty".to_string(),
714 ));
715 }
716
717 let k = k.min(self.points.shape()[0]);
719
720 if k == 0 {
721 return Ok(Vec::new());
722 }
723
724 if self.points.shape()[0] <= self.leafsize {
726 return self.linear_k_nearest_neighbors_optimized(query, k, max_distance);
727 }
728
729 use std::collections::BinaryHeap;
730
731 let mut heap = BinaryHeap::with_capacity(k + 1);
732 let mut search_radius =
733 max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
734
735 self.search_k_nearest_optimized(
737 self.root.unwrap(),
738 query,
739 k,
740 &mut heap,
741 &mut search_radius,
742 );
743
744 let mut results: Vec<(usize, F)> = heap
746 .into_iter()
747 .map(|(dist, idx)| (idx, dist.into_inner()))
748 .collect();
749
750 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
752
753 Ok(results)
754 }
755
756 fn linear_k_nearest_neighbors_optimized(
758 &self,
759 query: &[F],
760 k: usize,
761 max_distance: Option<F>,
762 ) -> InterpolateResult<Vec<(usize, F)>> {
763 let n_points = self.points.shape()[0];
764 let k = k.min(n_points);
765 let max_dist = max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
766
767 let mut distances: Vec<(usize, F)> = Vec::with_capacity(n_points);
768
769 for i in 0..n_points {
770 let point = self.points.row(i);
771 let dist = euclidean_distance(query, &point.to_vec());
772
773 if dist <= max_dist {
775 distances.push((i, dist));
776 }
777 }
778
779 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
781
782 distances.truncate(k);
784 Ok(distances)
785 }
786
787 #[allow(clippy::type_complexity)]
789 fn search_k_nearest_optimized(
790 &self,
791 node_idx: usize,
792 query: &[F],
793 k: usize,
794 heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
795 search_radius: &mut F,
796 ) {
797 let node = &self.nodes[node_idx];
798
799 let center_dist = euclidean_distance(query, &node.center);
801
802 let min_possible_dist = if center_dist > node.radius {
804 center_dist - node.radius
805 } else {
806 F::zero()
807 };
808
809 let kth_dist = if heap.len() < k {
811 *search_radius
812 } else {
813 match heap.peek() {
814 Some(&(dist_, _)) => dist_.into_inner(),
815 None => *search_radius,
816 }
817 };
818
819 if min_possible_dist > kth_dist {
821 return;
822 }
823
824 if node.left.is_none() && node.right.is_none() {
826 for &idx in &node.indices {
827 let point = self.points.row(idx);
828 let dist = euclidean_distance(query, &point.to_vec());
829
830 if dist <= *search_radius {
832 heap.push((OrderedFloat(dist), idx));
833
834 if heap.len() > k {
836 heap.pop();
837 }
838
839 if heap.len() == k {
841 if let Some(&(max_dist_, _)) = heap.peek() {
842 *search_radius = max_dist_.into_inner();
843 }
844 }
845 }
846 }
847 return;
848 }
849
850 let left_idx = node.left.unwrap();
852 let right_idx = node.right.unwrap();
853
854 let left_node = &self.nodes[left_idx];
855 let right_node = &self.nodes[right_idx];
856
857 let left_center_dist = euclidean_distance(query, &left_node.center);
859 let right_center_dist = euclidean_distance(query, &right_node.center);
860
861 let left_min_dist = if left_center_dist > left_node.radius {
862 left_center_dist - left_node.radius
863 } else {
864 F::zero()
865 };
866
867 let right_min_dist = if right_center_dist > right_node.radius {
868 right_center_dist - right_node.radius
869 } else {
870 F::zero()
871 };
872
873 let (first_idx, second_idx, second_min_dist) = if left_min_dist < right_min_dist {
875 (left_idx, right_idx, right_min_dist)
876 } else {
877 (right_idx, left_idx, left_min_dist)
878 };
879
880 self.search_k_nearest_optimized(first_idx, query, k, heap, search_radius);
882
883 let updated_kth_dist = if heap.len() < k {
885 *search_radius
886 } else {
887 match heap.peek() {
888 Some(&(dist_, _)) => dist_.into_inner(),
889 None => *search_radius,
890 }
891 };
892
893 if second_min_dist <= updated_kth_dist {
895 self.search_k_nearest_optimized(second_idx, query, k, heap, search_radius);
896 }
897 }
898
899 pub fn approximate_k_nearest_neighbors(
914 &self,
915 query: &[F],
916 k: usize,
917 max_checks: usize,
918 ) -> InterpolateResult<Vec<(usize, F)>> {
919 if query.len() != self.dim {
921 return Err(InterpolateError::DimensionMismatch(format!(
922 "Query dimension {} doesn't match Ball Tree dimension {}",
923 query.len(),
924 self.dim
925 )));
926 }
927
928 if self.root.is_none() {
930 return Err(InterpolateError::InvalidState(
931 "Ball Tree is empty".to_string(),
932 ));
933 }
934
935 let k = k.min(self.points.shape()[0]);
937
938 if k == 0 {
939 return Ok(Vec::new());
940 }
941
942 if self.points.shape()[0] <= self.leafsize || max_checks >= self.points.shape()[0] {
944 return self.k_nearest_neighbors(query, k);
945 }
946
947 use std::collections::{BinaryHeap, VecDeque};
948
949 let mut heap = BinaryHeap::with_capacity(k + 1);
950 let mut checks_performed = 0;
951 let mut nodes_to_visit = VecDeque::new();
952
953 nodes_to_visit.push_back((self.root.unwrap(), F::zero()));
955
956 while let Some((node_idx, _min_dist)) = nodes_to_visit.pop_front() {
957 if checks_performed >= max_checks {
958 break;
959 }
960
961 let node = &self.nodes[node_idx];
962
963 let _center_dist = euclidean_distance(query, &node.center);
965
966 if node.left.is_none() && node.right.is_none() {
968 for &idx in &node.indices {
969 if checks_performed >= max_checks {
970 break;
971 }
972
973 let point = self.points.row(idx);
974 let dist = euclidean_distance(query, &point.to_vec());
975 checks_performed += 1;
976
977 heap.push((OrderedFloat(dist), idx));
978
979 if heap.len() > k {
981 heap.pop();
982 }
983 }
984 } else {
985 if let Some(left_idx) = node.left {
987 let left_node = &self.nodes[left_idx];
988 let left_center_dist = euclidean_distance(query, &left_node.center);
989 let left_min_dist = if left_center_dist > left_node.radius {
990 left_center_dist - left_node.radius
991 } else {
992 F::zero()
993 };
994
995 nodes_to_visit.push_back((left_idx, left_min_dist));
996 }
997
998 if let Some(right_idx) = node.right {
999 let right_node = &self.nodes[right_idx];
1000 let right_center_dist = euclidean_distance(query, &right_node.center);
1001 let right_min_dist = if right_center_dist > right_node.radius {
1002 right_center_dist - right_node.radius
1003 } else {
1004 F::zero()
1005 };
1006
1007 nodes_to_visit.push_back((right_idx, right_min_dist));
1008 }
1009
1010 nodes_to_visit
1012 .make_contiguous()
1013 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1014 }
1015 }
1016
1017 let mut results: Vec<(usize, F)> = heap
1019 .into_iter()
1020 .map(|(dist, idx)| (idx, dist.into_inner()))
1021 .collect();
1022
1023 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1025
1026 Ok(results)
1027 }
1028}
1029
1030#[allow(dead_code)]
1032fn compute_centroid<F: Float + FromPrimitive>(points: &Array2<F>, indices: &[usize]) -> Vec<F> {
1033 let n_points = indices.len();
1034 let n_dims = points.shape()[1];
1035
1036 if n_points == 0 {
1037 return vec![F::zero(); n_dims];
1038 }
1039
1040 let mut center = vec![F::zero(); n_dims];
1041
1042 for &idx in indices {
1044 let point = points.row(idx);
1045 for d in 0..n_dims {
1046 center[d] = center[d] + point[d];
1047 }
1048 }
1049
1050 let n = F::from_usize(n_points).unwrap();
1052 for val in center.iter_mut() {
1053 *val = *val / n;
1054 }
1055
1056 center
1057}
1058
1059#[allow(dead_code)]
1061fn compute_radius<F: Float>(points: &Array2<F>, indices: &[usize], center: &[F]) -> F {
1062 let n_points = indices.len();
1063
1064 if n_points == 0 {
1065 return F::zero();
1066 }
1067
1068 let mut max_dist = F::zero();
1069
1070 for &idx in indices {
1072 let point = points.row(idx);
1073 let dist = euclidean_distance(&point.to_vec(), center);
1074
1075 if dist > max_dist {
1076 max_dist = dist;
1077 }
1078 }
1079
1080 max_dist
1081}
1082
1083#[allow(dead_code)]
1085fn find_max_spread_dimension<F: Float>(points: &Array2<F>, indices: &[usize]) -> (usize, F) {
1086 let n_points = indices.len();
1087 let n_dims = points.shape()[1];
1088
1089 if n_points <= 1 {
1090 return (0, F::zero());
1091 }
1092
1093 let mut max_dim = 0;
1094 let mut max_spread = F::neg_infinity();
1095
1096 for d in 0..n_dims {
1098 let mut min_val = F::infinity();
1099 let mut max_val = F::neg_infinity();
1100
1101 for &idx in indices {
1102 let val = points[[idx, d]];
1103
1104 if val < min_val {
1105 min_val = val;
1106 }
1107
1108 if val > max_val {
1109 max_val = val;
1110 }
1111 }
1112
1113 let spread = max_val - min_val;
1114
1115 if spread > max_spread {
1116 max_spread = spread;
1117 max_dim = d;
1118 }
1119 }
1120
1121 (max_dim, max_spread)
1122}
1123
1124#[allow(dead_code)]
1126fn find_distant_points<F: Float>(
1127 points: &Array2<F>,
1128 indices: &[usize],
1129 dim: usize,
1130) -> (usize, usize) {
1131 let n_points = indices.len();
1132
1133 if n_points <= 1 {
1134 return (indices[0], indices[0]);
1135 }
1136
1137 let mut min_idx = indices[0];
1139 let mut max_idx = indices[0];
1140 let mut min_val = points[[min_idx, dim]];
1141 let mut max_val = min_val;
1142
1143 for &idx in indices.iter().skip(1) {
1144 let val = points[[idx, dim]];
1145
1146 if val < min_val {
1147 min_val = val;
1148 min_idx = idx;
1149 }
1150
1151 if val > max_val {
1152 max_val = val;
1153 max_idx = idx;
1154 }
1155 }
1156
1157 (min_idx, max_idx)
1158}
1159
1160#[allow(dead_code)]
1162fn partition_by_seeds<F: Float>(
1163 points: &Array2<F>,
1164 indices: &[usize],
1165 seed1: usize,
1166 seed2: usize,
1167) -> (Vec<usize>, Vec<usize>) {
1168 let seed1_point = points.row(seed1).to_vec();
1169 let seed2_point = points.row(seed2).to_vec();
1170
1171 let mut left_indices = Vec::new();
1172 let mut right_indices = Vec::new();
1173
1174 left_indices.push(seed1);
1176 right_indices.push(seed2);
1177
1178 for &idx in indices {
1180 if idx == seed1 || idx == seed2 {
1181 continue; }
1183
1184 let point = points.row(idx).to_vec();
1185 let dist1 = euclidean_distance(&point, &seed1_point);
1186 let dist2 = euclidean_distance(&point, &seed2_point);
1187
1188 if dist1 <= dist2 {
1189 left_indices.push(idx);
1190 } else {
1191 right_indices.push(idx);
1192 }
1193 }
1194
1195 if left_indices.is_empty() && right_indices.len() >= 2 {
1197 left_indices.push(right_indices.pop().unwrap());
1198 } else if right_indices.is_empty() && left_indices.len() >= 2 {
1199 right_indices.push(left_indices.pop().unwrap());
1200 }
1201
1202 (left_indices, right_indices)
1203}
1204
1205#[allow(dead_code)]
1207fn euclidean_distance<F: Float>(a: &[F], b: &[F]) -> F {
1208 debug_assert_eq!(a.len(), b.len());
1209
1210 let mut sum_sq = F::zero();
1211
1212 for i in 0..a.len() {
1213 let diff = a[i] - b[i];
1214 sum_sq = sum_sq + diff * diff;
1215 }
1216
1217 sum_sq.sqrt()
1218}
1219
1220#[cfg(test)]
1221mod tests {
1222 use super::*;
1223 use scirs2_core::ndarray::arr2;
1224
1225 #[test]
1226 fn test_balltree_creation() {
1227 let points = arr2(&[
1229 [0.0, 0.0, 0.0],
1230 [1.0, 0.0, 0.0],
1231 [0.0, 1.0, 0.0],
1232 [0.0, 0.0, 1.0],
1233 [0.5, 0.5, 0.5],
1234 ]);
1235
1236 let balltree = BallTree::new(points).unwrap();
1237
1238 assert_eq!(balltree.len(), 5);
1240 assert_eq!(balltree.dim(), 3);
1241 assert!(!balltree.is_empty());
1242 }
1243
1244 #[test]
1245 fn test_nearest_neighbor() {
1246 let points = arr2(&[
1248 [0.0, 0.0, 0.0],
1249 [1.0, 0.0, 0.0],
1250 [0.0, 1.0, 0.0],
1251 [0.0, 0.0, 1.0],
1252 [0.5, 0.5, 0.5],
1253 ]);
1254
1255 let balltree = BallTree::new(points).unwrap();
1256
1257 for i in 0..5 {
1259 let point = balltree.points().row(i).to_vec();
1260 let (idx, dist) = balltree.nearest_neighbor(&point).unwrap();
1261 assert_eq!(idx, i);
1262 assert!(dist < 1e-10);
1263 }
1264
1265 let query = vec![0.6, 0.6, 0.6];
1267 let (idx, _) = balltree.nearest_neighbor(&query).unwrap();
1268 assert_eq!(idx, 4); let query = vec![0.9, 0.1, 0.1];
1271 let (idx, _) = balltree.nearest_neighbor(&query).unwrap();
1272 assert_eq!(idx, 1); }
1274
1275 #[test]
1276 fn test_k_nearest_neighbors() {
1277 let points = arr2(&[
1279 [0.0, 0.0, 0.0],
1280 [1.0, 0.0, 0.0],
1281 [0.0, 1.0, 0.0],
1282 [0.0, 0.0, 1.0],
1283 [0.5, 0.5, 0.5],
1284 ]);
1285
1286 let balltree = BallTree::new(points).unwrap();
1287
1288 let query = vec![0.6, 0.6, 0.6];
1290
1291 let neighbors = balltree.k_nearest_neighbors(&query, 3).unwrap();
1293
1294 assert_eq!(neighbors.len(), 3);
1296 assert_eq!(neighbors[0].0, 4); }
1298
1299 #[test]
1300 fn test_points_within_radius() {
1301 let points = arr2(&[
1303 [0.0, 0.0, 0.0],
1304 [1.0, 0.0, 0.0],
1305 [0.0, 1.0, 0.0],
1306 [0.0, 0.0, 1.0],
1307 [0.5, 0.5, 0.5],
1308 ]);
1309
1310 let balltree = BallTree::new(points).unwrap();
1311
1312 let query = vec![0.0, 0.0, 0.0];
1314 let radius = 0.7;
1315
1316 let results = balltree.points_within_radius(&query, radius).unwrap();
1317
1318 assert!(!results.is_empty());
1320 assert_eq!(results[0].0, 0); }
1322}