1use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::BinaryHeap;
10use std::fmt::Debug;
11
12use crate::error::{ClusteringError, Result};
13
14#[derive(Debug, Clone)]
16pub struct NeighborSearchConfig {
17 pub algorithm: NeighborSearchAlgorithm,
19 pub leaf_size: usize,
21 pub n_hash_tables: usize,
23 pub n_hash_functions: usize,
25 pub parallel: bool,
27}
28
29impl Default for NeighborSearchConfig {
30 fn default() -> Self {
31 Self {
32 algorithm: NeighborSearchAlgorithm::Auto,
33 leaf_size: 30,
34 n_hash_tables: 10,
35 n_hash_functions: 4,
36 parallel: true,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum NeighborSearchAlgorithm {
44 Auto,
46 BruteForce,
48 KDTree,
50 BallTree,
52 LSH,
54}
55
56#[derive(Debug, Clone)]
58pub struct NeighborResult {
59 pub indices: Vec<usize>,
61 pub distances: Vec<f64>,
63}
64
65pub trait NeighborSearcher<F: Float> {
67 fn fit(&mut self, data: ArrayView2<F>) -> Result<()>;
69
70 fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult>;
72
73 fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult>;
75
76 fn kneighbors_batch(&self, queries: ArrayView2<F>, k: usize) -> Result<Vec<NeighborResult>> {
78 let mut results = Vec::new();
79 for query in queries.outer_iter() {
80 results.push(self.kneighbors(query, k)?);
81 }
82 Ok(results)
83 }
84}
85
86#[derive(Debug)]
91pub struct KDTree<F: Float> {
92 data: Option<Array2<F>>,
93 tree: Option<KDNode>,
94 leaf_size: usize,
95}
96
97#[derive(Debug, Clone)]
98struct KDNode {
99 indices: Vec<usize>,
101 split_dim: usize,
103 split_val: f64,
105 left: Option<Box<KDNode>>,
107 right: Option<Box<KDNode>>,
109 is_leaf: bool,
111}
112
113impl<F: Float + FromPrimitive + Debug> KDTree<F> {
114 pub fn new(leaf_size: usize) -> Self {
116 Self {
117 data: None,
118 tree: None,
119 leaf_size,
120 }
121 }
122}
123
124impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for KDTree<F> {
125 fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
126 let n_samples = data.shape()[0];
127 let n_features = data.shape()[1];
128
129 if n_samples == 0 {
130 return Err(ClusteringError::InvalidInput(
131 "Cannot fit on empty data".into(),
132 ));
133 }
134
135 self.data = Some(data.to_owned());
137
138 let indices: Vec<usize> = (0..n_samples).collect();
140 self.tree = Some(self.build_tree(indices, 0, n_features)?);
141
142 Ok(())
143 }
144
145 fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
146 let data = self
147 .data
148 .as_ref()
149 .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
150
151 let tree = self
152 .tree
153 .as_ref()
154 .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
155
156 if k == 0 {
157 return Ok(NeighborResult {
158 indices: vec![],
159 distances: vec![],
160 });
161 }
162
163 let mut heap = BinaryHeap::new();
164 self.search_knn(tree, query, k, data.view(), &mut heap);
165
166 let mut indices = Vec::new();
168 let mut distances = Vec::new();
169
170 while let Some(neighbor) = heap.pop() {
171 indices.push(neighbor.index);
172 distances.push(neighbor.distance);
173 }
174
175 indices.reverse();
177 distances.reverse();
178
179 Ok(NeighborResult { indices, distances })
180 }
181
182 fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
183 let data = self
184 .data
185 .as_ref()
186 .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
187
188 let tree = self
189 .tree
190 .as_ref()
191 .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
192
193 let mut result = NeighborResult {
194 indices: Vec::new(),
195 distances: Vec::new(),
196 };
197
198 let radius_f64 = radius.to_f64().unwrap_or(0.0);
199 self.search_radius(tree, query, radius_f64, data.view(), &mut result);
200
201 Ok(result)
202 }
203}
204
205impl<F: Float + FromPrimitive + Debug> KDTree<F> {
206 fn build_tree(
207 &self,
208 mut indices: Vec<usize>,
209 depth: usize,
210 n_features: usize,
211 ) -> Result<KDNode> {
212 if indices.len() <= self.leaf_size {
213 return Ok(KDNode {
214 indices,
215 split_dim: 0,
216 split_val: 0.0,
217 left: None,
218 right: None,
219 is_leaf: true,
220 });
221 }
222
223 let data = self.data.as_ref().unwrap();
224
225 let split_dim = depth % n_features;
227
228 indices.sort_by(|&a, &b| {
230 let val_a = data[[a, split_dim]].to_f64().unwrap_or(0.0);
231 let val_b = data[[b, split_dim]].to_f64().unwrap_or(0.0);
232 val_a
233 .partial_cmp(&val_b)
234 .unwrap_or(std::cmp::Ordering::Equal)
235 });
236
237 let median_idx = indices.len() / 2;
239 let split_val = data[[indices[median_idx], split_dim]]
240 .to_f64()
241 .unwrap_or(0.0);
242
243 let left_indices = indices[..median_idx].to_vec();
245 let right_indices = indices[median_idx..].to_vec();
246
247 let left = if !left_indices.is_empty() {
249 Some(Box::new(self.build_tree(
250 left_indices,
251 depth + 1,
252 n_features,
253 )?))
254 } else {
255 None
256 };
257
258 let right = if !right_indices.is_empty() {
259 Some(Box::new(self.build_tree(
260 right_indices,
261 depth + 1,
262 n_features,
263 )?))
264 } else {
265 None
266 };
267
268 Ok(KDNode {
269 indices: vec![], split_dim,
271 split_val,
272 left,
273 right,
274 is_leaf: false,
275 })
276 }
277
278 #[allow(clippy::only_used_in_recursion)]
279 fn search_knn(
280 &self,
281 node: &KDNode,
282 query: ArrayView1<F>,
283 k: usize,
284 data: ArrayView2<F>,
285 heap: &mut BinaryHeap<NeighborCandidate>,
286 ) {
287 if node.is_leaf {
288 for &idx in &node.indices {
290 let dist = euclidean_distance(query, data.row(idx));
291
292 if heap.len() < k {
293 heap.push(NeighborCandidate {
294 distance: dist,
295 index: idx,
296 });
297 } else if let Some(top) = heap.peek() {
298 if dist < top.distance {
299 heap.pop();
300 heap.push(NeighborCandidate {
301 distance: dist,
302 index: idx,
303 });
304 }
305 }
306 }
307 } else {
308 let query_val = query[node.split_dim].to_f64().unwrap_or(0.0);
310 let (first_child, second_child) = if query_val < node.split_val {
311 (&node.left, &node.right)
312 } else {
313 (&node.right, &node.left)
314 };
315
316 if let Some(child) = first_child {
318 self.search_knn(child, query, k, data, heap);
319 }
320
321 let split_dist = (query_val - node.split_val).abs();
323 if heap.len() < k || heap.peek().is_none_or(|top| split_dist < top.distance) {
324 if let Some(child) = second_child {
325 self.search_knn(child, query, k, data, heap);
326 }
327 }
328 }
329 }
330
331 #[allow(clippy::only_used_in_recursion)]
332 fn search_radius(
333 &self,
334 node: &KDNode,
335 query: ArrayView1<F>,
336 radius: f64,
337 data: ArrayView2<F>,
338 result: &mut NeighborResult,
339 ) {
340 if node.is_leaf {
341 for &idx in &node.indices {
343 let dist = euclidean_distance(query, data.row(idx));
344
345 if dist <= radius {
346 result.indices.push(idx);
347 result.distances.push(dist);
348 }
349 }
350 } else {
351 let query_val = query[node.split_dim].to_f64().unwrap_or(0.0);
353 let split_dist = (query_val - node.split_val).abs();
354
355 if query_val < node.split_val {
357 if let Some(child) = &node.left {
358 self.search_radius(child, query, radius, data, result);
359 }
360 if split_dist <= radius {
361 if let Some(child) = &node.right {
362 self.search_radius(child, query, radius, data, result);
363 }
364 }
365 } else {
366 if let Some(child) = &node.right {
367 self.search_radius(child, query, radius, data, result);
368 }
369 if split_dist <= radius {
370 if let Some(child) = &node.left {
371 self.search_radius(child, query, radius, data, result);
372 }
373 }
374 }
375 }
376 }
377}
378
379#[derive(Debug, Clone, PartialEq)]
381struct NeighborCandidate {
382 distance: f64,
383 index: usize,
384}
385
386impl Eq for NeighborCandidate {}
387
388impl Ord for NeighborCandidate {
389 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
390 self.distance
391 .partial_cmp(&other.distance)
392 .unwrap_or(std::cmp::Ordering::Equal)
393 }
394}
395
396impl PartialOrd for NeighborCandidate {
397 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
398 Some(self.cmp(other))
399 }
400}
401
402#[derive(Debug)]
407pub struct BruteForceSearch<F: Float> {
408 data: Option<Array2<F>>,
409}
410
411impl<F: Float + FromPrimitive + Debug> BruteForceSearch<F> {
412 pub fn new() -> Self {
414 Self { data: None }
415 }
416}
417
418impl<F: Float + FromPrimitive + Debug> Default for BruteForceSearch<F> {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for BruteForceSearch<F> {
425 fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
426 if data.shape()[0] == 0 {
427 return Err(ClusteringError::InvalidInput(
428 "Cannot fit on empty data".into(),
429 ));
430 }
431
432 self.data = Some(data.to_owned());
433 Ok(())
434 }
435
436 fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
437 let data = self
438 .data
439 .as_ref()
440 .ok_or_else(|| ClusteringError::InvalidInput("Searcher not fitted yet".into()))?;
441
442 if k == 0 {
443 return Ok(NeighborResult {
444 indices: vec![],
445 distances: vec![],
446 });
447 }
448
449 let n_samples = data.shape()[0];
450 let k_actual = k.min(n_samples);
451
452 let mut candidates: Vec<NeighborCandidate> = (0..n_samples)
454 .map(|i| {
455 let dist = euclidean_distance(query, data.row(i));
456 NeighborCandidate {
457 distance: dist,
458 index: i,
459 }
460 })
461 .collect();
462
463 candidates.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
465 candidates.truncate(k_actual);
466
467 let indices = candidates.iter().map(|c| c.index).collect();
468 let distances = candidates.iter().map(|c| c.distance).collect();
469
470 Ok(NeighborResult { indices, distances })
471 }
472
473 fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
474 let data = self
475 .data
476 .as_ref()
477 .ok_or_else(|| ClusteringError::InvalidInput("Searcher not fitted yet".into()))?;
478
479 let radius_f64 = radius.to_f64().unwrap_or(0.0);
480 let n_samples = data.shape()[0];
481
482 let mut indices = Vec::new();
483 let mut distances = Vec::new();
484
485 for i in 0..n_samples {
486 let dist = euclidean_distance(query, data.row(i));
487 if dist <= radius_f64 {
488 indices.push(i);
489 distances.push(dist);
490 }
491 }
492
493 Ok(NeighborResult { indices, distances })
494 }
495}
496
497#[derive(Debug)]
502pub struct BallTree<F: Float> {
503 data: Option<Array2<F>>,
504 tree: Option<BallNode>,
505 leaf_size: usize,
506}
507
508#[derive(Debug, Clone)]
509struct BallNode {
510 center: Vec<f64>,
512 radius: f64,
514 indices: Vec<usize>,
516 left: Option<Box<BallNode>>,
518 right: Option<Box<BallNode>>,
520 is_leaf: bool,
522}
523
524impl<F: Float + FromPrimitive + Debug> BallTree<F> {
525 pub fn new(leaf_size: usize) -> Self {
527 Self {
528 data: None,
529 tree: None,
530 leaf_size,
531 }
532 }
533}
534
535impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for BallTree<F> {
536 fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
537 let n_samples = data.shape()[0];
538
539 if n_samples == 0 {
540 return Err(ClusteringError::InvalidInput(
541 "Cannot fit on empty data".into(),
542 ));
543 }
544
545 self.data = Some(data.to_owned());
546
547 let indices: Vec<usize> = (0..n_samples).collect();
548 self.tree = Some(self.build_ball_tree(indices, data.view())?);
549
550 Ok(())
551 }
552
553 fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
554 let data = self
555 .data
556 .as_ref()
557 .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
558
559 let tree = self
560 .tree
561 .as_ref()
562 .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
563
564 if k == 0 {
565 return Ok(NeighborResult {
566 indices: vec![],
567 distances: vec![],
568 });
569 }
570
571 let mut heap = BinaryHeap::new();
572 self.search_ball_knn(tree, query, k, data.view(), &mut heap);
573
574 let mut indices = Vec::new();
575 let mut distances = Vec::new();
576
577 while let Some(neighbor) = heap.pop() {
578 indices.push(neighbor.index);
579 distances.push(neighbor.distance);
580 }
581
582 indices.reverse();
583 distances.reverse();
584
585 Ok(NeighborResult { indices, distances })
586 }
587
588 fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
589 let data = self
590 .data
591 .as_ref()
592 .ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
593
594 let tree = self
595 .tree
596 .as_ref()
597 .ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
598
599 let mut result = NeighborResult {
600 indices: Vec::new(),
601 distances: Vec::new(),
602 };
603
604 let radius_f64 = radius.to_f64().unwrap_or(0.0);
605 self.search_ball_radius(tree, query, radius_f64, data.view(), &mut result);
606
607 Ok(result)
608 }
609}
610
611impl<F: Float + FromPrimitive + Debug> BallTree<F> {
612 fn build_ball_tree(&self, indices: Vec<usize>, data: ArrayView2<F>) -> Result<BallNode> {
613 if indices.len() <= self.leaf_size {
614 let (center, radius) = self.compute_ball(&indices, data);
615 return Ok(BallNode {
616 center,
617 radius,
618 indices,
619 left: None,
620 right: None,
621 is_leaf: true,
622 });
623 }
624
625 let n_features = data.shape()[1];
627 let mut best_dim = 0;
628 let mut best_spread = 0.0;
629
630 for dim in 0..n_features {
631 let mut min_val = f64::INFINITY;
632 let mut max_val = f64::NEG_INFINITY;
633
634 for &idx in &indices {
635 let val = data[[idx, dim]].to_f64().unwrap_or(0.0);
636 min_val = min_val.min(val);
637 max_val = max_val.max(val);
638 }
639
640 let spread = max_val - min_val;
641 if spread > best_spread {
642 best_spread = spread;
643 best_dim = dim;
644 }
645 }
646
647 let mut sorted_indices = indices;
649 sorted_indices.sort_by(|&a, &b| {
650 let val_a = data[[a, best_dim]].to_f64().unwrap_or(0.0);
651 let val_b = data[[b, best_dim]].to_f64().unwrap_or(0.0);
652 val_a
653 .partial_cmp(&val_b)
654 .unwrap_or(std::cmp::Ordering::Equal)
655 });
656
657 let split_idx = sorted_indices.len() / 2;
659 let left_indices = sorted_indices[..split_idx].to_vec();
660 let right_indices = sorted_indices[split_idx..].to_vec();
661
662 let left = if !left_indices.is_empty() {
664 Some(Box::new(self.build_ball_tree(left_indices, data)?))
665 } else {
666 None
667 };
668
669 let right = if !right_indices.is_empty() {
670 Some(Box::new(self.build_ball_tree(right_indices, data)?))
671 } else {
672 None
673 };
674
675 let (center, radius) = self.compute_ball(&sorted_indices, data);
677
678 Ok(BallNode {
679 center,
680 radius,
681 indices: vec![], left,
683 right,
684 is_leaf: false,
685 })
686 }
687
688 fn compute_ball(&self, indices: &[usize], data: ArrayView2<F>) -> (Vec<f64>, f64) {
689 if indices.is_empty() {
690 return (vec![], 0.0);
691 }
692
693 let n_features = data.shape()[1];
694 let mut center = vec![0.0; n_features];
695
696 for &idx in indices {
698 for j in 0..n_features {
699 center[j] += data[[idx, j]].to_f64().unwrap_or(0.0);
700 }
701 }
702
703 let n_points = indices.len() as f64;
704 for val in &mut center {
705 *val /= n_points;
706 }
707
708 let mut radius = 0.0;
710 for &idx in indices {
711 let mut dist_sq = 0.0;
712 for j in 0..n_features {
713 let diff = data[[idx, j]].to_f64().unwrap_or(0.0) - center[j];
714 dist_sq += diff * diff;
715 }
716 radius = radius.max(dist_sq.sqrt());
717 }
718
719 (center, radius)
720 }
721
722 #[allow(clippy::only_used_in_recursion)]
723 fn search_ball_knn(
724 &self,
725 node: &BallNode,
726 query: ArrayView1<F>,
727 k: usize,
728 data: ArrayView2<F>,
729 heap: &mut BinaryHeap<NeighborCandidate>,
730 ) {
731 if node.is_leaf {
732 for &idx in &node.indices {
734 let dist = euclidean_distance(query, data.row(idx));
735
736 if heap.len() < k {
737 heap.push(NeighborCandidate {
738 distance: dist,
739 index: idx,
740 });
741 } else if let Some(top) = heap.peek() {
742 if dist < top.distance {
743 heap.pop();
744 heap.push(NeighborCandidate {
745 distance: dist,
746 index: idx,
747 });
748 }
749 }
750 }
751 } else {
752 let query_vec: Vec<f64> = query.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
754
755 let dist_to_center = euclidean_distance_vec(&query_vec, &node.center);
756 let min_dist_to_ball = (dist_to_center - node.radius).max(0.0);
757
758 if heap.len() < k
759 || heap
760 .peek()
761 .is_none_or(|top| min_dist_to_ball < top.distance)
762 {
763 if let (Some(left), Some(right)) = (&node.left, &node.right) {
765 let left_dist = euclidean_distance_vec(&query_vec, &left.center);
766 let right_dist = euclidean_distance_vec(&query_vec, &right.center);
767
768 if left_dist <= right_dist {
769 self.search_ball_knn(left, query, k, data, heap);
770 self.search_ball_knn(right, query, k, data, heap);
771 } else {
772 self.search_ball_knn(right, query, k, data, heap);
773 self.search_ball_knn(left, query, k, data, heap);
774 }
775 } else if let Some(child) = &node.left {
776 self.search_ball_knn(child, query, k, data, heap);
777 } else if let Some(child) = &node.right {
778 self.search_ball_knn(child, query, k, data, heap);
779 }
780 }
781 }
782 }
783
784 #[allow(clippy::only_used_in_recursion)]
785 fn search_ball_radius(
786 &self,
787 node: &BallNode,
788 query: ArrayView1<F>,
789 radius: f64,
790 data: ArrayView2<F>,
791 result: &mut NeighborResult,
792 ) {
793 if node.is_leaf {
794 for &idx in &node.indices {
796 let dist = euclidean_distance(query, data.row(idx));
797
798 if dist <= radius {
799 result.indices.push(idx);
800 result.distances.push(dist);
801 }
802 }
803 } else {
804 let query_vec: Vec<f64> = query.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
806
807 let dist_to_center = euclidean_distance_vec(&query_vec, &node.center);
808
809 if dist_to_center <= radius + node.radius {
810 if let Some(child) = &node.left {
812 self.search_ball_radius(child, query, radius, data, result);
813 }
814 if let Some(child) = &node.right {
815 self.search_ball_radius(child, query, radius, data, result);
816 }
817 }
818 }
819 }
820}
821
822#[allow(dead_code)]
824pub fn create_neighbor_searcher<F: Float + FromPrimitive + Debug + 'static>(
825 config: NeighborSearchConfig,
826) -> Box<dyn NeighborSearcher<F>> {
827 match config.algorithm {
828 NeighborSearchAlgorithm::BruteForce => Box::new(BruteForceSearch::new()),
829 NeighborSearchAlgorithm::KDTree => Box::new(KDTree::new(config.leaf_size)),
830 NeighborSearchAlgorithm::BallTree => Box::new(BallTree::new(config.leaf_size)),
831 NeighborSearchAlgorithm::Auto => {
832 Box::new(KDTree::new(config.leaf_size))
834 }
835 NeighborSearchAlgorithm::LSH => {
836 Box::new(KDTree::new(config.leaf_size))
838 }
839 }
840}
841
842#[allow(dead_code)]
844fn euclidean_distance<F: Float + FromPrimitive>(a: ArrayView1<F>, b: ArrayView1<F>) -> f64 {
845 let mut sum = 0.0;
846 for (x, y) in a.iter().zip(b.iter()) {
847 let diff = x.to_f64().unwrap_or(0.0) - y.to_f64().unwrap_or(0.0);
848 sum += diff * diff;
849 }
850 sum.sqrt()
851}
852
853#[allow(dead_code)]
855fn euclidean_distance_vec(a: &[f64], b: &[f64]) -> f64 {
856 let mut sum = 0.0;
857 for (x, y) in a.iter().zip(b.iter()) {
858 let diff = x - y;
859 sum += diff * diff;
860 }
861 sum.sqrt()
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
868
869 fn create_test_data() -> Array2<f64> {
870 Array2::from_shape_vec(
871 (6, 2),
872 vec![
873 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, ],
880 )
881 .unwrap()
882 }
883
884 #[test]
885 fn test_brute_force_search() {
886 let data = create_test_data();
887 let mut searcher = BruteForceSearch::new();
888
889 searcher.fit(data.view()).unwrap();
890
891 let query = Array1::from_vec(vec![0.0, 0.0]);
893 let result = searcher.kneighbors(query.view(), 3).unwrap();
894
895 assert_eq!(result.indices.len(), 3);
896 assert_eq!(result.distances.len(), 3);
897
898 assert_eq!(result.indices[0], 0);
900 assert!(result.distances[0] < 1e-10);
901
902 let radius_result = searcher.radius_neighbors(query.view(), 1.5).unwrap();
904 assert!(radius_result.indices.len() >= 3); }
906
907 #[test]
908 fn test_kdtree_search() {
909 let data = create_test_data();
910 let mut searcher = KDTree::new(2);
911
912 searcher.fit(data.view()).unwrap();
913
914 let query = Array1::from_vec(vec![0.0, 0.0]);
916 let result = searcher.kneighbors(query.view(), 3).unwrap();
917
918 assert_eq!(result.indices.len(), 3);
919 assert_eq!(result.distances.len(), 3);
920
921 assert_eq!(result.indices[0], 0);
923 assert!(result.distances[0] < 1e-10);
924 }
925
926 #[test]
927 fn test_ball_tree_search() {
928 let data = create_test_data();
929 let mut searcher = BallTree::new(2);
930
931 searcher.fit(data.view()).unwrap();
932
933 let query = Array1::from_vec(vec![0.0, 0.0]);
935 let result = searcher.kneighbors(query.view(), 3).unwrap();
936
937 assert_eq!(result.indices.len(), 3);
938 assert_eq!(result.distances.len(), 3);
939
940 assert_eq!(result.indices[0], 0);
942 assert!(result.distances[0] < 1e-10);
943 }
944
945 #[test]
946 fn test_neighbor_searcher_factory() {
947 let data = create_test_data();
948
949 let algorithms = vec![
950 NeighborSearchAlgorithm::BruteForce,
951 NeighborSearchAlgorithm::KDTree,
952 NeighborSearchAlgorithm::BallTree,
953 NeighborSearchAlgorithm::Auto,
954 ];
955
956 for algorithm in algorithms {
957 let config = NeighborSearchConfig {
958 algorithm,
959 ..Default::default()
960 };
961
962 let mut searcher = create_neighbor_searcher(config);
963 searcher.fit(data.view()).unwrap();
964
965 let query = Array1::from_vec(vec![0.0, 0.0]);
966 let result = searcher.kneighbors(query.view(), 2).unwrap();
967
968 assert_eq!(result.indices.len(), 2);
969 assert_eq!(result.distances.len(), 2);
970 }
971 }
972
973 #[test]
974 fn test_empty_data_error() {
975 let empty_data: Array2<f64> = Array2::zeros((0, 2));
976 let mut searcher = BruteForceSearch::new();
977
978 let result = searcher.fit(empty_data.view());
979 assert!(result.is_err());
980 }
981
982 #[test]
983 fn test_k_zero() {
984 let data = create_test_data();
985 let mut searcher = BruteForceSearch::new();
986 searcher.fit(data.view()).unwrap();
987
988 let query = Array1::from_vec(vec![0.0, 0.0]);
989 let result = searcher.kneighbors(query.view(), 0).unwrap();
990
991 assert_eq!(result.indices.len(), 0);
992 assert_eq!(result.distances.len(), 0);
993 }
994
995 #[test]
996 fn test_batch_queries() {
997 let data = create_test_data();
998 let mut searcher = BruteForceSearch::new();
999 searcher.fit(data.view()).unwrap();
1000
1001 let queries = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 10.0, 10.0]).unwrap();
1002
1003 let results = searcher.kneighbors_batch(queries.view(), 2).unwrap();
1004
1005 assert_eq!(results.len(), 2);
1006 assert_eq!(results[0].indices.len(), 2);
1007 assert_eq!(results[1].indices.len(), 2);
1008
1009 assert_eq!(results[0].indices[0], 0);
1011 assert_eq!(results[1].indices[0], 3);
1013 }
1014}