scirs2_spatial/
balltree.rs

1//! Ball tree for efficient nearest neighbor searches
2//!
3//! Ball trees are spatial data structures that organize points in a metric space into a tree structure.
4//! Each node represents a hypersphere (ball) that contains a subset of the points.
5//! This implementation shares similarities with KD-tree, but can be more efficient for high-dimensional data
6//! or when using general distance metrics beyond Euclidean.
7//!
8//! ## Features
9//!
10//! * Fast construction of ball trees with customizable leaf size
11//! * Nearest neighbor queries with configurable k
12//! * Range queries to find all points within a distance
13//! * Support for all distance metrics defined in the distance module
14//! * Suitable for high-dimensional data where KD-trees become less efficient
15//!
16//! ## References
17//!
18//! * Omohundro, S.M. (1989) "Five Balltree Construction Algorithms"
19//! * Liu, T. et al. (2006) "An Investigation of Practical Approximate Nearest Neighbor Algorithms"
20//! * scikit-learn ball tree implementation
21
22use crate::distance::{Distance, EuclideanDistance};
23use crate::error::{SpatialError, SpatialResult};
24use crate::safe_conversions::*;
25use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
26use scirs2_core::numeric::Float;
27use std::cmp::Ordering;
28use std::marker::PhantomData;
29
30/// A node in the ball tree
31#[derive(Clone, Debug)]
32struct BallTreeNode<T: Float> {
33    /// Index of the start of the points contained in this node
34    start_idx: usize,
35
36    /// Index of the end of the points contained in this node
37    endidx: usize,
38
39    /// Centroid of the points in this node (center of the ball)
40    centroid: Vec<T>,
41
42    /// Radius of the ball that contains all points in this node
43    radius: T,
44
45    /// Index of the left child node
46    left_child: Option<usize>,
47
48    /// Index of the right child node
49    right_child: Option<usize>,
50}
51
52/// Ball tree for efficient nearest neighbor searches
53///
54/// The ball tree partitions data into a set of nested hyperspheres (balls), which allows
55/// for efficient nearest neighbor searches, especially in high-dimensional spaces.
56/// Each node in the tree represents a ball containing a subset of the points.
57///
58/// # Type Parameters
59///
60/// * `T`: Floating point type (f32 or f64)
61/// * `D`: Distance metric that implements the [`Distance`] trait
62#[derive(Clone, Debug)]
63pub struct BallTree<T: Float + Send + Sync, D: Distance<T>> {
64    /// Points stored in the ball tree
65    data: Array2<T>,
66
67    /// Indices of points in the original array, reordered during tree construction
68    indices: Array1<usize>,
69
70    /// Nodes in the ball tree
71    nodes: Vec<BallTreeNode<T>>,
72
73    /// Number of data points
74    n_samples: usize,
75
76    /// Dimension of data points
77    n_features: usize,
78
79    /// Maximum number of points in leaf nodes
80    leaf_size: usize,
81
82    /// Distance metric to use
83    distance: D,
84
85    /// Phantom data for the float type
86    _phantom: PhantomData<T>,
87}
88
89impl<T: Float + Send + Sync + 'static, D: Distance<T> + Send + Sync + 'static> BallTree<T, D> {
90    /// Create a new ball tree from the given data points
91    ///
92    /// # Arguments
93    ///
94    /// * `data` - 2D array of data points (n_samples x n_features)
95    /// * `leaf_size` - Maximum number of points in leaf nodes
96    /// * `distance` - Distance metric to use
97    ///
98    /// # Returns
99    ///
100    /// * `SpatialResult<BallTree<T, D>>` - A new ball tree
101    pub fn new(
102        data: &ArrayView2<T>,
103        leaf_size: usize,
104        distance: D,
105    ) -> SpatialResult<BallTree<T, D>> {
106        let n_samples = data.nrows();
107        let n_features = data.ncols();
108
109        if n_samples == 0 {
110            return Err(SpatialError::ValueError(
111                "Input data array is empty".to_string(),
112            ));
113        }
114
115        // Clone the data array and create an array of indices
116        // Ensure data is in standard memory layout for as_slice to work
117        let data = if data.is_standard_layout() {
118            data.to_owned()
119        } else {
120            data.as_standard_layout().to_owned()
121        };
122        let indices = Array1::from_iter(0..n_samples);
123
124        // Initialize empty nodes vector (will be filled during build)
125        let nodes = Vec::new();
126
127        let mut ball_tree = BallTree {
128            data,
129            indices,
130            nodes,
131            n_samples,
132            n_features,
133            leaf_size,
134            distance,
135            _phantom: PhantomData,
136        };
137
138        // Build the tree
139        ball_tree.build_tree()?;
140
141        Ok(ball_tree)
142    }
143
144    /// Build the ball tree recursively
145    ///
146    /// This initializes the tree structure and builds the nodes.
147    fn build_tree(&mut self) -> SpatialResult<()> {
148        if self.n_samples == 0 {
149            return Ok(());
150        }
151
152        // Reserve space for the nodes (maximum nodes = 2*n_samples - 1)
153        self.nodes = Vec::with_capacity(2 * self.n_samples);
154
155        // Build the tree recursively
156        self.build_subtree(0, self.n_samples)?;
157
158        Ok(())
159    }
160
161    /// Build a subtree recursively
162    ///
163    /// # Arguments
164    ///
165    /// * `start_idx` - Start index of points for this subtree
166    /// * `endidx` - End index of points for this subtree
167    ///
168    /// # Returns
169    ///
170    /// * `SpatialResult<usize>` - Index of the root node of the subtree
171    fn build_subtree(&mut self, start_idx: usize, endidx: usize) -> SpatialResult<usize> {
172        let n_points = endidx - start_idx;
173
174        // Calculate centroid of points in this node
175        let mut centroid = vec![T::zero(); self.n_features];
176        for i in start_idx..endidx {
177            let point_idx = self.indices[i];
178            let point = self.data.row(point_idx);
179
180            for (j, &val) in point.iter().take(self.n_features).enumerate() {
181                centroid[j] = centroid[j] + val;
182            }
183        }
184
185        for val in centroid.iter_mut().take(self.n_features) {
186            *val = *val / safe_from_usize::<T>(n_points, "balltree centroid calculation")?;
187        }
188
189        // Calculate radius (maximum distance from centroid to any point)
190        let mut radius = T::zero();
191        for i in start_idx..endidx {
192            let point_idx = self.indices[i];
193            let point = self.data.row(point_idx);
194
195            let dist = self.distance.distance(&centroid, point.to_vec().as_slice());
196
197            if dist > radius {
198                radius = dist;
199            }
200        }
201
202        // Create node
203        let node_idx = self.nodes.len();
204        let node = BallTreeNode {
205            start_idx,
206            endidx,
207            centroid,
208            radius,
209            left_child: None,
210            right_child: None,
211        };
212
213        self.nodes.push(node);
214
215        // If this is a leaf node (n_points <= leaf_size), we're done
216        if n_points <= self.leaf_size {
217            return Ok(node_idx);
218        }
219
220        // Otherwise, split the points and recursively build subtrees
221        // We'll split along the direction of maximum variance
222        self.split_points(node_idx, start_idx, endidx)?;
223
224        // Recursively build left and right subtrees
225        let mid_idx = start_idx + n_points / 2;
226
227        let left_idx = self.build_subtree(start_idx, mid_idx)?;
228        let right_idx = self.build_subtree(mid_idx, endidx)?;
229
230        // Update node with child indices
231        self.nodes[node_idx].left_child = Some(left_idx);
232        self.nodes[node_idx].right_child = Some(right_idx);
233
234        Ok(node_idx)
235    }
236
237    /// Split the points in a node into two groups
238    ///
239    /// This method partitions the points in a node into two groups,
240    /// attempting to create a balanced split.
241    ///
242    /// # Arguments
243    ///
244    /// * `node_idx` - Index of the node to split
245    /// * `start_idx` - Start index of points in the node
246    /// * `endidx` - End index of points in the node
247    ///
248    /// # Returns
249    ///
250    /// * `SpatialResult<()>` - Result of the split operation
251    fn split_points(
252        &mut self,
253        node_idx: usize,
254        start_idx: usize,
255        endidx: usize,
256    ) -> SpatialResult<()> {
257        // Find the dimension with the largest variance
258        let node = &self.nodes[node_idx];
259        let centroid = &node.centroid;
260
261        // Calculate distances from centroid to all points
262        let mut distances: Vec<(usize, T)> = (start_idx..endidx)
263            .map(|i| {
264                let point_idx = self.indices[i];
265                let point = self.data.row(point_idx);
266                let dist = self.distance.distance(centroid, point.to_vec().as_slice());
267                (i, dist)
268            })
269            .collect();
270
271        // Sort points by distance from centroid
272        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
273
274        // Reorder indices array based on sorted distances
275        // Midpoint is calculated but used implicitly when we reorder the indices
276        let _mid_idx = start_idx + (endidx - start_idx) / 2;
277        let mut new_indices = Vec::with_capacity(endidx - start_idx);
278
279        for (i_, _) in distances {
280            new_indices.push(self.indices[i_]);
281        }
282
283        for (i, idx) in new_indices.into_iter().enumerate() {
284            self.indices[start_idx + i] = idx;
285        }
286
287        Ok(())
288    }
289
290    /// Query the k nearest neighbors to the given point
291    ///
292    /// # Arguments
293    ///
294    /// * `point` - Query point
295    /// * `k` - Number of neighbors to find
296    /// * `return_distance` - Whether to return distances
297    ///
298    /// # Returns
299    ///
300    /// * `SpatialResult<(Vec<usize>, Option<Vec<T>>)>` - Indices and optionally distances of the k nearest neighbors
301    pub fn query(
302        &self,
303        point: &[T],
304        k: usize,
305        return_distance: bool,
306    ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
307        if point.len() != self.n_features {
308            return Err(SpatialError::DimensionError(format!(
309                "Query point has {} dimensions, but tree has {} dimensions",
310                point.len(),
311                self.n_features
312            )));
313        }
314
315        if k > self.n_samples {
316            return Err(SpatialError::ValueError(format!(
317                "k ({}) cannot be greater than the number of samples ({})",
318                k, self.n_samples
319            )));
320        }
321
322        // Store up to k nearest neighbors and their distances
323        let mut nearest_neighbors = Vec::<(T, usize)>::with_capacity(k);
324        let mut max_dist = T::infinity();
325
326        // Perform the recursive search
327        self.query_recursive(0, point, k, &mut nearest_neighbors, &mut max_dist);
328
329        // Sort by _distance
330        nearest_neighbors.sort_by(|a, b| {
331            safe_partial_cmp(&a.0, &b.0, "balltree sort results").unwrap_or(Ordering::Equal)
332        });
333
334        // Extract indices and distances
335        let (distances, indices): (Vec<_>, Vec<_>) = nearest_neighbors.into_iter().unzip();
336
337        // Return only distances if requested
338        let distances_opt = if return_distance {
339            Some(distances)
340        } else {
341            None
342        };
343
344        Ok((indices, distances_opt))
345    }
346
347    /// Recursively search for k nearest neighbors
348    ///
349    /// # Arguments
350    ///
351    /// * `node_idx` - Index of the current node
352    /// * `point` - Query point
353    /// * `k` - Number of neighbors to find
354    /// * `nearest` - Vector of (distance, index) pairs for nearest neighbors
355    /// * `max_dist` - Maximum distance to consider
356    fn query_recursive(
357        &self,
358        node_idx: usize,
359        point: &[T],
360        k: usize,
361        nearest: &mut Vec<(T, usize)>,
362        max_dist: &mut T,
363    ) {
364        let node = &self.nodes[node_idx];
365
366        // If this node is further than max_dist, skip it
367        let dist_to_centroid = self.distance.distance(point, &node.centroid);
368        if dist_to_centroid > node.radius + *max_dist {
369            return;
370        }
371
372        // If this is a leaf node, check all points
373        if node.left_child.is_none() {
374            for i in node.start_idx..node.endidx {
375                let idx = self.indices[i];
376                let row_vec = self.data.row(idx).to_vec();
377                let _dist = self.distance.distance(point, row_vec.as_slice());
378
379                if _dist < *max_dist || nearest.len() < k {
380                    // Add this point to nearest neighbors
381                    nearest.push((_dist, idx));
382
383                    // If we have more than k points, remove the furthest
384                    if nearest.len() > k {
385                        // Find the index of the point with maximum distance
386                        let max_idx = nearest
387                            .iter()
388                            .enumerate()
389                            .max_by(|(_, a), (_, b)| {
390                                safe_partial_cmp(&a.0, &b.0, "balltree max distance")
391                                    .unwrap_or(Ordering::Equal)
392                            })
393                            .map(|(idx_, _)| idx_)
394                            .unwrap_or(0);
395
396                        // Remove that point
397                        nearest.swap_remove(max_idx);
398
399                        // Update max_dist to the new maximum distance
400                        *max_dist = nearest
401                            .iter()
402                            .map(|(dist_, _)| *dist_)
403                            .max_by(|a, b| {
404                                safe_partial_cmp(a, b, "balltree update max_dist")
405                                    .unwrap_or(Ordering::Equal)
406                            })
407                            .unwrap_or(T::infinity());
408                    }
409                }
410            }
411            return;
412        }
413
414        // Otherwise, recursively search child nodes
415        // Determine which child to search first (closest to the query point)
416        // Get child indices - we know they exist because this is not a leaf node
417        let left_idx = match node.left_child {
418            Some(idx) => idx,
419            None => return, // Should not happen if tree is properly built
420        };
421        let right_idx = match node.right_child {
422            Some(idx) => idx,
423            None => return, // Should not happen if tree is properly built
424        };
425
426        let left_node = &self.nodes[left_idx];
427        let right_node = &self.nodes[right_idx];
428
429        let dist_left = self.distance.distance(point, &left_node.centroid);
430        let dist_right = self.distance.distance(point, &right_node.centroid);
431
432        // Search the closest child first
433        if dist_left <= dist_right {
434            self.query_recursive(left_idx, point, k, nearest, max_dist);
435            self.query_recursive(right_idx, point, k, nearest, max_dist);
436        } else {
437            self.query_recursive(right_idx, point, k, nearest, max_dist);
438            self.query_recursive(left_idx, point, k, nearest, max_dist);
439        }
440    }
441
442    /// Find all points within a given radius of the query point
443    ///
444    /// # Arguments
445    ///
446    /// * `point` - Query point
447    /// * `radius` - Radius to search within
448    /// * `return_distance` - Whether to return distances
449    ///
450    /// # Returns
451    ///
452    /// * `SpatialResult<(Vec<usize>, Option<Vec<T>>)>` - Indices and optionally distances of points within radius
453    pub fn query_radius(
454        &self,
455        point: &[T],
456        radius: T,
457        return_distance: bool,
458    ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
459        if point.len() != self.n_features {
460            return Err(SpatialError::DimensionError(format!(
461                "Query point has {} dimensions, but tree has {} dimensions",
462                point.len(),
463                self.n_features
464            )));
465        }
466
467        if radius < T::zero() {
468            return Err(SpatialError::ValueError(
469                "Radius must be non-negative".to_string(),
470            ));
471        }
472
473        // Collect points within radius
474        let mut result_indices = Vec::new();
475        let mut result_distances = Vec::new();
476
477        // Search the tree recursively
478        self.query_radius_recursive(0, point, radius, &mut result_indices, &mut result_distances);
479
480        // Sort by _distance if needed
481        if !result_indices.is_empty() {
482            let mut idx_dist: Vec<(usize, T)> =
483                result_indices.into_iter().zip(result_distances).collect();
484
485            idx_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
486
487            let (indices, distances): (Vec<_>, Vec<_>) = idx_dist.into_iter().unzip();
488
489            let distances_opt = if return_distance {
490                Some(distances)
491            } else {
492                None
493            };
494
495            Ok((indices, distances_opt))
496        } else {
497            Ok((
498                Vec::new(),
499                if return_distance {
500                    Some(Vec::new())
501                } else {
502                    None
503                },
504            ))
505        }
506    }
507
508    /// Recursively find all points within radius
509    fn query_radius_recursive(
510        &self,
511        node_idx: usize,
512        point: &[T],
513        radius: T,
514        indices: &mut Vec<usize>,
515        distances: &mut Vec<T>,
516    ) {
517        let node = &self.nodes[node_idx];
518
519        // If this node is too far, skip it
520        let dist_to_centroid = self.distance.distance(point, &node.centroid);
521        if dist_to_centroid > node.radius + radius {
522            return;
523        }
524
525        // If this is a leaf node, check all points
526        if node.left_child.is_none() {
527            for i in node.start_idx..node.endidx {
528                let _idx = self.indices[i];
529                let row_vec = self.data.row(_idx).to_vec();
530                let dist = self.distance.distance(point, row_vec.as_slice());
531
532                if dist <= radius {
533                    indices.push(_idx);
534                    distances.push(dist);
535                }
536            }
537            return;
538        }
539
540        // Otherwise, recursively search child nodes
541        let left_idx = match node.left_child {
542            Some(idx) => idx,
543            None => return, // Should not happen if tree is properly built
544        };
545        let right_idx = match node.right_child {
546            Some(idx) => idx,
547            None => return, // Should not happen if tree is properly built
548        };
549
550        self.query_radius_recursive(left_idx, point, radius, indices, distances);
551        self.query_radius_recursive(right_idx, point, radius, indices, distances);
552    }
553
554    /// Find all pairs of points from two trees that are within a given radius
555    ///
556    /// # Arguments
557    ///
558    /// * `other` - Another ball tree
559    /// * `radius` - Radius to search within
560    ///
561    /// # Returns
562    ///
563    /// * `SpatialResult<Vec<(usize, usize)>>` - Pairs of indices (self_idx, other_idx) within radius
564    pub fn query_radius_tree(
565        &self,
566        other: &BallTree<T, D>,
567        radius: T,
568    ) -> SpatialResult<Vec<(usize, usize)>> {
569        if self.n_features != other.n_features {
570            return Err(SpatialError::DimensionError(format!(
571                "Trees have different dimensions: {} and {}",
572                self.n_features, other.n_features
573            )));
574        }
575
576        if radius < T::zero() {
577            return Err(SpatialError::ValueError(
578                "Radius must be non-negative".to_string(),
579            ));
580        }
581
582        let mut pairs = Vec::new();
583
584        self.query_radius_tree_recursive(0, other, 0, radius, &mut pairs);
585
586        Ok(pairs)
587    }
588
589    /// Recursively find all pairs of points from two trees that are within radius
590    fn query_radius_tree_recursive(
591        &self,
592        self_node_idx: usize,
593        other: &BallTree<T, D>,
594        other_node_idx: usize,
595        radius: T,
596        pairs: &mut Vec<(usize, usize)>,
597    ) {
598        let self_node = &self.nodes[self_node_idx];
599        let other_node = &other.nodes[other_node_idx];
600
601        // Calculate minimum distance between nodes
602        let dist_between_centroids = self
603            .distance
604            .distance(&self_node.centroid, &other_node.centroid);
605
606        // If the minimum distance between nodes is greater than radius, we can skip
607        if dist_between_centroids > self_node.radius + other_node.radius + radius {
608            return;
609        }
610
611        // If both are leaf nodes, check all point pairs
612        if self_node.left_child.is_none() && other_node.left_child.is_none() {
613            for i in self_node.start_idx..self_node.endidx {
614                let self_idx = self.indices[i];
615                let self_point = self.data.row(self_idx);
616
617                for j in other_node.start_idx..other_node.endidx {
618                    let other_idx = other.indices[j];
619                    let other_point = other.data.row(other_idx);
620
621                    let self_vec = self_point.to_vec();
622                    let other_vec = other_point.to_vec();
623                    let dist = self
624                        .distance
625                        .distance(self_vec.as_slice(), other_vec.as_slice());
626
627                    if dist <= radius {
628                        pairs.push((self_idx, other_idx));
629                    }
630                }
631            }
632            return;
633        }
634
635        // Otherwise, recursively search child nodes
636        // Split the node with more points
637        if self_node.left_child.is_some()
638            && (other_node.left_child.is_none()
639                || (self_node.endidx - self_node.start_idx)
640                    > (other_node.endidx - other_node.start_idx))
641        {
642            let left_idx = match self_node.left_child {
643                Some(idx) => idx,
644                None => return, // Should not happen
645            };
646            let right_idx = match self_node.right_child {
647                Some(idx) => idx,
648                None => return, // Should not happen
649            };
650
651            self.query_radius_tree_recursive(left_idx, other, other_node_idx, radius, pairs);
652            self.query_radius_tree_recursive(right_idx, other, other_node_idx, radius, pairs);
653        } else if other_node.left_child.is_some() {
654            let left_idx = match other_node.left_child {
655                Some(idx) => idx,
656                None => return, // Should not happen
657            };
658            let right_idx = match other_node.right_child {
659                Some(idx) => idx,
660                None => return, // Should not happen
661            };
662
663            self.query_radius_tree_recursive(self_node_idx, other, left_idx, radius, pairs);
664            self.query_radius_tree_recursive(self_node_idx, other, right_idx, radius, pairs);
665        }
666    }
667
668    /// Get the original data points
669    pub fn get_data(&self) -> &Array2<T> {
670        &self.data
671    }
672
673    /// Get the number of data points
674    pub fn get_n_samples(&self) -> usize {
675        self.n_samples
676    }
677
678    /// Get the dimension of data points
679    pub fn get_n_features(&self) -> usize {
680        self.n_features
681    }
682
683    /// Get the leaf size
684    pub fn get_leaf_size(&self) -> usize {
685        self.leaf_size
686    }
687}
688
689// Implement constructor with default distance metric (Euclidean)
690impl<T: Float + Send + Sync + 'static> BallTree<T, EuclideanDistance<T>> {
691    /// Create a new ball tree with default Euclidean distance metric
692    ///
693    /// # Arguments
694    ///
695    /// * `data` - 2D array of data points (n_samples x n_features)
696    /// * `leaf_size` - Maximum number of points in leaf nodes
697    ///
698    /// # Returns
699    ///
700    /// * `SpatialResult<BallTree<T, EuclideanDistance<T>>>` - A new ball tree
701    pub fn with_euclidean_distance(
702        data: &ArrayView2<T>,
703        leaf_size: usize,
704    ) -> SpatialResult<BallTree<T, EuclideanDistance<T>>> {
705        BallTree::new(data, leaf_size, EuclideanDistance::new())
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use super::BallTree;
712    use crate::distance::euclidean;
713    use approx::assert_relative_eq;
714    use scirs2_core::ndarray::arr2;
715
716    #[test]
717    fn test_ball_tree_construction() {
718        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
719
720        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
721
722        assert_eq!(tree.get_n_samples(), 5);
723        assert_eq!(tree.get_n_features(), 2);
724        assert_eq!(tree.get_leaf_size(), 2);
725    }
726
727    #[test]
728    fn test_ball_tree_nearest_neighbor() {
729        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
730
731        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
732
733        // Test 1-NN
734        let (indices, distances) = tree.query(&[5.1, 5.9], 1, true).unwrap();
735        assert_eq!(indices, vec![2]); // Index of [5.0, 6.0]
736        assert!(distances.is_some());
737        assert_relative_eq!(distances.unwrap()[0], euclidean(&[5.1, 5.9], &[5.0, 6.0]));
738
739        // Test 3-NN
740        let (indices, distances) = tree.query(&[5.1, 5.9], 3, true).unwrap();
741        assert_eq!(indices.len(), 3);
742        assert!(indices.contains(&2)); // Should contain index of [5.0, 6.0]
743        assert!(distances.is_some());
744        assert_eq!(distances.unwrap().len(), 3);
745
746        // Test without distances
747        let (indices, distances) = tree.query(&[5.1, 5.9], 1, false).unwrap();
748        assert_eq!(indices, vec![2]);
749        assert!(distances.is_none());
750    }
751
752    #[test]
753    fn test_ball_tree_radius_search() {
754        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
755
756        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
757
758        // Search with small radius
759        let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 1.0, true).unwrap();
760        assert_eq!(indices.len(), 1);
761        assert_eq!(indices[0], 2); // Only [5.0, 6.0] itself should be within radius 1.0
762
763        // Search with larger radius
764        let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 3.0, true).unwrap();
765        assert!(indices.len() > 1); // Should include neighbors
766
767        // Test without distances
768        let (indices, distances) = tree.query_radius(&[5.0, 6.0], 3.0, false).unwrap();
769        assert!(indices.len() > 1);
770        assert!(distances.is_none());
771    }
772
773    #[test]
774    fn test_ball_tree_dual_tree_search() {
775        let data1 = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
776
777        let data2 = arr2(&[[2.0, 2.0], [4.0, 4.0], [6.0, 6.0]]);
778
779        let tree1 = BallTree::with_euclidean_distance(&data1.view(), 2).unwrap();
780        let tree2 = BallTree::with_euclidean_distance(&data2.view(), 2).unwrap();
781
782        // Test dual tree search with small radius
783        let pairs = tree1.query_radius_tree(&tree2, 1.5).unwrap();
784        assert_eq!(pairs.len(), 3); // Each point in data1 should be close to its corresponding point in data2
785
786        // Test dual tree search with large radius
787        let pairs = tree1.query_radius_tree(&tree2, 10.0).unwrap();
788        assert_eq!(pairs.len(), 9); // All pairs should be within radius 10.0
789    }
790
791    #[test]
792    fn test_ball_tree_empty_input() {
793        let data = arr2(&[[0.0f64; 2]; 0]);
794        let result = BallTree::with_euclidean_distance(&data.view(), 2);
795        assert!(result.is_err());
796    }
797
798    #[test]
799    fn test_ball_tree_dimension_mismatch() {
800        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
801
802        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
803
804        // Query with wrong dimension
805        let result = tree.query(&[1.0], 1, false);
806        assert!(result.is_err());
807
808        let result = tree.query_radius(&[1.0, 2.0, 3.0], 1.0, false);
809        assert!(result.is_err());
810    }
811
812    #[test]
813    fn test_ball_tree_invalid_parameters() {
814        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
815
816        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
817
818        // Query with k > n_samples
819        let result = tree.query(&[1.0, 2.0], 4, false);
820        assert!(result.is_err());
821
822        // Query with negative radius
823        let result = tree.query_radius(&[1.0, 2.0], -1.0, false);
824        assert!(result.is_err());
825    }
826}