1use crate::distance::{Distance, EuclideanDistance};
23use crate::error::{SpatialError, SpatialResult};
24use ndarray::{Array1, Array2, ArrayView2};
25use num_traits::Float;
26use std::cmp::Ordering;
27use std::marker::PhantomData;
28
29#[derive(Clone, Debug)]
31struct BallTreeNode<T: Float> {
32 start_idx: usize,
34
35 end_idx: usize,
37
38 centroid: Vec<T>,
40
41 radius: T,
43
44 left_child: Option<usize>,
46
47 right_child: Option<usize>,
49}
50
51#[derive(Clone, Debug)]
62pub struct BallTree<T: Float + Send + Sync, D: Distance<T>> {
63 data: Array2<T>,
65
66 indices: Array1<usize>,
68
69 nodes: Vec<BallTreeNode<T>>,
71
72 n_samples: usize,
74
75 n_features: usize,
77
78 leaf_size: usize,
80
81 distance: D,
83
84 _phantom: PhantomData<T>,
86}
87
88impl<T: Float + Send + Sync + 'static, D: Distance<T> + Send + Sync + 'static> BallTree<T, D> {
89 pub fn new(
101 data: &ArrayView2<T>,
102 leaf_size: usize,
103 distance: D,
104 ) -> SpatialResult<BallTree<T, D>> {
105 let n_samples = data.nrows();
106 let n_features = data.ncols();
107
108 if n_samples == 0 {
109 return Err(SpatialError::ValueError(
110 "Input data array is empty".to_string(),
111 ));
112 }
113
114 let data = data.to_owned();
116 let indices = Array1::from_iter(0..n_samples);
117
118 let nodes = Vec::new();
120
121 let mut ball_tree = BallTree {
122 data,
123 indices,
124 nodes,
125 n_samples,
126 n_features,
127 leaf_size,
128 distance,
129 _phantom: PhantomData,
130 };
131
132 ball_tree.build_tree()?;
134
135 Ok(ball_tree)
136 }
137
138 fn build_tree(&mut self) -> SpatialResult<()> {
142 if self.n_samples == 0 {
143 return Ok(());
144 }
145
146 self.nodes = Vec::with_capacity(2 * self.n_samples);
148
149 self.build_subtree(0, self.n_samples)?;
151
152 Ok(())
153 }
154
155 fn build_subtree(&mut self, start_idx: usize, end_idx: usize) -> SpatialResult<usize> {
166 let n_points = end_idx - start_idx;
167
168 let mut centroid = vec![T::zero(); self.n_features];
170 for i in start_idx..end_idx {
171 let point_idx = self.indices[i];
172 let point = self.data.row(point_idx);
173
174 for (j, &val) in point.iter().take(self.n_features).enumerate() {
175 centroid[j] = centroid[j] + val;
176 }
177 }
178
179 for val in centroid.iter_mut().take(self.n_features) {
180 *val = *val / T::from(n_points).unwrap();
181 }
182
183 let mut radius = T::zero();
185 for i in start_idx..end_idx {
186 let point_idx = self.indices[i];
187 let point = self.data.row(point_idx);
188
189 let dist = self.distance.distance(¢roid, point.as_slice().unwrap());
190
191 if dist > radius {
192 radius = dist;
193 }
194 }
195
196 let node_idx = self.nodes.len();
198 let node = BallTreeNode {
199 start_idx,
200 end_idx,
201 centroid,
202 radius,
203 left_child: None,
204 right_child: None,
205 };
206
207 self.nodes.push(node);
208
209 if n_points <= self.leaf_size {
211 return Ok(node_idx);
212 }
213
214 self.split_points(node_idx, start_idx, end_idx)?;
217
218 let mid_idx = start_idx + n_points / 2;
220
221 let left_idx = self.build_subtree(start_idx, mid_idx)?;
222 let right_idx = self.build_subtree(mid_idx, end_idx)?;
223
224 self.nodes[node_idx].left_child = Some(left_idx);
226 self.nodes[node_idx].right_child = Some(right_idx);
227
228 Ok(node_idx)
229 }
230
231 fn split_points(
246 &mut self,
247 node_idx: usize,
248 start_idx: usize,
249 end_idx: usize,
250 ) -> SpatialResult<()> {
251 let node = &self.nodes[node_idx];
253 let centroid = &node.centroid;
254
255 let mut distances: Vec<(usize, T)> = (start_idx..end_idx)
257 .map(|i| {
258 let point_idx = self.indices[i];
259 let point = self.data.row(point_idx);
260 let dist = self.distance.distance(centroid, point.as_slice().unwrap());
261 (i, dist)
262 })
263 .collect();
264
265 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
267
268 let _mid_idx = start_idx + (end_idx - start_idx) / 2;
271 let mut new_indices = Vec::with_capacity(end_idx - start_idx);
272
273 for (i, _) in distances {
274 new_indices.push(self.indices[i]);
275 }
276
277 for (i, idx) in new_indices.into_iter().enumerate() {
278 self.indices[start_idx + i] = idx;
279 }
280
281 Ok(())
282 }
283
284 pub fn query(
296 &self,
297 point: &[T],
298 k: usize,
299 return_distance: bool,
300 ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
301 if point.len() != self.n_features {
302 return Err(SpatialError::DimensionError(format!(
303 "Query point has {} dimensions, but tree has {} dimensions",
304 point.len(),
305 self.n_features
306 )));
307 }
308
309 if k > self.n_samples {
310 return Err(SpatialError::ValueError(format!(
311 "k ({}) cannot be greater than the number of samples ({})",
312 k, self.n_samples
313 )));
314 }
315
316 let mut nearest_neighbors = Vec::<(T, usize)>::with_capacity(k);
318 let mut max_dist = T::infinity();
319
320 self.query_recursive(0, point, k, &mut nearest_neighbors, &mut max_dist);
322
323 nearest_neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
325
326 let (distances, indices): (Vec<_>, Vec<_>) = nearest_neighbors.into_iter().unzip();
328
329 let distances_opt = if return_distance {
331 Some(distances)
332 } else {
333 None
334 };
335
336 Ok((indices, distances_opt))
337 }
338
339 fn query_recursive(
349 &self,
350 node_idx: usize,
351 point: &[T],
352 k: usize,
353 nearest: &mut Vec<(T, usize)>,
354 max_dist: &mut T,
355 ) {
356 let node = &self.nodes[node_idx];
357
358 let dist_to_centroid = self.distance.distance(point, &node.centroid);
360 if dist_to_centroid > node.radius + *max_dist {
361 return;
362 }
363
364 if node.left_child.is_none() {
366 for i in node.start_idx..node.end_idx {
367 let idx = self.indices[i];
368 let dist = self
369 .distance
370 .distance(point, self.data.row(idx).as_slice().unwrap());
371
372 if dist < *max_dist || nearest.len() < k {
373 nearest.push((dist, idx));
375
376 if nearest.len() > k {
378 let max_idx = nearest
380 .iter()
381 .enumerate()
382 .max_by(|(_, a), (_, b)| {
383 a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)
384 })
385 .map(|(idx, _)| idx)
386 .unwrap();
387
388 nearest.swap_remove(max_idx);
390
391 *max_dist = nearest
393 .iter()
394 .map(|(dist, _)| *dist)
395 .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
396 .unwrap_or(T::infinity());
397 }
398 }
399 }
400 return;
401 }
402
403 let left_idx = node.left_child.unwrap();
406 let right_idx = node.right_child.unwrap();
407
408 let left_node = &self.nodes[left_idx];
409 let right_node = &self.nodes[right_idx];
410
411 let dist_left = self.distance.distance(point, &left_node.centroid);
412 let dist_right = self.distance.distance(point, &right_node.centroid);
413
414 if dist_left <= dist_right {
416 self.query_recursive(left_idx, point, k, nearest, max_dist);
417 self.query_recursive(right_idx, point, k, nearest, max_dist);
418 } else {
419 self.query_recursive(right_idx, point, k, nearest, max_dist);
420 self.query_recursive(left_idx, point, k, nearest, max_dist);
421 }
422 }
423
424 pub fn query_radius(
436 &self,
437 point: &[T],
438 radius: T,
439 return_distance: bool,
440 ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
441 if point.len() != self.n_features {
442 return Err(SpatialError::DimensionError(format!(
443 "Query point has {} dimensions, but tree has {} dimensions",
444 point.len(),
445 self.n_features
446 )));
447 }
448
449 if radius < T::zero() {
450 return Err(SpatialError::ValueError(
451 "Radius must be non-negative".to_string(),
452 ));
453 }
454
455 let mut result_indices = Vec::new();
457 let mut result_distances = Vec::new();
458
459 self.query_radius_recursive(0, point, radius, &mut result_indices, &mut result_distances);
461
462 if !result_indices.is_empty() {
464 let mut idx_dist: Vec<(usize, T)> =
465 result_indices.into_iter().zip(result_distances).collect();
466
467 idx_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
468
469 let (indices, distances): (Vec<_>, Vec<_>) = idx_dist.into_iter().unzip();
470
471 let distances_opt = if return_distance {
472 Some(distances)
473 } else {
474 None
475 };
476
477 Ok((indices, distances_opt))
478 } else {
479 Ok((
480 Vec::new(),
481 if return_distance {
482 Some(Vec::new())
483 } else {
484 None
485 },
486 ))
487 }
488 }
489
490 fn query_radius_recursive(
492 &self,
493 node_idx: usize,
494 point: &[T],
495 radius: T,
496 indices: &mut Vec<usize>,
497 distances: &mut Vec<T>,
498 ) {
499 let node = &self.nodes[node_idx];
500
501 let dist_to_centroid = self.distance.distance(point, &node.centroid);
503 if dist_to_centroid > node.radius + radius {
504 return;
505 }
506
507 if node.left_child.is_none() {
509 for i in node.start_idx..node.end_idx {
510 let idx = self.indices[i];
511 let dist = self
512 .distance
513 .distance(point, self.data.row(idx).as_slice().unwrap());
514
515 if dist <= radius {
516 indices.push(idx);
517 distances.push(dist);
518 }
519 }
520 return;
521 }
522
523 let left_idx = node.left_child.unwrap();
525 let right_idx = node.right_child.unwrap();
526
527 self.query_radius_recursive(left_idx, point, radius, indices, distances);
528 self.query_radius_recursive(right_idx, point, radius, indices, distances);
529 }
530
531 pub fn query_radius_tree(
542 &self,
543 other: &BallTree<T, D>,
544 radius: T,
545 ) -> SpatialResult<Vec<(usize, usize)>> {
546 if self.n_features != other.n_features {
547 return Err(SpatialError::DimensionError(format!(
548 "Trees have different dimensions: {} and {}",
549 self.n_features, other.n_features
550 )));
551 }
552
553 if radius < T::zero() {
554 return Err(SpatialError::ValueError(
555 "Radius must be non-negative".to_string(),
556 ));
557 }
558
559 let mut pairs = Vec::new();
560
561 self.query_radius_tree_recursive(0, other, 0, radius, &mut pairs);
562
563 Ok(pairs)
564 }
565
566 fn query_radius_tree_recursive(
568 &self,
569 self_node_idx: usize,
570 other: &BallTree<T, D>,
571 other_node_idx: usize,
572 radius: T,
573 pairs: &mut Vec<(usize, usize)>,
574 ) {
575 let self_node = &self.nodes[self_node_idx];
576 let other_node = &other.nodes[other_node_idx];
577
578 let dist_between_centroids = self
580 .distance
581 .distance(&self_node.centroid, &other_node.centroid);
582
583 if dist_between_centroids > self_node.radius + other_node.radius + radius {
585 return;
586 }
587
588 if self_node.left_child.is_none() && other_node.left_child.is_none() {
590 for i in self_node.start_idx..self_node.end_idx {
591 let self_idx = self.indices[i];
592 let self_point = self.data.row(self_idx);
593
594 for j in other_node.start_idx..other_node.end_idx {
595 let other_idx = other.indices[j];
596 let other_point = other.data.row(other_idx);
597
598 let dist = self.distance.distance(
599 self_point.as_slice().unwrap(),
600 other_point.as_slice().unwrap(),
601 );
602
603 if dist <= radius {
604 pairs.push((self_idx, other_idx));
605 }
606 }
607 }
608 return;
609 }
610
611 if self_node.left_child.is_some()
614 && (other_node.left_child.is_none()
615 || (self_node.end_idx - self_node.start_idx)
616 > (other_node.end_idx - other_node.start_idx))
617 {
618 let left_idx = self_node.left_child.unwrap();
619 let right_idx = self_node.right_child.unwrap();
620
621 self.query_radius_tree_recursive(left_idx, other, other_node_idx, radius, pairs);
622 self.query_radius_tree_recursive(right_idx, other, other_node_idx, radius, pairs);
623 } else if other_node.left_child.is_some() {
624 let left_idx = other_node.left_child.unwrap();
625 let right_idx = other_node.right_child.unwrap();
626
627 self.query_radius_tree_recursive(self_node_idx, other, left_idx, radius, pairs);
628 self.query_radius_tree_recursive(self_node_idx, other, right_idx, radius, pairs);
629 }
630 }
631
632 pub fn get_data(&self) -> &Array2<T> {
634 &self.data
635 }
636
637 pub fn get_n_samples(&self) -> usize {
639 self.n_samples
640 }
641
642 pub fn get_n_features(&self) -> usize {
644 self.n_features
645 }
646
647 pub fn get_leaf_size(&self) -> usize {
649 self.leaf_size
650 }
651}
652
653impl<T: Float + Send + Sync + 'static> BallTree<T, EuclideanDistance<T>> {
655 pub fn with_euclidean_distance(
666 data: &ArrayView2<T>,
667 leaf_size: usize,
668 ) -> SpatialResult<BallTree<T, EuclideanDistance<T>>> {
669 BallTree::new(data, leaf_size, EuclideanDistance::new())
670 }
671}
672
673#[cfg(test)]
674mod tests {
675 use super::*;
676 use crate::distance::euclidean;
677 use approx::assert_relative_eq;
678 use ndarray::arr2;
679
680 #[test]
681 fn test_ball_tree_construction() {
682 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
683
684 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
685
686 assert_eq!(tree.get_n_samples(), 5);
687 assert_eq!(tree.get_n_features(), 2);
688 assert_eq!(tree.get_leaf_size(), 2);
689 }
690
691 #[test]
692 fn test_ball_tree_nearest_neighbor() {
693 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
694
695 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
696
697 let (indices, distances) = tree.query(&[5.1, 5.9], 1, true).unwrap();
699 assert_eq!(indices, vec![2]); assert!(distances.is_some());
701 assert_relative_eq!(distances.unwrap()[0], euclidean(&[5.1, 5.9], &[5.0, 6.0]));
702
703 let (indices, distances) = tree.query(&[5.1, 5.9], 3, true).unwrap();
705 assert_eq!(indices.len(), 3);
706 assert!(indices.contains(&2)); assert!(distances.is_some());
708 assert_eq!(distances.unwrap().len(), 3);
709
710 let (indices, distances) = tree.query(&[5.1, 5.9], 1, false).unwrap();
712 assert_eq!(indices, vec![2]);
713 assert!(distances.is_none());
714 }
715
716 #[test]
717 fn test_ball_tree_radius_search() {
718 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
719
720 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
721
722 let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 1.0, true).unwrap();
724 assert_eq!(indices.len(), 1);
725 assert_eq!(indices[0], 2); let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 3.0, true).unwrap();
729 assert!(indices.len() > 1); let (indices, distances) = tree.query_radius(&[5.0, 6.0], 3.0, false).unwrap();
733 assert!(indices.len() > 1);
734 assert!(distances.is_none());
735 }
736
737 #[test]
738 fn test_ball_tree_dual_tree_search() {
739 let data1 = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
740
741 let data2 = arr2(&[[2.0, 2.0], [4.0, 4.0], [6.0, 6.0]]);
742
743 let tree1 = BallTree::with_euclidean_distance(&data1.view(), 2).unwrap();
744 let tree2 = BallTree::with_euclidean_distance(&data2.view(), 2).unwrap();
745
746 let pairs = tree1.query_radius_tree(&tree2, 1.5).unwrap();
748 assert_eq!(pairs.len(), 3); let pairs = tree1.query_radius_tree(&tree2, 10.0).unwrap();
752 assert_eq!(pairs.len(), 9); }
754
755 #[test]
756 fn test_ball_tree_empty_input() {
757 let data = arr2(&[[0.0f64; 2]; 0]);
758 let result = BallTree::with_euclidean_distance(&data.view(), 2);
759 assert!(result.is_err());
760 }
761
762 #[test]
763 fn test_ball_tree_dimension_mismatch() {
764 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
765
766 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
767
768 let result = tree.query(&[1.0], 1, false);
770 assert!(result.is_err());
771
772 let result = tree.query_radius(&[1.0, 2.0, 3.0], 1.0, false);
773 assert!(result.is_err());
774 }
775
776 #[test]
777 fn test_ball_tree_invalid_parameters() {
778 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
779
780 let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
781
782 let result = tree.query(&[1.0, 2.0], 4, false);
784 assert!(result.is_err());
785
786 let result = tree.query_radius(&[1.0, 2.0], -1.0, false);
788 assert!(result.is_err());
789 }
790}