Skip to main content

scirs2_spatial/
simd_ops.rs

1//! SIMD-accelerated operations for spatial algorithms
2//!
3//! This module provides high-performance SIMD implementations for critical spatial operations:
4//! - Distance computations (Euclidean, Manhattan, Chebyshev, Minkowski, Cosine)
5//! - KD-Tree operations (bounding box tests, point-to-box distance)
6//! - Nearest neighbor search (batch distance, priority queues, radius search)
7//!
8//! All operations use `scirs2_core::simd::SimdUnifiedOps` for optimal hardware utilization.
9//!
10//! # Architecture Support
11//!
12//! The SIMD operations are automatically optimized based on available hardware:
13//! - AVX-512 (8x f64 vectors)
14//! - AVX2 (4x f64 vectors)
15//! - ARM NEON (2x f64 vectors)
16//! - SSE (2x f64 vectors - fallback)
17//!
18//! # Examples
19//!
20//! ```
21//! use scirs2_spatial::simd_ops::{simd_euclidean_distance, simd_batch_distances};
22//! use scirs2_core::ndarray::array;
23//!
24//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
25//! // Single distance computation
26//! let a = array![1.0, 2.0, 3.0];
27//! let b = array![4.0, 5.0, 6.0];
28//! let dist = simd_euclidean_distance(&a.view(), &b.view())?;
29//!
30//! // Batch distance computation
31//! let points1 = array![[1.0, 2.0], [3.0, 4.0]];
32//! let points2 = array![[2.0, 3.0], [4.0, 5.0]];
33//! let distances = simd_batch_distances(&points1.view(), &points2.view())?;
34//! # Ok(())
35//! # }
36//! ```
37
38use crate::error::{SpatialError, SpatialResult};
39use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
40use scirs2_core::simd_ops::SimdUnifiedOps;
41
42// ============================================================================
43// Distance Computations (SIMD-accelerated)
44// ============================================================================
45
46/// SIMD-accelerated Euclidean distance between two points
47///
48/// Computes: `sqrt(sum((a[i] - b[i])^2))`
49///
50/// # Arguments
51///
52/// * `a` - First point
53/// * `b` - Second point
54///
55/// # Returns
56///
57/// * Euclidean distance between the points
58///
59/// # Errors
60///
61/// Returns error if points have different dimensions
62///
63/// # Examples
64///
65/// ```
66/// use scirs2_spatial::simd_ops::simd_euclidean_distance;
67/// use scirs2_core::ndarray::array;
68///
69/// let a = array![1.0, 2.0, 3.0];
70/// let b = array![4.0, 5.0, 6.0];
71/// let dist = simd_euclidean_distance(&a.view(), &b.view()).unwrap();
72/// assert!((dist - 5.196152422706632).abs() < 1e-10);
73/// ```
74pub fn simd_euclidean_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
75    if a.len() != b.len() {
76        return Err(SpatialError::ValueError(
77            "Points must have the same dimension".to_string(),
78        ));
79    }
80
81    let diff = f64::simd_sub(a, b);
82    let squared = f64::simd_mul(&diff.view(), &diff.view());
83    let sum = f64::simd_sum(&squared.view());
84    Ok(sum.sqrt())
85}
86
87/// SIMD-accelerated squared Euclidean distance between two points
88///
89/// Computes: `sum((a[i] - b[i])^2)`
90/// Faster than full Euclidean distance as it avoids the square root operation.
91///
92/// # Arguments
93///
94/// * `a` - First point
95/// * `b` - Second point
96///
97/// # Returns
98///
99/// * Squared Euclidean distance between the points
100///
101/// # Errors
102///
103/// Returns error if points have different dimensions
104pub fn simd_squared_euclidean_distance(
105    a: &ArrayView1<f64>,
106    b: &ArrayView1<f64>,
107) -> SpatialResult<f64> {
108    if a.len() != b.len() {
109        return Err(SpatialError::ValueError(
110            "Points must have the same dimension".to_string(),
111        ));
112    }
113
114    let diff = f64::simd_sub(a, b);
115    let squared = f64::simd_mul(&diff.view(), &diff.view());
116    Ok(f64::simd_sum(&squared.view()))
117}
118
119/// SIMD-accelerated Manhattan distance between two points
120///
121/// Computes: `sum(|a[i] - b[i]|)`
122///
123/// # Arguments
124///
125/// * `a` - First point
126/// * `b` - Second point
127///
128/// # Returns
129///
130/// * Manhattan (L1) distance between the points
131///
132/// # Errors
133///
134/// Returns error if points have different dimensions
135///
136/// # Examples
137///
138/// ```
139/// use scirs2_spatial::simd_ops::simd_manhattan_distance;
140/// use scirs2_core::ndarray::array;
141///
142/// let a = array![1.0, 2.0, 3.0];
143/// let b = array![4.0, 5.0, 6.0];
144/// let dist = simd_manhattan_distance(&a.view(), &b.view()).unwrap();
145/// assert_eq!(dist, 9.0);
146/// ```
147pub fn simd_manhattan_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
148    if a.len() != b.len() {
149        return Err(SpatialError::ValueError(
150            "Points must have the same dimension".to_string(),
151        ));
152    }
153
154    let diff = f64::simd_sub(a, b);
155    let abs_diff = f64::simd_abs(&diff.view());
156    Ok(f64::simd_sum(&abs_diff.view()))
157}
158
159/// SIMD-accelerated Chebyshev distance between two points
160///
161/// Computes: `max(|a[i] - b[i]|)`
162///
163/// # Arguments
164///
165/// * `a` - First point
166/// * `b` - Second point
167///
168/// # Returns
169///
170/// * Chebyshev (L∞) distance between the points
171///
172/// # Errors
173///
174/// Returns error if points have different dimensions
175///
176/// # Examples
177///
178/// ```
179/// use scirs2_spatial::simd_ops::simd_chebyshev_distance;
180/// use scirs2_core::ndarray::array;
181///
182/// let a = array![1.0, 2.0, 3.0];
183/// let b = array![4.0, 6.0, 5.0];
184/// let dist = simd_chebyshev_distance(&a.view(), &b.view()).unwrap();
185/// assert_eq!(dist, 4.0); // max(|1-4|, |2-6|, |3-5|) = max(3, 4, 2) = 4
186/// ```
187pub fn simd_chebyshev_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
188    if a.len() != b.len() {
189        return Err(SpatialError::ValueError(
190            "Points must have the same dimension".to_string(),
191        ));
192    }
193
194    let diff = f64::simd_sub(a, b);
195    let abs_diff = f64::simd_abs(&diff.view());
196    Ok(f64::simd_max_element(&abs_diff.view()))
197}
198
199/// SIMD-accelerated Minkowski distance between two points
200///
201/// Computes: `(sum(|a[i] - b[i]|^p))^(1/p)`
202///
203/// # Arguments
204///
205/// * `a` - First point
206/// * `b` - Second point
207/// * `p` - Order of the norm (p >= 1.0)
208///
209/// # Returns
210///
211/// * Minkowski distance of order p
212///
213/// # Errors
214///
215/// Returns error if:
216/// - Points have different dimensions
217/// - p < 1.0
218///
219/// # Special Cases
220///
221/// - p = 1.0: Manhattan distance
222/// - p = 2.0: Euclidean distance
223/// - p → ∞: Chebyshev distance
224///
225/// # Examples
226///
227/// ```
228/// use scirs2_spatial::simd_ops::simd_minkowski_distance;
229/// use scirs2_core::ndarray::array;
230///
231/// let a = array![1.0, 2.0, 3.0];
232/// let b = array![4.0, 5.0, 6.0];
233/// let dist = simd_minkowski_distance(&a.view(), &b.view(), 3.0).unwrap();
234/// assert!((dist - 4.3267487109222245).abs() < 1e-10);
235/// ```
236pub fn simd_minkowski_distance(
237    a: &ArrayView1<f64>,
238    b: &ArrayView1<f64>,
239    p: f64,
240) -> SpatialResult<f64> {
241    if a.len() != b.len() {
242        return Err(SpatialError::ValueError(
243            "Points must have the same dimension".to_string(),
244        ));
245    }
246
247    if p < 1.0 {
248        return Err(SpatialError::ValueError(
249            "Minkowski p must be >= 1.0".to_string(),
250        ));
251    }
252
253    // Special cases for efficiency
254    if (p - 1.0).abs() < 1e-10 {
255        return simd_manhattan_distance(a, b);
256    }
257    if (p - 2.0).abs() < 1e-10 {
258        return simd_euclidean_distance(a, b);
259    }
260
261    let diff = f64::simd_sub(a, b);
262    let abs_diff = f64::simd_abs(&diff.view());
263    let powered = f64::simd_powf(&abs_diff.view(), p);
264    let sum = f64::simd_sum(&powered.view());
265    Ok(sum.powf(1.0 / p))
266}
267
268/// SIMD-accelerated Cosine distance between two points
269///
270/// Computes: 1 - (a · b) / (||a|| * ||b||)
271/// where · is dot product and ||·|| is L2 norm
272///
273/// # Arguments
274///
275/// * `a` - First point
276/// * `b` - Second point
277///
278/// # Returns
279///
280/// * Cosine distance (1 - cosine similarity)
281///
282/// # Errors
283///
284/// Returns error if:
285/// - Points have different dimensions
286/// - Either point is zero vector
287///
288/// # Examples
289///
290/// ```
291/// use scirs2_spatial::simd_ops::simd_cosine_distance;
292/// use scirs2_core::ndarray::array;
293///
294/// let a = array![1.0, 2.0, 3.0];
295/// let b = array![4.0, 5.0, 6.0];
296/// let dist = simd_cosine_distance(&a.view(), &b.view()).unwrap();
297/// assert!(dist < 0.03); // Very similar direction
298/// ```
299pub fn simd_cosine_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
300    if a.len() != b.len() {
301        return Err(SpatialError::ValueError(
302            "Points must have the same dimension".to_string(),
303        ));
304    }
305
306    let dot_product = f64::simd_dot(a, b);
307    let norm_a = f64::simd_norm(a);
308    let norm_b = f64::simd_norm(b);
309
310    if norm_a == 0.0 || norm_b == 0.0 {
311        return Err(SpatialError::ValueError(
312            "Cannot compute cosine distance for zero vectors".to_string(),
313        ));
314    }
315
316    let cosine_similarity = dot_product / (norm_a * norm_b);
317    Ok(1.0 - cosine_similarity)
318}
319
320// ============================================================================
321// KD-Tree Operations (SIMD-accelerated)
322// ============================================================================
323
324/// SIMD-accelerated point-to-axis-aligned-box minimum distance
325///
326/// Computes the minimum distance from a point to an axis-aligned bounding box.
327/// Used for efficient KD-Tree traversal.
328///
329/// # Arguments
330///
331/// * `point` - Query point
332/// * `box_min` - Minimum corner of the bounding box
333/// * `box_max` - Maximum corner of the bounding box
334///
335/// # Returns
336///
337/// * Squared minimum distance from point to box
338///
339/// # Errors
340///
341/// Returns error if dimensions don't match
342pub fn simd_point_to_box_min_distance_squared(
343    point: &ArrayView1<f64>,
344    box_min: &ArrayView1<f64>,
345    box_max: &ArrayView1<f64>,
346) -> SpatialResult<f64> {
347    if point.len() != box_min.len() || point.len() != box_max.len() {
348        return Err(SpatialError::ValueError(
349            "Point and box dimensions must match".to_string(),
350        ));
351    }
352
353    // For each dimension, compute the distance to the box
354    // If point is inside box in that dimension, distance is 0
355    // Otherwise, distance is to the nearest face
356
357    // Clamp point to box: closest_point = clamp(point, box_min, box_max)
358    let clamped = f64::simd_clamp(
359        point,
360        *box_min
361            .first()
362            .ok_or_else(|| SpatialError::ValueError("Empty array".to_string()))?,
363        *box_max
364            .first()
365            .ok_or_else(|| SpatialError::ValueError("Empty array".to_string()))?,
366    );
367
368    // Compute element-wise clamping manually for each dimension
369    let mut closest_point = Array1::zeros(point.len());
370    for i in 0..point.len() {
371        closest_point[i] = point[i].clamp(box_min[i], box_max[i]);
372    }
373
374    // Compute squared distance from point to closest point on box
375    let diff = f64::simd_sub(point, &closest_point.view());
376    let squared = f64::simd_mul(&diff.view(), &diff.view());
377    Ok(f64::simd_sum(&squared.view()))
378}
379
380/// SIMD-accelerated axis-aligned bounding box intersection test
381///
382/// Tests if two axis-aligned bounding boxes intersect.
383///
384/// # Arguments
385///
386/// * `box1_min` - Minimum corner of first box
387/// * `box1_max` - Maximum corner of first box
388/// * `box2_min` - Minimum corner of second box
389/// * `box2_max` - Maximum corner of second box
390///
391/// # Returns
392///
393/// * true if boxes intersect, false otherwise
394///
395/// # Errors
396///
397/// Returns error if dimensions don't match
398pub fn simd_box_box_intersection(
399    box1_min: &ArrayView1<f64>,
400    box1_max: &ArrayView1<f64>,
401    box2_min: &ArrayView1<f64>,
402    box2_max: &ArrayView1<f64>,
403) -> SpatialResult<bool> {
404    if box1_min.len() != box1_max.len()
405        || box1_min.len() != box2_min.len()
406        || box1_min.len() != box2_max.len()
407    {
408        return Err(SpatialError::ValueError(
409            "All box dimensions must match".to_string(),
410        ));
411    }
412
413    // Boxes intersect if they overlap in all dimensions
414    // They overlap in dimension i if: box1_max[i] >= box2_min[i] && box1_min[i] <= box2_max[i]
415
416    for i in 0..box1_min.len() {
417        if box1_max[i] < box2_min[i] || box1_min[i] > box2_max[i] {
418            return Ok(false);
419        }
420    }
421
422    Ok(true)
423}
424
425/// SIMD-accelerated batch distance computation for KD-Tree queries
426///
427/// Computes squared distances from a query point to multiple data points.
428/// Used for efficient k-NN search in KD-Trees.
429///
430/// # Arguments
431///
432/// * `query_point` - Query point
433/// * `data_points` - Matrix of data points (n_points x n_dims)
434///
435/// # Returns
436///
437/// * Array of squared distances
438///
439/// # Errors
440///
441/// Returns error if dimensions don't match
442pub fn simd_batch_squared_distances(
443    query_point: &ArrayView1<f64>,
444    data_points: &ArrayView2<f64>,
445) -> SpatialResult<Array1<f64>> {
446    if query_point.len() != data_points.ncols() {
447        return Err(SpatialError::ValueError(
448            "Query point dimension must match data points".to_string(),
449        ));
450    }
451
452    let n_points = data_points.nrows();
453    let mut distances = Array1::zeros(n_points);
454
455    for i in 0..n_points {
456        let data_point = data_points.row(i);
457        let diff = f64::simd_sub(query_point, &data_point);
458        let squared = f64::simd_mul(&diff.view(), &diff.view());
459        distances[i] = f64::simd_sum(&squared.view());
460    }
461
462    Ok(distances)
463}
464
465// ============================================================================
466// Nearest Neighbor Search Operations (SIMD-accelerated)
467// ============================================================================
468
469/// SIMD-accelerated batch distance computation between point sets
470///
471/// Computes distances between corresponding points in two arrays.
472///
473/// # Arguments
474///
475/// * `points1` - First set of points (n_points x n_dims)
476/// * `points2` - Second set of points (n_points x n_dims)
477///
478/// # Returns
479///
480/// * Array of distances (n_points)
481///
482/// # Errors
483///
484/// Returns error if shapes don't match
485pub fn simd_batch_distances(
486    points1: &ArrayView2<f64>,
487    points2: &ArrayView2<f64>,
488) -> SpatialResult<Array1<f64>> {
489    if points1.shape() != points2.shape() {
490        return Err(SpatialError::ValueError(
491            "Point arrays must have the same shape".to_string(),
492        ));
493    }
494
495    let n_points = points1.nrows();
496    let mut distances = Array1::zeros(n_points);
497
498    for i in 0..n_points {
499        let p1 = points1.row(i);
500        let p2 = points2.row(i);
501        let diff = f64::simd_sub(&p1, &p2);
502        let squared = f64::simd_mul(&diff.view(), &diff.view());
503        let sum = f64::simd_sum(&squared.view());
504        distances[i] = sum.sqrt();
505    }
506
507    Ok(distances)
508}
509
510/// SIMD-accelerated k-nearest neighbors distance computation
511///
512/// Finds k nearest neighbors and their distances using SIMD operations.
513///
514/// # Arguments
515///
516/// * `query_point` - Query point
517/// * `data_points` - Matrix of data points (n_points x n_dims)
518/// * `k` - Number of nearest neighbors to find
519///
520/// # Returns
521///
522/// * Tuple of (indices, distances) for k nearest neighbors
523///
524/// # Errors
525///
526/// Returns error if:
527/// - Dimensions don't match
528/// - k > number of data points
529/// - k == 0
530pub fn simd_knn_search(
531    query_point: &ArrayView1<f64>,
532    data_points: &ArrayView2<f64>,
533    k: usize,
534) -> SpatialResult<(Array1<usize>, Array1<f64>)> {
535    if query_point.len() != data_points.ncols() {
536        return Err(SpatialError::ValueError(
537            "Query point dimension must match data points".to_string(),
538        ));
539    }
540
541    let n_points = data_points.nrows();
542
543    if k == 0 {
544        return Err(SpatialError::ValueError(
545            "k must be greater than 0".to_string(),
546        ));
547    }
548
549    if k > n_points {
550        return Err(SpatialError::ValueError(format!(
551            "k ({}) cannot be larger than number of data points ({})",
552            k, n_points
553        )));
554    }
555
556    // Compute all distances using SIMD
557    let squared_distances = simd_batch_squared_distances(query_point, data_points)?;
558
559    // Convert to (distance, index) pairs and partial sort
560    let mut indexed_distances: Vec<(f64, usize)> = squared_distances
561        .iter()
562        .enumerate()
563        .map(|(idx, &dist)| (dist, idx))
564        .collect();
565
566    // Partial sort to get k smallest elements
567    indexed_distances.select_nth_unstable_by(k - 1, |a, b| {
568        a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
569    });
570
571    // Sort the k smallest for consistent ordering
572    indexed_distances[..k]
573        .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
574
575    // Extract indices and compute full distances
576    let mut indices = Array1::zeros(k);
577    let mut distances = Array1::zeros(k);
578
579    for (i, (dist_sq, idx)) in indexed_distances[..k].iter().enumerate() {
580        indices[i] = *idx;
581        distances[i] = dist_sq.sqrt();
582    }
583
584    Ok((indices, distances))
585}
586
587/// SIMD-accelerated radius search
588///
589/// Finds all points within a given radius of a query point.
590///
591/// # Arguments
592///
593/// * `query_point` - Query point
594/// * `data_points` - Matrix of data points (n_points x n_dims)
595/// * `radius` - Search radius
596///
597/// # Returns
598///
599/// * Tuple of (indices, distances) for points within radius
600///
601/// # Errors
602///
603/// Returns error if:
604/// - Dimensions don't match
605/// - radius < 0
606pub fn simd_radius_search(
607    query_point: &ArrayView1<f64>,
608    data_points: &ArrayView2<f64>,
609    radius: f64,
610) -> SpatialResult<(Array1<usize>, Array1<f64>)> {
611    if query_point.len() != data_points.ncols() {
612        return Err(SpatialError::ValueError(
613            "Query point dimension must match data points".to_string(),
614        ));
615    }
616
617    if radius < 0.0 {
618        return Err(SpatialError::ValueError(
619            "Radius must be non-negative".to_string(),
620        ));
621    }
622
623    // Compute all squared distances using SIMD
624    let squared_distances = simd_batch_squared_distances(query_point, data_points)?;
625    let radius_squared = radius * radius;
626
627    // Filter points within radius
628    let mut indices = Vec::new();
629    let mut distances = Vec::new();
630
631    for (idx, &dist_sq) in squared_distances.iter().enumerate() {
632        if dist_sq <= radius_squared {
633            indices.push(idx);
634            distances.push(dist_sq.sqrt());
635        }
636    }
637
638    Ok((Array1::from(indices), Array1::from(distances)))
639}
640
641/// SIMD-accelerated pairwise distance matrix computation
642///
643/// Computes all pairwise distances between points in a dataset.
644///
645/// # Arguments
646///
647/// * `points` - Matrix of points (n_points x n_dims)
648///
649/// # Returns
650///
651/// * Symmetric distance matrix (n_points x n_points)
652pub fn simd_pairwise_distance_matrix(points: &ArrayView2<f64>) -> SpatialResult<Array2<f64>> {
653    let n_points = points.nrows();
654    let mut distances = Array2::zeros((n_points, n_points));
655
656    // Only compute upper triangle (matrix is symmetric)
657    for i in 0..n_points {
658        let point_i = points.row(i);
659
660        for j in (i + 1)..n_points {
661            let point_j = points.row(j);
662            let diff = f64::simd_sub(&point_i, &point_j);
663            let squared = f64::simd_mul(&diff.view(), &diff.view());
664            let sum = f64::simd_sum(&squared.view());
665            let dist = sum.sqrt();
666
667            distances[[i, j]] = dist;
668            distances[[j, i]] = dist; // Symmetric
669        }
670    }
671
672    Ok(distances)
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use approx::assert_relative_eq;
679    use scirs2_core::ndarray::array;
680
681    #[test]
682    fn test_simd_euclidean_distance() {
683        let a = array![1.0, 2.0, 3.0];
684        let b = array![4.0, 5.0, 6.0];
685
686        let dist =
687            simd_euclidean_distance(&a.view(), &b.view()).expect("Distance computation failed");
688
689        // Expected: sqrt(3^2 + 3^2 + 3^2) = sqrt(27) ≈ 5.196
690        assert_relative_eq!(dist, 5.196152422706632, epsilon = 1e-10);
691    }
692
693    #[test]
694    fn test_simd_manhattan_distance() {
695        let a = array![1.0, 2.0, 3.0];
696        let b = array![4.0, 5.0, 6.0];
697
698        let dist =
699            simd_manhattan_distance(&a.view(), &b.view()).expect("Distance computation failed");
700
701        // Expected: |1-4| + |2-5| + |3-6| = 3 + 3 + 3 = 9
702        assert_eq!(dist, 9.0);
703    }
704
705    #[test]
706    fn test_simd_chebyshev_distance() {
707        let a = array![1.0, 2.0, 3.0];
708        let b = array![4.0, 6.0, 5.0];
709
710        let dist =
711            simd_chebyshev_distance(&a.view(), &b.view()).expect("Distance computation failed");
712
713        // Expected: max(|1-4|, |2-6|, |3-5|) = max(3, 4, 2) = 4
714        assert_eq!(dist, 4.0);
715    }
716
717    #[test]
718    fn test_simd_minkowski_distance() {
719        let a = array![1.0, 2.0, 3.0];
720        let b = array![4.0, 5.0, 6.0];
721
722        // Test p=1 (Manhattan)
723        let dist_p1 = simd_minkowski_distance(&a.view(), &b.view(), 1.0)
724            .expect("Distance computation failed");
725        assert_eq!(dist_p1, 9.0);
726
727        // Test p=2 (Euclidean)
728        let dist_p2 = simd_minkowski_distance(&a.view(), &b.view(), 2.0)
729            .expect("Distance computation failed");
730        assert_relative_eq!(dist_p2, 5.196152422706632, epsilon = 1e-10);
731
732        // Test p=3
733        let dist_p3 = simd_minkowski_distance(&a.view(), &b.view(), 3.0)
734            .expect("Distance computation failed");
735        assert_relative_eq!(dist_p3, 4.3267487109222245, epsilon = 1e-10);
736    }
737
738    #[test]
739    fn test_simd_cosine_distance() {
740        let a = array![1.0, 2.0, 3.0];
741        let b = array![4.0, 5.0, 6.0];
742
743        let dist = simd_cosine_distance(&a.view(), &b.view()).expect("Distance computation failed");
744
745        // Vectors are in similar direction, distance should be small
746        assert!(dist < 0.03);
747        assert!(dist >= 0.0);
748    }
749
750    #[test]
751    fn test_simd_batch_distances() {
752        let points1 = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
753        let points2 = array![[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]];
754
755        let distances = simd_batch_distances(&points1.view(), &points2.view())
756            .expect("Batch distance computation failed");
757
758        assert_eq!(distances.len(), 3);
759
760        // Each distance should be sqrt(2) ≈ 1.414
761        for &dist in distances.iter() {
762            assert_relative_eq!(dist, std::f64::consts::SQRT_2, epsilon = 1e-10);
763        }
764    }
765
766    #[test]
767    fn test_simd_knn_search() {
768        let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 2.0]];
769        let query = array![0.5, 0.5];
770
771        let (indices, distances) =
772            simd_knn_search(&query.view(), &data_points.view(), 3).expect("k-NN search failed");
773
774        assert_eq!(indices.len(), 3);
775        assert_eq!(distances.len(), 3);
776
777        // Distances should be sorted
778        for i in 1..distances.len() {
779            assert!(distances[i] >= distances[i - 1]);
780        }
781    }
782
783    #[test]
784    fn test_simd_radius_search() {
785        let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [5.0, 5.0]];
786        let query = array![0.5, 0.5];
787        let radius = 1.0;
788
789        let (indices, distances) = simd_radius_search(&query.view(), &data_points.view(), radius)
790            .expect("Radius search failed");
791
792        // Should find the 4 close points, not the far one at [5.0, 5.0]
793        assert_eq!(indices.len(), 4);
794
795        // All distances should be within radius
796        for &dist in distances.iter() {
797            assert!(dist <= radius);
798        }
799    }
800
801    #[test]
802    fn test_simd_point_to_box_distance() {
803        let point = array![2.0, 2.0];
804        let box_min = array![0.0, 0.0];
805        let box_max = array![1.0, 1.0];
806
807        let dist_sq =
808            simd_point_to_box_min_distance_squared(&point.view(), &box_min.view(), &box_max.view())
809                .expect("Point-to-box distance failed");
810
811        // Point is at (2,2), box is [0,1] x [0,1]
812        // Nearest point on box is (1,1)
813        // Distance squared = (2-1)^2 + (2-1)^2 = 2
814        assert_relative_eq!(dist_sq, 2.0, epsilon = 1e-10);
815    }
816
817    #[test]
818    fn test_simd_box_intersection() {
819        let box1_min = array![0.0, 0.0];
820        let box1_max = array![2.0, 2.0];
821        let box2_min = array![1.0, 1.0];
822        let box2_max = array![3.0, 3.0];
823
824        let intersects = simd_box_box_intersection(
825            &box1_min.view(),
826            &box1_max.view(),
827            &box2_min.view(),
828            &box2_max.view(),
829        )
830        .expect("Box intersection test failed");
831
832        assert!(intersects);
833
834        // Test non-intersecting boxes
835        let box3_min = array![10.0, 10.0];
836        let box3_max = array![20.0, 20.0];
837
838        let no_intersect = simd_box_box_intersection(
839            &box1_min.view(),
840            &box1_max.view(),
841            &box3_min.view(),
842            &box3_max.view(),
843        )
844        .expect("Box intersection test failed");
845
846        assert!(!no_intersect);
847    }
848
849    #[test]
850    fn test_dimension_mismatch_errors() {
851        let a = array![1.0, 2.0];
852        let b = array![1.0, 2.0, 3.0];
853
854        assert!(simd_euclidean_distance(&a.view(), &b.view()).is_err());
855        assert!(simd_manhattan_distance(&a.view(), &b.view()).is_err());
856        assert!(simd_chebyshev_distance(&a.view(), &b.view()).is_err());
857        assert!(simd_cosine_distance(&a.view(), &b.view()).is_err());
858    }
859
860    #[test]
861    fn test_zero_vector_cosine() {
862        let a = array![0.0, 0.0, 0.0];
863        let b = array![1.0, 2.0, 3.0];
864
865        assert!(simd_cosine_distance(&a.view(), &b.view()).is_err());
866    }
867}