1use crate::distance::{Distance, EuclideanDistance};
23use crate::error::{SpatialError, SpatialResult};
24use crate::safe_conversions::*;
25use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
26use scirs2_core::numeric::Float;
27use std::cmp::Ordering;
28use std::marker::PhantomData;
29
30#[derive(Clone, Debug)]
32struct BallTreeNode<T: Float> {
33 start_idx: usize,
35
36 endidx: usize,
38
39 centroid: Vec<T>,
41
42 radius: T,
44
45 left_child: Option<usize>,
47
48 right_child: Option<usize>,
50}
51
52#[derive(Clone, Debug)]
63pub struct BallTree<T: Float + Send + Sync, D: Distance<T>> {
64 data: Array2<T>,
66
67 indices: Array1<usize>,
69
70 nodes: Vec<BallTreeNode<T>>,
72
73 n_samples: usize,
75
76 n_features: usize,
78
79 leaf_size: usize,
81
82 distance: D,
84
85 _phantom: PhantomData<T>,
87}
88
89impl<T: Float + Send + Sync + 'static, D: Distance<T> + Send + Sync + 'static> BallTree<T, D> {
90 pub fn new(
102 data: &ArrayView2<T>,
103 leaf_size: usize,
104 distance: D,
105 ) -> SpatialResult<BallTree<T, D>> {
106 let n_samples = data.nrows();
107 let n_features = data.ncols();
108
109 if n_samples == 0 {
110 return Err(SpatialError::ValueError(
111 "Input data array is empty".to_string(),
112 ));
113 }
114
115 let data = if data.is_standard_layout() {
118 data.to_owned()
119 } else {
120 data.as_standard_layout().to_owned()
121 };
122 let indices = Array1::from_iter(0..n_samples);
123
124 let nodes = Vec::new();
126
127 let mut ball_tree = BallTree {
128 data,
129 indices,
130 nodes,
131 n_samples,
132 n_features,
133 leaf_size,
134 distance,
135 _phantom: PhantomData,
136 };
137
138 ball_tree.build_tree()?;
140
141 Ok(ball_tree)
142 }
143
144 fn build_tree(&mut self) -> SpatialResult<()> {
148 if self.n_samples == 0 {
149 return Ok(());
150 }
151
152 self.nodes = Vec::with_capacity(2 * self.n_samples);
154
155 self.build_subtree(0, self.n_samples)?;
157
158 Ok(())
159 }
160
161 fn build_subtree(&mut self, start_idx: usize, endidx: usize) -> SpatialResult<usize> {
172 let n_points = endidx - start_idx;
173
174 let mut centroid = vec![T::zero(); self.n_features];
176 for i in start_idx..endidx {
177 let point_idx = self.indices[i];
178 let point = self.data.row(point_idx);
179
180 for (j, &val) in point.iter().take(self.n_features).enumerate() {
181 centroid[j] = centroid[j] + val;
182 }
183 }
184
185 for val in centroid.iter_mut().take(self.n_features) {
186 *val = *val / safe_from_usize::<T>(n_points, "balltree centroid calculation")?;
187 }
188
189 let mut radius = T::zero();
191 for i in start_idx..endidx {
192 let point_idx = self.indices[i];
193 let point = self.data.row(point_idx);
194
195 let dist = self.distance.distance(¢roid, point.to_vec().as_slice());
196
197 if dist > radius {
198 radius = dist;
199 }
200 }
201
202 let node_idx = self.nodes.len();
204 let node = BallTreeNode {
205 start_idx,
206 endidx,
207 centroid,
208 radius,
209 left_child: None,
210 right_child: None,
211 };
212
213 self.nodes.push(node);
214
215 if n_points <= self.leaf_size {
217 return Ok(node_idx);
218 }
219
220 self.split_points(node_idx, start_idx, endidx)?;
223
224 let mid_idx = start_idx + n_points / 2;
226
227 let left_idx = self.build_subtree(start_idx, mid_idx)?;
228 let right_idx = self.build_subtree(mid_idx, endidx)?;
229
230 self.nodes[node_idx].left_child = Some(left_idx);
232 self.nodes[node_idx].right_child = Some(right_idx);
233
234 Ok(node_idx)
235 }
236
237 fn split_points(
252 &mut self,
253 node_idx: usize,
254 start_idx: usize,
255 endidx: usize,
256 ) -> SpatialResult<()> {
257 let node = &self.nodes[node_idx];
259 let centroid = &node.centroid;
260
261 let mut distances: Vec<(usize, T)> = (start_idx..endidx)
263 .map(|i| {
264 let point_idx = self.indices[i];
265 let point = self.data.row(point_idx);
266 let dist = self.distance.distance(centroid, point.to_vec().as_slice());
267 (i, dist)
268 })
269 .collect();
270
271 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
273
274 let _mid_idx = start_idx + (endidx - start_idx) / 2;
277 let mut new_indices = Vec::with_capacity(endidx - start_idx);
278
279 for (i_, _) in distances {
280 new_indices.push(self.indices[i_]);
281 }
282
283 for (i, idx) in new_indices.into_iter().enumerate() {
284 self.indices[start_idx + i] = idx;
285 }
286
287 Ok(())
288 }
289
290 pub fn query(
302 &self,
303 point: &[T],
304 k: usize,
305 return_distance: bool,
306 ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
307 if point.len() != self.n_features {
308 return Err(SpatialError::DimensionError(format!(
309 "Query point has {} dimensions, but tree has {} dimensions",
310 point.len(),
311 self.n_features
312 )));
313 }
314
315 if k > self.n_samples {
316 return Err(SpatialError::ValueError(format!(
317 "k ({}) cannot be greater than the number of samples ({})",
318 k, self.n_samples
319 )));
320 }
321
322 let mut nearest_neighbors = Vec::<(T, usize)>::with_capacity(k);
324 let mut max_dist = T::infinity();
325
326 self.query_recursive(0, point, k, &mut nearest_neighbors, &mut max_dist);
328
329 nearest_neighbors.sort_by(|a, b| {
331 safe_partial_cmp(&a.0, &b.0, "balltree sort results").unwrap_or(Ordering::Equal)
332 });
333
334 let (distances, indices): (Vec<_>, Vec<_>) = nearest_neighbors.into_iter().unzip();
336
337 let distances_opt = if return_distance {
339 Some(distances)
340 } else {
341 None
342 };
343
344 Ok((indices, distances_opt))
345 }
346
347 fn query_recursive(
357 &self,
358 node_idx: usize,
359 point: &[T],
360 k: usize,
361 nearest: &mut Vec<(T, usize)>,
362 max_dist: &mut T,
363 ) {
364 let node = &self.nodes[node_idx];
365
366 let dist_to_centroid = self.distance.distance(point, &node.centroid);
368 if dist_to_centroid > node.radius + *max_dist {
369 return;
370 }
371
372 if node.left_child.is_none() {
374 for i in node.start_idx..node.endidx {
375 let idx = self.indices[i];
376 let row_vec = self.data.row(idx).to_vec();
377 let _dist = self.distance.distance(point, row_vec.as_slice());
378
379 if _dist < *max_dist || nearest.len() < k {
380 nearest.push((_dist, idx));
382
383 if nearest.len() > k {
385 let max_idx = nearest
387 .iter()
388 .enumerate()
389 .max_by(|(_, a), (_, b)| {
390 safe_partial_cmp(&a.0, &b.0, "balltree max distance")
391 .unwrap_or(Ordering::Equal)
392 })
393 .map(|(idx_, _)| idx_)
394 .unwrap_or(0);
395
396 nearest.swap_remove(max_idx);
398
399 *max_dist = nearest
401 .iter()
402 .map(|(dist_, _)| *dist_)
403 .max_by(|a, b| {
404 safe_partial_cmp(a, b, "balltree update max_dist")
405 .unwrap_or(Ordering::Equal)
406 })
407 .unwrap_or(T::infinity());
408 }
409 }
410 }
411 return;
412 }
413
414 let left_idx = match node.left_child {
418 Some(idx) => idx,
419 None => return, };
421 let right_idx = match node.right_child {
422 Some(idx) => idx,
423 None => return, };
425
426 let left_node = &self.nodes[left_idx];
427 let right_node = &self.nodes[right_idx];
428
429 let dist_left = self.distance.distance(point, &left_node.centroid);
430 let dist_right = self.distance.distance(point, &right_node.centroid);
431
432 if dist_left <= dist_right {
434 self.query_recursive(left_idx, point, k, nearest, max_dist);
435 self.query_recursive(right_idx, point, k, nearest, max_dist);
436 } else {
437 self.query_recursive(right_idx, point, k, nearest, max_dist);
438 self.query_recursive(left_idx, point, k, nearest, max_dist);
439 }
440 }
441
442 pub fn query_radius(
454 &self,
455 point: &[T],
456 radius: T,
457 return_distance: bool,
458 ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
459 if point.len() != self.n_features {
460 return Err(SpatialError::DimensionError(format!(
461 "Query point has {} dimensions, but tree has {} dimensions",
462 point.len(),
463 self.n_features
464 )));
465 }
466
467 if radius < T::zero() {
468 return Err(SpatialError::ValueError(
469 "Radius must be non-negative".to_string(),
470 ));
471 }
472
473 let mut result_indices = Vec::new();
475 let mut result_distances = Vec::new();
476
477 self.query_radius_recursive(0, point, radius, &mut result_indices, &mut result_distances);
479
480 if !result_indices.is_empty() {
482 let mut idx_dist: Vec<(usize, T)> =
483 result_indices.into_iter().zip(result_distances).collect();
484
485 idx_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
486
487 let (indices, distances): (Vec<_>, Vec<_>) = idx_dist.into_iter().unzip();
488
489 let distances_opt = if return_distance {
490 Some(distances)
491 } else {
492 None
493 };
494
495 Ok((indices, distances_opt))
496 } else {
497 Ok((
498 Vec::new(),
499 if return_distance {
500 Some(Vec::new())
501 } else {
502 None
503 },
504 ))
505 }
506 }
507
508 fn query_radius_recursive(
510 &self,
511 node_idx: usize,
512 point: &[T],
513 radius: T,
514 indices: &mut Vec<usize>,
515 distances: &mut Vec<T>,
516 ) {
517 let node = &self.nodes[node_idx];
518
519 let dist_to_centroid = self.distance.distance(point, &node.centroid);
521 if dist_to_centroid > node.radius + radius {
522 return;
523 }
524
525 if node.left_child.is_none() {
527 for i in node.start_idx..node.endidx {
528 let _idx = self.indices[i];
529 let row_vec = self.data.row(_idx).to_vec();
530 let dist = self.distance.distance(point, row_vec.as_slice());
531
532 if dist <= radius {
533 indices.push(_idx);
534 distances.push(dist);
535 }
536 }
537 return;
538 }
539
540 let left_idx = match node.left_child {
542 Some(idx) => idx,
543 None => return, };
545 let right_idx = match node.right_child {
546 Some(idx) => idx,
547 None => return, };
549
550 self.query_radius_recursive(left_idx, point, radius, indices, distances);
551 self.query_radius_recursive(right_idx, point, radius, indices, distances);
552 }
553
554 pub fn query_radius_tree(
565 &self,
566 other: &BallTree<T, D>,
567 radius: T,
568 ) -> SpatialResult<Vec<(usize, usize)>> {
569 if self.n_features != other.n_features {
570 return Err(SpatialError::DimensionError(format!(
571 "Trees have different dimensions: {} and {}",
572 self.n_features, other.n_features
573 )));
574 }
575
576 if radius < T::zero() {
577 return Err(SpatialError::ValueError(
578 "Radius must be non-negative".to_string(),
579 ));
580 }
581
582 let mut pairs = Vec::new();
583
584 self.query_radius_tree_recursive(0, other, 0, radius, &mut pairs);
585
586 Ok(pairs)
587 }
588
589 fn query_radius_tree_recursive(
591 &self,
592 self_node_idx: usize,
593 other: &BallTree<T, D>,
594 other_node_idx: usize,
595 radius: T,
596 pairs: &mut Vec<(usize, usize)>,
597 ) {
598 let self_node = &self.nodes[self_node_idx];
599 let other_node = &other.nodes[other_node_idx];
600
601 let dist_between_centroids = self
603 .distance
604 .distance(&self_node.centroid, &other_node.centroid);
605
606 if dist_between_centroids > self_node.radius + other_node.radius + radius {
608 return;
609 }
610
611 if self_node.left_child.is_none() && other_node.left_child.is_none() {
613 for i in self_node.start_idx..self_node.endidx {
614 let self_idx = self.indices[i];
615 let self_point = self.data.row(self_idx);
616
617 for j in other_node.start_idx..other_node.endidx {
618 let other_idx = other.indices[j];
619 let other_point = other.data.row(other_idx);
620
621 let self_vec = self_point.to_vec();
622 let other_vec = other_point.to_vec();
623 let dist = self
624 .distance
625 .distance(self_vec.as_slice(), other_vec.as_slice());
626
627 if dist <= radius {
628 pairs.push((self_idx, other_idx));
629 }
630 }
631 }
632 return;
633 }
634
635 if self_node.left_child.is_some()
638 && (other_node.left_child.is_none()
639 || (self_node.endidx - self_node.start_idx)
640 > (other_node.endidx - other_node.start_idx))
641 {
642 let left_idx = match self_node.left_child {
643 Some(idx) => idx,
644 None => return, };
646 let right_idx = match self_node.right_child {
647 Some(idx) => idx,
648 None => return, };
650
651 self.query_radius_tree_recursive(left_idx, other, other_node_idx, radius, pairs);
652 self.query_radius_tree_recursive(right_idx, other, other_node_idx, radius, pairs);
653 } else if other_node.left_child.is_some() {
654 let left_idx = match other_node.left_child {
655 Some(idx) => idx,
656 None => return, };
658 let right_idx = match other_node.right_child {
659 Some(idx) => idx,
660 None => return, };
662
663 self.query_radius_tree_recursive(self_node_idx, other, left_idx, radius, pairs);
664 self.query_radius_tree_recursive(self_node_idx, other, right_idx, radius, pairs);
665 }
666 }
667
668 pub fn get_data(&self) -> &Array2<T> {
670 &self.data
671 }
672
673 pub fn get_n_samples(&self) -> usize {
675 self.n_samples
676 }
677
678 pub fn get_n_features(&self) -> usize {
680 self.n_features
681 }
682
683 pub fn get_leaf_size(&self) -> usize {
685 self.leaf_size
686 }
687}
688
689impl<T: Float + Send + Sync + 'static> BallTree<T, EuclideanDistance<T>> {
691 pub fn with_euclidean_distance(
702 data: &ArrayView2<T>,
703 leaf_size: usize,
704 ) -> SpatialResult<BallTree<T, EuclideanDistance<T>>> {
705 BallTree::new(data, leaf_size, EuclideanDistance::new())
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::BallTree;
712 use crate::distance::euclidean;
713 use approx::assert_relative_eq;
714 use scirs2_core::ndarray::arr2;
715
716 #[test]
717 fn test_ball_tree_construction() {
718 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
719
720 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
721
722 assert_eq!(tree.get_n_samples(), 5);
723 assert_eq!(tree.get_n_features(), 2);
724 assert_eq!(tree.get_leaf_size(), 2);
725 }
726
727 #[test]
728 fn test_ball_tree_nearest_neighbor() {
729 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
730
731 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
732
733 let (indices, distances) = tree.query(&[5.1, 5.9], 1, true).unwrap();
735 assert_eq!(indices, vec![2]); assert!(distances.is_some());
737 assert_relative_eq!(distances.unwrap()[0], euclidean(&[5.1, 5.9], &[5.0, 6.0]));
738
739 let (indices, distances) = tree.query(&[5.1, 5.9], 3, true).unwrap();
741 assert_eq!(indices.len(), 3);
742 assert!(indices.contains(&2)); assert!(distances.is_some());
744 assert_eq!(distances.unwrap().len(), 3);
745
746 let (indices, distances) = tree.query(&[5.1, 5.9], 1, false).unwrap();
748 assert_eq!(indices, vec![2]);
749 assert!(distances.is_none());
750 }
751
752 #[test]
753 fn test_ball_tree_radius_search() {
754 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
755
756 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
757
758 let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 1.0, true).unwrap();
760 assert_eq!(indices.len(), 1);
761 assert_eq!(indices[0], 2); let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 3.0, true).unwrap();
765 assert!(indices.len() > 1); let (indices, distances) = tree.query_radius(&[5.0, 6.0], 3.0, false).unwrap();
769 assert!(indices.len() > 1);
770 assert!(distances.is_none());
771 }
772
773 #[test]
774 fn test_ball_tree_dual_tree_search() {
775 let data1 = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
776
777 let data2 = arr2(&[[2.0, 2.0], [4.0, 4.0], [6.0, 6.0]]);
778
779 let tree1 = BallTree::with_euclidean_distance(&data1.view(), 2).unwrap();
780 let tree2 = BallTree::with_euclidean_distance(&data2.view(), 2).unwrap();
781
782 let pairs = tree1.query_radius_tree(&tree2, 1.5).unwrap();
784 assert_eq!(pairs.len(), 3); let pairs = tree1.query_radius_tree(&tree2, 10.0).unwrap();
788 assert_eq!(pairs.len(), 9); }
790
791 #[test]
792 fn test_ball_tree_empty_input() {
793 let data = arr2(&[[0.0f64; 2]; 0]);
794 let result = BallTree::with_euclidean_distance(&data.view(), 2);
795 assert!(result.is_err());
796 }
797
798 #[test]
799 fn test_ball_tree_dimension_mismatch() {
800 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
801
802 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
803
804 let result = tree.query(&[1.0], 1, false);
806 assert!(result.is_err());
807
808 let result = tree.query_radius(&[1.0, 2.0, 3.0], 1.0, false);
809 assert!(result.is_err());
810 }
811
812 #[test]
813 fn test_ball_tree_invalid_parameters() {
814 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
815
816 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
817
818 let result = tree.query(&[1.0, 2.0], 4, false);
820 assert!(result.is_err());
821
822 let result = tree.query_radius(&[1.0, 2.0], -1.0, false);
824 assert!(result.is_err());
825 }
826}