1use crate::error::{SpatialError, SpatialResult};
39use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
40use scirs2_core::simd_ops::SimdUnifiedOps;
41
42pub fn simd_euclidean_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
75 if a.len() != b.len() {
76 return Err(SpatialError::ValueError(
77 "Points must have the same dimension".to_string(),
78 ));
79 }
80
81 let diff = f64::simd_sub(a, b);
82 let squared = f64::simd_mul(&diff.view(), &diff.view());
83 let sum = f64::simd_sum(&squared.view());
84 Ok(sum.sqrt())
85}
86
87pub fn simd_squared_euclidean_distance(
105 a: &ArrayView1<f64>,
106 b: &ArrayView1<f64>,
107) -> SpatialResult<f64> {
108 if a.len() != b.len() {
109 return Err(SpatialError::ValueError(
110 "Points must have the same dimension".to_string(),
111 ));
112 }
113
114 let diff = f64::simd_sub(a, b);
115 let squared = f64::simd_mul(&diff.view(), &diff.view());
116 Ok(f64::simd_sum(&squared.view()))
117}
118
119pub fn simd_manhattan_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
148 if a.len() != b.len() {
149 return Err(SpatialError::ValueError(
150 "Points must have the same dimension".to_string(),
151 ));
152 }
153
154 let diff = f64::simd_sub(a, b);
155 let abs_diff = f64::simd_abs(&diff.view());
156 Ok(f64::simd_sum(&abs_diff.view()))
157}
158
159pub fn simd_chebyshev_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
188 if a.len() != b.len() {
189 return Err(SpatialError::ValueError(
190 "Points must have the same dimension".to_string(),
191 ));
192 }
193
194 let diff = f64::simd_sub(a, b);
195 let abs_diff = f64::simd_abs(&diff.view());
196 Ok(f64::simd_max_element(&abs_diff.view()))
197}
198
199pub fn simd_minkowski_distance(
237 a: &ArrayView1<f64>,
238 b: &ArrayView1<f64>,
239 p: f64,
240) -> SpatialResult<f64> {
241 if a.len() != b.len() {
242 return Err(SpatialError::ValueError(
243 "Points must have the same dimension".to_string(),
244 ));
245 }
246
247 if p < 1.0 {
248 return Err(SpatialError::ValueError(
249 "Minkowski p must be >= 1.0".to_string(),
250 ));
251 }
252
253 if (p - 1.0).abs() < 1e-10 {
255 return simd_manhattan_distance(a, b);
256 }
257 if (p - 2.0).abs() < 1e-10 {
258 return simd_euclidean_distance(a, b);
259 }
260
261 let diff = f64::simd_sub(a, b);
262 let abs_diff = f64::simd_abs(&diff.view());
263 let powered = f64::simd_powf(&abs_diff.view(), p);
264 let sum = f64::simd_sum(&powered.view());
265 Ok(sum.powf(1.0 / p))
266}
267
268pub fn simd_cosine_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
300 if a.len() != b.len() {
301 return Err(SpatialError::ValueError(
302 "Points must have the same dimension".to_string(),
303 ));
304 }
305
306 let dot_product = f64::simd_dot(a, b);
307 let norm_a = f64::simd_norm(a);
308 let norm_b = f64::simd_norm(b);
309
310 if norm_a == 0.0 || norm_b == 0.0 {
311 return Err(SpatialError::ValueError(
312 "Cannot compute cosine distance for zero vectors".to_string(),
313 ));
314 }
315
316 let cosine_similarity = dot_product / (norm_a * norm_b);
317 Ok(1.0 - cosine_similarity)
318}
319
320pub fn simd_point_to_box_min_distance_squared(
343 point: &ArrayView1<f64>,
344 box_min: &ArrayView1<f64>,
345 box_max: &ArrayView1<f64>,
346) -> SpatialResult<f64> {
347 if point.len() != box_min.len() || point.len() != box_max.len() {
348 return Err(SpatialError::ValueError(
349 "Point and box dimensions must match".to_string(),
350 ));
351 }
352
353 let clamped = f64::simd_clamp(
359 point,
360 *box_min
361 .first()
362 .ok_or_else(|| SpatialError::ValueError("Empty array".to_string()))?,
363 *box_max
364 .first()
365 .ok_or_else(|| SpatialError::ValueError("Empty array".to_string()))?,
366 );
367
368 let mut closest_point = Array1::zeros(point.len());
370 for i in 0..point.len() {
371 closest_point[i] = point[i].clamp(box_min[i], box_max[i]);
372 }
373
374 let diff = f64::simd_sub(point, &closest_point.view());
376 let squared = f64::simd_mul(&diff.view(), &diff.view());
377 Ok(f64::simd_sum(&squared.view()))
378}
379
380pub fn simd_box_box_intersection(
399 box1_min: &ArrayView1<f64>,
400 box1_max: &ArrayView1<f64>,
401 box2_min: &ArrayView1<f64>,
402 box2_max: &ArrayView1<f64>,
403) -> SpatialResult<bool> {
404 if box1_min.len() != box1_max.len()
405 || box1_min.len() != box2_min.len()
406 || box1_min.len() != box2_max.len()
407 {
408 return Err(SpatialError::ValueError(
409 "All box dimensions must match".to_string(),
410 ));
411 }
412
413 for i in 0..box1_min.len() {
417 if box1_max[i] < box2_min[i] || box1_min[i] > box2_max[i] {
418 return Ok(false);
419 }
420 }
421
422 Ok(true)
423}
424
425pub fn simd_batch_squared_distances(
443 query_point: &ArrayView1<f64>,
444 data_points: &ArrayView2<f64>,
445) -> SpatialResult<Array1<f64>> {
446 if query_point.len() != data_points.ncols() {
447 return Err(SpatialError::ValueError(
448 "Query point dimension must match data points".to_string(),
449 ));
450 }
451
452 let n_points = data_points.nrows();
453 let mut distances = Array1::zeros(n_points);
454
455 for i in 0..n_points {
456 let data_point = data_points.row(i);
457 let diff = f64::simd_sub(query_point, &data_point);
458 let squared = f64::simd_mul(&diff.view(), &diff.view());
459 distances[i] = f64::simd_sum(&squared.view());
460 }
461
462 Ok(distances)
463}
464
465pub fn simd_batch_distances(
486 points1: &ArrayView2<f64>,
487 points2: &ArrayView2<f64>,
488) -> SpatialResult<Array1<f64>> {
489 if points1.shape() != points2.shape() {
490 return Err(SpatialError::ValueError(
491 "Point arrays must have the same shape".to_string(),
492 ));
493 }
494
495 let n_points = points1.nrows();
496 let mut distances = Array1::zeros(n_points);
497
498 for i in 0..n_points {
499 let p1 = points1.row(i);
500 let p2 = points2.row(i);
501 let diff = f64::simd_sub(&p1, &p2);
502 let squared = f64::simd_mul(&diff.view(), &diff.view());
503 let sum = f64::simd_sum(&squared.view());
504 distances[i] = sum.sqrt();
505 }
506
507 Ok(distances)
508}
509
510pub fn simd_knn_search(
531 query_point: &ArrayView1<f64>,
532 data_points: &ArrayView2<f64>,
533 k: usize,
534) -> SpatialResult<(Array1<usize>, Array1<f64>)> {
535 if query_point.len() != data_points.ncols() {
536 return Err(SpatialError::ValueError(
537 "Query point dimension must match data points".to_string(),
538 ));
539 }
540
541 let n_points = data_points.nrows();
542
543 if k == 0 {
544 return Err(SpatialError::ValueError(
545 "k must be greater than 0".to_string(),
546 ));
547 }
548
549 if k > n_points {
550 return Err(SpatialError::ValueError(format!(
551 "k ({}) cannot be larger than number of data points ({})",
552 k, n_points
553 )));
554 }
555
556 let squared_distances = simd_batch_squared_distances(query_point, data_points)?;
558
559 let mut indexed_distances: Vec<(f64, usize)> = squared_distances
561 .iter()
562 .enumerate()
563 .map(|(idx, &dist)| (dist, idx))
564 .collect();
565
566 indexed_distances.select_nth_unstable_by(k - 1, |a, b| {
568 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
569 });
570
571 indexed_distances[..k]
573 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
574
575 let mut indices = Array1::zeros(k);
577 let mut distances = Array1::zeros(k);
578
579 for (i, (dist_sq, idx)) in indexed_distances[..k].iter().enumerate() {
580 indices[i] = *idx;
581 distances[i] = dist_sq.sqrt();
582 }
583
584 Ok((indices, distances))
585}
586
587pub fn simd_radius_search(
607 query_point: &ArrayView1<f64>,
608 data_points: &ArrayView2<f64>,
609 radius: f64,
610) -> SpatialResult<(Array1<usize>, Array1<f64>)> {
611 if query_point.len() != data_points.ncols() {
612 return Err(SpatialError::ValueError(
613 "Query point dimension must match data points".to_string(),
614 ));
615 }
616
617 if radius < 0.0 {
618 return Err(SpatialError::ValueError(
619 "Radius must be non-negative".to_string(),
620 ));
621 }
622
623 let squared_distances = simd_batch_squared_distances(query_point, data_points)?;
625 let radius_squared = radius * radius;
626
627 let mut indices = Vec::new();
629 let mut distances = Vec::new();
630
631 for (idx, &dist_sq) in squared_distances.iter().enumerate() {
632 if dist_sq <= radius_squared {
633 indices.push(idx);
634 distances.push(dist_sq.sqrt());
635 }
636 }
637
638 Ok((Array1::from(indices), Array1::from(distances)))
639}
640
641pub fn simd_pairwise_distance_matrix(points: &ArrayView2<f64>) -> SpatialResult<Array2<f64>> {
653 let n_points = points.nrows();
654 let mut distances = Array2::zeros((n_points, n_points));
655
656 for i in 0..n_points {
658 let point_i = points.row(i);
659
660 for j in (i + 1)..n_points {
661 let point_j = points.row(j);
662 let diff = f64::simd_sub(&point_i, &point_j);
663 let squared = f64::simd_mul(&diff.view(), &diff.view());
664 let sum = f64::simd_sum(&squared.view());
665 let dist = sum.sqrt();
666
667 distances[[i, j]] = dist;
668 distances[[j, i]] = dist; }
670 }
671
672 Ok(distances)
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use approx::assert_relative_eq;
679 use scirs2_core::ndarray::array;
680
681 #[test]
682 fn test_simd_euclidean_distance() {
683 let a = array![1.0, 2.0, 3.0];
684 let b = array![4.0, 5.0, 6.0];
685
686 let dist =
687 simd_euclidean_distance(&a.view(), &b.view()).expect("Distance computation failed");
688
689 assert_relative_eq!(dist, 5.196152422706632, epsilon = 1e-10);
691 }
692
693 #[test]
694 fn test_simd_manhattan_distance() {
695 let a = array![1.0, 2.0, 3.0];
696 let b = array![4.0, 5.0, 6.0];
697
698 let dist =
699 simd_manhattan_distance(&a.view(), &b.view()).expect("Distance computation failed");
700
701 assert_eq!(dist, 9.0);
703 }
704
705 #[test]
706 fn test_simd_chebyshev_distance() {
707 let a = array![1.0, 2.0, 3.0];
708 let b = array![4.0, 6.0, 5.0];
709
710 let dist =
711 simd_chebyshev_distance(&a.view(), &b.view()).expect("Distance computation failed");
712
713 assert_eq!(dist, 4.0);
715 }
716
717 #[test]
718 fn test_simd_minkowski_distance() {
719 let a = array![1.0, 2.0, 3.0];
720 let b = array![4.0, 5.0, 6.0];
721
722 let dist_p1 = simd_minkowski_distance(&a.view(), &b.view(), 1.0)
724 .expect("Distance computation failed");
725 assert_eq!(dist_p1, 9.0);
726
727 let dist_p2 = simd_minkowski_distance(&a.view(), &b.view(), 2.0)
729 .expect("Distance computation failed");
730 assert_relative_eq!(dist_p2, 5.196152422706632, epsilon = 1e-10);
731
732 let dist_p3 = simd_minkowski_distance(&a.view(), &b.view(), 3.0)
734 .expect("Distance computation failed");
735 assert_relative_eq!(dist_p3, 4.3267487109222245, epsilon = 1e-10);
736 }
737
738 #[test]
739 fn test_simd_cosine_distance() {
740 let a = array![1.0, 2.0, 3.0];
741 let b = array![4.0, 5.0, 6.0];
742
743 let dist = simd_cosine_distance(&a.view(), &b.view()).expect("Distance computation failed");
744
745 assert!(dist < 0.03);
747 assert!(dist >= 0.0);
748 }
749
750 #[test]
751 fn test_simd_batch_distances() {
752 let points1 = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
753 let points2 = array![[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]];
754
755 let distances = simd_batch_distances(&points1.view(), &points2.view())
756 .expect("Batch distance computation failed");
757
758 assert_eq!(distances.len(), 3);
759
760 for &dist in distances.iter() {
762 assert_relative_eq!(dist, std::f64::consts::SQRT_2, epsilon = 1e-10);
763 }
764 }
765
766 #[test]
767 fn test_simd_knn_search() {
768 let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 2.0]];
769 let query = array![0.5, 0.5];
770
771 let (indices, distances) =
772 simd_knn_search(&query.view(), &data_points.view(), 3).expect("k-NN search failed");
773
774 assert_eq!(indices.len(), 3);
775 assert_eq!(distances.len(), 3);
776
777 for i in 1..distances.len() {
779 assert!(distances[i] >= distances[i - 1]);
780 }
781 }
782
783 #[test]
784 fn test_simd_radius_search() {
785 let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [5.0, 5.0]];
786 let query = array![0.5, 0.5];
787 let radius = 1.0;
788
789 let (indices, distances) = simd_radius_search(&query.view(), &data_points.view(), radius)
790 .expect("Radius search failed");
791
792 assert_eq!(indices.len(), 4);
794
795 for &dist in distances.iter() {
797 assert!(dist <= radius);
798 }
799 }
800
801 #[test]
802 fn test_simd_point_to_box_distance() {
803 let point = array![2.0, 2.0];
804 let box_min = array![0.0, 0.0];
805 let box_max = array![1.0, 1.0];
806
807 let dist_sq =
808 simd_point_to_box_min_distance_squared(&point.view(), &box_min.view(), &box_max.view())
809 .expect("Point-to-box distance failed");
810
811 assert_relative_eq!(dist_sq, 2.0, epsilon = 1e-10);
815 }
816
817 #[test]
818 fn test_simd_box_intersection() {
819 let box1_min = array![0.0, 0.0];
820 let box1_max = array![2.0, 2.0];
821 let box2_min = array![1.0, 1.0];
822 let box2_max = array![3.0, 3.0];
823
824 let intersects = simd_box_box_intersection(
825 &box1_min.view(),
826 &box1_max.view(),
827 &box2_min.view(),
828 &box2_max.view(),
829 )
830 .expect("Box intersection test failed");
831
832 assert!(intersects);
833
834 let box3_min = array![10.0, 10.0];
836 let box3_max = array![20.0, 20.0];
837
838 let no_intersect = simd_box_box_intersection(
839 &box1_min.view(),
840 &box1_max.view(),
841 &box3_min.view(),
842 &box3_max.view(),
843 )
844 .expect("Box intersection test failed");
845
846 assert!(!no_intersect);
847 }
848
849 #[test]
850 fn test_dimension_mismatch_errors() {
851 let a = array![1.0, 2.0];
852 let b = array![1.0, 2.0, 3.0];
853
854 assert!(simd_euclidean_distance(&a.view(), &b.view()).is_err());
855 assert!(simd_manhattan_distance(&a.view(), &b.view()).is_err());
856 assert!(simd_chebyshev_distance(&a.view(), &b.view()).is_err());
857 assert!(simd_cosine_distance(&a.view(), &b.view()).is_err());
858 }
859
860 #[test]
861 fn test_zero_vector_cosine() {
862 let a = array![0.0, 0.0, 0.0];
863 let b = array![1.0, 2.0, 3.0];
864
865 assert!(simd_cosine_distance(&a.view(), &b.view()).is_err());
866 }
867}