1#![doc = include_str!("../README.md")]
2
3mod fallback;
4
5#[cfg(target_arch = "aarch64")]
6mod neon;
7
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9mod sse2;
10
11#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
12mod avx2;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SimdBackend {
17 Scalar,
19 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21 Sse2,
22 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
24 Avx2,
25 #[cfg(target_arch = "aarch64")]
27 Neon,
28}
29
30impl SimdBackend {
31 pub fn name(self) -> &'static str {
33 match self {
34 Self::Scalar => "scalar",
35 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
36 Self::Sse2 => "sse2",
37 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38 Self::Avx2 => "avx2+fma",
39 #[cfg(target_arch = "aarch64")]
40 Self::Neon => "neon",
41 }
42 }
43
44 pub fn dot_product(self, a: &[f32], b: &[f32]) -> f32 {
48 debug_assert_eq!(a.len(), b.len());
49 match self {
50 Self::Scalar => fallback::dot_product(a, b),
51 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
52 Self::Sse2 => unsafe { sse2::dot_product(a, b) },
54 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
55 Self::Avx2 => unsafe { avx2::dot_product(a, b) },
57 #[cfg(target_arch = "aarch64")]
58 Self::Neon => unsafe { neon::dot_product(a, b) },
60 }
61 }
62
63 pub fn dual_dot_product(self, input: &[f32], k1: &[f32], k2: &[f32]) -> (f32, f32) {
68 debug_assert_eq!(input.len(), k1.len());
69 debug_assert_eq!(input.len(), k2.len());
70 match self {
71 Self::Scalar => fallback::dual_dot_product(input, k1, k2),
72 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
73 Self::Sse2 => unsafe { sse2::dual_dot_product(input, k1, k2) },
75 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76 Self::Avx2 => unsafe { avx2::dual_dot_product(input, k1, k2) },
78 #[cfg(target_arch = "aarch64")]
79 Self::Neon => unsafe { neon::dual_dot_product(input, k1, k2) },
81 }
82 }
83
84 pub fn convolve_sinc(
94 self,
95 input: &[f32],
96 k1: &[f32],
97 k2: &[f32],
98 kernel_interpolation_factor: f64,
99 ) -> f32 {
100 debug_assert_eq!(input.len(), k1.len());
101 debug_assert_eq!(input.len(), k2.len());
102 match self {
103 Self::Scalar => fallback::convolve_sinc(input, k1, k2, kernel_interpolation_factor),
104 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
105 Self::Sse2 => unsafe {
107 sse2::convolve_sinc(input, k1, k2, kernel_interpolation_factor)
108 },
109 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
110 Self::Avx2 => unsafe {
112 avx2::convolve_sinc(input, k1, k2, kernel_interpolation_factor)
113 },
114 #[cfg(target_arch = "aarch64")]
115 Self::Neon => unsafe {
117 neon::convolve_sinc(input, k1, k2, kernel_interpolation_factor)
118 },
119 }
120 }
121
122 pub fn multiply_accumulate(self, acc: &mut [f32], a: &[f32], b: &[f32]) {
126 debug_assert_eq!(acc.len(), a.len());
127 debug_assert_eq!(acc.len(), b.len());
128 match self {
129 Self::Scalar => fallback::multiply_accumulate(acc, a, b),
130 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
131 Self::Sse2 => unsafe { sse2::multiply_accumulate(acc, a, b) },
133 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
134 Self::Avx2 => unsafe { avx2::multiply_accumulate(acc, a, b) },
136 #[cfg(target_arch = "aarch64")]
137 Self::Neon => unsafe { neon::multiply_accumulate(acc, a, b) },
139 }
140 }
141
142 pub fn sum(self, x: &[f32]) -> f32 {
144 match self {
145 Self::Scalar => fallback::sum(x),
146 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
147 Self::Sse2 => unsafe { sse2::sum(x) },
149 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
150 Self::Avx2 => unsafe { avx2::sum(x) },
152 #[cfg(target_arch = "aarch64")]
153 Self::Neon => unsafe { neon::sum(x) },
155 }
156 }
157
158 pub fn elementwise_sqrt(self, x: &mut [f32]) {
160 match self {
161 Self::Scalar => fallback::elementwise_sqrt(x),
162 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
163 Self::Sse2 => unsafe { sse2::elementwise_sqrt(x) },
164 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
165 Self::Avx2 => unsafe { avx2::elementwise_sqrt(x) },
166 #[cfg(target_arch = "aarch64")]
167 Self::Neon => unsafe { neon::elementwise_sqrt(x) },
168 }
169 }
170
171 pub fn elementwise_multiply(self, x: &[f32], y: &[f32], z: &mut [f32]) {
175 debug_assert_eq!(x.len(), y.len());
176 debug_assert_eq!(x.len(), z.len());
177 match self {
178 Self::Scalar => fallback::elementwise_multiply(x, y, z),
179 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
180 Self::Sse2 => unsafe { sse2::elementwise_multiply(x, y, z) },
181 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
182 Self::Avx2 => unsafe { avx2::elementwise_multiply(x, y, z) },
183 #[cfg(target_arch = "aarch64")]
184 Self::Neon => unsafe { neon::elementwise_multiply(x, y, z) },
185 }
186 }
187
188 pub fn elementwise_accumulate(self, x: &[f32], z: &mut [f32]) {
192 debug_assert_eq!(x.len(), z.len());
193 match self {
194 Self::Scalar => fallback::elementwise_accumulate(x, z),
195 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196 Self::Sse2 => unsafe { sse2::elementwise_accumulate(x, z) },
197 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
198 Self::Avx2 => unsafe { avx2::elementwise_accumulate(x, z) },
199 #[cfg(target_arch = "aarch64")]
200 Self::Neon => unsafe { neon::elementwise_accumulate(x, z) },
201 }
202 }
203
204 pub fn power_spectrum(self, re: &[f32], im: &[f32], out: &mut [f32]) {
208 debug_assert_eq!(re.len(), im.len());
209 debug_assert_eq!(re.len(), out.len());
210 match self {
211 Self::Scalar => fallback::power_spectrum(re, im, out),
212 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
213 Self::Sse2 => unsafe { sse2::power_spectrum(re, im, out) },
214 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
215 Self::Avx2 => unsafe { avx2::power_spectrum(re, im, out) },
216 #[cfg(target_arch = "aarch64")]
217 Self::Neon => unsafe { neon::power_spectrum(re, im, out) },
218 }
219 }
220
221 pub fn elementwise_min(self, a: &[f32], b: &[f32], out: &mut [f32]) {
225 debug_assert_eq!(a.len(), b.len());
226 debug_assert_eq!(a.len(), out.len());
227 match self {
228 Self::Scalar => fallback::elementwise_min(a, b, out),
229 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
230 Self::Sse2 => unsafe { sse2::elementwise_min(a, b, out) },
231 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
232 Self::Avx2 => unsafe { avx2::elementwise_min(a, b, out) },
233 #[cfg(target_arch = "aarch64")]
234 Self::Neon => unsafe { neon::elementwise_min(a, b, out) },
235 }
236 }
237
238 pub fn elementwise_max(self, a: &[f32], b: &[f32], out: &mut [f32]) {
242 debug_assert_eq!(a.len(), b.len());
243 debug_assert_eq!(a.len(), out.len());
244 match self {
245 Self::Scalar => fallback::elementwise_max(a, b, out),
246 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
247 Self::Sse2 => unsafe { sse2::elementwise_max(a, b, out) },
248 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249 Self::Avx2 => unsafe { avx2::elementwise_max(a, b, out) },
250 #[cfg(target_arch = "aarch64")]
251 Self::Neon => unsafe { neon::elementwise_max(a, b, out) },
252 }
253 }
254
255 pub fn complex_multiply_accumulate(
261 self,
262 x_re: &[f32],
263 x_im: &[f32],
264 h_re: &[f32],
265 h_im: &[f32],
266 acc_re: &mut [f32],
267 acc_im: &mut [f32],
268 ) {
269 debug_assert_eq!(x_re.len(), x_im.len());
270 debug_assert_eq!(x_re.len(), h_re.len());
271 debug_assert_eq!(x_re.len(), h_im.len());
272 debug_assert_eq!(x_re.len(), acc_re.len());
273 debug_assert_eq!(x_re.len(), acc_im.len());
274 match self {
275 Self::Scalar => {
276 fallback::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
277 }
278 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
279 Self::Sse2 => unsafe {
280 sse2::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
281 },
282 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
283 Self::Avx2 => unsafe {
284 avx2::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
285 },
286 #[cfg(target_arch = "aarch64")]
287 Self::Neon => unsafe {
288 neon::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
289 },
290 }
291 }
292 pub fn complex_multiply_accumulate_standard(
298 self,
299 x_re: &[f32],
300 x_im: &[f32],
301 h_re: &[f32],
302 h_im: &[f32],
303 acc_re: &mut [f32],
304 acc_im: &mut [f32],
305 ) {
306 debug_assert_eq!(x_re.len(), x_im.len());
307 debug_assert_eq!(x_re.len(), h_re.len());
308 debug_assert_eq!(x_re.len(), h_im.len());
309 debug_assert_eq!(x_re.len(), acc_re.len());
310 debug_assert_eq!(x_re.len(), acc_im.len());
311 match self {
312 Self::Scalar => {
313 fallback::complex_multiply_accumulate_standard(
314 x_re, x_im, h_re, h_im, acc_re, acc_im,
315 );
316 }
317 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
318 Self::Sse2 => unsafe {
319 sse2::complex_multiply_accumulate_standard(x_re, x_im, h_re, h_im, acc_re, acc_im);
320 },
321 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
322 Self::Avx2 => unsafe {
323 avx2::complex_multiply_accumulate_standard(x_re, x_im, h_re, h_im, acc_re, acc_im);
324 },
325 #[cfg(target_arch = "aarch64")]
326 Self::Neon => unsafe {
327 neon::complex_multiply_accumulate_standard(x_re, x_im, h_re, h_im, acc_re, acc_im);
328 },
329 }
330 }
331}
332
333#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
335cpufeatures::new!(has_avx2_fma, "avx2", "fma");
336#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
337cpufeatures::new!(has_sse2, "sse2");
338
339pub fn available_backends() -> Vec<SimdBackend> {
345 let mut backends = vec![SimdBackend::Scalar];
346
347 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
348 {
349 if has_sse2::get() {
350 backends.push(SimdBackend::Sse2);
351 }
352 if has_avx2_fma::get() {
353 backends.push(SimdBackend::Avx2);
354 }
355 }
356
357 #[cfg(target_arch = "aarch64")]
358 backends.push(SimdBackend::Neon);
359
360 backends
361}
362
363pub fn detect_backend() -> SimdBackend {
369 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
370 {
371 if has_avx2_fma::get() {
372 return SimdBackend::Avx2;
373 }
374 if has_sse2::get() {
375 return SimdBackend::Sse2;
376 }
377 }
378
379 #[cfg(target_arch = "aarch64")]
380 {
381 return SimdBackend::Neon;
382 }
383
384 #[allow(unreachable_code, reason = "fallback for architectures without SIMD")]
385 SimdBackend::Scalar
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_detect_backend() {
394 let backend = detect_backend();
395 println!("Detected SIMD backend: {}", backend.name());
396 assert!(!backend.name().is_empty());
397 }
398
399 #[test]
400 fn test_backend_is_copy() {
401 let a = detect_backend();
402 let b = a;
403 assert_eq!(a, b);
404 }
405
406 #[test]
407 fn test_dot_product_simple() {
408 let ops = detect_backend();
409 let a = [1.0f32, 2.0, 3.0, 4.0];
410 let b = [5.0f32, 6.0, 7.0, 8.0];
411 let result = ops.dot_product(&a, &b);
412 assert!((result - 70.0).abs() < 1e-6);
414 }
415
416 #[test]
417 fn test_dual_dot_product_simple() {
418 let ops = detect_backend();
419 let input = [1.0f32, 2.0, 3.0, 4.0];
420 let k1 = [1.0f32, 0.0, 1.0, 0.0];
421 let k2 = [0.0f32, 1.0, 0.0, 1.0];
422 let (d1, d2) = ops.dual_dot_product(&input, &k1, &k2);
423 assert!((d1 - 4.0).abs() < 1e-6);
424 assert!((d2 - 6.0).abs() < 1e-6);
425 }
426
427 #[test]
428 fn test_multiply_accumulate_simple() {
429 let ops = detect_backend();
430 let mut acc = [10.0f32, 20.0, 30.0, 40.0];
431 let a = [1.0f32, 2.0, 3.0, 4.0];
432 let b = [5.0f32, 6.0, 7.0, 8.0];
433 ops.multiply_accumulate(&mut acc, &a, &b);
434 assert!((acc[0] - 15.0).abs() < 1e-6);
435 assert!((acc[1] - 32.0).abs() < 1e-6);
436 assert!((acc[2] - 51.0).abs() < 1e-6);
437 assert!((acc[3] - 72.0).abs() < 1e-6);
438 }
439
440 #[test]
441 fn test_sum_simple() {
442 let ops = detect_backend();
443 let x = [1.0f32, 2.0, 3.0, 4.0, 5.0];
444 assert!((ops.sum(&x) - 15.0).abs() < 1e-6);
445 }
446
447 #[test]
448 fn test_empty_slices() {
449 let ops = detect_backend();
450 assert_eq!(ops.dot_product(&[], &[]), 0.0);
451 assert_eq!(ops.sum(&[]), 0.0);
452 let (d1, d2) = ops.dual_dot_product(&[], &[], &[]);
453 assert_eq!(d1, 0.0);
454 assert_eq!(d2, 0.0);
455 }
456
457 #[test]
459 fn test_dot_product_matches_scalar() {
460 let scalar = SimdBackend::Scalar;
461
462 for &backend in &available_backends() {
463 if backend == SimdBackend::Scalar {
464 continue;
465 }
466 for size in [0, 1, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256] {
467 let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
468 let b: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.005).collect();
469
470 let scalar_result = scalar.dot_product(&a, &b);
471 let simd_result = backend.dot_product(&a, &b);
472
473 assert!(
474 (scalar_result - simd_result).abs() < 1e-3,
475 "[{}] Mismatch for size {size}: scalar={scalar_result}, simd={simd_result}",
476 backend.name()
477 );
478 }
479 }
480 }
481
482 #[test]
483 fn test_dual_dot_product_matches_scalar() {
484 let scalar = SimdBackend::Scalar;
485
486 for &backend in &available_backends() {
487 if backend == SimdBackend::Scalar {
488 continue;
489 }
490 for size in [0, 1, 4, 7, 16, 31, 64, 128, 256] {
491 let input: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
492 let k1: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.003).collect();
493 let k2: Vec<f32> = (0..size).map(|i| 0.5 + (i as f32) * 0.002).collect();
494
495 let (s1, s2) = scalar.dual_dot_product(&input, &k1, &k2);
496 let (d1, d2) = backend.dual_dot_product(&input, &k1, &k2);
497
498 assert!(
499 (s1 - d1).abs() < 1e-3,
500 "[{}] k1 mismatch for size {size}: scalar={s1}, simd={d1}",
501 backend.name()
502 );
503 assert!(
504 (s2 - d2).abs() < 1e-3,
505 "[{}] k2 mismatch for size {size}: scalar={s2}, simd={d2}",
506 backend.name()
507 );
508 }
509 }
510 }
511
512 #[test]
513 fn test_multiply_accumulate_matches_scalar() {
514 let scalar = SimdBackend::Scalar;
515
516 for &backend in &available_backends() {
517 if backend == SimdBackend::Scalar {
518 continue;
519 }
520 for size in [0, 1, 4, 7, 16, 31, 64, 128] {
521 let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
522 let b: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.005).collect();
523
524 let mut acc_scalar: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
525 let mut acc_simd = acc_scalar.clone();
526
527 scalar.multiply_accumulate(&mut acc_scalar, &a, &b);
528 backend.multiply_accumulate(&mut acc_simd, &a, &b);
529
530 for i in 0..size {
531 assert!(
532 (acc_scalar[i] - acc_simd[i]).abs() < 1e-4,
533 "[{}] Mismatch at index {i} for size {size}: scalar={}, simd={}",
534 backend.name(),
535 acc_scalar[i],
536 acc_simd[i]
537 );
538 }
539 }
540 }
541 }
542
543 #[test]
544 fn test_elementwise_sqrt_simple() {
545 let ops = detect_backend();
546 let mut x = [4.0f32, 9.0, 16.0, 25.0, 36.0];
547 ops.elementwise_sqrt(&mut x);
548 assert!((x[0] - 2.0).abs() < 1e-6);
549 assert!((x[1] - 3.0).abs() < 1e-6);
550 assert!((x[2] - 4.0).abs() < 1e-6);
551 assert!((x[3] - 5.0).abs() < 1e-6);
552 assert!((x[4] - 6.0).abs() < 1e-6);
553 }
554
555 #[test]
556 fn test_elementwise_sqrt_matches_scalar() {
557 let scalar = SimdBackend::Scalar;
558
559 for &backend in &available_backends() {
560 if backend == SimdBackend::Scalar {
561 continue;
562 }
563 for size in [0, 1, 4, 7, 8, 15, 16, 31, 64, 65, 128] {
564 let mut x_scalar: Vec<f32> = (0..size).map(|i| (i as f32) * 0.5 + 0.1).collect();
565 let mut x_simd = x_scalar.clone();
566
567 scalar.elementwise_sqrt(&mut x_scalar);
568 backend.elementwise_sqrt(&mut x_simd);
569
570 for i in 0..size {
571 assert!(
572 (x_scalar[i] - x_simd[i]).abs() < 1e-6,
573 "[{}] sqrt mismatch at index {i} for size {size}: scalar={}, simd={}",
574 backend.name(),
575 x_scalar[i],
576 x_simd[i]
577 );
578 }
579 }
580 }
581 }
582
583 #[test]
584 fn test_elementwise_multiply_simple() {
585 let ops = detect_backend();
586 let x = [1.0f32, 2.0, 3.0, 4.0, 5.0];
587 let y = [5.0f32, 4.0, 3.0, 2.0, 1.0];
588 let mut z = [0.0f32; 5];
589 ops.elementwise_multiply(&x, &y, &mut z);
590 assert!((z[0] - 5.0).abs() < 1e-6);
591 assert!((z[1] - 8.0).abs() < 1e-6);
592 assert!((z[2] - 9.0).abs() < 1e-6);
593 assert!((z[3] - 8.0).abs() < 1e-6);
594 assert!((z[4] - 5.0).abs() < 1e-6);
595 }
596
597 #[test]
598 fn test_elementwise_multiply_matches_scalar() {
599 let scalar = SimdBackend::Scalar;
600
601 for &backend in &available_backends() {
602 if backend == SimdBackend::Scalar {
603 continue;
604 }
605 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
606 let x: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
607 let y: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.005).collect();
608 let mut z_scalar = vec![0.0f32; size];
609 let mut z_simd = vec![0.0f32; size];
610
611 scalar.elementwise_multiply(&x, &y, &mut z_scalar);
612 backend.elementwise_multiply(&x, &y, &mut z_simd);
613
614 for i in 0..size {
615 assert!(
616 (z_scalar[i] - z_simd[i]).abs() < 1e-6,
617 "[{}] multiply mismatch at index {i} for size {size}: scalar={}, simd={}",
618 backend.name(),
619 z_scalar[i],
620 z_simd[i]
621 );
622 }
623 }
624 }
625 }
626
627 #[test]
628 fn test_elementwise_accumulate_simple() {
629 let ops = detect_backend();
630 let x = [1.0f32, 2.0, 3.0, 4.0, 5.0];
631 let mut z = [10.0f32, 20.0, 30.0, 40.0, 50.0];
632 ops.elementwise_accumulate(&x, &mut z);
633 assert!((z[0] - 11.0).abs() < 1e-6);
634 assert!((z[1] - 22.0).abs() < 1e-6);
635 assert!((z[2] - 33.0).abs() < 1e-6);
636 assert!((z[3] - 44.0).abs() < 1e-6);
637 assert!((z[4] - 55.0).abs() < 1e-6);
638 }
639
640 #[test]
641 fn test_elementwise_accumulate_matches_scalar() {
642 let scalar = SimdBackend::Scalar;
643
644 for &backend in &available_backends() {
645 if backend == SimdBackend::Scalar {
646 continue;
647 }
648 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
649 let x: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
650 let mut z_scalar: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
651 let mut z_simd = z_scalar.clone();
652
653 scalar.elementwise_accumulate(&x, &mut z_scalar);
654 backend.elementwise_accumulate(&x, &mut z_simd);
655
656 for i in 0..size {
657 assert!(
658 (z_scalar[i] - z_simd[i]).abs() < 1e-6,
659 "[{}] accumulate mismatch at index {i} for size {size}: scalar={}, simd={}",
660 backend.name(),
661 z_scalar[i],
662 z_simd[i]
663 );
664 }
665 }
666 }
667 }
668
669 #[test]
670 fn test_power_spectrum_simple() {
671 let ops = detect_backend();
672 let re = [3.0f32, 0.0, 1.0, 2.0, 5.0];
673 let im = [4.0f32, 1.0, 0.0, 3.0, 12.0];
674 let mut out = [0.0f32; 5];
675 ops.power_spectrum(&re, &im, &mut out);
676 assert!((out[0] - 25.0).abs() < 1e-6); assert!((out[1] - 1.0).abs() < 1e-6); assert!((out[2] - 1.0).abs() < 1e-6); assert!((out[3] - 13.0).abs() < 1e-6); assert!((out[4] - 169.0).abs() < 1e-6); }
682
683 #[test]
684 fn test_power_spectrum_matches_scalar() {
685 let scalar = SimdBackend::Scalar;
686
687 for &backend in &available_backends() {
688 if backend == SimdBackend::Scalar {
689 continue;
690 }
691 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
692 let re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
693 let im: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
694 let mut out_scalar = vec![0.0f32; size];
695 let mut out_simd = vec![0.0f32; size];
696
697 scalar.power_spectrum(&re, &im, &mut out_scalar);
698 backend.power_spectrum(&re, &im, &mut out_simd);
699
700 for i in 0..size {
701 assert!(
702 (out_scalar[i] - out_simd[i]).abs() < 1e-4,
703 "[{}] power_spectrum mismatch at index {i} for size {size}: scalar={}, simd={}",
704 backend.name(),
705 out_scalar[i],
706 out_simd[i]
707 );
708 }
709 }
710 }
711 }
712
713 #[test]
714 fn test_elementwise_min_simple() {
715 let ops = detect_backend();
716 let a = [1.0f32, 5.0, 3.0, 8.0, 2.0];
717 let b = [4.0f32, 2.0, 7.0, 1.0, 9.0];
718 let mut out = [0.0f32; 5];
719 ops.elementwise_min(&a, &b, &mut out);
720 assert_eq!(out, [1.0, 2.0, 3.0, 1.0, 2.0]);
721 }
722
723 #[test]
724 fn test_elementwise_min_matches_scalar() {
725 let scalar = SimdBackend::Scalar;
726
727 for &backend in &available_backends() {
728 if backend == SimdBackend::Scalar {
729 continue;
730 }
731 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
732 let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
733 let b: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
734 let mut out_scalar = vec![0.0f32; size];
735 let mut out_simd = vec![0.0f32; size];
736
737 scalar.elementwise_min(&a, &b, &mut out_scalar);
738 backend.elementwise_min(&a, &b, &mut out_simd);
739
740 for i in 0..size {
741 assert!(
742 (out_scalar[i] - out_simd[i]).abs() < 1e-6,
743 "[{}] min mismatch at index {i} for size {size}: scalar={}, simd={}",
744 backend.name(),
745 out_scalar[i],
746 out_simd[i]
747 );
748 }
749 }
750 }
751 }
752
753 #[test]
754 fn test_elementwise_max_simple() {
755 let ops = detect_backend();
756 let a = [1.0f32, 5.0, 3.0, 8.0, 2.0];
757 let b = [4.0f32, 2.0, 7.0, 1.0, 9.0];
758 let mut out = [0.0f32; 5];
759 ops.elementwise_max(&a, &b, &mut out);
760 assert_eq!(out, [4.0, 5.0, 7.0, 8.0, 9.0]);
761 }
762
763 #[test]
764 fn test_elementwise_max_matches_scalar() {
765 let scalar = SimdBackend::Scalar;
766
767 for &backend in &available_backends() {
768 if backend == SimdBackend::Scalar {
769 continue;
770 }
771 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
772 let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
773 let b: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
774 let mut out_scalar = vec![0.0f32; size];
775 let mut out_simd = vec![0.0f32; size];
776
777 scalar.elementwise_max(&a, &b, &mut out_scalar);
778 backend.elementwise_max(&a, &b, &mut out_simd);
779
780 for i in 0..size {
781 assert!(
782 (out_scalar[i] - out_simd[i]).abs() < 1e-6,
783 "[{}] max mismatch at index {i} for size {size}: scalar={}, simd={}",
784 backend.name(),
785 out_scalar[i],
786 out_simd[i]
787 );
788 }
789 }
790 }
791 }
792
793 #[test]
794 fn test_complex_multiply_accumulate_simple() {
795 let ops = detect_backend();
796 let x_re = [1.0f32];
800 let x_im = [2.0f32];
801 let h_re = [3.0f32];
802 let h_im = [4.0f32];
803 let mut acc_re = [0.0f32];
804 let mut acc_im = [0.0f32];
805 ops.complex_multiply_accumulate(&x_re, &x_im, &h_re, &h_im, &mut acc_re, &mut acc_im);
806 assert!((acc_re[0] - 11.0).abs() < 1e-6);
807 assert!((acc_im[0] - (-2.0)).abs() < 1e-6);
808 }
809
810 #[test]
811 fn test_complex_multiply_accumulate_matches_scalar() {
812 let scalar = SimdBackend::Scalar;
813
814 for &backend in &available_backends() {
815 if backend == SimdBackend::Scalar {
816 continue;
817 }
818 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
819 let x_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
820 let x_im: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
821 let h_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.05 + 0.5).collect();
822 let h_im: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.03).collect();
823
824 let mut acc_re_scalar = vec![0.5f32; size];
825 let mut acc_im_scalar = vec![-0.3f32; size];
826 let mut acc_re_simd = acc_re_scalar.clone();
827 let mut acc_im_simd = acc_im_scalar.clone();
828
829 scalar.complex_multiply_accumulate(
830 &x_re,
831 &x_im,
832 &h_re,
833 &h_im,
834 &mut acc_re_scalar,
835 &mut acc_im_scalar,
836 );
837 backend.complex_multiply_accumulate(
838 &x_re,
839 &x_im,
840 &h_re,
841 &h_im,
842 &mut acc_re_simd,
843 &mut acc_im_simd,
844 );
845
846 for i in 0..size {
847 assert!(
848 (acc_re_scalar[i] - acc_re_simd[i]).abs() < 1e-4,
849 "[{}] cma re mismatch at {i} for size {size}: scalar={}, simd={}",
850 backend.name(),
851 acc_re_scalar[i],
852 acc_re_simd[i]
853 );
854 assert!(
855 (acc_im_scalar[i] - acc_im_simd[i]).abs() < 1e-4,
856 "[{}] cma im mismatch at {i} for size {size}: scalar={}, simd={}",
857 backend.name(),
858 acc_im_scalar[i],
859 acc_im_simd[i]
860 );
861 }
862 }
863 }
864 }
865
866 #[test]
867 fn test_complex_multiply_accumulate_standard_simple() {
868 let ops = detect_backend();
869 let x_re = [1.0f32];
873 let x_im = [2.0f32];
874 let h_re = [3.0f32];
875 let h_im = [4.0f32];
876 let mut acc_re = [0.0f32];
877 let mut acc_im = [0.0f32];
878 ops.complex_multiply_accumulate_standard(
879 &x_re,
880 &x_im,
881 &h_re,
882 &h_im,
883 &mut acc_re,
884 &mut acc_im,
885 );
886 assert!((acc_re[0] - (-5.0)).abs() < 1e-6);
887 assert!((acc_im[0] - 10.0).abs() < 1e-6);
888 }
889
890 #[test]
891 fn test_complex_multiply_accumulate_standard_matches_scalar() {
892 let scalar = SimdBackend::Scalar;
893
894 for &backend in &available_backends() {
895 if backend == SimdBackend::Scalar {
896 continue;
897 }
898 for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
899 let x_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
900 let x_im: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
901 let h_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.05 + 0.5).collect();
902 let h_im: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.03).collect();
903
904 let mut acc_re_scalar = vec![0.5f32; size];
905 let mut acc_im_scalar = vec![-0.3f32; size];
906 let mut acc_re_simd = acc_re_scalar.clone();
907 let mut acc_im_simd = acc_im_scalar.clone();
908
909 scalar.complex_multiply_accumulate_standard(
910 &x_re,
911 &x_im,
912 &h_re,
913 &h_im,
914 &mut acc_re_scalar,
915 &mut acc_im_scalar,
916 );
917 backend.complex_multiply_accumulate_standard(
918 &x_re,
919 &x_im,
920 &h_re,
921 &h_im,
922 &mut acc_re_simd,
923 &mut acc_im_simd,
924 );
925
926 for i in 0..size {
927 assert!(
928 (acc_re_scalar[i] - acc_re_simd[i]).abs() < 1e-4,
929 "[{}] std cma re mismatch at {i} for size {size}: scalar={}, simd={}",
930 backend.name(),
931 acc_re_scalar[i],
932 acc_re_simd[i]
933 );
934 assert!(
935 (acc_im_scalar[i] - acc_im_simd[i]).abs() < 1e-4,
936 "[{}] std cma im mismatch at {i} for size {size}: scalar={}, simd={}",
937 backend.name(),
938 acc_im_scalar[i],
939 acc_im_simd[i]
940 );
941 }
942 }
943 }
944 }
945}