1use threecrate_core::{Point3f, Result, NearestNeighborSearch};
4use std::collections::BinaryHeap;
5use std::cmp::Ordering;
6
7#[derive(Debug)]
9struct KdNode {
10 point: Point3f,
11 original_index: usize, left: Option<Box<KdNode>>,
13 right: Option<Box<KdNode>>,
14 axis: usize, }
16
17impl KdNode {
18 fn new(point: Point3f, original_index: usize, axis: usize) -> Self {
19 Self {
20 point,
21 original_index,
22 left: None,
23 right: None,
24 axis,
25 }
26 }
27}
28
29pub struct KdTree {
31 root: Option<Box<KdNode>>,
32 points: Vec<Point3f>, }
34
35impl KdTree {
36 pub fn new(points: &[Point3f]) -> Result<Self> {
38 if points.is_empty() {
39 return Ok(Self {
40 root: None,
41 points: Vec::new(),
42 });
43 }
44
45 let mut points_with_indices: Vec<(Point3f, usize)> = points
46 .iter()
47 .enumerate()
48 .map(|(i, &point)| (point, i))
49 .collect();
50
51 let root = Self::build_tree(&mut points_with_indices, 0, 0, points.len() - 1);
52
53 Ok(Self {
54 root: Some(Box::new(root)),
55 points: points.to_vec(),
56 })
57 }
58
59 fn build_tree(points: &mut [(Point3f, usize)], depth: usize, start: usize, end: usize) -> KdNode {
61 if start == end {
62 let (point, index) = points[start];
63 return KdNode::new(point, index, depth % 3);
64 }
65
66 let axis = depth % 3;
67 let median_idx = (start + end) / 2;
68
69 Self::select_median(points, start, end, median_idx, axis);
71
72 let (point, index) = points[median_idx];
73 let mut node = KdNode::new(point, index, axis);
74
75 if median_idx > start {
77 node.left = Some(Box::new(Self::build_tree(points, depth + 1, start, median_idx - 1)));
78 }
79
80 if median_idx < end {
82 node.right = Some(Box::new(Self::build_tree(points, depth + 1, median_idx + 1, end)));
83 }
84
85 node
86 }
87
88 fn select_median(points: &mut [(Point3f, usize)], start: usize, end: usize, target: usize, axis: usize) {
90 let mut left = start;
91 let mut right = end;
92
93 while left < right {
94 let pivot_idx = Self::partition(points, left, right, axis);
95
96 if pivot_idx == target {
97 return;
98 } else if pivot_idx < target {
99 left = pivot_idx + 1;
100 } else {
101 right = pivot_idx - 1;
102 }
103 }
104 }
105
106 fn partition(points: &mut [(Point3f, usize)], start: usize, end: usize, axis: usize) -> usize {
108 let pivot_value = match axis {
109 0 => points[end].0.x,
110 1 => points[end].0.y,
111 2 => points[end].0.z,
112 _ => unreachable!(),
113 };
114
115 let mut i = start;
116 for j in start..end {
117 let point_value = match axis {
118 0 => points[j].0.x,
119 1 => points[j].0.y,
120 2 => points[j].0.z,
121 _ => unreachable!(),
122 };
123
124 if point_value <= pivot_value {
125 points.swap(i, j);
126 i += 1;
127 }
128 }
129
130 points.swap(i, end);
131 i
132 }
133
134 fn distance_squared(a: &Point3f, b: &Point3f) -> f32 {
136 let dx = a.x - b.x;
137 let dy = a.y - b.y;
138 let dz = a.z - b.z;
139 dx * dx + dy * dy + dz * dz
140 }
141}
142
143impl NearestNeighborSearch for KdTree {
144 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
145 if k == 0 || self.points.is_empty() {
146 return Vec::new();
147 }
148
149 let mut heap = BinaryHeap::new();
150 let mut result = Vec::new();
151
152 if let Some(ref root) = self.root {
153 self.search_k_nearest(root, query, k, &mut heap, 0);
154 }
155
156 while let Some(Neighbor { distance, index }) = heap.pop() {
158 result.push((index, distance));
159 }
160
161 result.reverse(); result
163 }
164
165 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
166 if radius <= 0.0 || self.points.is_empty() {
167 return Vec::new();
168 }
169
170 let radius_squared = radius * radius;
171 let mut result = Vec::new();
172
173 if let Some(ref root) = self.root {
174 self.search_radius_neighbors(root, query, radius_squared, &mut result, 0);
175 }
176
177 result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
178 result
179 }
180}
181
182impl KdTree {
183 fn search_k_nearest(
185 &self,
186 node: &KdNode,
187 query: &Point3f,
188 k: usize,
189 heap: &mut BinaryHeap<Neighbor>,
190 depth: usize,
191 ) {
192 let distance_squared = Self::distance_squared(&node.point, query);
193 let distance = distance_squared.sqrt();
194
195 if heap.len() < k {
197 heap.push(Neighbor {
198 distance,
199 index: node.original_index,
200 });
201 } else if let Some(farthest) = heap.peek() {
202 if distance < farthest.distance {
203 heap.pop();
204 heap.push(Neighbor {
205 distance,
206 index: node.original_index,
207 });
208 }
209 }
210
211 let query_value = match node.axis {
212 0 => query.x,
213 1 => query.y,
214 2 => query.z,
215 _ => unreachable!(),
216 };
217 let node_value = match node.axis {
218 0 => node.point.x,
219 1 => node.point.y,
220 2 => node.point.z,
221 _ => unreachable!(),
222 };
223
224 let (near_subtree, far_subtree) = if query_value <= node_value {
226 (&node.left, &node.right)
227 } else {
228 (&node.right, &node.left)
229 };
230
231 if let Some(ref near) = near_subtree {
233 self.search_k_nearest(near, query, k, heap, depth + 1);
234 }
235
236 let axis_distance = (query_value - node_value).abs();
238 let should_search_far = if let Some(farthest) = heap.peek() {
239 heap.len() < k || axis_distance < farthest.distance
240 } else {
241 true
242 };
243
244 if should_search_far {
245 if let Some(ref far) = far_subtree {
246 self.search_k_nearest(far, query, k, heap, depth + 1);
247 }
248 }
249 }
250
251 fn search_radius_neighbors(
253 &self,
254 node: &KdNode,
255 query: &Point3f,
256 radius_squared: f32,
257 result: &mut Vec<(usize, f32)>,
258 depth: usize,
259 ) {
260 let distance_squared = Self::distance_squared(&node.point, query);
261
262 if distance_squared <= radius_squared {
263 let distance = distance_squared.sqrt();
264 result.push((node.original_index, distance));
265 }
266
267 let query_value = match node.axis {
268 0 => query.x,
269 1 => query.y,
270 2 => query.z,
271 _ => unreachable!(),
272 };
273 let node_value = match node.axis {
274 0 => node.point.x,
275 1 => node.point.y,
276 2 => node.point.z,
277 _ => unreachable!(),
278 };
279
280 let (near_subtree, far_subtree) = if query_value <= node_value {
282 (&node.left, &node.right)
283 } else {
284 (&node.right, &node.left)
285 };
286
287 if let Some(ref near) = near_subtree {
289 self.search_radius_neighbors(near, query, radius_squared, result, depth + 1);
290 }
291
292 let axis_distance = (query_value - node_value).abs();
294 if axis_distance * axis_distance <= radius_squared {
295 if let Some(ref far) = far_subtree {
296 self.search_radius_neighbors(far, query, radius_squared, result, depth + 1);
297 }
298 }
299 }
300}
301
302#[derive(Debug, PartialEq)]
304struct Neighbor {
305 distance: f32,
306 index: usize,
307}
308
309impl Eq for Neighbor {}
310
311impl PartialOrd for Neighbor {
312 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
313 self.distance.partial_cmp(&other.distance)
314 }
315}
316
317impl Ord for Neighbor {
318 fn cmp(&self, other: &Self) -> Ordering {
319 self.partial_cmp(other).unwrap_or(Ordering::Equal)
320 }
321}
322
323pub struct BruteForceSearch {
325 points: Vec<Point3f>,
326}
327
328impl BruteForceSearch {
329 pub fn new(points: &[Point3f]) -> Self {
330 Self {
331 points: points.to_vec(),
332 }
333 }
334}
335
336impl NearestNeighborSearch for BruteForceSearch {
337 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
338 if k == 0 || self.points.is_empty() {
339 return Vec::new();
340 }
341
342 let mut distances: Vec<(usize, f32)> = self.points
343 .iter()
344 .enumerate()
345 .map(|(idx, point)| {
346 let dx = point.x - query.x;
347 let dy = point.y - query.y;
348 let dz = point.z - query.z;
349 let distance = (dx * dx + dy * dy + dz * dz).sqrt();
350 (idx, distance)
351 })
352 .collect();
353
354 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
356 distances.truncate(k);
357 distances
358 }
359
360 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
361 if radius <= 0.0 || self.points.is_empty() {
362 return Vec::new();
363 }
364
365 let radius_squared = radius * radius;
366 self.points
367 .iter()
368 .enumerate()
369 .filter_map(|(idx, point)| {
370 let dx = point.x - query.x;
371 let dy = point.y - query.y;
372 let dz = point.z - query.z;
373 let distance_squared = dx * dx + dy * dy + dz * dz;
374
375 if distance_squared <= radius_squared {
376 Some((idx, distance_squared.sqrt()))
377 } else {
378 None
379 }
380 })
381 .collect()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use threecrate_core::Point3f;
389 use rand::Rng;
390
391 fn create_test_points() -> Vec<Point3f> {
392 vec![
393 Point3f::new(0.0, 0.0, 0.0),
394 Point3f::new(1.0, 0.0, 0.0),
395 Point3f::new(0.0, 1.0, 0.0),
396 Point3f::new(0.0, 0.0, 1.0),
397 Point3f::new(1.0, 1.0, 0.0),
398 Point3f::new(1.0, 0.0, 1.0),
399 Point3f::new(0.0, 1.0, 1.0),
400 Point3f::new(1.0, 1.0, 1.0),
401 ]
402 }
403
404 #[test]
405 fn test_kd_tree_construction() {
406 let points = create_test_points();
407 let kdtree = KdTree::new(&points).unwrap();
408
409 assert_eq!(kdtree.points.len(), points.len());
410 assert!(kdtree.root.is_some());
411 }
412
413 #[test]
414 fn test_empty_kd_tree() {
415 let kdtree = KdTree::new(&[]).unwrap();
416 assert!(kdtree.root.is_none());
417 assert!(kdtree.points.is_empty());
418
419 let query = Point3f::new(0.0, 0.0, 0.0);
420 let result = kdtree.find_k_nearest(&query, 5);
421 assert!(result.is_empty());
422 }
423
424 #[test]
425 fn test_k_nearest_neighbors_consistency() {
426 let points = create_test_points();
427 let kdtree = KdTree::new(&points).unwrap();
428 let brute_force = BruteForceSearch::new(&points);
429
430 let query = Point3f::new(0.5, 0.5, 0.5);
431 let k = 3;
432
433 let mut kdtree_result = kdtree.find_k_nearest(&query, k);
434 let mut brute_force_result = brute_force.find_k_nearest(&query, k);
435
436 println!("KD-tree result before sorting: {:?}", kdtree_result);
437 println!("Brute force result before sorting: {:?}", brute_force_result);
438
439 kdtree_result.sort_by(|a, b| {
441 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
442 .then(a.0.cmp(&b.0))
443 });
444 brute_force_result.sort_by(|a, b| {
445 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
446 .then(a.0.cmp(&b.0))
447 });
448
449 println!("KD-tree result after sorting: {:?}", kdtree_result);
450 println!("Brute force result after sorting: {:?}", brute_force_result);
451
452 assert_eq!(kdtree_result.len(), brute_force_result.len());
454 assert_eq!(kdtree_result.len(), k);
455
456 for i in 1..kdtree_result.len() {
458 assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
459 assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
460 }
461
462 for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
464 assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
465 }
466
467 println!("Test passed: Both methods found {} neighbors with correct distances", k);
470 }
471
472 #[test]
473 fn test_radius_neighbors_consistency() {
474 let points = create_test_points();
475 let kdtree = KdTree::new(&points).unwrap();
476 let brute_force = BruteForceSearch::new(&points);
477
478 let query = Point3f::new(0.5, 0.5, 0.5);
479 let radius = 1.5;
480
481 let mut kdtree_result = kdtree.find_radius_neighbors(&query, radius);
482 let mut brute_force_result = brute_force.find_radius_neighbors(&query, radius);
483
484 kdtree_result.sort_by(|a, b| {
486 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
487 .then(a.0.cmp(&b.0))
488 });
489 brute_force_result.sort_by(|a, b| {
490 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
491 .then(a.0.cmp(&b.0))
492 });
493
494 assert_eq!(kdtree_result.len(), brute_force_result.len());
496
497 for i in 1..kdtree_result.len() {
499 assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
500 assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
501 }
502
503 for (_, distance) in &kdtree_result {
505 assert!(*distance <= radius);
506 }
507
508 for (_, distance) in &brute_force_result {
509 assert!(*distance <= radius);
510 }
511
512 for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
514 assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
515 }
516
517 println!("Test passed: Both methods found {} neighbors within radius {}", kdtree_result.len(), radius);
518 }
519
520 #[test]
521 fn test_edge_cases() {
522 let points = create_test_points();
523 let kdtree = KdTree::new(&points).unwrap();
524 let _brute_force = BruteForceSearch::new(&points);
525
526 let query = Point3f::new(0.0, 0.0, 0.0);
527
528 let result = kdtree.find_k_nearest(&query, 0);
530 assert!(result.is_empty());
531
532 let result = kdtree.find_k_nearest(&query, 20);
534 assert_eq!(result.len(), points.len());
535
536 let result = kdtree.find_radius_neighbors(&query, 0.0);
538 assert!(result.is_empty());
539
540 let result = kdtree.find_radius_neighbors(&query, -1.0);
542 assert!(result.is_empty());
543 }
544
545 #[test]
546 fn test_random_points() {
547 let mut rng = rand::thread_rng();
548 let mut points = Vec::new();
549
550 for _ in 0..100 {
552 points.push(Point3f::new(
553 rng.gen_range(-10.0..10.0),
554 rng.gen_range(-10.0..10.0),
555 rng.gen_range(-10.0..10.0),
556 ));
557 }
558
559 let kdtree = KdTree::new(&points).unwrap();
560 let brute_force = BruteForceSearch::new(&points);
561
562 for _ in 0..10 {
564 let query = Point3f::new(
565 rng.gen_range(-5.0..5.0),
566 rng.gen_range(-5.0..5.0),
567 rng.gen_range(-5.0..5.0),
568 );
569
570 let k = rng.gen_range(1..=10);
571 let radius = rng.gen_range(1.0..5.0);
572
573 let mut kdtree_knn = kdtree.find_k_nearest(&query, k);
574 let mut brute_knn = brute_force.find_k_nearest(&query, k);
575
576 let mut kdtree_radius = kdtree.find_radius_neighbors(&query, radius);
577 let mut brute_radius = brute_force.find_radius_neighbors(&query, radius);
578
579 kdtree_knn.sort_by(|a, b| {
581 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
582 .then(a.0.cmp(&b.0))
583 });
584 brute_knn.sort_by(|a, b| {
585 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
586 .then(a.0.cmp(&b.0))
587 });
588
589 kdtree_radius.sort_by(|a, b| {
590 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
591 .then(a.0.cmp(&b.0))
592 });
593 brute_radius.sort_by(|a, b| {
594 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
595 .then(a.0.cmp(&b.0))
596 });
597
598 assert_eq!(kdtree_knn.len(), brute_knn.len());
600 assert_eq!(kdtree_knn.len(), k.min(points.len()));
601
602 let min_len = kdtree_knn.len().min(brute_knn.len());
604 for i in 0..min_len {
605 assert!((kdtree_knn[i].1 - brute_knn[i].1).abs() < 1e-6);
606 }
607
608 assert_eq!(kdtree_radius.len(), brute_radius.len());
610
611 let min_len = kdtree_radius.len().min(brute_radius.len());
613 for i in 0..min_len {
614 assert!((kdtree_radius[i].1 - brute_radius[i].1).abs() < 1e-6);
615 }
616 }
617 }
618
619 #[test]
620 fn test_performance_comparison() {
621 let mut rng = rand::thread_rng();
622 let mut points = Vec::new();
623
624 for _ in 0..1000 {
626 points.push(Point3f::new(
627 rng.gen_range(-10.0..10.0),
628 rng.gen_range(-10.0..10.0),
629 rng.gen_range(-10.0..10.0),
630 ));
631 }
632
633 let kdtree = KdTree::new(&points).unwrap();
634 let brute_force = BruteForceSearch::new(&points);
635
636 let query = Point3f::new(0.0, 0.0, 0.0);
637 let k = 10;
638
639 let start = std::time::Instant::now();
641 let _kdtree_result = kdtree.find_k_nearest(&query, k);
642 let kdtree_time = start.elapsed();
643
644 let start = std::time::Instant::now();
646 let _brute_result = brute_force.find_k_nearest(&query, k);
647 let brute_time = start.elapsed();
648
649 println!("KD-tree time: {:?}", kdtree_time);
651 println!("Brute force time: {:?}", brute_time);
652
653 assert!(kdtree_time.as_nanos() > 0);
657 assert!(brute_time.as_nanos() > 0);
658 }
659
660 #[test]
661 fn test_debug_k_nearest() {
662 let points = vec![
663 Point3f::new(0.0, 0.0, 0.0),
664 Point3f::new(1.0, 0.0, 0.0),
665 Point3f::new(0.0, 1.0, 0.0),
666 Point3f::new(0.0, 0.0, 1.0),
667 Point3f::new(1.0, 1.0, 0.0),
668 Point3f::new(1.0, 0.0, 1.0),
669 Point3f::new(0.0, 1.0, 1.0),
670 Point3f::new(1.0, 1.0, 1.0),
671 ];
672
673 let kdtree = KdTree::new(&points).unwrap();
674 let brute_force = BruteForceSearch::new(&points);
675
676 let query = Point3f::new(0.5, 0.5, 0.5);
677 let k = 3;
678
679 let kdtree_result = kdtree.find_k_nearest(&query, k);
680 let brute_force_result = brute_force.find_k_nearest(&query, k);
681
682 println!("KD-tree result: {:?}", kdtree_result);
683 println!("Brute force result: {:?}", brute_force_result);
684
685 let mut manual_distances: Vec<(usize, f32)> = points
687 .iter()
688 .enumerate()
689 .map(|(i, point)| {
690 let dx = point.x - query.x;
691 let dy = point.y - query.y;
692 let dz = point.z - query.z;
693 let distance = (dx * dx + dy * dy + dz * dz).sqrt();
694 (i, distance)
695 })
696 .collect();
697
698 manual_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
699 manual_distances.truncate(k);
700
701 println!("Manual calculation: {:?}", manual_distances);
702
703 assert_eq!(kdtree_result.len(), brute_force_result.len());
704 assert_eq!(kdtree_result.len(), k);
705 }
706}