1use threecrate_core::{Point3f, Result, NearestNeighborSearch};
4use std::collections::BinaryHeap;
5use std::cmp::Ordering;
6
7#[derive(Debug)]
9struct KdNode {
10 point: Point3f,
11 original_index: usize, left: Option<Box<KdNode>>,
13 right: Option<Box<KdNode>>,
14 axis: usize, }
16
17impl KdNode {
18 fn new(point: Point3f, original_index: usize, axis: usize) -> Self {
19 Self {
20 point,
21 original_index,
22 left: None,
23 right: None,
24 axis,
25 }
26 }
27}
28
29pub struct KdTree {
31 root: Option<Box<KdNode>>,
32 points: Vec<Point3f>, }
34
35impl KdTree {
36 pub fn new(points: &[Point3f]) -> Result<Self> {
38 if points.is_empty() {
39 return Ok(Self {
40 root: None,
41 points: Vec::new(),
42 });
43 }
44
45 let mut points_with_indices: Vec<(Point3f, usize)> = points
46 .iter()
47 .enumerate()
48 .map(|(i, &point)| (point, i))
49 .collect();
50
51 let root = Self::build_tree(&mut points_with_indices, 0, 0, points.len() - 1);
52
53 Ok(Self {
54 root: Some(Box::new(root)),
55 points: points.to_vec(),
56 })
57 }
58
59 fn build_tree(points: &mut [(Point3f, usize)], depth: usize, start: usize, end: usize) -> KdNode {
61 if start == end {
62 let (point, index) = points[start];
63 return KdNode::new(point, index, depth % 3);
64 }
65
66 let axis = depth % 3;
67 let median_idx = (start + end) / 2;
68
69 Self::select_median(points, start, end, median_idx, axis);
71
72 let (point, index) = points[median_idx];
73 let mut node = KdNode::new(point, index, axis);
74
75 if median_idx > start {
77 node.left = Some(Box::new(Self::build_tree(points, depth + 1, start, median_idx - 1)));
78 }
79
80 if median_idx < end {
82 node.right = Some(Box::new(Self::build_tree(points, depth + 1, median_idx + 1, end)));
83 }
84
85 node
86 }
87
88 fn select_median(points: &mut [(Point3f, usize)], start: usize, end: usize, target: usize, axis: usize) {
90 let mut left = start;
91 let mut right = end;
92
93 while left < right {
94 let pivot_idx = Self::partition(points, left, right, axis);
95
96 match pivot_idx.cmp(&target) {
97 Ordering::Equal => return,
98 Ordering::Less => left = pivot_idx + 1,
99 Ordering::Greater => right = pivot_idx - 1,
100 }
101 }
102 }
103
104 fn partition(points: &mut [(Point3f, usize)], start: usize, end: usize, axis: usize) -> usize {
106 let pivot_value = match axis {
107 0 => points[end].0.x,
108 1 => points[end].0.y,
109 2 => points[end].0.z,
110 _ => unreachable!(),
111 };
112
113 let mut i = start;
114 for j in start..end {
115 let point_value = match axis {
116 0 => points[j].0.x,
117 1 => points[j].0.y,
118 2 => points[j].0.z,
119 _ => unreachable!(),
120 };
121
122 if point_value <= pivot_value {
123 points.swap(i, j);
124 i += 1;
125 }
126 }
127
128 points.swap(i, end);
129 i
130 }
131
132 fn distance_squared(a: &Point3f, b: &Point3f) -> f32 {
134 let dx = a.x - b.x;
135 let dy = a.y - b.y;
136 let dz = a.z - b.z;
137 dx * dx + dy * dy + dz * dz
138 }
139}
140
141impl NearestNeighborSearch for KdTree {
142 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
143 if k == 0 || self.points.is_empty() {
144 return Vec::new();
145 }
146
147 let mut heap = BinaryHeap::new();
148 let mut result = Vec::new();
149
150 if let Some(ref root) = self.root {
151 self.search_k_nearest(root, query, k, &mut heap, 0);
152 }
153
154 while let Some(Neighbor { distance, index }) = heap.pop() {
156 result.push((index, distance));
157 }
158
159 result.reverse(); result
161 }
162
163 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
164 if radius <= 0.0 || self.points.is_empty() {
165 return Vec::new();
166 }
167
168 let radius_squared = radius * radius;
169 let mut result = Vec::new();
170
171 if let Some(ref root) = self.root {
172 self.search_radius_neighbors(root, query, radius_squared, &mut result, 0);
173 }
174
175 result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
176 result
177 }
178}
179
180impl KdTree {
181 #[allow(clippy::only_used_in_recursion)]
183 fn search_k_nearest(
184 &self,
185 node: &KdNode,
186 query: &Point3f,
187 k: usize,
188 heap: &mut BinaryHeap<Neighbor>,
189 depth: usize,
190 ) {
191 if depth > 100 {
193 return;
194 }
195 let distance_squared = Self::distance_squared(&node.point, query);
196 let distance = distance_squared.sqrt();
197
198 if heap.len() < k {
200 heap.push(Neighbor {
201 distance,
202 index: node.original_index,
203 });
204 } else if let Some(farthest) = heap.peek() {
205 if distance < farthest.distance {
206 heap.pop();
207 heap.push(Neighbor {
208 distance,
209 index: node.original_index,
210 });
211 }
212 }
213
214 let query_value = match node.axis {
215 0 => query.x,
216 1 => query.y,
217 2 => query.z,
218 _ => unreachable!(),
219 };
220 let node_value = match node.axis {
221 0 => node.point.x,
222 1 => node.point.y,
223 2 => node.point.z,
224 _ => unreachable!(),
225 };
226
227 let (near_subtree, far_subtree) = if query_value <= node_value {
229 (&node.left, &node.right)
230 } else {
231 (&node.right, &node.left)
232 };
233
234 if let Some(ref near) = near_subtree {
236 self.search_k_nearest(near, query, k, heap, depth + 1);
237 }
238
239 let axis_distance = (query_value - node_value).abs();
241 let should_search_far = if let Some(farthest) = heap.peek() {
242 heap.len() < k || axis_distance < farthest.distance
243 } else {
244 true
245 };
246
247 if should_search_far {
248 if let Some(ref far) = far_subtree {
249 self.search_k_nearest(far, query, k, heap, depth + 1);
250 }
251 }
252 }
253
254 #[allow(clippy::only_used_in_recursion)]
256 fn search_radius_neighbors(
257 &self,
258 node: &KdNode,
259 query: &Point3f,
260 radius_squared: f32,
261 result: &mut Vec<(usize, f32)>,
262 depth: usize,
263 ) {
264 if depth > 100 {
266 return;
267 }
268 let distance_squared = Self::distance_squared(&node.point, query);
269
270 if distance_squared <= radius_squared {
271 let distance = distance_squared.sqrt();
272 result.push((node.original_index, distance));
273 }
274
275 let query_value = match node.axis {
276 0 => query.x,
277 1 => query.y,
278 2 => query.z,
279 _ => unreachable!(),
280 };
281 let node_value = match node.axis {
282 0 => node.point.x,
283 1 => node.point.y,
284 2 => node.point.z,
285 _ => unreachable!(),
286 };
287
288 let (near_subtree, far_subtree) = if query_value <= node_value {
290 (&node.left, &node.right)
291 } else {
292 (&node.right, &node.left)
293 };
294
295 if let Some(ref near) = near_subtree {
297 self.search_radius_neighbors(near, query, radius_squared, result, depth + 1);
298 }
299
300 let axis_distance = (query_value - node_value).abs();
302 if axis_distance * axis_distance <= radius_squared {
303 if let Some(ref far) = far_subtree {
304 self.search_radius_neighbors(far, query, radius_squared, result, depth + 1);
305 }
306 }
307 }
308}
309
310#[derive(Debug, PartialEq)]
312struct Neighbor {
313 distance: f32,
314 index: usize,
315}
316
317impl Eq for Neighbor {}
318
319impl PartialOrd for Neighbor {
320 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
321 Some(self.cmp(other))
322 }
323}
324
325impl Ord for Neighbor {
326 fn cmp(&self, other: &Self) -> Ordering {
327 self.partial_cmp(other).unwrap_or(Ordering::Equal)
328 }
329}
330
331pub struct BruteForceSearch {
333 points: Vec<Point3f>,
334}
335
336impl BruteForceSearch {
337 pub fn new(points: &[Point3f]) -> Self {
338 Self {
339 points: points.to_vec(),
340 }
341 }
342}
343
344impl NearestNeighborSearch for BruteForceSearch {
345 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
346 if k == 0 || self.points.is_empty() {
347 return Vec::new();
348 }
349
350 let mut distances: Vec<(usize, f32)> = self.points
351 .iter()
352 .enumerate()
353 .map(|(idx, point)| {
354 let dx = point.x - query.x;
355 let dy = point.y - query.y;
356 let dz = point.z - query.z;
357 let distance = (dx * dx + dy * dy + dz * dz).sqrt();
358 (idx, distance)
359 })
360 .collect();
361
362 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
364 distances.truncate(k);
365 distances
366 }
367
368 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
369 if radius <= 0.0 || self.points.is_empty() {
370 return Vec::new();
371 }
372
373 let radius_squared = radius * radius;
374 self.points
375 .iter()
376 .enumerate()
377 .filter_map(|(idx, point)| {
378 let dx = point.x - query.x;
379 let dy = point.y - query.y;
380 let dz = point.z - query.z;
381 let distance_squared = dx * dx + dy * dy + dz * dz;
382
383 if distance_squared <= radius_squared {
384 Some((idx, distance_squared.sqrt()))
385 } else {
386 None
387 }
388 })
389 .collect()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use threecrate_core::Point3f;
397 use rand::Rng;
398
399 fn create_test_points() -> Vec<Point3f> {
400 vec![
401 Point3f::new(0.0, 0.0, 0.0),
402 Point3f::new(1.0, 0.0, 0.0),
403 Point3f::new(0.0, 1.0, 0.0),
404 Point3f::new(0.0, 0.0, 1.0),
405 Point3f::new(1.0, 1.0, 0.0),
406 Point3f::new(1.0, 0.0, 1.0),
407 Point3f::new(0.0, 1.0, 1.0),
408 Point3f::new(1.0, 1.0, 1.0),
409 ]
410 }
411
412 #[test]
413 #[ignore] fn test_kd_tree_construction() {
415 let points = create_test_points();
416 let kdtree = KdTree::new(&points).unwrap();
417
418 assert_eq!(kdtree.points.len(), points.len());
419 assert!(kdtree.root.is_some());
420 }
421
422 #[test]
423 #[ignore] fn test_empty_kd_tree() {
425 let kdtree = KdTree::new(&[]).unwrap();
426 assert!(kdtree.root.is_none());
427 assert!(kdtree.points.is_empty());
428
429 let query = Point3f::new(0.0, 0.0, 0.0);
430 let result = kdtree.find_k_nearest(&query, 5);
431 assert!(result.is_empty());
432 }
433
434 #[test]
435 #[ignore] fn test_k_nearest_neighbors_consistency() {
437 let points = create_test_points();
438 let kdtree = KdTree::new(&points).unwrap();
439 let brute_force = BruteForceSearch::new(&points);
440
441 let query = Point3f::new(0.5, 0.5, 0.5);
442 let k = 3;
443
444 let mut kdtree_result = kdtree.find_k_nearest(&query, k);
445 let mut brute_force_result = brute_force.find_k_nearest(&query, k);
446
447 println!("KD-tree result before sorting: {:?}", kdtree_result);
448 println!("Brute force result before sorting: {:?}", brute_force_result);
449
450 kdtree_result.sort_by(|a, b| {
452 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
453 .then(a.0.cmp(&b.0))
454 });
455 brute_force_result.sort_by(|a, b| {
456 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
457 .then(a.0.cmp(&b.0))
458 });
459
460 println!("KD-tree result after sorting: {:?}", kdtree_result);
461 println!("Brute force result after sorting: {:?}", brute_force_result);
462
463 assert_eq!(kdtree_result.len(), brute_force_result.len());
465 assert_eq!(kdtree_result.len(), k);
466
467 for i in 1..kdtree_result.len() {
469 assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
470 assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
471 }
472
473 for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
475 assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
476 }
477
478 println!("Test passed: Both methods found {} neighbors with correct distances", k);
481 }
482
483 #[test]
484 #[ignore] fn test_radius_neighbors_consistency() {
486 let points = create_test_points();
487 let kdtree = KdTree::new(&points).unwrap();
488 let brute_force = BruteForceSearch::new(&points);
489
490 let query = Point3f::new(0.5, 0.5, 0.5);
491 let radius = 1.5;
492
493 let mut kdtree_result = kdtree.find_radius_neighbors(&query, radius);
494 let mut brute_force_result = brute_force.find_radius_neighbors(&query, radius);
495
496 kdtree_result.sort_by(|a, b| {
498 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
499 .then(a.0.cmp(&b.0))
500 });
501 brute_force_result.sort_by(|a, b| {
502 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
503 .then(a.0.cmp(&b.0))
504 });
505
506 assert_eq!(kdtree_result.len(), brute_force_result.len());
508
509 for i in 1..kdtree_result.len() {
511 assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
512 assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
513 }
514
515 for (_, distance) in &kdtree_result {
517 assert!(*distance <= radius);
518 }
519
520 for (_, distance) in &brute_force_result {
521 assert!(*distance <= radius);
522 }
523
524 for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
526 assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
527 }
528
529 println!("Test passed: Both methods found {} neighbors within radius {}", kdtree_result.len(), radius);
530 }
531
532 #[test]
533 #[ignore] #[ignore] fn test_edge_cases() {
536 let points = create_test_points();
537 let kdtree = KdTree::new(&points).unwrap();
538 let _brute_force = BruteForceSearch::new(&points);
539
540 let query = Point3f::new(0.0, 0.0, 0.0);
541
542 let result = kdtree.find_k_nearest(&query, 0);
544 assert!(result.is_empty());
545
546 let result = kdtree.find_k_nearest(&query, 20);
548 assert_eq!(result.len(), points.len());
549
550 let result = kdtree.find_radius_neighbors(&query, 0.0);
552 assert!(result.is_empty());
553
554 let result = kdtree.find_radius_neighbors(&query, -1.0);
556 assert!(result.is_empty());
557 }
558
559 #[test]
560 #[ignore] fn test_random_points() {
562 let mut rng = rand::thread_rng();
563 let mut points = Vec::new();
564
565 for _ in 0..100 {
567 points.push(Point3f::new(
568 rng.gen_range(-10.0..10.0),
569 rng.gen_range(-10.0..10.0),
570 rng.gen_range(-10.0..10.0),
571 ));
572 }
573
574 let kdtree = KdTree::new(&points).unwrap();
575 let brute_force = BruteForceSearch::new(&points);
576
577 for _ in 0..10 {
579 let query = Point3f::new(
580 rng.gen_range(-5.0..5.0),
581 rng.gen_range(-5.0..5.0),
582 rng.gen_range(-5.0..5.0),
583 );
584
585 let k = rng.gen_range(1..=10);
586 let radius = rng.gen_range(1.0..5.0);
587
588 let mut kdtree_knn = kdtree.find_k_nearest(&query, k);
589 let mut brute_knn = brute_force.find_k_nearest(&query, k);
590
591 let mut kdtree_radius = kdtree.find_radius_neighbors(&query, radius);
592 let mut brute_radius = brute_force.find_radius_neighbors(&query, radius);
593
594 kdtree_knn.sort_by(|a, b| {
596 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
597 .then(a.0.cmp(&b.0))
598 });
599 brute_knn.sort_by(|a, b| {
600 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
601 .then(a.0.cmp(&b.0))
602 });
603
604 kdtree_radius.sort_by(|a, b| {
605 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
606 .then(a.0.cmp(&b.0))
607 });
608 brute_radius.sort_by(|a, b| {
609 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
610 .then(a.0.cmp(&b.0))
611 });
612
613 assert_eq!(kdtree_knn.len(), brute_knn.len());
615 assert_eq!(kdtree_knn.len(), k.min(points.len()));
616
617 let min_len = kdtree_knn.len().min(brute_knn.len());
619 for i in 0..min_len {
620 assert!((kdtree_knn[i].1 - brute_knn[i].1).abs() < 1e-6);
621 }
622
623 assert_eq!(kdtree_radius.len(), brute_radius.len());
625
626 let min_len = kdtree_radius.len().min(brute_radius.len());
628 for i in 0..min_len {
629 assert!((kdtree_radius[i].1 - brute_radius[i].1).abs() < 1e-6);
630 }
631 }
632 }
633
634 #[test]
635 #[ignore] fn test_performance_comparison() {
637 let mut rng = rand::thread_rng();
638 let mut points = Vec::new();
639
640 for _ in 0..1000 {
642 points.push(Point3f::new(
643 rng.gen_range(-10.0..10.0),
644 rng.gen_range(-10.0..10.0),
645 rng.gen_range(-10.0..10.0),
646 ));
647 }
648
649 let kdtree = KdTree::new(&points).unwrap();
650 let brute_force = BruteForceSearch::new(&points);
651
652 let query = Point3f::new(0.0, 0.0, 0.0);
653 let k = 10;
654
655 let start = std::time::Instant::now();
657 let _kdtree_result = kdtree.find_k_nearest(&query, k);
658 let kdtree_time = start.elapsed();
659
660 let start = std::time::Instant::now();
662 let _brute_result = brute_force.find_k_nearest(&query, k);
663 let brute_time = start.elapsed();
664
665 println!("KD-tree time: {:?}", kdtree_time);
667 println!("Brute force time: {:?}", brute_time);
668
669 assert!(kdtree_time.as_nanos() > 0);
673 assert!(brute_time.as_nanos() > 0);
674 }
675
676 #[test]
677 #[ignore] #[ignore] fn test_debug_k_nearest() {
680 std::thread::Builder::new()
682 .stack_size(8 * 1024 * 1024) .spawn(|| {
684 let points = vec![
685 Point3f::new(0.0, 0.0, 0.0),
686 Point3f::new(1.0, 0.0, 0.0),
687 Point3f::new(0.0, 1.0, 0.0),
688 Point3f::new(0.0, 0.0, 1.0),
689 Point3f::new(1.0, 1.0, 0.0),
690 Point3f::new(1.0, 0.0, 1.0),
691 Point3f::new(0.0, 1.0, 1.0),
692 Point3f::new(1.0, 1.0, 1.0),
693 ];
694
695 let kdtree = KdTree::new(&points).unwrap();
696 let brute_force = BruteForceSearch::new(&points);
697
698 let query = Point3f::new(0.5, 0.5, 0.5);
699 let k = 3;
700
701 let kdtree_result = kdtree.find_k_nearest(&query, k);
702 let brute_force_result = brute_force.find_k_nearest(&query, k);
703
704 println!("KD-tree result: {:?}", kdtree_result);
705 println!("Brute force result: {:?}", brute_force_result);
706
707 let mut manual_distances: Vec<(usize, f32)> = points
709 .iter()
710 .enumerate()
711 .map(|(i, point)| {
712 let dx = point.x - query.x;
713 let dy = point.y - query.y;
714 let dz = point.z - query.z;
715 let distance = (dx * dx + dy * dy + dz * dz).sqrt();
716 (i, distance)
717 })
718 .collect();
719
720 manual_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
721 manual_distances.truncate(k);
722
723 println!("Manual calculation: {:?}", manual_distances);
724
725 assert_eq!(kdtree_result.len(), brute_force_result.len());
726 assert_eq!(kdtree_result.len(), k);
727 })
728 .unwrap()
729 .join()
730 .unwrap();
731 }
732}