scirs2_spatial/
balltree.rs

1//! Ball tree for efficient nearest neighbor searches
2//!
3//! Ball trees are spatial data structures that organize points in a metric space into a tree structure.
4//! Each node represents a hypersphere (ball) that contains a subset of the points.
5//! This implementation shares similarities with KD-tree, but can be more efficient for high-dimensional data
6//! or when using general distance metrics beyond Euclidean.
7//!
8//! ## Features
9//!
10//! * Fast construction of ball trees with customizable leaf size
11//! * Nearest neighbor queries with configurable k
12//! * Range queries to find all points within a distance
13//! * Support for all distance metrics defined in the distance module
14//! * Suitable for high-dimensional data where KD-trees become less efficient
15//!
16//! ## References
17//!
18//! * Omohundro, S.M. (1989) "Five Balltree Construction Algorithms"
19//! * Liu, T. et al. (2006) "An Investigation of Practical Approximate Nearest Neighbor Algorithms"
20//! * scikit-learn ball tree implementation
21
22use crate::distance::{Distance, EuclideanDistance};
23use crate::error::{SpatialError, SpatialResult};
24use ndarray::{Array1, Array2, ArrayView2};
25use num_traits::Float;
26use std::cmp::Ordering;
27use std::marker::PhantomData;
28
29/// A node in the ball tree
30#[derive(Clone, Debug)]
31struct BallTreeNode<T: Float> {
32    /// Index of the start of the points contained in this node
33    start_idx: usize,
34
35    /// Index of the end of the points contained in this node
36    end_idx: usize,
37
38    /// Centroid of the points in this node (center of the ball)
39    centroid: Vec<T>,
40
41    /// Radius of the ball that contains all points in this node
42    radius: T,
43
44    /// Index of the left child node
45    left_child: Option<usize>,
46
47    /// Index of the right child node
48    right_child: Option<usize>,
49}
50
51/// Ball tree for efficient nearest neighbor searches
52///
53/// The ball tree partitions data into a set of nested hyperspheres (balls), which allows
54/// for efficient nearest neighbor searches, especially in high-dimensional spaces.
55/// Each node in the tree represents a ball containing a subset of the points.
56///
57/// # Type Parameters
58///
59/// * `T`: Floating point type (f32 or f64)
60/// * `D`: Distance metric that implements the [`Distance`] trait
61#[derive(Clone, Debug)]
62pub struct BallTree<T: Float + Send + Sync, D: Distance<T>> {
63    /// Points stored in the ball tree
64    data: Array2<T>,
65
66    /// Indices of points in the original array, reordered during tree construction
67    indices: Array1<usize>,
68
69    /// Nodes in the ball tree
70    nodes: Vec<BallTreeNode<T>>,
71
72    /// Number of data points
73    n_samples: usize,
74
75    /// Dimension of data points
76    n_features: usize,
77
78    /// Maximum number of points in leaf nodes
79    leaf_size: usize,
80
81    /// Distance metric to use
82    distance: D,
83
84    /// Phantom data for the float type
85    _phantom: PhantomData<T>,
86}
87
88impl<T: Float + Send + Sync + 'static, D: Distance<T> + Send + Sync + 'static> BallTree<T, D> {
89    /// Create a new ball tree from the given data points
90    ///
91    /// # Arguments
92    ///
93    /// * `data` - 2D array of data points (n_samples x n_features)
94    /// * `leaf_size` - Maximum number of points in leaf nodes
95    /// * `distance` - Distance metric to use
96    ///
97    /// # Returns
98    ///
99    /// * `SpatialResult<BallTree<T, D>>` - A new ball tree
100    pub fn new(
101        data: &ArrayView2<T>,
102        leaf_size: usize,
103        distance: D,
104    ) -> SpatialResult<BallTree<T, D>> {
105        let n_samples = data.nrows();
106        let n_features = data.ncols();
107
108        if n_samples == 0 {
109            return Err(SpatialError::ValueError(
110                "Input data array is empty".to_string(),
111            ));
112        }
113
114        // Clone the data array and create an array of indices
115        let data = data.to_owned();
116        let indices = Array1::from_iter(0..n_samples);
117
118        // Initialize empty nodes vector (will be filled during build)
119        let nodes = Vec::new();
120
121        let mut ball_tree = BallTree {
122            data,
123            indices,
124            nodes,
125            n_samples,
126            n_features,
127            leaf_size,
128            distance,
129            _phantom: PhantomData,
130        };
131
132        // Build the tree
133        ball_tree.build_tree()?;
134
135        Ok(ball_tree)
136    }
137
138    /// Build the ball tree recursively
139    ///
140    /// This initializes the tree structure and builds the nodes.
141    fn build_tree(&mut self) -> SpatialResult<()> {
142        if self.n_samples == 0 {
143            return Ok(());
144        }
145
146        // Reserve space for the nodes (maximum nodes = 2*n_samples - 1)
147        self.nodes = Vec::with_capacity(2 * self.n_samples);
148
149        // Build the tree recursively
150        self.build_subtree(0, self.n_samples)?;
151
152        Ok(())
153    }
154
155    /// Build a subtree recursively
156    ///
157    /// # Arguments
158    ///
159    /// * `start_idx` - Start index of points for this subtree
160    /// * `end_idx` - End index of points for this subtree
161    ///
162    /// # Returns
163    ///
164    /// * `SpatialResult<usize>` - Index of the root node of the subtree
165    fn build_subtree(&mut self, start_idx: usize, end_idx: usize) -> SpatialResult<usize> {
166        let n_points = end_idx - start_idx;
167
168        // Calculate centroid of points in this node
169        let mut centroid = vec![T::zero(); self.n_features];
170        for i in start_idx..end_idx {
171            let point_idx = self.indices[i];
172            let point = self.data.row(point_idx);
173
174            for (j, &val) in point.iter().take(self.n_features).enumerate() {
175                centroid[j] = centroid[j] + val;
176            }
177        }
178
179        for val in centroid.iter_mut().take(self.n_features) {
180            *val = *val / T::from(n_points).unwrap();
181        }
182
183        // Calculate radius (maximum distance from centroid to any point)
184        let mut radius = T::zero();
185        for i in start_idx..end_idx {
186            let point_idx = self.indices[i];
187            let point = self.data.row(point_idx);
188
189            let dist = self.distance.distance(&centroid, point.as_slice().unwrap());
190
191            if dist > radius {
192                radius = dist;
193            }
194        }
195
196        // Create node
197        let node_idx = self.nodes.len();
198        let node = BallTreeNode {
199            start_idx,
200            end_idx,
201            centroid,
202            radius,
203            left_child: None,
204            right_child: None,
205        };
206
207        self.nodes.push(node);
208
209        // If this is a leaf node (n_points <= leaf_size), we're done
210        if n_points <= self.leaf_size {
211            return Ok(node_idx);
212        }
213
214        // Otherwise, split the points and recursively build subtrees
215        // We'll split along the direction of maximum variance
216        self.split_points(node_idx, start_idx, end_idx)?;
217
218        // Recursively build left and right subtrees
219        let mid_idx = start_idx + n_points / 2;
220
221        let left_idx = self.build_subtree(start_idx, mid_idx)?;
222        let right_idx = self.build_subtree(mid_idx, end_idx)?;
223
224        // Update node with child indices
225        self.nodes[node_idx].left_child = Some(left_idx);
226        self.nodes[node_idx].right_child = Some(right_idx);
227
228        Ok(node_idx)
229    }
230
231    /// Split the points in a node into two groups
232    ///
233    /// This method partitions the points in a node into two groups,
234    /// attempting to create a balanced split.
235    ///
236    /// # Arguments
237    ///
238    /// * `node_idx` - Index of the node to split
239    /// * `start_idx` - Start index of points in the node
240    /// * `end_idx` - End index of points in the node
241    ///
242    /// # Returns
243    ///
244    /// * `SpatialResult<()>` - Result of the split operation
245    fn split_points(
246        &mut self,
247        node_idx: usize,
248        start_idx: usize,
249        end_idx: usize,
250    ) -> SpatialResult<()> {
251        // Find the dimension with the largest variance
252        let node = &self.nodes[node_idx];
253        let centroid = &node.centroid;
254
255        // Calculate distances from centroid to all points
256        let mut distances: Vec<(usize, T)> = (start_idx..end_idx)
257            .map(|i| {
258                let point_idx = self.indices[i];
259                let point = self.data.row(point_idx);
260                let dist = self.distance.distance(centroid, point.as_slice().unwrap());
261                (i, dist)
262            })
263            .collect();
264
265        // Sort points by distance from centroid
266        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
267
268        // Reorder indices array based on sorted distances
269        // Midpoint is calculated but used implicitly when we reorder the indices
270        let _mid_idx = start_idx + (end_idx - start_idx) / 2;
271        let mut new_indices = Vec::with_capacity(end_idx - start_idx);
272
273        for (i, _) in distances {
274            new_indices.push(self.indices[i]);
275        }
276
277        for (i, idx) in new_indices.into_iter().enumerate() {
278            self.indices[start_idx + i] = idx;
279        }
280
281        Ok(())
282    }
283
284    /// Query the k nearest neighbors to the given point
285    ///
286    /// # Arguments
287    ///
288    /// * `point` - Query point
289    /// * `k` - Number of neighbors to find
290    /// * `return_distance` - Whether to return distances
291    ///
292    /// # Returns
293    ///
294    /// * `SpatialResult<(Vec<usize>, Option<Vec<T>>)>` - Indices and optionally distances of the k nearest neighbors
295    pub fn query(
296        &self,
297        point: &[T],
298        k: usize,
299        return_distance: bool,
300    ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
301        if point.len() != self.n_features {
302            return Err(SpatialError::DimensionError(format!(
303                "Query point has {} dimensions, but tree has {} dimensions",
304                point.len(),
305                self.n_features
306            )));
307        }
308
309        if k > self.n_samples {
310            return Err(SpatialError::ValueError(format!(
311                "k ({}) cannot be greater than the number of samples ({})",
312                k, self.n_samples
313            )));
314        }
315
316        // Store up to k nearest neighbors and their distances
317        let mut nearest_neighbors = Vec::<(T, usize)>::with_capacity(k);
318        let mut max_dist = T::infinity();
319
320        // Perform the recursive search
321        self.query_recursive(0, point, k, &mut nearest_neighbors, &mut max_dist);
322
323        // Sort by distance
324        nearest_neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
325
326        // Extract indices and distances
327        let (distances, indices): (Vec<_>, Vec<_>) = nearest_neighbors.into_iter().unzip();
328
329        // Return only distances if requested
330        let distances_opt = if return_distance {
331            Some(distances)
332        } else {
333            None
334        };
335
336        Ok((indices, distances_opt))
337    }
338
339    /// Recursively search for k nearest neighbors
340    ///
341    /// # Arguments
342    ///
343    /// * `node_idx` - Index of the current node
344    /// * `point` - Query point
345    /// * `k` - Number of neighbors to find
346    /// * `nearest` - Vector of (distance, index) pairs for nearest neighbors
347    /// * `max_dist` - Maximum distance to consider
348    fn query_recursive(
349        &self,
350        node_idx: usize,
351        point: &[T],
352        k: usize,
353        nearest: &mut Vec<(T, usize)>,
354        max_dist: &mut T,
355    ) {
356        let node = &self.nodes[node_idx];
357
358        // If this node is further than max_dist, skip it
359        let dist_to_centroid = self.distance.distance(point, &node.centroid);
360        if dist_to_centroid > node.radius + *max_dist {
361            return;
362        }
363
364        // If this is a leaf node, check all points
365        if node.left_child.is_none() {
366            for i in node.start_idx..node.end_idx {
367                let idx = self.indices[i];
368                let dist = self
369                    .distance
370                    .distance(point, self.data.row(idx).as_slice().unwrap());
371
372                if dist < *max_dist || nearest.len() < k {
373                    // Add this point to nearest neighbors
374                    nearest.push((dist, idx));
375
376                    // If we have more than k points, remove the furthest
377                    if nearest.len() > k {
378                        // Find the index of the point with maximum distance
379                        let max_idx = nearest
380                            .iter()
381                            .enumerate()
382                            .max_by(|(_, a), (_, b)| {
383                                a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)
384                            })
385                            .map(|(idx, _)| idx)
386                            .unwrap();
387
388                        // Remove that point
389                        nearest.swap_remove(max_idx);
390
391                        // Update max_dist to the new maximum distance
392                        *max_dist = nearest
393                            .iter()
394                            .map(|(dist, _)| *dist)
395                            .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
396                            .unwrap_or(T::infinity());
397                    }
398                }
399            }
400            return;
401        }
402
403        // Otherwise, recursively search child nodes
404        // Determine which child to search first (closest to the query point)
405        let left_idx = node.left_child.unwrap();
406        let right_idx = node.right_child.unwrap();
407
408        let left_node = &self.nodes[left_idx];
409        let right_node = &self.nodes[right_idx];
410
411        let dist_left = self.distance.distance(point, &left_node.centroid);
412        let dist_right = self.distance.distance(point, &right_node.centroid);
413
414        // Search the closest child first
415        if dist_left <= dist_right {
416            self.query_recursive(left_idx, point, k, nearest, max_dist);
417            self.query_recursive(right_idx, point, k, nearest, max_dist);
418        } else {
419            self.query_recursive(right_idx, point, k, nearest, max_dist);
420            self.query_recursive(left_idx, point, k, nearest, max_dist);
421        }
422    }
423
424    /// Find all points within a given radius of the query point
425    ///
426    /// # Arguments
427    ///
428    /// * `point` - Query point
429    /// * `radius` - Radius to search within
430    /// * `return_distance` - Whether to return distances
431    ///
432    /// # Returns
433    ///
434    /// * `SpatialResult<(Vec<usize>, Option<Vec<T>>)>` - Indices and optionally distances of points within radius
435    pub fn query_radius(
436        &self,
437        point: &[T],
438        radius: T,
439        return_distance: bool,
440    ) -> SpatialResult<(Vec<usize>, Option<Vec<T>>)> {
441        if point.len() != self.n_features {
442            return Err(SpatialError::DimensionError(format!(
443                "Query point has {} dimensions, but tree has {} dimensions",
444                point.len(),
445                self.n_features
446            )));
447        }
448
449        if radius < T::zero() {
450            return Err(SpatialError::ValueError(
451                "Radius must be non-negative".to_string(),
452            ));
453        }
454
455        // Collect points within radius
456        let mut result_indices = Vec::new();
457        let mut result_distances = Vec::new();
458
459        // Search the tree recursively
460        self.query_radius_recursive(0, point, radius, &mut result_indices, &mut result_distances);
461
462        // Sort by distance if needed
463        if !result_indices.is_empty() {
464            let mut idx_dist: Vec<(usize, T)> =
465                result_indices.into_iter().zip(result_distances).collect();
466
467            idx_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
468
469            let (indices, distances): (Vec<_>, Vec<_>) = idx_dist.into_iter().unzip();
470
471            let distances_opt = if return_distance {
472                Some(distances)
473            } else {
474                None
475            };
476
477            Ok((indices, distances_opt))
478        } else {
479            Ok((
480                Vec::new(),
481                if return_distance {
482                    Some(Vec::new())
483                } else {
484                    None
485                },
486            ))
487        }
488    }
489
490    /// Recursively find all points within radius
491    fn query_radius_recursive(
492        &self,
493        node_idx: usize,
494        point: &[T],
495        radius: T,
496        indices: &mut Vec<usize>,
497        distances: &mut Vec<T>,
498    ) {
499        let node = &self.nodes[node_idx];
500
501        // If this node is too far, skip it
502        let dist_to_centroid = self.distance.distance(point, &node.centroid);
503        if dist_to_centroid > node.radius + radius {
504            return;
505        }
506
507        // If this is a leaf node, check all points
508        if node.left_child.is_none() {
509            for i in node.start_idx..node.end_idx {
510                let idx = self.indices[i];
511                let dist = self
512                    .distance
513                    .distance(point, self.data.row(idx).as_slice().unwrap());
514
515                if dist <= radius {
516                    indices.push(idx);
517                    distances.push(dist);
518                }
519            }
520            return;
521        }
522
523        // Otherwise, recursively search child nodes
524        let left_idx = node.left_child.unwrap();
525        let right_idx = node.right_child.unwrap();
526
527        self.query_radius_recursive(left_idx, point, radius, indices, distances);
528        self.query_radius_recursive(right_idx, point, radius, indices, distances);
529    }
530
531    /// Find all pairs of points from two trees that are within a given radius
532    ///
533    /// # Arguments
534    ///
535    /// * `other` - Another ball tree
536    /// * `radius` - Radius to search within
537    ///
538    /// # Returns
539    ///
540    /// * `SpatialResult<Vec<(usize, usize)>>` - Pairs of indices (self_idx, other_idx) within radius
541    pub fn query_radius_tree(
542        &self,
543        other: &BallTree<T, D>,
544        radius: T,
545    ) -> SpatialResult<Vec<(usize, usize)>> {
546        if self.n_features != other.n_features {
547            return Err(SpatialError::DimensionError(format!(
548                "Trees have different dimensions: {} and {}",
549                self.n_features, other.n_features
550            )));
551        }
552
553        if radius < T::zero() {
554            return Err(SpatialError::ValueError(
555                "Radius must be non-negative".to_string(),
556            ));
557        }
558
559        let mut pairs = Vec::new();
560
561        self.query_radius_tree_recursive(0, other, 0, radius, &mut pairs);
562
563        Ok(pairs)
564    }
565
566    /// Recursively find all pairs of points from two trees that are within radius
567    fn query_radius_tree_recursive(
568        &self,
569        self_node_idx: usize,
570        other: &BallTree<T, D>,
571        other_node_idx: usize,
572        radius: T,
573        pairs: &mut Vec<(usize, usize)>,
574    ) {
575        let self_node = &self.nodes[self_node_idx];
576        let other_node = &other.nodes[other_node_idx];
577
578        // Calculate minimum distance between nodes
579        let dist_between_centroids = self
580            .distance
581            .distance(&self_node.centroid, &other_node.centroid);
582
583        // If the minimum distance between nodes is greater than radius, we can skip
584        if dist_between_centroids > self_node.radius + other_node.radius + radius {
585            return;
586        }
587
588        // If both are leaf nodes, check all point pairs
589        if self_node.left_child.is_none() && other_node.left_child.is_none() {
590            for i in self_node.start_idx..self_node.end_idx {
591                let self_idx = self.indices[i];
592                let self_point = self.data.row(self_idx);
593
594                for j in other_node.start_idx..other_node.end_idx {
595                    let other_idx = other.indices[j];
596                    let other_point = other.data.row(other_idx);
597
598                    let dist = self.distance.distance(
599                        self_point.as_slice().unwrap(),
600                        other_point.as_slice().unwrap(),
601                    );
602
603                    if dist <= radius {
604                        pairs.push((self_idx, other_idx));
605                    }
606                }
607            }
608            return;
609        }
610
611        // Otherwise, recursively search child nodes
612        // Split the node with more points
613        if self_node.left_child.is_some()
614            && (other_node.left_child.is_none()
615                || (self_node.end_idx - self_node.start_idx)
616                    > (other_node.end_idx - other_node.start_idx))
617        {
618            let left_idx = self_node.left_child.unwrap();
619            let right_idx = self_node.right_child.unwrap();
620
621            self.query_radius_tree_recursive(left_idx, other, other_node_idx, radius, pairs);
622            self.query_radius_tree_recursive(right_idx, other, other_node_idx, radius, pairs);
623        } else if other_node.left_child.is_some() {
624            let left_idx = other_node.left_child.unwrap();
625            let right_idx = other_node.right_child.unwrap();
626
627            self.query_radius_tree_recursive(self_node_idx, other, left_idx, radius, pairs);
628            self.query_radius_tree_recursive(self_node_idx, other, right_idx, radius, pairs);
629        }
630    }
631
632    /// Get the original data points
633    pub fn get_data(&self) -> &Array2<T> {
634        &self.data
635    }
636
637    /// Get the number of data points
638    pub fn get_n_samples(&self) -> usize {
639        self.n_samples
640    }
641
642    /// Get the dimension of data points
643    pub fn get_n_features(&self) -> usize {
644        self.n_features
645    }
646
647    /// Get the leaf size
648    pub fn get_leaf_size(&self) -> usize {
649        self.leaf_size
650    }
651}
652
653// Implement constructor with default distance metric (Euclidean)
654impl<T: Float + Send + Sync + 'static> BallTree<T, EuclideanDistance<T>> {
655    /// Create a new ball tree with default Euclidean distance metric
656    ///
657    /// # Arguments
658    ///
659    /// * `data` - 2D array of data points (n_samples x n_features)
660    /// * `leaf_size` - Maximum number of points in leaf nodes
661    ///
662    /// # Returns
663    ///
664    /// * `SpatialResult<BallTree<T, EuclideanDistance<T>>>` - A new ball tree
665    pub fn with_euclidean_distance(
666        data: &ArrayView2<T>,
667        leaf_size: usize,
668    ) -> SpatialResult<BallTree<T, EuclideanDistance<T>>> {
669        BallTree::new(data, leaf_size, EuclideanDistance::new())
670    }
671}
672
673#[cfg(test)]
674mod tests {
675    use super::*;
676    use crate::distance::euclidean;
677    use approx::assert_relative_eq;
678    use ndarray::arr2;
679
680    #[test]
681    fn test_ball_tree_construction() {
682        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
683
684        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
685
686        assert_eq!(tree.get_n_samples(), 5);
687        assert_eq!(tree.get_n_features(), 2);
688        assert_eq!(tree.get_leaf_size(), 2);
689    }
690
691    #[test]
692    fn test_ball_tree_nearest_neighbor() {
693        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
694
695        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
696
697        // Test 1-NN
698        let (indices, distances) = tree.query(&[5.1, 5.9], 1, true).unwrap();
699        assert_eq!(indices, vec![2]); // Index of [5.0, 6.0]
700        assert!(distances.is_some());
701        assert_relative_eq!(distances.unwrap()[0], euclidean(&[5.1, 5.9], &[5.0, 6.0]));
702
703        // Test 3-NN
704        let (indices, distances) = tree.query(&[5.1, 5.9], 3, true).unwrap();
705        assert_eq!(indices.len(), 3);
706        assert!(indices.contains(&2)); // Should contain index of [5.0, 6.0]
707        assert!(distances.is_some());
708        assert_eq!(distances.unwrap().len(), 3);
709
710        // Test without distances
711        let (indices, distances) = tree.query(&[5.1, 5.9], 1, false).unwrap();
712        assert_eq!(indices, vec![2]);
713        assert!(distances.is_none());
714    }
715
716    #[test]
717    fn test_ball_tree_radius_search() {
718        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]);
719
720        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
721
722        // Search with small radius
723        let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 1.0, true).unwrap();
724        assert_eq!(indices.len(), 1);
725        assert_eq!(indices[0], 2); // Only [5.0, 6.0] itself should be within radius 1.0
726
727        // Search with larger radius
728        let (indices, _distances) = tree.query_radius(&[5.0, 6.0], 3.0, true).unwrap();
729        assert!(indices.len() > 1); // Should include neighbors
730
731        // Test without distances
732        let (indices, distances) = tree.query_radius(&[5.0, 6.0], 3.0, false).unwrap();
733        assert!(indices.len() > 1);
734        assert!(distances.is_none());
735    }
736
737    #[test]
738    fn test_ball_tree_dual_tree_search() {
739        let data1 = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
740
741        let data2 = arr2(&[[2.0, 2.0], [4.0, 4.0], [6.0, 6.0]]);
742
743        let tree1 = BallTree::with_euclidean_distance(&data1.view(), 2).unwrap();
744        let tree2 = BallTree::with_euclidean_distance(&data2.view(), 2).unwrap();
745
746        // Test dual tree search with small radius
747        let pairs = tree1.query_radius_tree(&tree2, 1.5).unwrap();
748        assert_eq!(pairs.len(), 3); // Each point in data1 should be close to its corresponding point in data2
749
750        // Test dual tree search with large radius
751        let pairs = tree1.query_radius_tree(&tree2, 10.0).unwrap();
752        assert_eq!(pairs.len(), 9); // All pairs should be within radius 10.0
753    }
754
755    #[test]
756    fn test_ball_tree_empty_input() {
757        let data = arr2(&[[0.0f64; 2]; 0]);
758        let result = BallTree::with_euclidean_distance(&data.view(), 2);
759        assert!(result.is_err());
760    }
761
762    #[test]
763    fn test_ball_tree_dimension_mismatch() {
764        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
765
766        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
767
768        // Query with wrong dimension
769        let result = tree.query(&[1.0], 1, false);
770        assert!(result.is_err());
771
772        let result = tree.query_radius(&[1.0, 2.0, 3.0], 1.0, false);
773        assert!(result.is_err());
774    }
775
776    #[test]
777    fn test_ball_tree_invalid_parameters() {
778        let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
779
780        let tree = BallTree::with_euclidean_distance(&data.view(), 2).unwrap();
781
782        // Query with k > n_samples
783        let result = tree.query(&[1.0, 2.0], 4, false);
784        assert!(result.is_err());
785
786        // Query with negative radius
787        let result = tree.query_radius(&[1.0, 2.0], -1.0, false);
788        assert!(result.is_err());
789    }
790}