1#[cfg(feature = "no-std")]
23use alloc::vec;
24#[cfg(feature = "no-std")]
25use alloc::vec::Vec;
26#[cfg(not(feature = "no-std"))]
27use std::vec::Vec;
28
29#[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
31use std::arch::is_aarch64_feature_detected;
32
33pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
59 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
60
61 if a.is_empty() {
62 return 0.0;
63 }
64
65 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
66 {
67 if crate::simd_feature_detected!("avx512f") {
68 return unsafe { dot_product_avx512(a, b) };
69 } else if crate::simd_feature_detected!("avx2") {
70 return unsafe { dot_product_avx2(a, b) };
71 } else if crate::simd_feature_detected!("sse2") {
72 return unsafe { dot_product_sse2(a, b) };
73 }
74 }
75
76 #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
77 {
78 if is_aarch64_feature_detected!("neon") {
79 return unsafe { dot_product_neon(a, b) };
80 }
81 }
82
83 dot_product_scalar(a, b)
84}
85
86pub fn norm_l2(x: &[f32]) -> f32 {
105 dot_product(x, x).sqrt()
106}
107
108pub fn norm_l1(x: &[f32]) -> f32 {
128 if x.is_empty() {
129 return 0.0;
130 }
131
132 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
133 {
134 if crate::simd_feature_detected!("avx2") {
135 return unsafe { norm_l1_avx2(x) };
136 } else if crate::simd_feature_detected!("sse2") {
137 return unsafe { norm_l1_sse2(x) };
138 }
139 }
140
141 #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
142 {
143 if is_aarch64_feature_detected!("neon") {
144 return unsafe { norm_l1_neon(x) };
145 }
146 }
147
148 norm_l1_scalar(x)
149}
150
151pub fn norm_inf(x: &[f32]) -> f32 {
171 if x.is_empty() {
172 return 0.0;
173 }
174
175 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
176 {
177 if crate::simd_feature_detected!("avx2") {
178 return unsafe { norm_inf_avx2(x) };
179 } else if crate::simd_feature_detected!("sse2") {
180 return unsafe { norm_inf_sse2(x) };
181 }
182 }
183
184 norm_inf_scalar(x)
185}
186
187pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
213 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
214
215 if a.is_empty() {
216 return 0.0;
217 }
218
219 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
220 {
221 if crate::simd_feature_detected!("avx2") {
222 return unsafe { euclidean_distance_avx2(a, b) };
223 } else if crate::simd_feature_detected!("sse2") {
224 return unsafe { euclidean_distance_sse2(a, b) };
225 }
226 }
227
228 euclidean_distance_scalar(a, b)
229}
230
231pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
256 assert_eq!(a.len(), b.len(), "Vectors must have the same length");
257
258 if a.is_empty() {
259 return 1.0; }
261
262 let dot_ab = dot_product(a, b);
263 let norm_a = norm_l2(a);
264 let norm_b = norm_l2(b);
265
266 if norm_a == 0.0 || norm_b == 0.0 {
267 return 0.0; }
269
270 dot_ab / (norm_a * norm_b)
271}
272
273pub fn cross_product(a: &[f32], b: &[f32]) -> Result<Vec<f32>, &'static str> {
295 if a.len() != 3 || b.len() != 3 {
296 return Err("Cross product requires exactly 3-dimensional vectors");
297 }
298
299 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
300 {
301 if crate::simd_feature_detected!("sse2") {
302 return Ok(unsafe { cross_product_sse2(a, b) });
303 }
304 }
305
306 Ok(cross_product_scalar(a, b))
307}
308
309pub fn outer_product(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
333 let m = a.len();
334 let n = b.len();
335
336 if m == 0 || n == 0 {
337 return vec![];
338 }
339
340 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
341 {
342 if crate::simd_feature_detected!("avx2") {
343 return unsafe { outer_product_avx2(a, b) };
344 } else if crate::simd_feature_detected!("sse2") {
345 return unsafe { outer_product_sse2(a, b) };
346 }
347 }
348
349 outer_product_scalar(a, b)
350}
351
352fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
357 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
358}
359
360fn norm_l1_scalar(x: &[f32]) -> f32 {
361 x.iter().map(|&v| v.abs()).sum()
362}
363
364fn norm_inf_scalar(x: &[f32]) -> f32 {
365 x.iter().map(|&v| v.abs()).fold(0.0f32, |a, b| a.max(b))
366}
367
368fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
369 a.iter()
370 .zip(b.iter())
371 .map(|(x, y)| {
372 let diff = x - y;
373 diff * diff
374 })
375 .sum::<f32>()
376 .sqrt()
377}
378
379fn cross_product_scalar(a: &[f32], b: &[f32]) -> Vec<f32> {
380 vec![
381 a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], ]
385}
386
387fn outer_product_scalar(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
388 let m = a.len();
389 let n = b.len();
390 let mut result = vec![vec![0.0; n]; m];
391
392 for i in 0..m {
393 for (j, &b_val) in b.iter().enumerate().take(n) {
394 result[i][j] = a[i] * b_val;
395 }
396 }
397
398 result
399}
400
401#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
406#[target_feature(enable = "sse2")]
407unsafe fn dot_product_sse2(a: &[f32], b: &[f32]) -> f32 {
408 #[cfg(feature = "no-std")]
409 use core::arch::x86_64::*;
410 #[cfg(not(feature = "no-std"))]
411 use core::arch::x86_64::*;
412
413 let mut sum = _mm_setzero_ps();
414 let mut i = 0;
415
416 while i + 4 <= a.len() {
418 let a_vec = _mm_loadu_ps(a.as_ptr().add(i));
419 let b_vec = _mm_loadu_ps(b.as_ptr().add(i));
420 let prod = _mm_mul_ps(a_vec, b_vec);
421 sum = _mm_add_ps(sum, prod);
422 i += 4;
423 }
424
425 let mut result = [0.0f32; 4];
427 _mm_storeu_ps(result.as_mut_ptr(), sum);
428 let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
429
430 while i < a.len() {
432 scalar_sum += a[i] * b[i];
433 i += 1;
434 }
435
436 scalar_sum
437}
438
439#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
440#[target_feature(enable = "sse2")]
441unsafe fn norm_l1_sse2(x: &[f32]) -> f32 {
442 #[cfg(feature = "no-std")]
443 use core::arch::x86_64::*;
444 #[cfg(not(feature = "no-std"))]
445 use core::arch::x86_64::*;
446
447 let abs_mask = _mm_set1_ps(f32::from_bits(0x7FFFFFFF));
449 let mut sum = _mm_setzero_ps();
450 let mut i = 0;
451
452 while i + 4 <= x.len() {
454 let x_vec = _mm_loadu_ps(x.as_ptr().add(i));
455 let abs_vec = _mm_and_ps(x_vec, abs_mask);
456 sum = _mm_add_ps(sum, abs_vec);
457 i += 4;
458 }
459
460 let mut result = [0.0f32; 4];
462 _mm_storeu_ps(result.as_mut_ptr(), sum);
463 let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
464
465 while i < x.len() {
467 scalar_sum += x[i].abs();
468 i += 1;
469 }
470
471 scalar_sum
472}
473
474#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
475#[target_feature(enable = "sse2")]
476unsafe fn norm_inf_sse2(x: &[f32]) -> f32 {
477 #[cfg(feature = "no-std")]
478 use core::arch::x86_64::*;
479 #[cfg(not(feature = "no-std"))]
480 use core::arch::x86_64::*;
481
482 let abs_mask = _mm_set1_ps(f32::from_bits(0x7FFFFFFF));
483 let mut max_vec = _mm_setzero_ps();
484 let mut i = 0;
485
486 while i + 4 <= x.len() {
487 let x_vec = _mm_loadu_ps(x.as_ptr().add(i));
488 let abs_vec = _mm_and_ps(x_vec, abs_mask);
489 max_vec = _mm_max_ps(max_vec, abs_vec);
490 i += 4;
491 }
492
493 let mut result = [0.0f32; 4];
495 _mm_storeu_ps(result.as_mut_ptr(), max_vec);
496 let mut max_val = result[0].max(result[1]).max(result[2]).max(result[3]);
497
498 while i < x.len() {
500 max_val = max_val.max(x[i].abs());
501 i += 1;
502 }
503
504 max_val
505}
506
507#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
508#[target_feature(enable = "sse2")]
509unsafe fn euclidean_distance_sse2(a: &[f32], b: &[f32]) -> f32 {
510 #[cfg(feature = "no-std")]
511 use core::arch::x86_64::*;
512 #[cfg(not(feature = "no-std"))]
513 use core::arch::x86_64::*;
514
515 let mut sum = _mm_setzero_ps();
516 let mut i = 0;
517
518 while i + 4 <= a.len() {
519 let a_vec = _mm_loadu_ps(a.as_ptr().add(i));
520 let b_vec = _mm_loadu_ps(b.as_ptr().add(i));
521 let diff = _mm_sub_ps(a_vec, b_vec);
522 let squared = _mm_mul_ps(diff, diff);
523 sum = _mm_add_ps(sum, squared);
524 i += 4;
525 }
526
527 let mut result = [0.0f32; 4];
529 _mm_storeu_ps(result.as_mut_ptr(), sum);
530 let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
531
532 while i < a.len() {
534 let diff = a[i] - b[i];
535 scalar_sum += diff * diff;
536 i += 1;
537 }
538
539 scalar_sum.sqrt()
540}
541
542#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
543#[target_feature(enable = "sse2")]
544unsafe fn cross_product_sse2(a: &[f32], b: &[f32]) -> Vec<f32> {
545 #[cfg(feature = "no-std")]
546 use core::arch::x86_64::*;
547 #[cfg(not(feature = "no-std"))]
548 use core::arch::x86_64::*;
549
550 let a_vec = _mm_set_ps(0.0, a[2], a[1], a[0]);
552 let b_vec = _mm_set_ps(0.0, b[2], b[1], b[0]);
553
554 let a_yzx = _mm_shuffle_ps(a_vec, a_vec, 0xC9);
557 let b_zxy = _mm_shuffle_ps(b_vec, b_vec, 0xD2);
559
560 let a_zxy = _mm_shuffle_ps(a_vec, a_vec, 0xD2);
562 let b_yzx = _mm_shuffle_ps(b_vec, b_vec, 0xC9);
564
565 let prod1 = _mm_mul_ps(a_yzx, b_zxy);
567 let prod2 = _mm_mul_ps(a_zxy, b_yzx);
568 let result_vec = _mm_sub_ps(prod1, prod2);
569
570 let mut output = [0.0f32; 4];
572 _mm_storeu_ps(output.as_mut_ptr(), result_vec);
573
574 vec![output[0], output[1], output[2]]
575}
576
577#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
578#[target_feature(enable = "sse2")]
579unsafe fn outer_product_sse2(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
580 #[cfg(feature = "no-std")]
581 use core::arch::x86_64::*;
582 #[cfg(not(feature = "no-std"))]
583 use core::arch::x86_64::*;
584
585 let m = a.len();
586 let n = b.len();
587 let mut result = vec![vec![0.0; n]; m];
588
589 for i in 0..m {
590 let a_broadcast = _mm_set1_ps(a[i]);
591 let mut j = 0;
592
593 while j + 4 <= n {
594 let b_vec = _mm_loadu_ps(b.as_ptr().add(j));
595 let prod = _mm_mul_ps(a_broadcast, b_vec);
596 _mm_storeu_ps(result[i].as_mut_ptr().add(j), prod);
597 j += 4;
598 }
599
600 while j < n {
602 result[i][j] = a[i] * b[j];
603 j += 1;
604 }
605 }
606
607 result
608}
609
610#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
615#[target_feature(enable = "avx2")]
616unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
617 #[cfg(feature = "no-std")]
618 use core::arch::x86_64::*;
619 #[cfg(not(feature = "no-std"))]
620 use core::arch::x86_64::*;
621
622 let mut sum = _mm256_setzero_ps();
623 let mut i = 0;
624
625 while i + 8 <= a.len() {
627 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i));
628 let b_vec = _mm256_loadu_ps(b.as_ptr().add(i));
629 let prod = _mm256_mul_ps(a_vec, b_vec);
630 sum = _mm256_add_ps(sum, prod);
631 i += 8;
632 }
633
634 let mut result = [0.0f32; 8];
636 _mm256_storeu_ps(result.as_mut_ptr(), sum);
637 let mut scalar_sum = result.iter().sum::<f32>();
638
639 while i < a.len() {
641 scalar_sum += a[i] * b[i];
642 i += 1;
643 }
644
645 scalar_sum
646}
647
648#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
649#[target_feature(enable = "avx2")]
650unsafe fn norm_l1_avx2(x: &[f32]) -> f32 {
651 #[cfg(feature = "no-std")]
652 use core::arch::x86_64::*;
653 #[cfg(not(feature = "no-std"))]
654 use core::arch::x86_64::*;
655
656 let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFFFFFF));
657 let mut sum = _mm256_setzero_ps();
658 let mut i = 0;
659
660 while i + 8 <= x.len() {
661 let x_vec = _mm256_loadu_ps(x.as_ptr().add(i));
662 let abs_vec = _mm256_and_ps(x_vec, abs_mask);
663 sum = _mm256_add_ps(sum, abs_vec);
664 i += 8;
665 }
666
667 let mut result = [0.0f32; 8];
669 _mm256_storeu_ps(result.as_mut_ptr(), sum);
670 let mut scalar_sum = result.iter().sum::<f32>();
671
672 while i < x.len() {
674 scalar_sum += x[i].abs();
675 i += 1;
676 }
677
678 scalar_sum
679}
680
681#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
682#[target_feature(enable = "avx2")]
683unsafe fn norm_inf_avx2(x: &[f32]) -> f32 {
684 #[cfg(feature = "no-std")]
685 use core::arch::x86_64::*;
686 #[cfg(not(feature = "no-std"))]
687 use core::arch::x86_64::*;
688
689 let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFFFFFF));
690 let mut max_vec = _mm256_setzero_ps();
691 let mut i = 0;
692
693 while i + 8 <= x.len() {
694 let x_vec = _mm256_loadu_ps(x.as_ptr().add(i));
695 let abs_vec = _mm256_and_ps(x_vec, abs_mask);
696 max_vec = _mm256_max_ps(max_vec, abs_vec);
697 i += 8;
698 }
699
700 let mut result = [0.0f32; 8];
702 _mm256_storeu_ps(result.as_mut_ptr(), max_vec);
703 let mut max_val = result.iter().fold(0.0f32, |a, &b| a.max(b));
704
705 while i < x.len() {
707 max_val = max_val.max(x[i].abs());
708 i += 1;
709 }
710
711 max_val
712}
713
714#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
715#[target_feature(enable = "avx2")]
716unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
717 #[cfg(feature = "no-std")]
718 use core::arch::x86_64::*;
719 #[cfg(not(feature = "no-std"))]
720 use core::arch::x86_64::*;
721
722 let mut sum = _mm256_setzero_ps();
723 let mut i = 0;
724
725 while i + 8 <= a.len() {
726 let a_vec = _mm256_loadu_ps(a.as_ptr().add(i));
727 let b_vec = _mm256_loadu_ps(b.as_ptr().add(i));
728 let diff = _mm256_sub_ps(a_vec, b_vec);
729 let squared = _mm256_mul_ps(diff, diff);
730 sum = _mm256_add_ps(sum, squared);
731 i += 8;
732 }
733
734 let mut result = [0.0f32; 8];
736 _mm256_storeu_ps(result.as_mut_ptr(), sum);
737 let mut scalar_sum = result.iter().sum::<f32>();
738
739 while i < a.len() {
741 let diff = a[i] - b[i];
742 scalar_sum += diff * diff;
743 i += 1;
744 }
745
746 scalar_sum.sqrt()
747}
748
749#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
750#[target_feature(enable = "avx2")]
751unsafe fn outer_product_avx2(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
752 #[cfg(feature = "no-std")]
753 use core::arch::x86_64::*;
754 #[cfg(not(feature = "no-std"))]
755 use core::arch::x86_64::*;
756
757 let m = a.len();
758 let n = b.len();
759 let mut result = vec![vec![0.0; n]; m];
760
761 for i in 0..m {
762 let a_broadcast = _mm256_set1_ps(a[i]);
763 let mut j = 0;
764
765 while j + 8 <= n {
766 let b_vec = _mm256_loadu_ps(b.as_ptr().add(j));
767 let prod = _mm256_mul_ps(a_broadcast, b_vec);
768 _mm256_storeu_ps(result[i].as_mut_ptr().add(j), prod);
769 j += 8;
770 }
771
772 while j < n {
774 result[i][j] = a[i] * b[j];
775 j += 1;
776 }
777 }
778
779 result
780}
781
782#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
787#[target_feature(enable = "avx512f")]
788unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
789 #[cfg(feature = "no-std")]
790 use core::arch::x86_64::*;
791 #[cfg(not(feature = "no-std"))]
792 use core::arch::x86_64::*;
793
794 let mut sum = _mm512_setzero_ps();
795 let mut i = 0;
796
797 while i + 16 <= a.len() {
799 let a_vec = _mm512_loadu_ps(a.as_ptr().add(i));
800 let b_vec = _mm512_loadu_ps(b.as_ptr().add(i));
801 sum = _mm512_fmadd_ps(a_vec, b_vec, sum); i += 16;
803 }
804
805 let scalar_sum = _mm512_reduce_add_ps(sum);
807
808 let mut remaining_sum = 0.0f32;
810 while i < a.len() {
811 remaining_sum += a[i] * b[i];
812 i += 1;
813 }
814
815 scalar_sum + remaining_sum
816}
817
818#[cfg(target_arch = "aarch64")]
823#[allow(dead_code)] #[target_feature(enable = "neon")]
825unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
826 use core::arch::aarch64::*;
827
828 let mut sum = vdupq_n_f32(0.0);
829 let mut i = 0;
830
831 while i + 4 <= a.len() {
833 let a_vec = vld1q_f32(a.as_ptr().add(i));
834 let b_vec = vld1q_f32(b.as_ptr().add(i));
835 sum = vfmaq_f32(sum, a_vec, b_vec); i += 4;
837 }
838
839 let sum_pair = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
841 let final_sum = vpadd_f32(sum_pair, sum_pair);
842 let mut scalar_sum = vget_lane_f32(final_sum, 0);
843
844 while i < a.len() {
846 scalar_sum += a[i] * b[i];
847 i += 1;
848 }
849
850 scalar_sum
851}
852
853#[cfg(target_arch = "aarch64")]
854#[allow(dead_code)] #[target_feature(enable = "neon")]
856unsafe fn norm_l1_neon(x: &[f32]) -> f32 {
857 use core::arch::aarch64::*;
858
859 let mut sum = vdupq_n_f32(0.0);
860 let mut i = 0;
861
862 while i + 4 <= x.len() {
863 let x_vec = vld1q_f32(x.as_ptr().add(i));
864 let abs_vec = vabsq_f32(x_vec);
865 sum = vaddq_f32(sum, abs_vec);
866 i += 4;
867 }
868
869 let sum_pair = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
871 let final_sum = vpadd_f32(sum_pair, sum_pair);
872 let mut scalar_sum = vget_lane_f32(final_sum, 0);
873
874 while i < x.len() {
876 scalar_sum += x[i].abs();
877 i += 1;
878 }
879
880 scalar_sum
881}
882
883#[allow(non_snake_case)]
884#[cfg(all(test, not(feature = "no-std")))]
885mod tests {
886 use super::*;
887
888 #[test]
889 fn test_dot_product() {
890 let a = vec![1.0, 2.0, 3.0, 4.0];
891 let b = vec![5.0, 6.0, 7.0, 8.0];
892 let result = dot_product(&a, &b);
893 assert_eq!(result, 70.0); let empty_a: Vec<f32> = vec![];
897 let empty_b: Vec<f32> = vec![];
898 assert_eq!(dot_product(&empty_a, &empty_b), 0.0);
899
900 let single_a = vec![3.0];
902 let single_b = vec![4.0];
903 assert_eq!(dot_product(&single_a, &single_b), 12.0);
904 }
905
906 #[test]
907 fn test_norms() {
908 let x = vec![3.0, 4.0];
909
910 let norm2 = norm_l2(&x);
912 assert_eq!(norm2, 5.0); let norm1 = norm_l1(&x);
916 assert_eq!(norm1, 7.0); let norm_inf_val = norm_inf(&x);
920 assert_eq!(norm_inf_val, 4.0); let y = vec![-3.0, 4.0, -5.0];
924 assert_eq!(norm_l1(&y), 12.0); assert_eq!(norm_inf(&y), 5.0); let empty: Vec<f32> = vec![];
929 assert_eq!(norm_l2(&empty), 0.0);
930 assert_eq!(norm_l1(&empty), 0.0);
931 assert_eq!(norm_inf(&empty), 0.0);
932 }
933
934 #[test]
935 fn test_euclidean_distance() {
936 let a = vec![1.0, 2.0, 3.0];
937 let b = vec![4.0, 5.0, 6.0];
938 let result = euclidean_distance(&a, &b);
939 assert!((result - 5.196).abs() < 0.01);
941
942 let identical = euclidean_distance(&a, &a);
944 assert_eq!(identical, 0.0);
945
946 let empty_a: Vec<f32> = vec![];
948 let empty_b: Vec<f32> = vec![];
949 assert_eq!(euclidean_distance(&empty_a, &empty_b), 0.0);
950 }
951
952 #[test]
953 fn test_cosine_similarity() {
954 let a = vec![1.0, 0.0, 0.0];
956 let b = vec![0.0, 1.0, 0.0];
957 let result = cosine_similarity(&a, &b);
958 assert!((result - 0.0).abs() < f32::EPSILON);
959
960 let identical = cosine_similarity(&a, &a);
962 assert!((identical - 1.0).abs() < f32::EPSILON);
963
964 let opposite = vec![-1.0, 0.0, 0.0];
966 let opposite_sim = cosine_similarity(&a, &opposite);
967 assert!((opposite_sim - (-1.0)).abs() < f32::EPSILON);
968
969 let zero = vec![0.0, 0.0, 0.0];
971 let zero_sim = cosine_similarity(&a, &zero);
972 assert_eq!(zero_sim, 0.0);
973
974 let empty_a: Vec<f32> = vec![];
976 let empty_b: Vec<f32> = vec![];
977 assert_eq!(cosine_similarity(&empty_a, &empty_b), 1.0);
978 }
979
980 #[test]
981 fn test_cross_product() {
982 let i = vec![1.0, 0.0, 0.0];
984 let j = vec![0.0, 1.0, 0.0];
985 let result = cross_product(&i, &j).expect("operation should succeed");
986 assert_eq!(result, vec![0.0, 0.0, 1.0]); let a = vec![1.0, 2.0, 3.0];
990 let b = vec![4.0, 5.0, 6.0];
991 let cross = cross_product(&a, &b).expect("operation should succeed");
992 assert_eq!(cross, vec![-3.0, 6.0, -3.0]);
994
995 let wrong_dim = vec![1.0, 2.0];
997 assert!(cross_product(&wrong_dim, &j).is_err());
998 }
999
1000 #[test]
1001 fn test_outer_product() {
1002 let a = vec![1.0, 2.0];
1003 let b = vec![3.0, 4.0, 5.0];
1004 let result = outer_product(&a, &b);
1005
1006 assert_eq!(result.len(), 2);
1008 assert_eq!(result[0].len(), 3);
1009 assert_eq!(result[0], vec![3.0, 4.0, 5.0]);
1010 assert_eq!(result[1], vec![6.0, 8.0, 10.0]);
1011
1012 let empty_a: Vec<f32> = vec![];
1014 let empty_result = outer_product(&empty_a, &b);
1015 assert!(empty_result.is_empty());
1016
1017 let empty_b: Vec<f32> = vec![];
1018 let empty_result2 = outer_product(&a, &empty_b);
1019 assert!(empty_result2.is_empty());
1020 }
1021
1022 #[test]
1023 #[should_panic(expected = "Vectors must have the same length")]
1024 fn test_dot_product_dimension_mismatch() {
1025 let a = vec![1.0, 2.0, 3.0];
1026 let b = vec![4.0, 5.0];
1027 dot_product(&a, &b);
1028 }
1029
1030 #[test]
1031 #[should_panic(expected = "Vectors must have the same length")]
1032 fn test_euclidean_distance_dimension_mismatch() {
1033 let a = vec![1.0, 2.0, 3.0];
1034 let b = vec![4.0, 5.0];
1035 euclidean_distance(&a, &b);
1036 }
1037
1038 #[test]
1039 #[should_panic(expected = "Vectors must have the same length")]
1040 fn test_cosine_similarity_dimension_mismatch() {
1041 let a = vec![1.0, 2.0, 3.0];
1042 let b = vec![4.0, 5.0];
1043 cosine_similarity(&a, &b);
1044 }
1045}