Skip to main content

threecrate_algorithms/
simd_distance.rs

1//! SIMD-accelerated distance computations for nearest-neighbor search.
2//!
3//! # Strategy
4//!
5//! Points are stored in **Structure-of-Arrays (SoA)** layout — three separate
6//! `f32` slices for X, Y, and Z — so that SIMD lanes map directly onto
7//! multiple points at once.
8//!
9//! ```text
10//! AoS (cache-unfriendly for distance):   [(x0,y0,z0), (x1,y1,z1), …]
11//! SoA (SIMD-friendly):                   xs=[x0,x1,…]  ys=[y0,y1,…]  zs=[z0,z1,…]
12//! ```
13//!
14//! Given a query `q = (qx, qy, qz)` the squared distance to point `i` is:
15//!
16//! ```text
17//! d²ᵢ = (xᵢ−qx)² + (yᵢ−qy)² + (zᵢ−qz)²
18//! ```
19//!
20//! SIMD processes 4 (SSE2) or 8 (AVX2) distances per instruction cycle.
21//!
22//! # Dispatch
23//!
24//! ```text
25//! has AVX2?  → avx2_distances_squared  (8-wide f32)
26//! else SSE2? → sse2_distances_squared  (4-wide f32, always present on x86-64)
27//! else       → scalar_distances_squared (portable fallback)
28//! ```
29//!
30//! The dispatch is resolved **at runtime** using `is_x86_feature_detected!`, so
31//! the same binary runs correctly on all hardware while using the widest SIMD
32//! available.
33
34use std::cmp::Ordering;
35use threecrate_core::{NearestNeighborSearch, Point3f};
36
37// ---------------------------------------------------------------------------
38// SoA point store
39// ---------------------------------------------------------------------------
40
41/// Point cloud stored in Structure-of-Arrays format for SIMD-friendly access.
42///
43/// ```
44/// # use threecrate_algorithms::SoaPoints;
45/// # use threecrate_core::Point3f;
46/// let pts = vec![Point3f::new(1.0, 2.0, 3.0), Point3f::new(4.0, 5.0, 6.0)];
47/// let soa = SoaPoints::from_points(&pts);
48/// assert_eq!(soa.xs(), &[1.0, 4.0]);
49/// ```
50#[derive(Debug, Clone)]
51pub struct SoaPoints {
52    xs: Vec<f32>,
53    ys: Vec<f32>,
54    zs: Vec<f32>,
55}
56
57impl SoaPoints {
58    /// Build an SoA store from an AoS slice.
59    pub fn from_points(points: &[Point3f]) -> Self {
60        let mut xs = Vec::with_capacity(points.len());
61        let mut ys = Vec::with_capacity(points.len());
62        let mut zs = Vec::with_capacity(points.len());
63        for p in points {
64            xs.push(p.x);
65            ys.push(p.y);
66            zs.push(p.z);
67        }
68        Self { xs, ys, zs }
69    }
70
71    /// Number of stored points.
72    #[inline]
73    pub fn len(&self) -> usize {
74        self.xs.len()
75    }
76
77    /// Returns `true` if there are no stored points.
78    #[inline]
79    pub fn is_empty(&self) -> bool {
80        self.xs.is_empty()
81    }
82
83    /// X coordinates slice.
84    #[inline]
85    pub fn xs(&self) -> &[f32] { &self.xs }
86
87    /// Y coordinates slice.
88    #[inline]
89    pub fn ys(&self) -> &[f32] { &self.ys }
90
91    /// Z coordinates slice.
92    #[inline]
93    pub fn zs(&self) -> &[f32] { &self.zs }
94}
95
96// ---------------------------------------------------------------------------
97// Public batch-distance API
98// ---------------------------------------------------------------------------
99
100/// Compute squared Euclidean distances from `query` to every point in `pts`.
101///
102/// Results are written into `out`, which must have the same length as `pts`.
103/// Uses the widest SIMD available (AVX2 → SSE2 → scalar).
104pub fn batch_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
105    debug_assert_eq!(out.len(), pts.len());
106
107    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
108    {
109        if is_x86_feature_detected!("avx2") {
110            // SAFETY: feature was detected at runtime.
111            return unsafe { avx2_distances_squared(query, pts, out) };
112        }
113        if is_x86_feature_detected!("sse2") {
114            return unsafe { sse2_distances_squared(query, pts, out) };
115        }
116    }
117
118    scalar_distances_squared(query, pts, out);
119}
120
121/// Portable, vectorisation-friendly fallback (also the reference implementation).
122///
123/// This loop is written so that LLVM can auto-vectorise it on any target.
124#[inline]
125pub fn scalar_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
126    let (qx, qy, qz) = (query.x, query.y, query.z);
127    let n = pts.len();
128    let xs = pts.xs();
129    let ys = pts.ys();
130    let zs = pts.zs();
131    for i in 0..n {
132        let dx = xs[i] - qx;
133        let dy = ys[i] - qy;
134        let dz = zs[i] - qz;
135        out[i] = dx * dx + dy * dy + dz * dz;
136    }
137}
138
139// ---------------------------------------------------------------------------
140// SSE2 implementation (4-wide, always available on x86-64)
141// ---------------------------------------------------------------------------
142
143/// Compute distances using SSE2 (4 × f32 per cycle).
144///
145/// # Safety
146/// Caller must ensure the `sse2` CPU feature is present.
147#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
148#[target_feature(enable = "sse2")]
149unsafe fn sse2_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
150    #[cfg(target_arch = "x86")]
151    use std::arch::x86::*;
152    #[cfg(target_arch = "x86_64")]
153    use std::arch::x86_64::*;
154
155    let n = pts.len();
156    let xs = pts.xs();
157    let ys = pts.ys();
158    let zs = pts.zs();
159
160    let qx_v = _mm_set1_ps(query.x);
161    let qy_v = _mm_set1_ps(query.y);
162    let qz_v = _mm_set1_ps(query.z);
163
164    let chunks = n / 4;
165    let remainder = n % 4;
166
167    for c in 0..chunks {
168        let base = c * 4;
169        let xs_v = _mm_loadu_ps(xs.as_ptr().add(base));
170        let ys_v = _mm_loadu_ps(ys.as_ptr().add(base));
171        let zs_v = _mm_loadu_ps(zs.as_ptr().add(base));
172
173        let dx = _mm_sub_ps(xs_v, qx_v);
174        let dy = _mm_sub_ps(ys_v, qy_v);
175        let dz = _mm_sub_ps(zs_v, qz_v);
176
177        let d2 = _mm_add_ps(
178            _mm_add_ps(_mm_mul_ps(dx, dx), _mm_mul_ps(dy, dy)),
179            _mm_mul_ps(dz, dz),
180        );
181
182        _mm_storeu_ps(out.as_mut_ptr().add(base), d2);
183    }
184
185    // Handle remainder with scalar code.
186    let rem_start = chunks * 4;
187    scalar_remainder(query, xs, ys, zs, out, rem_start, remainder);
188}
189
190// ---------------------------------------------------------------------------
191// AVX2 implementation (8-wide)
192// ---------------------------------------------------------------------------
193
194/// Compute distances using AVX2 (8 × f32 per cycle).
195///
196/// # Safety
197/// Caller must ensure the `avx2` CPU feature is present.
198#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
199#[target_feature(enable = "avx2")]
200unsafe fn avx2_distances_squared(query: &Point3f, pts: &SoaPoints, out: &mut [f32]) {
201    #[cfg(target_arch = "x86")]
202    use std::arch::x86::*;
203    #[cfg(target_arch = "x86_64")]
204    use std::arch::x86_64::*;
205
206    let n = pts.len();
207    let xs = pts.xs();
208    let ys = pts.ys();
209    let zs = pts.zs();
210
211    let qx_v = _mm256_set1_ps(query.x);
212    let qy_v = _mm256_set1_ps(query.y);
213    let qz_v = _mm256_set1_ps(query.z);
214
215    let chunks = n / 8;
216    let remainder_start = chunks * 8;
217    let remainder = n - remainder_start;
218
219    for c in 0..chunks {
220        let base = c * 8;
221        let xs_v = _mm256_loadu_ps(xs.as_ptr().add(base));
222        let ys_v = _mm256_loadu_ps(ys.as_ptr().add(base));
223        let zs_v = _mm256_loadu_ps(zs.as_ptr().add(base));
224
225        let dx = _mm256_sub_ps(xs_v, qx_v);
226        let dy = _mm256_sub_ps(ys_v, qy_v);
227        let dz = _mm256_sub_ps(zs_v, qz_v);
228
229        let d2 = _mm256_add_ps(
230            _mm256_add_ps(_mm256_mul_ps(dx, dx), _mm256_mul_ps(dy, dy)),
231            _mm256_mul_ps(dz, dz),
232        );
233
234        _mm256_storeu_ps(out.as_mut_ptr().add(base), d2);
235    }
236
237    // Use SSE2 for the leftover 4-element block (if any), then scalar for the rest.
238    let mut rem = remainder;
239    let mut rem_base = remainder_start;
240
241    if rem >= 4 {
242        // We already ensured avx2 implies sse2 on x86-64, so this is safe.
243        #[cfg(target_arch = "x86_64")]
244        use std::arch::x86_64::*;
245        let qx_s = _mm_set1_ps(query.x);
246        let qy_s = _mm_set1_ps(query.y);
247        let qz_s = _mm_set1_ps(query.z);
248
249        let xs_v = _mm_loadu_ps(xs.as_ptr().add(rem_base));
250        let ys_v = _mm_loadu_ps(ys.as_ptr().add(rem_base));
251        let zs_v = _mm_loadu_ps(zs.as_ptr().add(rem_base));
252
253        let dx = _mm_sub_ps(xs_v, qx_s);
254        let dy = _mm_sub_ps(ys_v, qy_s);
255        let dz = _mm_sub_ps(zs_v, qz_s);
256
257        let d2 = _mm_add_ps(
258            _mm_add_ps(_mm_mul_ps(dx, dx), _mm_mul_ps(dy, dy)),
259            _mm_mul_ps(dz, dz),
260        );
261        _mm_storeu_ps(out.as_mut_ptr().add(rem_base), d2);
262
263        rem_base += 4;
264        rem -= 4;
265    }
266
267    scalar_remainder(query, xs, ys, zs, out, rem_base, rem);
268}
269
270/// Scalar tail processing for SIMD functions.
271#[cfg_attr(not(any(target_arch = "x86", target_arch = "x86_64")), allow(dead_code))]
272#[inline(always)]
273fn scalar_remainder(
274    query: &Point3f,
275    xs: &[f32],
276    ys: &[f32],
277    zs: &[f32],
278    out: &mut [f32],
279    start: usize,
280    count: usize,
281) {
282    let (qx, qy, qz) = (query.x, query.y, query.z);
283    for i in 0..count {
284        let idx = start + i;
285        let dx = xs[idx] - qx;
286        let dy = ys[idx] - qy;
287        let dz = zs[idx] - qz;
288        out[idx] = dx * dx + dy * dy + dz * dz;
289    }
290}
291
292// ---------------------------------------------------------------------------
293// SimdBruteForceSearch
294// ---------------------------------------------------------------------------
295
296/// Brute-force nearest-neighbor search with SIMD-accelerated distance computation.
297///
298/// All N squared distances are computed in a single SIMD-vectorised pass before
299/// selection, which is more cache-friendly than repeated point-by-point comparison
300/// and exploits the full width of the CPU's SIMD units.
301///
302/// | Method            | Complexity | Notes                                |
303/// |-------------------|------------|--------------------------------------|
304/// | `find_k_nearest`  | O(N + k log k) | compute-all-then-partial-select  |
305/// | `find_radius_neighbors` | O(N) | compute-all-then-filter           |
306///
307/// # Example
308/// ```
309/// # use threecrate_algorithms::SimdBruteForceSearch;
310/// # use threecrate_core::{Point3f, NearestNeighborSearch};
311/// let pts = vec![Point3f::new(0.0, 0.0, 0.0), Point3f::new(1.0, 0.0, 0.0)];
312/// let searcher = SimdBruteForceSearch::new(&pts);
313/// let result = searcher.find_k_nearest(&Point3f::new(0.1, 0.0, 0.0), 1);
314/// assert_eq!(result[0].0, 0); // index of nearest point
315/// ```
316pub struct SimdBruteForceSearch {
317    soa: SoaPoints,
318}
319
320impl SimdBruteForceSearch {
321    /// Construct a new searcher from a point slice (O(N) time and space).
322    pub fn new(points: &[Point3f]) -> Self {
323        Self { soa: SoaPoints::from_points(points) }
324    }
325
326    /// Number of indexed points.
327    pub fn len(&self) -> usize { self.soa.len() }
328
329    /// Returns `true` if the index is empty.
330    pub fn is_empty(&self) -> bool { self.soa.is_empty() }
331
332    /// Return the SoA representation (useful for benchmarks / inspection).
333    pub fn soa(&self) -> &SoaPoints { &self.soa }
334}
335
336impl NearestNeighborSearch for SimdBruteForceSearch {
337    fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
338        if k == 0 || self.soa.is_empty() {
339            return Vec::new();
340        }
341
342        let n = self.soa.len();
343        let k = k.min(n);
344
345        // ---- 1. Compute all squared distances in one SIMD pass ----
346        let mut dist_sq = vec![0.0f32; n];
347        batch_distances_squared(query, &self.soa, &mut dist_sq);
348
349        // ---- 2. Partial sort: find the k smallest using a max-heap ----
350        let mut heap: std::collections::BinaryHeap<DistEntry> =
351            std::collections::BinaryHeap::with_capacity(k + 1);
352
353        for (idx, &d2) in dist_sq.iter().enumerate() {
354            if heap.len() < k {
355                heap.push(DistEntry { dist_sq: d2, index: idx });
356            } else if let Some(farthest) = heap.peek() {
357                if d2 < farthest.dist_sq {
358                    heap.pop();
359                    heap.push(DistEntry { dist_sq: d2, index: idx });
360                }
361            }
362        }
363
364        // ---- 3. Extract and sort ascending by distance ----
365        let mut result: Vec<(usize, f32)> = heap
366            .into_iter()
367            .map(|e| (e.index, e.dist_sq.sqrt()))
368            .collect();
369        result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
370        result
371    }
372
373    fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
374        if radius <= 0.0 || self.soa.is_empty() {
375            return Vec::new();
376        }
377
378        let n = self.soa.len();
379        let radius_sq = radius * radius;
380
381        // ---- 1. Compute all squared distances ----
382        let mut dist_sq = vec![0.0f32; n];
383        batch_distances_squared(query, &self.soa, &mut dist_sq);
384
385        // ---- 2. Filter and convert ----
386        let mut result: Vec<(usize, f32)> = dist_sq
387            .iter()
388            .enumerate()
389            .filter_map(|(idx, &d2)| {
390                if d2 <= radius_sq { Some((idx, d2.sqrt())) } else { None }
391            })
392            .collect();
393
394        result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
395        result
396    }
397}
398
399// ---------------------------------------------------------------------------
400// Internal helper: max-heap entry ordered by dist_sq (largest = top of heap)
401// ---------------------------------------------------------------------------
402
403#[derive(Debug, Clone, Copy, PartialEq)]
404struct DistEntry {
405    dist_sq: f32,
406    index: usize,
407}
408
409impl Eq for DistEntry {}
410
411impl PartialOrd for DistEntry {
412    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
413        Some(self.cmp(other))
414    }
415}
416
417impl Ord for DistEntry {
418    fn cmp(&self, other: &Self) -> Ordering {
419        // Larger dist_sq → higher priority in max-heap → root = farthest of the k candidates.
420        // `total_cmp` handles NaN consistently and avoids floating-point comparison UB.
421        self.dist_sq
422            .total_cmp(&other.dist_sq)
423            .then(self.index.cmp(&other.index))
424    }
425}
426
427// ---------------------------------------------------------------------------
428// Tests
429// ---------------------------------------------------------------------------
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use threecrate_core::Point3f;
435
436    fn cube_points() -> Vec<Point3f> {
437        vec![
438            Point3f::new(0.0, 0.0, 0.0),
439            Point3f::new(1.0, 0.0, 0.0),
440            Point3f::new(0.0, 1.0, 0.0),
441            Point3f::new(0.0, 0.0, 1.0),
442            Point3f::new(1.0, 1.0, 0.0),
443            Point3f::new(1.0, 0.0, 1.0),
444            Point3f::new(0.0, 1.0, 1.0),
445            Point3f::new(1.0, 1.0, 1.0),
446        ]
447    }
448
449    // ---- SoaPoints -------------------------------------------------------
450
451    #[test]
452    fn test_soa_layout() {
453        let pts = cube_points();
454        let soa = SoaPoints::from_points(&pts);
455        assert_eq!(soa.len(), pts.len());
456        for (i, p) in pts.iter().enumerate() {
457            assert_eq!(soa.xs()[i], p.x);
458            assert_eq!(soa.ys()[i], p.y);
459            assert_eq!(soa.zs()[i], p.z);
460        }
461    }
462
463    // ---- batch_distances_squared -----------------------------------------
464
465    fn reference_dist_sq(query: &Point3f, pts: &[Point3f]) -> Vec<f32> {
466        pts.iter()
467            .map(|p| {
468                let dx = p.x - query.x;
469                let dy = p.y - query.y;
470                let dz = p.z - query.z;
471                dx * dx + dy * dy + dz * dz
472            })
473            .collect()
474    }
475
476    #[test]
477    fn test_scalar_distances_match_reference() {
478        let pts = cube_points();
479        let soa = SoaPoints::from_points(&pts);
480        let query = Point3f::new(0.5, 0.5, 0.5);
481        let reference = reference_dist_sq(&query, &pts);
482        let mut out = vec![0.0f32; pts.len()];
483        scalar_distances_squared(&query, &soa, &mut out);
484        for (got, expected) in out.iter().zip(reference.iter()) {
485            assert!((got - expected).abs() < 1e-6, "got={got}, expected={expected}");
486        }
487    }
488
489    #[test]
490    fn test_batch_distances_match_scalar() {
491        let pts = cube_points();
492        let soa = SoaPoints::from_points(&pts);
493        let query = Point3f::new(0.3, 0.7, 0.2);
494
495        let mut scalar_out = vec![0.0f32; pts.len()];
496        scalar_distances_squared(&query, &soa, &mut scalar_out);
497
498        let mut simd_out = vec![0.0f32; pts.len()];
499        batch_distances_squared(&query, &soa, &mut simd_out);
500
501        for (got, expected) in simd_out.iter().zip(scalar_out.iter()) {
502            assert!(
503                (got - expected).abs() < 1e-5,
504                "SIMD={got}, scalar={expected}"
505            );
506        }
507    }
508
509    /// Exhaustively test various point counts including non-multiples of 4 and 8
510    /// to verify that the remainder handling is correct.
511    #[test]
512    fn test_batch_distances_various_sizes() {
513        for n in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 100] {
514            let pts: Vec<Point3f> = (0..n)
515                .map(|i| Point3f::new(i as f32, (i * 2) as f32, (i * 3) as f32))
516                .collect();
517            let soa = SoaPoints::from_points(&pts);
518            let query = Point3f::new(5.0, 10.0, 15.0);
519            let reference = reference_dist_sq(&query, &pts);
520
521            let mut simd_out = vec![0.0f32; n];
522            batch_distances_squared(&query, &soa, &mut simd_out);
523
524            for (i, (got, expected)) in simd_out.iter().zip(reference.iter()).enumerate() {
525                assert!(
526                    (got - expected).abs() < 1e-4,
527                    "n={n} i={i}: SIMD={got}, ref={expected}"
528                );
529            }
530        }
531    }
532
533    // ---- SimdBruteForceSearch --------------------------------------------
534
535    #[test]
536    fn test_simd_knn_matches_brute_force() {
537        use crate::nearest_neighbor::BruteForceSearch;
538        let pts = cube_points();
539        let query = Point3f::new(0.5, 0.5, 0.5);
540        let k = 3;
541
542        let simd = SimdBruteForceSearch::new(&pts);
543        let scalar = BruteForceSearch::new(&pts);
544
545        let mut simd_res = simd.find_k_nearest(&query, k);
546        let mut scalar_res = scalar.find_k_nearest(&query, k);
547
548        simd_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
549        scalar_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
550
551        assert_eq!(simd_res.len(), k);
552        for ((si, sd), (_, bd)) in simd_res.iter().zip(scalar_res.iter()) {
553            assert!((sd - bd).abs() < 1e-5, "dist mismatch: simd={sd} scalar={bd}");
554            let _ = si; // index ties are allowed at equal distance
555        }
556    }
557
558    #[test]
559    fn test_simd_radius_matches_brute_force() {
560        use crate::nearest_neighbor::BruteForceSearch;
561        let pts = cube_points();
562        let query = Point3f::new(0.5, 0.5, 0.5);
563        let radius = 1.0;
564
565        let simd = SimdBruteForceSearch::new(&pts);
566        let scalar = BruteForceSearch::new(&pts);
567
568        let mut simd_res = simd.find_radius_neighbors(&query, radius);
569        let mut scalar_res = scalar.find_radius_neighbors(&query, radius);
570
571        simd_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
572        scalar_res.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
573
574        assert_eq!(simd_res.len(), scalar_res.len(), "result count mismatch");
575        for ((_, sd), (_, bd)) in simd_res.iter().zip(scalar_res.iter()) {
576            assert!((sd - bd).abs() < 1e-5);
577        }
578    }
579
580    #[test]
581    fn test_empty_cloud() {
582        let simd = SimdBruteForceSearch::new(&[]);
583        let q = Point3f::new(0.0, 0.0, 0.0);
584        assert!(simd.find_k_nearest(&q, 5).is_empty());
585        assert!(simd.find_radius_neighbors(&q, 10.0).is_empty());
586    }
587
588    #[test]
589    fn test_k_larger_than_cloud() {
590        let pts = cube_points();
591        let simd = SimdBruteForceSearch::new(&pts);
592        let q = Point3f::new(0.0, 0.0, 0.0);
593        let result = simd.find_k_nearest(&q, 100);
594        assert_eq!(result.len(), pts.len());
595    }
596
597    #[test]
598    fn test_exact_origin_distance() {
599        let pts = vec![Point3f::new(3.0, 4.0, 0.0)]; // dist from origin = 5
600        let soa = SoaPoints::from_points(&pts);
601        let query = Point3f::origin();
602        let mut out = vec![0.0f32; 1];
603        batch_distances_squared(&query, &soa, &mut out);
604        assert!((out[0] - 25.0).abs() < 1e-6);
605    }
606}