1use threecrate_core::{Point3f, Result, NearestNeighborSearch};
4use std::collections::BinaryHeap;
5use std::cmp::Ordering;
6
7#[derive(Debug)]
9struct KdNode {
10 point: Point3f,
11 original_index: usize, left: Option<Box<KdNode>>,
13 right: Option<Box<KdNode>>,
14 axis: usize, }
16
17impl KdNode {
18 fn new(point: Point3f, original_index: usize, axis: usize) -> Self {
19 Self {
20 point,
21 original_index,
22 left: None,
23 right: None,
24 axis,
25 }
26 }
27}
28
29pub struct KdTree {
31 root: Option<Box<KdNode>>,
32 points: Vec<Point3f>, }
34
35impl KdTree {
36 pub fn new(points: &[Point3f]) -> Result<Self> {
38 if points.is_empty() {
39 return Ok(Self {
40 root: None,
41 points: Vec::new(),
42 });
43 }
44
45 let mut points_with_indices: Vec<(Point3f, usize)> = points
46 .iter()
47 .enumerate()
48 .map(|(i, &point)| (point, i))
49 .collect();
50
51 let root = Self::build_tree(&mut points_with_indices, 0, 0, points.len() - 1);
52
53 Ok(Self {
54 root: Some(Box::new(root)),
55 points: points.to_vec(),
56 })
57 }
58
59 fn build_tree(points: &mut [(Point3f, usize)], depth: usize, start: usize, end: usize) -> KdNode {
61 if start == end {
62 let (point, index) = points[start];
63 return KdNode::new(point, index, depth % 3);
64 }
65
66 let axis = depth % 3;
67 let median_idx = (start + end) / 2;
68
69 Self::select_median(points, start, end, median_idx, axis);
71
72 let (point, index) = points[median_idx];
73 let mut node = KdNode::new(point, index, axis);
74
75 if median_idx > start {
77 node.left = Some(Box::new(Self::build_tree(points, depth + 1, start, median_idx - 1)));
78 }
79
80 if median_idx < end {
82 node.right = Some(Box::new(Self::build_tree(points, depth + 1, median_idx + 1, end)));
83 }
84
85 node
86 }
87
88 fn select_median(points: &mut [(Point3f, usize)], start: usize, end: usize, target: usize, axis: usize) {
90 let mut left = start;
91 let mut right = end;
92
93 while left < right {
94 let pivot_idx = Self::partition(points, left, right, axis);
95
96 match pivot_idx.cmp(&target) {
97 Ordering::Equal => return,
98 Ordering::Less => left = pivot_idx + 1,
99 Ordering::Greater => right = pivot_idx - 1,
100 }
101 }
102 }
103
104 fn partition(points: &mut [(Point3f, usize)], start: usize, end: usize, axis: usize) -> usize {
106 let pivot_value = match axis {
107 0 => points[end].0.x,
108 1 => points[end].0.y,
109 2 => points[end].0.z,
110 _ => unreachable!(),
111 };
112
113 let mut i = start;
114 for j in start..end {
115 let point_value = match axis {
116 0 => points[j].0.x,
117 1 => points[j].0.y,
118 2 => points[j].0.z,
119 _ => unreachable!(),
120 };
121
122 if point_value <= pivot_value {
123 points.swap(i, j);
124 i += 1;
125 }
126 }
127
128 points.swap(i, end);
129 i
130 }
131
132 fn distance_squared(a: &Point3f, b: &Point3f) -> f32 {
134 let dx = a.x - b.x;
135 let dy = a.y - b.y;
136 let dz = a.z - b.z;
137 dx * dx + dy * dy + dz * dz
138 }
139}
140
141impl NearestNeighborSearch for KdTree {
142 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
149 if k == 0 || self.points.is_empty() {
150 return Vec::new();
151 }
152
153 let mut heap: BinaryHeap<Neighbor> = BinaryHeap::with_capacity(k + 1);
156 let mut stack: Vec<&KdNode> = Vec::new();
157
158 if let Some(ref root) = self.root {
159 stack.push(root);
160 }
161
162 while let Some(node) = stack.pop() {
163 let dist = Self::distance_squared(&node.point, query).sqrt();
164
165 if heap.len() < k {
166 heap.push(Neighbor { distance: dist, index: node.original_index });
167 } else if let Some(farthest) = heap.peek() {
168 if dist < farthest.distance {
169 heap.pop();
170 heap.push(Neighbor { distance: dist, index: node.original_index });
171 }
172 }
173
174 let query_val = query.coords[node.axis];
175 let node_val = node.point.coords[node.axis];
176 let axis_dist = (query_val - node_val).abs();
177
178 let (near, far) = if query_val <= node_val {
182 (&node.left, &node.right)
183 } else {
184 (&node.right, &node.left)
185 };
186
187 let search_far = if let Some(farthest) = heap.peek() {
191 heap.len() < k || axis_dist < farthest.distance
192 } else {
193 true
194 };
195 if search_far {
196 if let Some(ref far_node) = far {
197 stack.push(far_node);
198 }
199 }
200 if let Some(ref near_node) = near {
201 stack.push(near_node);
202 }
203 }
204
205 heap.into_sorted_vec()
207 .into_iter()
208 .map(|n| (n.index, n.distance))
209 .collect()
210 }
211
212 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
214 if radius <= 0.0 || self.points.is_empty() {
215 return Vec::new();
216 }
217
218 let radius_sq = radius * radius;
219 let mut result: Vec<(usize, f32)> = Vec::new();
220 let mut stack: Vec<&KdNode> = Vec::new();
221
222 if let Some(ref root) = self.root {
223 stack.push(root);
224 }
225
226 while let Some(node) = stack.pop() {
227 let dist_sq = Self::distance_squared(&node.point, query);
228 if dist_sq <= radius_sq {
229 result.push((node.original_index, dist_sq.sqrt()));
230 }
231
232 let query_val = query.coords[node.axis];
233 let node_val = node.point.coords[node.axis];
234 let axis_dist = query_val - node_val;
235
236 let (near, far) = if query_val <= node_val {
237 (&node.left, &node.right)
238 } else {
239 (&node.right, &node.left)
240 };
241
242 if axis_dist * axis_dist <= radius_sq {
245 if let Some(ref far_node) = far {
246 stack.push(far_node);
247 }
248 }
249 if let Some(ref near_node) = near {
250 stack.push(near_node);
251 }
252 }
253
254 result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
255 result
256 }
257}
258
259#[derive(Debug, PartialEq)]
261struct Neighbor {
262 distance: f32,
263 index: usize,
264}
265
266impl Eq for Neighbor {}
267
268impl PartialOrd for Neighbor {
269 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
270 Some(self.cmp(other))
271 }
272}
273
274impl Ord for Neighbor {
275 fn cmp(&self, other: &Self) -> Ordering {
276 self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
279 }
280}
281
282pub struct BruteForceSearch {
284 points: Vec<Point3f>,
285}
286
287impl BruteForceSearch {
288 pub fn new(points: &[Point3f]) -> Self {
289 Self {
290 points: points.to_vec(),
291 }
292 }
293}
294
295impl NearestNeighborSearch for BruteForceSearch {
296 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
297 if k == 0 || self.points.is_empty() {
298 return Vec::new();
299 }
300
301 let mut distances: Vec<(usize, f32)> = self.points
302 .iter()
303 .enumerate()
304 .map(|(idx, point)| {
305 let dx = point.x - query.x;
306 let dy = point.y - query.y;
307 let dz = point.z - query.z;
308 let distance = (dx * dx + dy * dy + dz * dz).sqrt();
309 (idx, distance)
310 })
311 .collect();
312
313 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
315 distances.truncate(k);
316 distances
317 }
318
319 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
320 if radius <= 0.0 || self.points.is_empty() {
321 return Vec::new();
322 }
323
324 let radius_squared = radius * radius;
325 self.points
326 .iter()
327 .enumerate()
328 .filter_map(|(idx, point)| {
329 let dx = point.x - query.x;
330 let dy = point.y - query.y;
331 let dz = point.z - query.z;
332 let distance_squared = dx * dx + dy * dy + dz * dz;
333
334 if distance_squared <= radius_squared {
335 Some((idx, distance_squared.sqrt()))
336 } else {
337 None
338 }
339 })
340 .collect()
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use threecrate_core::Point3f;
348 use rand::Rng;
349
350 fn create_test_points() -> Vec<Point3f> {
351 vec![
352 Point3f::new(0.0, 0.0, 0.0),
353 Point3f::new(1.0, 0.0, 0.0),
354 Point3f::new(0.0, 1.0, 0.0),
355 Point3f::new(0.0, 0.0, 1.0),
356 Point3f::new(1.0, 1.0, 0.0),
357 Point3f::new(1.0, 0.0, 1.0),
358 Point3f::new(0.0, 1.0, 1.0),
359 Point3f::new(1.0, 1.0, 1.0),
360 ]
361 }
362
363 #[test]
364 fn test_kd_tree_construction() {
365 let points = create_test_points();
366 let kdtree = KdTree::new(&points).unwrap();
367
368 assert_eq!(kdtree.points.len(), points.len());
369 assert!(kdtree.root.is_some());
370 }
371
372 #[test]
373 fn test_empty_kd_tree() {
374 let kdtree = KdTree::new(&[]).unwrap();
375 assert!(kdtree.root.is_none());
376 assert!(kdtree.points.is_empty());
377
378 let query = Point3f::new(0.0, 0.0, 0.0);
379 let result = kdtree.find_k_nearest(&query, 5);
380 assert!(result.is_empty());
381 }
382
383 #[test]
384 fn test_k_nearest_neighbors_consistency() {
385 let points = create_test_points();
386 let kdtree = KdTree::new(&points).unwrap();
387 let brute_force = BruteForceSearch::new(&points);
388
389 let query = Point3f::new(0.5, 0.5, 0.5);
390 let k = 3;
391
392 let mut kdtree_result = kdtree.find_k_nearest(&query, k);
393 let mut brute_force_result = brute_force.find_k_nearest(&query, k);
394
395 println!("KD-tree result before sorting: {:?}", kdtree_result);
396 println!("Brute force result before sorting: {:?}", brute_force_result);
397
398 kdtree_result.sort_by(|a, b| {
400 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
401 .then(a.0.cmp(&b.0))
402 });
403 brute_force_result.sort_by(|a, b| {
404 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
405 .then(a.0.cmp(&b.0))
406 });
407
408 println!("KD-tree result after sorting: {:?}", kdtree_result);
409 println!("Brute force result after sorting: {:?}", brute_force_result);
410
411 assert_eq!(kdtree_result.len(), brute_force_result.len());
413 assert_eq!(kdtree_result.len(), k);
414
415 for i in 1..kdtree_result.len() {
417 assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
418 assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
419 }
420
421 for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
423 assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
424 }
425
426 println!("Test passed: Both methods found {} neighbors with correct distances", k);
429 }
430
431 #[test]
432 fn test_radius_neighbors_consistency() {
433 let points = create_test_points();
434 let kdtree = KdTree::new(&points).unwrap();
435 let brute_force = BruteForceSearch::new(&points);
436
437 let query = Point3f::new(0.5, 0.5, 0.5);
438 let radius = 1.5;
439
440 let mut kdtree_result = kdtree.find_radius_neighbors(&query, radius);
441 let mut brute_force_result = brute_force.find_radius_neighbors(&query, radius);
442
443 kdtree_result.sort_by(|a, b| {
445 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
446 .then(a.0.cmp(&b.0))
447 });
448 brute_force_result.sort_by(|a, b| {
449 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
450 .then(a.0.cmp(&b.0))
451 });
452
453 assert_eq!(kdtree_result.len(), brute_force_result.len());
455
456 for i in 1..kdtree_result.len() {
458 assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
459 assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
460 }
461
462 for (_, distance) in &kdtree_result {
464 assert!(*distance <= radius);
465 }
466
467 for (_, distance) in &brute_force_result {
468 assert!(*distance <= radius);
469 }
470
471 for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
473 assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
474 }
475
476 println!("Test passed: Both methods found {} neighbors within radius {}", kdtree_result.len(), radius);
477 }
478
479 #[test]
480 fn test_edge_cases() {
481 let points = create_test_points();
482 let kdtree = KdTree::new(&points).unwrap();
483 let _brute_force = BruteForceSearch::new(&points);
484
485 let query = Point3f::new(0.0, 0.0, 0.0);
486
487 let result = kdtree.find_k_nearest(&query, 0);
489 assert!(result.is_empty());
490
491 let result = kdtree.find_k_nearest(&query, 20);
493 assert_eq!(result.len(), points.len());
494
495 let result = kdtree.find_radius_neighbors(&query, 0.0);
497 assert!(result.is_empty());
498
499 let result = kdtree.find_radius_neighbors(&query, -1.0);
501 assert!(result.is_empty());
502 }
503
504 #[test]
505 fn test_random_points() {
506 let mut rng = rand::thread_rng();
507 let mut points = Vec::new();
508
509 for _ in 0..100 {
511 points.push(Point3f::new(
512 rng.gen_range(-10.0..10.0),
513 rng.gen_range(-10.0..10.0),
514 rng.gen_range(-10.0..10.0),
515 ));
516 }
517
518 let kdtree = KdTree::new(&points).unwrap();
519 let brute_force = BruteForceSearch::new(&points);
520
521 for _ in 0..10 {
523 let query = Point3f::new(
524 rng.gen_range(-5.0..5.0),
525 rng.gen_range(-5.0..5.0),
526 rng.gen_range(-5.0..5.0),
527 );
528
529 let k = rng.gen_range(1..=10);
530 let radius = rng.gen_range(1.0..5.0);
531
532 let mut kdtree_knn = kdtree.find_k_nearest(&query, k);
533 let mut brute_knn = brute_force.find_k_nearest(&query, k);
534
535 let mut kdtree_radius = kdtree.find_radius_neighbors(&query, radius);
536 let mut brute_radius = brute_force.find_radius_neighbors(&query, radius);
537
538 kdtree_knn.sort_by(|a, b| {
540 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
541 .then(a.0.cmp(&b.0))
542 });
543 brute_knn.sort_by(|a, b| {
544 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
545 .then(a.0.cmp(&b.0))
546 });
547
548 kdtree_radius.sort_by(|a, b| {
549 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
550 .then(a.0.cmp(&b.0))
551 });
552 brute_radius.sort_by(|a, b| {
553 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
554 .then(a.0.cmp(&b.0))
555 });
556
557 assert_eq!(kdtree_knn.len(), brute_knn.len());
559 assert_eq!(kdtree_knn.len(), k.min(points.len()));
560
561 let min_len = kdtree_knn.len().min(brute_knn.len());
563 for i in 0..min_len {
564 assert!((kdtree_knn[i].1 - brute_knn[i].1).abs() < 1e-6);
565 }
566
567 assert_eq!(kdtree_radius.len(), brute_radius.len());
569
570 let min_len = kdtree_radius.len().min(brute_radius.len());
572 for i in 0..min_len {
573 assert!((kdtree_radius[i].1 - brute_radius[i].1).abs() < 1e-6);
574 }
575 }
576 }
577
578 #[test]
579 fn test_performance_comparison() {
580 let mut rng = rand::thread_rng();
581 let mut points = Vec::new();
582
583 for _ in 0..1000 {
585 points.push(Point3f::new(
586 rng.gen_range(-10.0..10.0),
587 rng.gen_range(-10.0..10.0),
588 rng.gen_range(-10.0..10.0),
589 ));
590 }
591
592 let kdtree = KdTree::new(&points).unwrap();
593 let brute_force = BruteForceSearch::new(&points);
594
595 let query = Point3f::new(0.0, 0.0, 0.0);
596 let k = 10;
597
598 let start = std::time::Instant::now();
600 let _kdtree_result = kdtree.find_k_nearest(&query, k);
601 let kdtree_time = start.elapsed();
602
603 let start = std::time::Instant::now();
605 let _brute_result = brute_force.find_k_nearest(&query, k);
606 let brute_time = start.elapsed();
607
608 println!("KD-tree time: {:?}", kdtree_time);
610 println!("Brute force time: {:?}", brute_time);
611
612 assert!(kdtree_time.as_nanos() > 0);
616 assert!(brute_time.as_nanos() > 0);
617 }
618
619 #[test]
620 fn test_debug_k_nearest() {
621 let points = vec![
622 Point3f::new(0.0, 0.0, 0.0),
623 Point3f::new(1.0, 0.0, 0.0),
624 Point3f::new(0.0, 1.0, 0.0),
625 Point3f::new(0.0, 0.0, 1.0),
626 Point3f::new(1.0, 1.0, 0.0),
627 Point3f::new(1.0, 0.0, 1.0),
628 Point3f::new(0.0, 1.0, 1.0),
629 Point3f::new(1.0, 1.0, 1.0),
630 ];
631
632 let kdtree = KdTree::new(&points).unwrap();
633 let brute_force = BruteForceSearch::new(&points);
634
635 let query = Point3f::new(0.5, 0.5, 0.5);
636 let k = 3;
637
638 let mut kdtree_result = kdtree.find_k_nearest(&query, k);
639 let mut brute_force_result = brute_force.find_k_nearest(&query, k);
640
641 kdtree_result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0)));
642 brute_force_result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0)));
643
644 assert_eq!(kdtree_result.len(), brute_force_result.len());
645 assert_eq!(kdtree_result.len(), k);
646 for (kd, bf) in kdtree_result.iter().zip(brute_force_result.iter()) {
647 assert!((kd.1 - bf.1).abs() < 1e-6, "distance mismatch: kd={}, bf={}", kd.1, bf.1);
648 }
649 }
650}