1use std::cmp::Ordering;
35use threecrate_core::{NearestNeighborSearch, Point3f};
36
37#[derive(Debug, Clone)]
51pub struct SoaPoints {
52 xs: Vec<f32>,
53 ys: Vec<f32>,
54 zs: Vec<f32>,
55}
56
57impl SoaPoints {
58 pub fn from_points(points: &[Point3f]) -> Self {
60 let mut xs = Vec::with_capacity(points.len());
61 let mut ys = Vec::with_capacity(points.len());
62 let mut zs = Vec::with_capacity(points.len());
63 for p in points {
64 xs.push(p.x);
65 ys.push(p.y);
66 zs.push(p.z);
67 }
68 Self { xs, ys, zs }
69 }
70
71 #[inline]
73 pub fn len(&self) -> usize {
74 self.xs.len()
75 }
76
77 #[inline]
79 pub fn is_empty(&self) -> bool {
80 self.xs.is_empty()
81 }
82
83 #[inline]
85 pub fn xs(&self) -> &[f32] { &self.xs }
86
87 #[inline]
89 pub fn ys(&self) -> &[f32] { &self.ys }
90
91 #[inline]
93 pub fn zs(&self) -> &[f32] { &self.zs }
94}
95
96pub fn batch_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
105 debug_assert_eq!(out.len(), pts.len());
106
107 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
108 {
109 if is_x86_feature_detected!("avx2") {
110 return unsafe { avx2_distances_squared(query, pts, out) };
112 }
113 if is_x86_feature_detected!("sse2") {
114 return unsafe { sse2_distances_squared(query, pts, out) };
115 }
116 }
117
118 scalar_distances_squared(query, pts, out);
119}
120
121#[inline]
125pub fn scalar_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
126 let (qx, qy, qz) = (query.x, query.y, query.z);
127 let n = pts.len();
128 let xs = pts.xs();
129 let ys = pts.ys();
130 let zs = pts.zs();
131 for i in 0..n {
132 let dx = xs[i] - qx;
133 let dy = ys[i] - qy;
134 let dz = zs[i] - qz;
135 out[i] = dx * dx + dy * dy + dz * dz;
136 }
137}
138
139#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
148#[target_feature(enable = "sse2")]
149unsafe fn sse2_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
150 #[cfg(target_arch = "x86")]
151 use std::arch::x86::*;
152 #[cfg(target_arch = "x86_64")]
153 use std::arch::x86_64::*;
154
155 let n = pts.len();
156 let xs = pts.xs();
157 let ys = pts.ys();
158 let zs = pts.zs();
159
160 let qx_v = _mm_set1_ps(query.x);
161 let qy_v = _mm_set1_ps(query.y);
162 let qz_v = _mm_set1_ps(query.z);
163
164 let chunks = n / 4;
165 let remainder = n % 4;
166
167 for c in 0..chunks {
168 let base = c * 4;
169 let xs_v = _mm_loadu_ps(xs.as_ptr().add(base));
170 let ys_v = _mm_loadu_ps(ys.as_ptr().add(base));
171 let zs_v = _mm_loadu_ps(zs.as_ptr().add(base));
172
173 let dx = _mm_sub_ps(xs_v, qx_v);
174 let dy = _mm_sub_ps(ys_v, qy_v);
175 let dz = _mm_sub_ps(zs_v, qz_v);
176
177 let d2 = _mm_add_ps(
178 _mm_add_ps(_mm_mul_ps(dx, dx), _mm_mul_ps(dy, dy)),
179 _mm_mul_ps(dz, dz),
180 );
181
182 _mm_storeu_ps(out.as_mut_ptr().add(base), d2);
183 }
184
185 let rem_start = chunks * 4;
187 scalar_remainder(query, xs, ys, zs, out, rem_start, remainder);
188}
189
190#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
199#[target_feature(enable = "avx2")]
200unsafe fn avx2_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
201 #[cfg(target_arch = "x86")]
202 use std::arch::x86::*;
203 #[cfg(target_arch = "x86_64")]
204 use std::arch::x86_64::*;
205
206 let n = pts.len();
207 let xs = pts.xs();
208 let ys = pts.ys();
209 let zs = pts.zs();
210
211 let qx_v = _mm256_set1_ps(query.x);
212 let qy_v = _mm256_set1_ps(query.y);
213 let qz_v = _mm256_set1_ps(query.z);
214
215 let chunks = n / 8;
216 let remainder_start = chunks * 8;
217 let remainder = n - remainder_start;
218
219 for c in 0..chunks {
220 let base = c * 8;
221 let xs_v = _mm256_loadu_ps(xs.as_ptr().add(base));
222 let ys_v = _mm256_loadu_ps(ys.as_ptr().add(base));
223 let zs_v = _mm256_loadu_ps(zs.as_ptr().add(base));
224
225 let dx = _mm256_sub_ps(xs_v, qx_v);
226 let dy = _mm256_sub_ps(ys_v, qy_v);
227 let dz = _mm256_sub_ps(zs_v, qz_v);
228
229 let d2 = _mm256_add_ps(
230 _mm256_add_ps(_mm256_mul_ps(dx, dx), _mm256_mul_ps(dy, dy)),
231 _mm256_mul_ps(dz, dz),
232 );
233
234 _mm256_storeu_ps(out.as_mut_ptr().add(base), d2);
235 }
236
237 let mut rem = remainder;
239 let mut rem_base = remainder_start;
240
241 if rem >= 4 {
242 #[cfg(target_arch = "x86_64")]
244 use std::arch::x86_64::*;
245 let qx_s = _mm_set1_ps(query.x);
246 let qy_s = _mm_set1_ps(query.y);
247 let qz_s = _mm_set1_ps(query.z);
248
249 let xs_v = _mm_loadu_ps(xs.as_ptr().add(rem_base));
250 let ys_v = _mm_loadu_ps(ys.as_ptr().add(rem_base));
251 let zs_v = _mm_loadu_ps(zs.as_ptr().add(rem_base));
252
253 let dx = _mm_sub_ps(xs_v, qx_s);
254 let dy = _mm_sub_ps(ys_v, qy_s);
255 let dz = _mm_sub_ps(zs_v, qz_s);
256
257 let d2 = _mm_add_ps(
258 _mm_add_ps(_mm_mul_ps(dx, dx), _mm_mul_ps(dy, dy)),
259 _mm_mul_ps(dz, dz),
260 );
261 _mm_storeu_ps(out.as_mut_ptr().add(rem_base), d2);
262
263 rem_base += 4;
264 rem -= 4;
265 }
266
267 scalar_remainder(query, xs, ys, zs, out, rem_base, rem);
268}
269
270#[cfg_attr(not(any(target_arch = "x86", target_arch = "x86_64")), allow(dead_code))]
272#[inline(always)]
273fn scalar_remainder(
274 query: &Point3f,
275 xs: &[f32],
276 ys: &[f32],
277 zs: &[f32],
278 out: &mut [f32],
279 start: usize,
280 count: usize,
281) {
282 let (qx, qy, qz) = (query.x, query.y, query.z);
283 for i in 0..count {
284 let idx = start + i;
285 let dx = xs[idx] - qx;
286 let dy = ys[idx] - qy;
287 let dz = zs[idx] - qz;
288 out[idx] = dx * dx + dy * dy + dz * dz;
289 }
290}
291
292pub struct SimdBruteForceSearch {
317 soa: SoaPoints,
318}
319
320impl SimdBruteForceSearch {
321 pub fn new(points: &[Point3f]) -> Self {
323 Self { soa: SoaPoints::from_points(points) }
324 }
325
326 pub fn len(&self) -> usize { self.soa.len() }
328
329 pub fn is_empty(&self) -> bool { self.soa.is_empty() }
331
332 pub fn soa(&self) -> &SoaPoints { &self.soa }
334}
335
336impl NearestNeighborSearch for SimdBruteForceSearch {
337 fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
338 if k == 0 || self.soa.is_empty() {
339 return Vec::new();
340 }
341
342 let n = self.soa.len();
343 let k = k.min(n);
344
345 let mut dist_sq = vec![0.0f32; n];
347 batch_distances_squared(query, &self.soa, &mut dist_sq);
348
349 let mut heap: std::collections::BinaryHeap<DistEntry> =
351 std::collections::BinaryHeap::with_capacity(k + 1);
352
353 for (idx, &d2) in dist_sq.iter().enumerate() {
354 if heap.len() < k {
355 heap.push(DistEntry { dist_sq: d2, index: idx });
356 } else if let Some(farthest) = heap.peek() {
357 if d2 < farthest.dist_sq {
358 heap.pop();
359 heap.push(DistEntry { dist_sq: d2, index: idx });
360 }
361 }
362 }
363
364 let mut result: Vec<(usize, f32)> = heap
366 .into_iter()
367 .map(|e| (e.index, e.dist_sq.sqrt()))
368 .collect();
369 result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
370 result
371 }
372
373 fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
374 if radius <= 0.0 || self.soa.is_empty() {
375 return Vec::new();
376 }
377
378 let n = self.soa.len();
379 let radius_sq = radius * radius;
380
381 let mut dist_sq = vec![0.0f32; n];
383 batch_distances_squared(query, &self.soa, &mut dist_sq);
384
385 let mut result: Vec<(usize, f32)> = dist_sq
387 .iter()
388 .enumerate()
389 .filter_map(|(idx, &d2)| {
390 if d2 <= radius_sq { Some((idx, d2.sqrt())) } else { None }
391 })
392 .collect();
393
394 result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
395 result
396 }
397}
398
399#[derive(Debug, Clone, Copy, PartialEq)]
404struct DistEntry {
405 dist_sq: f32,
406 index: usize,
407}
408
409impl Eq for DistEntry {}
410
411impl PartialOrd for DistEntry {
412 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
413 Some(self.cmp(other))
414 }
415}
416
417impl Ord for DistEntry {
418 fn cmp(&self, other: &Self) -> Ordering {
419 self.dist_sq
422 .total_cmp(&other.dist_sq)
423 .then(self.index.cmp(&other.index))
424 }
425}
426
427#[cfg(test)]
432mod tests {
433 use super::*;
434 use threecrate_core::Point3f;
435
436 fn cube_points() -> Vec<Point3f> {
437 vec![
438 Point3f::new(0.0, 0.0, 0.0),
439 Point3f::new(1.0, 0.0, 0.0),
440 Point3f::new(0.0, 1.0, 0.0),
441 Point3f::new(0.0, 0.0, 1.0),
442 Point3f::new(1.0, 1.0, 0.0),
443 Point3f::new(1.0, 0.0, 1.0),
444 Point3f::new(0.0, 1.0, 1.0),
445 Point3f::new(1.0, 1.0, 1.0),
446 ]
447 }
448
449 #[test]
452 fn test_soa_layout() {
453 let pts = cube_points();
454 let soa = SoaPoints::from_points(&pts);
455 assert_eq!(soa.len(), pts.len());
456 for (i, p) in pts.iter().enumerate() {
457 assert_eq!(soa.xs()[i], p.x);
458 assert_eq!(soa.ys()[i], p.y);
459 assert_eq!(soa.zs()[i], p.z);
460 }
461 }
462
463 fn reference_dist_sq(query: &Point3f, pts: &[Point3f]) -> Vec<f32> {
466 pts.iter()
467 .map(|p| {
468 let dx = p.x - query.x;
469 let dy = p.y - query.y;
470 let dz = p.z - query.z;
471 dx * dx + dy * dy + dz * dz
472 })
473 .collect()
474 }
475
476 #[test]
477 fn test_scalar_distances_match_reference() {
478 let pts = cube_points();
479 let soa = SoaPoints::from_points(&pts);
480 let query = Point3f::new(0.5, 0.5, 0.5);
481 let reference = reference_dist_sq(&query, &pts);
482 let mut out = vec![0.0f32; pts.len()];
483 scalar_distances_squared(&query, &soa, &mut out);
484 for (got, expected) in out.iter().zip(reference.iter()) {
485 assert!((got - expected).abs() < 1e-6, "got={got}, expected={expected}");
486 }
487 }
488
489 #[test]
490 fn test_batch_distances_match_scalar() {
491 let pts = cube_points();
492 let soa = SoaPoints::from_points(&pts);
493 let query = Point3f::new(0.3, 0.7, 0.2);
494
495 let mut scalar_out = vec![0.0f32; pts.len()];
496 scalar_distances_squared(&query, &soa, &mut scalar_out);
497
498 let mut simd_out = vec![0.0f32; pts.len()];
499 batch_distances_squared(&query, &soa, &mut simd_out);
500
501 for (got, expected) in simd_out.iter().zip(scalar_out.iter()) {
502 assert!(
503 (got - expected).abs() < 1e-5,
504 "SIMD={got}, scalar={expected}"
505 );
506 }
507 }
508
509 #[test]
512 fn test_batch_distances_various_sizes() {
513 for n in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 100] {
514 let pts: Vec<Point3f> = (0..n)
515 .map(|i| Point3f::new(i as f32, (i * 2) as f32, (i * 3) as f32))
516 .collect();
517 let soa = SoaPoints::from_points(&pts);
518 let query = Point3f::new(5.0, 10.0, 15.0);
519 let reference = reference_dist_sq(&query, &pts);
520
521 let mut simd_out = vec![0.0f32; n];
522 batch_distances_squared(&query, &soa, &mut simd_out);
523
524 for (i, (got, expected)) in simd_out.iter().zip(reference.iter()).enumerate() {
525 assert!(
526 (got - expected).abs() < 1e-4,
527 "n={n} i={i}: SIMD={got}, ref={expected}"
528 );
529 }
530 }
531 }
532
533 #[test]
536 fn test_simd_knn_matches_brute_force() {
537 use crate::nearest_neighbor::BruteForceSearch;
538 let pts = cube_points();
539 let query = Point3f::new(0.5, 0.5, 0.5);
540 let k = 3;
541
542 let simd = SimdBruteForceSearch::new(&pts);
543 let scalar = BruteForceSearch::new(&pts);
544
545 let mut simd_res = simd.find_k_nearest(&query, k);
546 let mut scalar_res = scalar.find_k_nearest(&query, k);
547
548 simd_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
549 scalar_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
550
551 assert_eq!(simd_res.len(), k);
552 for ((si, sd), (_, bd)) in simd_res.iter().zip(scalar_res.iter()) {
553 assert!((sd - bd).abs() < 1e-5, "dist mismatch: simd={sd} scalar={bd}");
554 let _ = si; }
556 }
557
558 #[test]
559 fn test_simd_radius_matches_brute_force() {
560 use crate::nearest_neighbor::BruteForceSearch;
561 let pts = cube_points();
562 let query = Point3f::new(0.5, 0.5, 0.5);
563 let radius = 1.0;
564
565 let simd = SimdBruteForceSearch::new(&pts);
566 let scalar = BruteForceSearch::new(&pts);
567
568 let mut simd_res = simd.find_radius_neighbors(&query, radius);
569 let mut scalar_res = scalar.find_radius_neighbors(&query, radius);
570
571 simd_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
572 scalar_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
573
574 assert_eq!(simd_res.len(), scalar_res.len(), "result count mismatch");
575 for ((_, sd), (_, bd)) in simd_res.iter().zip(scalar_res.iter()) {
576 assert!((sd - bd).abs() < 1e-5);
577 }
578 }
579
580 #[test]
581 fn test_empty_cloud() {
582 let simd = SimdBruteForceSearch::new(&[]);
583 let q = Point3f::new(0.0, 0.0, 0.0);
584 assert!(simd.find_k_nearest(&q, 5).is_empty());
585 assert!(simd.find_radius_neighbors(&q, 10.0).is_empty());
586 }
587
588 #[test]
589 fn test_k_larger_than_cloud() {
590 let pts = cube_points();
591 let simd = SimdBruteForceSearch::new(&pts);
592 let q = Point3f::new(0.0, 0.0, 0.0);
593 let result = simd.find_k_nearest(&q, 100);
594 assert_eq!(result.len(), pts.len());
595 }
596
597 #[test]
598 fn test_exact_origin_distance() {
599 let pts = vec![Point3f::new(3.0, 4.0, 0.0)]; let soa = SoaPoints::from_points(&pts);
601 let query = Point3f::origin();
602 let mut out = vec![0.0f32; 1];
603 batch_distances_squared(&query, &soa, &mut out);
604 assert!((out[0] - 25.0).abs() < 1e-6);
605 }
606}