scirs2_cluster/
neighbor_search.rs

1//! Efficient neighbor search algorithms for clustering
2//!
3//! This module provides various algorithms for fast nearest neighbor search,
4//! which are crucial for density-based clustering and other algorithms that
5//! require neighborhood computations.
6
7use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::BinaryHeap;
10use std::fmt::Debug;
11
12use crate::error::{ClusteringError, Result};
13
14/// Configuration for neighbor search algorithms
15#[derive(Debug, Clone)]
16pub struct NeighborSearchConfig {
17    /// Algorithm to use for neighbor search
18    pub algorithm: NeighborSearchAlgorithm,
19    /// Leaf size for tree-based algorithms
20    pub leaf_size: usize,
21    /// Number of hash tables for LSH
22    pub n_hash_tables: usize,
23    /// Number of hash functions per table for LSH
24    pub n_hash_functions: usize,
25    /// Whether to use parallel processing
26    pub parallel: bool,
27}
28
29impl Default for NeighborSearchConfig {
30    fn default() -> Self {
31        Self {
32            algorithm: NeighborSearchAlgorithm::Auto,
33            leaf_size: 30,
34            n_hash_tables: 10,
35            n_hash_functions: 4,
36            parallel: true,
37        }
38    }
39}
40
41/// Available neighbor search algorithms
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum NeighborSearchAlgorithm {
44    /// Automatically choose the best algorithm based on data characteristics
45    Auto,
46    /// Brute force search (exact, O(n²) time)
47    BruteForce,
48    /// KD-Tree search (good for low dimensions)
49    KDTree,
50    /// Ball Tree search (good for high dimensions)
51    BallTree,
52    /// Locality Sensitive Hashing (approximate, very fast)
53    LSH,
54}
55
56/// Neighbor search result
57#[derive(Debug, Clone)]
58pub struct NeighborResult {
59    /// Indices of the neighbors
60    pub indices: Vec<usize>,
61    /// Distances to the neighbors
62    pub distances: Vec<f64>,
63}
64
65/// Trait for neighbor search implementations
66pub trait NeighborSearcher<F: Float> {
67    /// Build the search structure from data
68    fn fit(&mut self, data: ArrayView2<F>) -> Result<()>;
69
70    /// Find k nearest neighbors for a query point
71    fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult>;
72
73    /// Find all neighbors within radius
74    fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult>;
75
76    /// Find k nearest neighbors for multiple query points
77    fn kneighbors_batch(&self, queries: ArrayView2<F>, k: usize) -> Result<Vec<NeighborResult>> {
78        let mut results = Vec::new();
79        for query in queries.outer_iter() {
80            results.push(self.kneighbors(query, k)?);
81        }
82        Ok(results)
83    }
84}
85
86/// KD-Tree implementation for fast nearest neighbor search
87///
88/// Works best for low-dimensional data (typically < 20 dimensions).
89/// Uses spatial partitioning to achieve O(log n) average query time.
90#[derive(Debug)]
91pub struct KDTree<F: Float> {
92    data: Option<Array2<F>>,
93    tree: Option<KDNode>,
94    leaf_size: usize,
95}
96
97#[derive(Debug, Clone)]
98struct KDNode {
99    /// Indices of points in this node
100    indices: Vec<usize>,
101    /// Splitting dimension
102    split_dim: usize,
103    /// Splitting value
104    split_val: f64,
105    /// Left child (points with split_dim < split_val)
106    left: Option<Box<KDNode>>,
107    /// Right child (points with split_dim >= split_val)
108    right: Option<Box<KDNode>>,
109    /// Whether this is a leaf node
110    is_leaf: bool,
111}
112
113impl<F: Float + FromPrimitive + Debug> KDTree<F> {
114    /// Create a new KD-Tree
115    pub fn new(leaf_size: usize) -> Self {
116        Self {
117            data: None,
118            tree: None,
119            leaf_size,
120        }
121    }
122}
123
124impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for KDTree<F> {
125    fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
126        let n_samples = data.shape()[0];
127        let n_features = data.shape()[1];
128
129        if n_samples == 0 {
130            return Err(ClusteringError::InvalidInput(
131                "Cannot fit on empty data".into(),
132            ));
133        }
134
135        // Store the data
136        self.data = Some(data.to_owned());
137
138        // Build the tree
139        let indices: Vec<usize> = (0..n_samples).collect();
140        self.tree = Some(self.build_tree(indices, 0, n_features)?);
141
142        Ok(())
143    }
144
145    fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
146        let data = self
147            .data
148            .as_ref()
149            .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
150
151        let tree = self
152            .tree
153            .as_ref()
154            .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
155
156        if k == 0 {
157            return Ok(NeighborResult {
158                indices: vec![],
159                distances: vec![],
160            });
161        }
162
163        let mut heap = BinaryHeap::new();
164        self.search_knn(tree, query, k, data.view(), &mut heap);
165
166        // Extract results from heap (in reverse order)
167        let mut indices = Vec::new();
168        let mut distances = Vec::new();
169
170        while let Some(neighbor) = heap.pop() {
171            indices.push(neighbor.index);
172            distances.push(neighbor.distance);
173        }
174
175        // Reverse to get nearest first
176        indices.reverse();
177        distances.reverse();
178
179        Ok(NeighborResult { indices, distances })
180    }
181
182    fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
183        let data = self
184            .data
185            .as_ref()
186            .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
187
188        let tree = self
189            .tree
190            .as_ref()
191            .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
192
193        let mut result = NeighborResult {
194            indices: Vec::new(),
195            distances: Vec::new(),
196        };
197
198        let radius_f64 = radius.to_f64().unwrap_or(0.0);
199        self.search_radius(tree, query, radius_f64, data.view(), &mut result);
200
201        Ok(result)
202    }
203}
204
205impl<F: Float + FromPrimitive + Debug> KDTree<F> {
206    fn build_tree(
207        &self,
208        mut indices: Vec<usize>,
209        depth: usize,
210        n_features: usize,
211    ) -> Result<KDNode> {
212        if indices.len() <= self.leaf_size {
213            return Ok(KDNode {
214                indices,
215                split_dim: 0,
216                split_val: 0.0,
217                left: None,
218                right: None,
219                is_leaf: true,
220            });
221        }
222
223        let data = self.data.as_ref().unwrap();
224
225        // Choose splitting dimension (cycling through dimensions)
226        let split_dim = depth % n_features;
227
228        // Sort indices by the splitting dimension
229        indices.sort_by(|&a, &b| {
230            let val_a = data[[a, split_dim]].to_f64().unwrap_or(0.0);
231            let val_b = data[[b, split_dim]].to_f64().unwrap_or(0.0);
232            val_a
233                .partial_cmp(&val_b)
234                .unwrap_or(std::cmp::Ordering::Equal)
235        });
236
237        // Find median
238        let median_idx = indices.len() / 2;
239        let split_val = data[[indices[median_idx], split_dim]]
240            .to_f64()
241            .unwrap_or(0.0);
242
243        // Split indices
244        let left_indices = indices[..median_idx].to_vec();
245        let right_indices = indices[median_idx..].to_vec();
246
247        // Recursively build children
248        let left = if !left_indices.is_empty() {
249            Some(Box::new(self.build_tree(
250                left_indices,
251                depth + 1,
252                n_features,
253            )?))
254        } else {
255            None
256        };
257
258        let right = if !right_indices.is_empty() {
259            Some(Box::new(self.build_tree(
260                right_indices,
261                depth + 1,
262                n_features,
263            )?))
264        } else {
265            None
266        };
267
268        Ok(KDNode {
269            indices: vec![], // Internal nodes don't store indices
270            split_dim,
271            split_val,
272            left,
273            right,
274            is_leaf: false,
275        })
276    }
277
278    #[allow(clippy::only_used_in_recursion)]
279    fn search_knn(
280        &self,
281        node: &KDNode,
282        query: ArrayView1<F>,
283        k: usize,
284        data: ArrayView2<F>,
285        heap: &mut BinaryHeap<NeighborCandidate>,
286    ) {
287        if node.is_leaf {
288            // Check all points in this leaf
289            for &idx in &node.indices {
290                let dist = euclidean_distance(query, data.row(idx));
291
292                if heap.len() < k {
293                    heap.push(NeighborCandidate {
294                        distance: dist,
295                        index: idx,
296                    });
297                } else if let Some(top) = heap.peek() {
298                    if dist < top.distance {
299                        heap.pop();
300                        heap.push(NeighborCandidate {
301                            distance: dist,
302                            index: idx,
303                        });
304                    }
305                }
306            }
307        } else {
308            // Determine which child to visit first
309            let query_val = query[node.split_dim].to_f64().unwrap_or(0.0);
310            let (first_child, second_child) = if query_val < node.split_val {
311                (&node.left, &node.right)
312            } else {
313                (&node.right, &node.left)
314            };
315
316            // Search the first child
317            if let Some(child) = first_child {
318                self.search_knn(child, query, k, data, heap);
319            }
320
321            // Check if we need to search the second child
322            let split_dist = (query_val - node.split_val).abs();
323            if heap.len() < k || heap.peek().is_none_or(|top| split_dist < top.distance) {
324                if let Some(child) = second_child {
325                    self.search_knn(child, query, k, data, heap);
326                }
327            }
328        }
329    }
330
331    #[allow(clippy::only_used_in_recursion)]
332    fn search_radius(
333        &self,
334        node: &KDNode,
335        query: ArrayView1<F>,
336        radius: f64,
337        data: ArrayView2<F>,
338        result: &mut NeighborResult,
339    ) {
340        if node.is_leaf {
341            // Check all points in this leaf
342            for &idx in &node.indices {
343                let dist = euclidean_distance(query, data.row(idx));
344
345                if dist <= radius {
346                    result.indices.push(idx);
347                    result.distances.push(dist);
348                }
349            }
350        } else {
351            // Check if the splitting hyperplane intersects the query sphere
352            let query_val = query[node.split_dim].to_f64().unwrap_or(0.0);
353            let split_dist = (query_val - node.split_val).abs();
354
355            // Search children that might contain neighbors
356            if query_val < node.split_val {
357                if let Some(child) = &node.left {
358                    self.search_radius(child, query, radius, data, result);
359                }
360                if split_dist <= radius {
361                    if let Some(child) = &node.right {
362                        self.search_radius(child, query, radius, data, result);
363                    }
364                }
365            } else {
366                if let Some(child) = &node.right {
367                    self.search_radius(child, query, radius, data, result);
368                }
369                if split_dist <= radius {
370                    if let Some(child) = &node.left {
371                        self.search_radius(child, query, radius, data, result);
372                    }
373                }
374            }
375        }
376    }
377}
378
379/// Neighbor candidate for priority queue
380#[derive(Debug, Clone, PartialEq)]
381struct NeighborCandidate {
382    distance: f64,
383    index: usize,
384}
385
386impl Eq for NeighborCandidate {}
387
388impl Ord for NeighborCandidate {
389    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
390        self.distance
391            .partial_cmp(&other.distance)
392            .unwrap_or(std::cmp::Ordering::Equal)
393    }
394}
395
396impl PartialOrd for NeighborCandidate {
397    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
398        Some(self.cmp(other))
399    }
400}
401
402/// Brute force neighbor search (exact but slow)
403///
404/// Uses direct distance computation between all pairs of points.
405/// Guaranteed to find exact neighbors but has O(n²) time complexity.
406#[derive(Debug)]
407pub struct BruteForceSearch<F: Float> {
408    data: Option<Array2<F>>,
409}
410
411impl<F: Float + FromPrimitive + Debug> BruteForceSearch<F> {
412    /// Create a new brute force searcher
413    pub fn new() -> Self {
414        Self { data: None }
415    }
416}
417
418impl<F: Float + FromPrimitive + Debug> Default for BruteForceSearch<F> {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for BruteForceSearch<F> {
425    fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
426        if data.shape()[0] == 0 {
427            return Err(ClusteringError::InvalidInput(
428                "Cannot fit on empty data".into(),
429            ));
430        }
431
432        self.data = Some(data.to_owned());
433        Ok(())
434    }
435
436    fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
437        let data = self
438            .data
439            .as_ref()
440            .ok_or_else(|| ClusteringError::InvalidInput("Searcher not fitted yet".into()))?;
441
442        if k == 0 {
443            return Ok(NeighborResult {
444                indices: vec![],
445                distances: vec![],
446            });
447        }
448
449        let n_samples = data.shape()[0];
450        let k_actual = k.min(n_samples);
451
452        // Calculate all distances
453        let mut candidates: Vec<NeighborCandidate> = (0..n_samples)
454            .map(|i| {
455                let dist = euclidean_distance(query, data.row(i));
456                NeighborCandidate {
457                    distance: dist,
458                    index: i,
459                }
460            })
461            .collect();
462
463        // Sort by distance and take k nearest
464        candidates.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
465        candidates.truncate(k_actual);
466
467        let indices = candidates.iter().map(|c| c.index).collect();
468        let distances = candidates.iter().map(|c| c.distance).collect();
469
470        Ok(NeighborResult { indices, distances })
471    }
472
473    fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
474        let data = self
475            .data
476            .as_ref()
477            .ok_or_else(|| ClusteringError::InvalidInput("Searcher not fitted yet".into()))?;
478
479        let radius_f64 = radius.to_f64().unwrap_or(0.0);
480        let n_samples = data.shape()[0];
481
482        let mut indices = Vec::new();
483        let mut distances = Vec::new();
484
485        for i in 0..n_samples {
486            let dist = euclidean_distance(query, data.row(i));
487            if dist <= radius_f64 {
488                indices.push(i);
489                distances.push(dist);
490            }
491        }
492
493        Ok(NeighborResult { indices, distances })
494    }
495}
496
497/// Ball Tree implementation for high-dimensional nearest neighbor search
498///
499/// More effective than KD-Tree for high-dimensional data.
500/// Uses hypersphere partitioning instead of hyperplane partitioning.
501#[derive(Debug)]
502pub struct BallTree<F: Float> {
503    data: Option<Array2<F>>,
504    tree: Option<BallNode>,
505    leaf_size: usize,
506}
507
508#[derive(Debug, Clone)]
509struct BallNode {
510    /// Center of the ball
511    center: Vec<f64>,
512    /// Radius of the ball
513    radius: f64,
514    /// Indices of points in this node
515    indices: Vec<usize>,
516    /// Left child
517    left: Option<Box<BallNode>>,
518    /// Right child
519    right: Option<Box<BallNode>>,
520    /// Whether this is a leaf node
521    is_leaf: bool,
522}
523
524impl<F: Float + FromPrimitive + Debug> BallTree<F> {
525    /// Create a new Ball Tree
526    pub fn new(leaf_size: usize) -> Self {
527        Self {
528            data: None,
529            tree: None,
530            leaf_size,
531        }
532    }
533}
534
535impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for BallTree<F> {
536    fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
537        let n_samples = data.shape()[0];
538
539        if n_samples == 0 {
540            return Err(ClusteringError::InvalidInput(
541                "Cannot fit on empty data".into(),
542            ));
543        }
544
545        self.data = Some(data.to_owned());
546
547        let indices: Vec<usize> = (0..n_samples).collect();
548        self.tree = Some(self.build_ball_tree(indices, data.view())?);
549
550        Ok(())
551    }
552
553    fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
554        let data = self
555            .data
556            .as_ref()
557            .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
558
559        let tree = self
560            .tree
561            .as_ref()
562            .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
563
564        if k == 0 {
565            return Ok(NeighborResult {
566                indices: vec![],
567                distances: vec![],
568            });
569        }
570
571        let mut heap = BinaryHeap::new();
572        self.search_ball_knn(tree, query, k, data.view(), &mut heap);
573
574        let mut indices = Vec::new();
575        let mut distances = Vec::new();
576
577        while let Some(neighbor) = heap.pop() {
578            indices.push(neighbor.index);
579            distances.push(neighbor.distance);
580        }
581
582        indices.reverse();
583        distances.reverse();
584
585        Ok(NeighborResult { indices, distances })
586    }
587
588    fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
589        let data = self
590            .data
591            .as_ref()
592            .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
593
594        let tree = self
595            .tree
596            .as_ref()
597            .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
598
599        let mut result = NeighborResult {
600            indices: Vec::new(),
601            distances: Vec::new(),
602        };
603
604        let radius_f64 = radius.to_f64().unwrap_or(0.0);
605        self.search_ball_radius(tree, query, radius_f64, data.view(), &mut result);
606
607        Ok(result)
608    }
609}
610
611impl<F: Float + FromPrimitive + Debug> BallTree<F> {
612    fn build_ball_tree(&self, indices: Vec<usize>, data: ArrayView2<F>) -> Result<BallNode> {
613        if indices.len() <= self.leaf_size {
614            let (center, radius) = self.compute_ball(&indices, data);
615            return Ok(BallNode {
616                center,
617                radius,
618                indices,
619                left: None,
620                right: None,
621                is_leaf: true,
622            });
623        }
624
625        // Find the dimension with the largest spread
626        let n_features = data.shape()[1];
627        let mut best_dim = 0;
628        let mut best_spread = 0.0;
629
630        for dim in 0..n_features {
631            let mut min_val = f64::INFINITY;
632            let mut max_val = f64::NEG_INFINITY;
633
634            for &idx in &indices {
635                let val = data[[idx, dim]].to_f64().unwrap_or(0.0);
636                min_val = min_val.min(val);
637                max_val = max_val.max(val);
638            }
639
640            let spread = max_val - min_val;
641            if spread > best_spread {
642                best_spread = spread;
643                best_dim = dim;
644            }
645        }
646
647        // Sort indices by the best dimension
648        let mut sorted_indices = indices;
649        sorted_indices.sort_by(|&a, &b| {
650            let val_a = data[[a, best_dim]].to_f64().unwrap_or(0.0);
651            let val_b = data[[b, best_dim]].to_f64().unwrap_or(0.0);
652            val_a
653                .partial_cmp(&val_b)
654                .unwrap_or(std::cmp::Ordering::Equal)
655        });
656
657        // Split at median
658        let split_idx = sorted_indices.len() / 2;
659        let left_indices = sorted_indices[..split_idx].to_vec();
660        let right_indices = sorted_indices[split_idx..].to_vec();
661
662        // Recursively build children
663        let left = if !left_indices.is_empty() {
664            Some(Box::new(self.build_ball_tree(left_indices, data)?))
665        } else {
666            None
667        };
668
669        let right = if !right_indices.is_empty() {
670            Some(Box::new(self.build_ball_tree(right_indices, data)?))
671        } else {
672            None
673        };
674
675        // Compute ball for this node
676        let (center, radius) = self.compute_ball(&sorted_indices, data);
677
678        Ok(BallNode {
679            center,
680            radius,
681            indices: vec![], // Internal nodes don't store indices
682            left,
683            right,
684            is_leaf: false,
685        })
686    }
687
688    fn compute_ball(&self, indices: &[usize], data: ArrayView2<F>) -> (Vec<f64>, f64) {
689        if indices.is_empty() {
690            return (vec![], 0.0);
691        }
692
693        let n_features = data.shape()[1];
694        let mut center = vec![0.0; n_features];
695
696        // Compute centroid
697        for &idx in indices {
698            for j in 0..n_features {
699                center[j] += data[[idx, j]].to_f64().unwrap_or(0.0);
700            }
701        }
702
703        let n_points = indices.len() as f64;
704        for val in &mut center {
705            *val /= n_points;
706        }
707
708        // Compute radius (maximum distance from center)
709        let mut radius = 0.0;
710        for &idx in indices {
711            let mut dist_sq = 0.0;
712            for j in 0..n_features {
713                let diff = data[[idx, j]].to_f64().unwrap_or(0.0) - center[j];
714                dist_sq += diff * diff;
715            }
716            radius = radius.max(dist_sq.sqrt());
717        }
718
719        (center, radius)
720    }
721
722    #[allow(clippy::only_used_in_recursion)]
723    fn search_ball_knn(
724        &self,
725        node: &BallNode,
726        query: ArrayView1<F>,
727        k: usize,
728        data: ArrayView2<F>,
729        heap: &mut BinaryHeap<NeighborCandidate>,
730    ) {
731        if node.is_leaf {
732            // Check all points in this leaf
733            for &idx in &node.indices {
734                let dist = euclidean_distance(query, data.row(idx));
735
736                if heap.len() < k {
737                    heap.push(NeighborCandidate {
738                        distance: dist,
739                        index: idx,
740                    });
741                } else if let Some(top) = heap.peek() {
742                    if dist < top.distance {
743                        heap.pop();
744                        heap.push(NeighborCandidate {
745                            distance: dist,
746                            index: idx,
747                        });
748                    }
749                }
750            }
751        } else {
752            // Check if this ball can contain better neighbors
753            let query_vec: Vec<f64> = query.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
754
755            let dist_to_center = euclidean_distance_vec(&query_vec, &node.center);
756            let min_dist_to_ball = (dist_to_center - node.radius).max(0.0);
757
758            if heap.len() < k
759                || heap
760                    .peek()
761                    .is_none_or(|top| min_dist_to_ball < top.distance)
762            {
763                // Search children (closer one first)
764                if let (Some(left), Some(right)) = (&node.left, &node.right) {
765                    let left_dist = euclidean_distance_vec(&query_vec, &left.center);
766                    let right_dist = euclidean_distance_vec(&query_vec, &right.center);
767
768                    if left_dist <= right_dist {
769                        self.search_ball_knn(left, query, k, data, heap);
770                        self.search_ball_knn(right, query, k, data, heap);
771                    } else {
772                        self.search_ball_knn(right, query, k, data, heap);
773                        self.search_ball_knn(left, query, k, data, heap);
774                    }
775                } else if let Some(child) = &node.left {
776                    self.search_ball_knn(child, query, k, data, heap);
777                } else if let Some(child) = &node.right {
778                    self.search_ball_knn(child, query, k, data, heap);
779                }
780            }
781        }
782    }
783
784    #[allow(clippy::only_used_in_recursion)]
785    fn search_ball_radius(
786        &self,
787        node: &BallNode,
788        query: ArrayView1<F>,
789        radius: f64,
790        data: ArrayView2<F>,
791        result: &mut NeighborResult,
792    ) {
793        if node.is_leaf {
794            // Check all points in this leaf
795            for &idx in &node.indices {
796                let dist = euclidean_distance(query, data.row(idx));
797
798                if dist <= radius {
799                    result.indices.push(idx);
800                    result.distances.push(dist);
801                }
802            }
803        } else {
804            // Check if this ball intersects the query sphere
805            let query_vec: Vec<f64> = query.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
806
807            let dist_to_center = euclidean_distance_vec(&query_vec, &node.center);
808
809            if dist_to_center <= radius + node.radius {
810                // Ball intersects query sphere, search children
811                if let Some(child) = &node.left {
812                    self.search_ball_radius(child, query, radius, data, result);
813                }
814                if let Some(child) = &node.right {
815                    self.search_ball_radius(child, query, radius, data, result);
816                }
817            }
818        }
819    }
820}
821
822/// Create the appropriate neighbor searcher based on configuration
823#[allow(dead_code)]
824pub fn create_neighbor_searcher<F: Float + FromPrimitive + Debug + 'static>(
825    config: NeighborSearchConfig,
826) -> Box<dyn NeighborSearcher<F>> {
827    match config.algorithm {
828        NeighborSearchAlgorithm::BruteForce => Box::new(BruteForceSearch::new()),
829        NeighborSearchAlgorithm::KDTree => Box::new(KDTree::new(config.leaf_size)),
830        NeighborSearchAlgorithm::BallTree => Box::new(BallTree::new(config.leaf_size)),
831        NeighborSearchAlgorithm::Auto => {
832            // Use KD-Tree by default (could be made smarter based on data characteristics)
833            Box::new(KDTree::new(config.leaf_size))
834        }
835        NeighborSearchAlgorithm::LSH => {
836            // LSH not implemented yet, fall back to KD-Tree
837            Box::new(KDTree::new(config.leaf_size))
838        }
839    }
840}
841
842/// Calculate Euclidean distance between two points
843#[allow(dead_code)]
844fn euclidean_distance<F: Float + FromPrimitive>(a: ArrayView1<F>, b: ArrayView1<F>) -> f64 {
845    let mut sum = 0.0;
846    for (x, y) in a.iter().zip(b.iter()) {
847        let diff = x.to_f64().unwrap_or(0.0) - y.to_f64().unwrap_or(0.0);
848        sum += diff * diff;
849    }
850    sum.sqrt()
851}
852
853/// Calculate Euclidean distance between two f64 vectors
854#[allow(dead_code)]
855fn euclidean_distance_vec(a: &[f64], b: &[f64]) -> f64 {
856    let mut sum = 0.0;
857    for (x, y) in a.iter().zip(b.iter()) {
858        let diff = x - y;
859        sum += diff * diff;
860    }
861    sum.sqrt()
862}
863
864#[cfg(test)]
865mod tests {
866    use super::*;
867    use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
868
869    fn create_test_data() -> Array2<f64> {
870        Array2::from_shape_vec(
871            (6, 2),
872            vec![
873                0.0, 0.0, // Point 0
874                1.0, 0.0, // Point 1
875                0.0, 1.0, // Point 2
876                10.0, 10.0, // Point 3
877                11.0, 10.0, // Point 4
878                10.0, 11.0, // Point 5
879            ],
880        )
881        .unwrap()
882    }
883
884    #[test]
885    fn test_brute_force_search() {
886        let data = create_test_data();
887        let mut searcher = BruteForceSearch::new();
888
889        searcher.fit(data.view()).unwrap();
890
891        // Query at origin - should find points 0, 1, 2 as nearest
892        let query = Array1::from_vec(vec![0.0, 0.0]);
893        let result = searcher.kneighbors(query.view(), 3).unwrap();
894
895        assert_eq!(result.indices.len(), 3);
896        assert_eq!(result.distances.len(), 3);
897
898        // First neighbor should be point 0 (distance 0)
899        assert_eq!(result.indices[0], 0);
900        assert!(result.distances[0] < 1e-10);
901
902        // Test radius search
903        let radius_result = searcher.radius_neighbors(query.view(), 1.5).unwrap();
904        assert!(radius_result.indices.len() >= 3); // Should find at least points 0, 1, 2
905    }
906
907    #[test]
908    fn test_kdtree_search() {
909        let data = create_test_data();
910        let mut searcher = KDTree::new(2);
911
912        searcher.fit(data.view()).unwrap();
913
914        // Query at origin
915        let query = Array1::from_vec(vec![0.0, 0.0]);
916        let result = searcher.kneighbors(query.view(), 3).unwrap();
917
918        assert_eq!(result.indices.len(), 3);
919        assert_eq!(result.distances.len(), 3);
920
921        // Should find the same nearest neighbors as brute force
922        assert_eq!(result.indices[0], 0);
923        assert!(result.distances[0] < 1e-10);
924    }
925
926    #[test]
927    fn test_ball_tree_search() {
928        let data = create_test_data();
929        let mut searcher = BallTree::new(2);
930
931        searcher.fit(data.view()).unwrap();
932
933        // Query at origin
934        let query = Array1::from_vec(vec![0.0, 0.0]);
935        let result = searcher.kneighbors(query.view(), 3).unwrap();
936
937        assert_eq!(result.indices.len(), 3);
938        assert_eq!(result.distances.len(), 3);
939
940        // Should find the same nearest neighbors as brute force
941        assert_eq!(result.indices[0], 0);
942        assert!(result.distances[0] < 1e-10);
943    }
944
945    #[test]
946    fn test_neighbor_searcher_factory() {
947        let data = create_test_data();
948
949        let algorithms = vec![
950            NeighborSearchAlgorithm::BruteForce,
951            NeighborSearchAlgorithm::KDTree,
952            NeighborSearchAlgorithm::BallTree,
953            NeighborSearchAlgorithm::Auto,
954        ];
955
956        for algorithm in algorithms {
957            let config = NeighborSearchConfig {
958                algorithm,
959                ..Default::default()
960            };
961
962            let mut searcher = create_neighbor_searcher(config);
963            searcher.fit(data.view()).unwrap();
964
965            let query = Array1::from_vec(vec![0.0, 0.0]);
966            let result = searcher.kneighbors(query.view(), 2).unwrap();
967
968            assert_eq!(result.indices.len(), 2);
969            assert_eq!(result.distances.len(), 2);
970        }
971    }
972
973    #[test]
974    fn test_empty_data_error() {
975        let empty_data: Array2<f64> = Array2::zeros((0, 2));
976        let mut searcher = BruteForceSearch::new();
977
978        let result = searcher.fit(empty_data.view());
979        assert!(result.is_err());
980    }
981
982    #[test]
983    fn test_k_zero() {
984        let data = create_test_data();
985        let mut searcher = BruteForceSearch::new();
986        searcher.fit(data.view()).unwrap();
987
988        let query = Array1::from_vec(vec![0.0, 0.0]);
989        let result = searcher.kneighbors(query.view(), 0).unwrap();
990
991        assert_eq!(result.indices.len(), 0);
992        assert_eq!(result.distances.len(), 0);
993    }
994
995    #[test]
996    fn test_batch_queries() {
997        let data = create_test_data();
998        let mut searcher = BruteForceSearch::new();
999        searcher.fit(data.view()).unwrap();
1000
1001        let queries = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 10.0, 10.0]).unwrap();
1002
1003        let results = searcher.kneighbors_batch(queries.view(), 2).unwrap();
1004
1005        assert_eq!(results.len(), 2);
1006        assert_eq!(results[0].indices.len(), 2);
1007        assert_eq!(results[1].indices.len(), 2);
1008
1009        // First query should find point 0 first
1010        assert_eq!(results[0].indices[0], 0);
1011        // Second query should find point 3 first
1012        assert_eq!(results[1].indices[0], 3);
1013    }
1014}