Skip to main content

sonora_simd/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod fallback;
4
5#[cfg(target_arch = "aarch64")]
6mod neon;
7
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9mod sse2;
10
11#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
12mod avx2;
13
14/// Available SIMD backends, selected at runtime based on CPU features.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum SimdBackend {
17    /// Scalar fallback — works on all platforms.
18    Scalar,
19    /// x86/x86_64 SSE2 (128-bit, 4 floats at a time).
20    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21    Sse2,
22    /// x86/x86_64 AVX2 + FMA (256-bit, 8 floats at a time).
23    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
24    Avx2,
25    /// ARM aarch64 NEON (128-bit, 4 floats at a time).
26    #[cfg(target_arch = "aarch64")]
27    Neon,
28}
29
30impl SimdBackend {
31    /// Returns the name of this backend.
32    pub fn name(self) -> &'static str {
33        match self {
34            Self::Scalar => "scalar",
35            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
36            Self::Sse2 => "sse2",
37            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38            Self::Avx2 => "avx2+fma",
39            #[cfg(target_arch = "aarch64")]
40            Self::Neon => "neon",
41        }
42    }
43
44    /// Compute the dot product of two float slices.
45    ///
46    /// `a` and `b` must have the same length. Returns `sum of a[i]*b[i]`.
47    pub fn dot_product(self, a: &[f32], b: &[f32]) -> f32 {
48        debug_assert_eq!(a.len(), b.len());
49        match self {
50            Self::Scalar => fallback::dot_product(a, b),
51            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
52            // SAFETY: detect_backend() only returns Sse2 after confirming sse2 support.
53            Self::Sse2 => unsafe { sse2::dot_product(a, b) },
54            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
55            // SAFETY: detect_backend() only returns Avx2 after confirming avx2+fma support.
56            Self::Avx2 => unsafe { avx2::dot_product(a, b) },
57            #[cfg(target_arch = "aarch64")]
58            // SAFETY: NEON is always available on aarch64.
59            Self::Neon => unsafe { neon::dot_product(a, b) },
60        }
61    }
62
63    /// Compute two dot products in parallel (for sinc resampler convolution).
64    ///
65    /// Returns (dot(input, k1), dot(input, k2)). All slices must have the
66    /// same length.
67    pub fn dual_dot_product(self, input: &[f32], k1: &[f32], k2: &[f32]) -> (f32, f32) {
68        debug_assert_eq!(input.len(), k1.len());
69        debug_assert_eq!(input.len(), k2.len());
70        match self {
71            Self::Scalar => fallback::dual_dot_product(input, k1, k2),
72            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
73            // SAFETY: detect_backend() only returns Sse2 after confirming sse2 support.
74            Self::Sse2 => unsafe { sse2::dual_dot_product(input, k1, k2) },
75            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76            // SAFETY: detect_backend() only returns Avx2 after confirming avx2+fma support.
77            Self::Avx2 => unsafe { avx2::dual_dot_product(input, k1, k2) },
78            #[cfg(target_arch = "aarch64")]
79            // SAFETY: NEON is always available on aarch64.
80            Self::Neon => unsafe { neon::dual_dot_product(input, k1, k2) },
81        }
82    }
83
84    /// Sinc resampler convolution: dual dot product with interpolation.
85    ///
86    /// Computes `(1-f)*dot(input,k1) + f*dot(input,k2)` where `f` is the
87    /// `kernel_interpolation_factor`. Unlike [`dual_dot_product()`](Self::dual_dot_product) followed by
88    /// scalar interpolation, this performs interpolation on SIMD vectors
89    /// *before* horizontal reduction, matching C++ `SincResampler::Convolve_*`
90    /// rounding behavior exactly.
91    ///
92    /// The scalar fallback interpolates in `f64` matching C++ `Convolve_C`.
93    pub fn convolve_sinc(
94        self,
95        input: &[f32],
96        k1: &[f32],
97        k2: &[f32],
98        kernel_interpolation_factor: f64,
99    ) -> f32 {
100        debug_assert_eq!(input.len(), k1.len());
101        debug_assert_eq!(input.len(), k2.len());
102        match self {
103            Self::Scalar => fallback::convolve_sinc(input, k1, k2, kernel_interpolation_factor),
104            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
105            // SAFETY: detect_backend() only returns Sse2 after confirming sse2 support.
106            Self::Sse2 => unsafe {
107                sse2::convolve_sinc(input, k1, k2, kernel_interpolation_factor)
108            },
109            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
110            // SAFETY: detect_backend() only returns Avx2 after confirming avx2+fma support.
111            Self::Avx2 => unsafe {
112                avx2::convolve_sinc(input, k1, k2, kernel_interpolation_factor)
113            },
114            #[cfg(target_arch = "aarch64")]
115            // SAFETY: NEON is always available on aarch64.
116            Self::Neon => unsafe {
117                neon::convolve_sinc(input, k1, k2, kernel_interpolation_factor)
118            },
119        }
120    }
121
122    /// Element-wise multiply-accumulate: `acc[i] += a[i] * b[i]`
123    ///
124    /// `acc`, `a`, and `b` must have the same length.
125    pub fn multiply_accumulate(self, acc: &mut [f32], a: &[f32], b: &[f32]) {
126        debug_assert_eq!(acc.len(), a.len());
127        debug_assert_eq!(acc.len(), b.len());
128        match self {
129            Self::Scalar => fallback::multiply_accumulate(acc, a, b),
130            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
131            // SAFETY: detect_backend() only returns Sse2 after confirming sse2 support.
132            Self::Sse2 => unsafe { sse2::multiply_accumulate(acc, a, b) },
133            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
134            // SAFETY: detect_backend() only returns Avx2 after confirming avx2+fma support.
135            Self::Avx2 => unsafe { avx2::multiply_accumulate(acc, a, b) },
136            #[cfg(target_arch = "aarch64")]
137            // SAFETY: NEON is always available on aarch64.
138            Self::Neon => unsafe { neon::multiply_accumulate(acc, a, b) },
139        }
140    }
141
142    /// Compute the sum of all elements in a slice.
143    pub fn sum(self, x: &[f32]) -> f32 {
144        match self {
145            Self::Scalar => fallback::sum(x),
146            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
147            // SAFETY: detect_backend() only returns Sse2 after confirming sse2 support.
148            Self::Sse2 => unsafe { sse2::sum(x) },
149            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
150            // SAFETY: detect_backend() only returns Avx2 after confirming avx2+fma support.
151            Self::Avx2 => unsafe { avx2::sum(x) },
152            #[cfg(target_arch = "aarch64")]
153            // SAFETY: NEON is always available on aarch64.
154            Self::Neon => unsafe { neon::sum(x) },
155        }
156    }
157
158    /// Elementwise square root: `x[i] = sqrt(x[i])`
159    pub fn elementwise_sqrt(self, x: &mut [f32]) {
160        match self {
161            Self::Scalar => fallback::elementwise_sqrt(x),
162            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
163            Self::Sse2 => unsafe { sse2::elementwise_sqrt(x) },
164            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
165            Self::Avx2 => unsafe { avx2::elementwise_sqrt(x) },
166            #[cfg(target_arch = "aarch64")]
167            Self::Neon => unsafe { neon::elementwise_sqrt(x) },
168        }
169    }
170
171    /// Elementwise vector multiplication: `z[i] = x[i] * y[i]`
172    ///
173    /// `x`, `y`, and `z` must have the same length.
174    pub fn elementwise_multiply(self, x: &[f32], y: &[f32], z: &mut [f32]) {
175        debug_assert_eq!(x.len(), y.len());
176        debug_assert_eq!(x.len(), z.len());
177        match self {
178            Self::Scalar => fallback::elementwise_multiply(x, y, z),
179            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
180            Self::Sse2 => unsafe { sse2::elementwise_multiply(x, y, z) },
181            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
182            Self::Avx2 => unsafe { avx2::elementwise_multiply(x, y, z) },
183            #[cfg(target_arch = "aarch64")]
184            Self::Neon => unsafe { neon::elementwise_multiply(x, y, z) },
185        }
186    }
187
188    /// Elementwise accumulate: `z[i] += x[i]`
189    ///
190    /// `x` and `z` must have the same length.
191    pub fn elementwise_accumulate(self, x: &[f32], z: &mut [f32]) {
192        debug_assert_eq!(x.len(), z.len());
193        match self {
194            Self::Scalar => fallback::elementwise_accumulate(x, z),
195            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196            Self::Sse2 => unsafe { sse2::elementwise_accumulate(x, z) },
197            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
198            Self::Avx2 => unsafe { avx2::elementwise_accumulate(x, z) },
199            #[cfg(target_arch = "aarch64")]
200            Self::Neon => unsafe { neon::elementwise_accumulate(x, z) },
201        }
202    }
203
204    /// Compute the power spectrum: `out[i] = re[i]^2 + im[i]^2`
205    ///
206    /// `re`, `im`, and `out` must have the same length.
207    pub fn power_spectrum(self, re: &[f32], im: &[f32], out: &mut [f32]) {
208        debug_assert_eq!(re.len(), im.len());
209        debug_assert_eq!(re.len(), out.len());
210        match self {
211            Self::Scalar => fallback::power_spectrum(re, im, out),
212            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
213            Self::Sse2 => unsafe { sse2::power_spectrum(re, im, out) },
214            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
215            Self::Avx2 => unsafe { avx2::power_spectrum(re, im, out) },
216            #[cfg(target_arch = "aarch64")]
217            Self::Neon => unsafe { neon::power_spectrum(re, im, out) },
218        }
219    }
220
221    /// Elementwise minimum: `out[i] = min(a[i], b[i])`
222    ///
223    /// `a`, `b`, and `out` must have the same length.
224    pub fn elementwise_min(self, a: &[f32], b: &[f32], out: &mut [f32]) {
225        debug_assert_eq!(a.len(), b.len());
226        debug_assert_eq!(a.len(), out.len());
227        match self {
228            Self::Scalar => fallback::elementwise_min(a, b, out),
229            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
230            Self::Sse2 => unsafe { sse2::elementwise_min(a, b, out) },
231            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
232            Self::Avx2 => unsafe { avx2::elementwise_min(a, b, out) },
233            #[cfg(target_arch = "aarch64")]
234            Self::Neon => unsafe { neon::elementwise_min(a, b, out) },
235        }
236    }
237
238    /// Elementwise maximum: `out[i] = max(a[i], b[i])`
239    ///
240    /// `a`, `b`, and `out` must have the same length.
241    pub fn elementwise_max(self, a: &[f32], b: &[f32], out: &mut [f32]) {
242        debug_assert_eq!(a.len(), b.len());
243        debug_assert_eq!(a.len(), out.len());
244        match self {
245            Self::Scalar => fallback::elementwise_max(a, b, out),
246            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
247            Self::Sse2 => unsafe { sse2::elementwise_max(a, b, out) },
248            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249            Self::Avx2 => unsafe { avx2::elementwise_max(a, b, out) },
250            #[cfg(target_arch = "aarch64")]
251            Self::Neon => unsafe { neon::elementwise_max(a, b, out) },
252        }
253    }
254
255    /// Complex multiply-accumulate (AEC3 conjugate convention):
256    ///   `acc_re[i] += x_re[i]*h_re[i] + x_im[i]*h_im[i]`
257    ///   `acc_im[i] += x_re[i]*h_im[i] - x_im[i]*h_re[i]`
258    ///
259    /// All slices must have the same length.
260    pub fn complex_multiply_accumulate(
261        self,
262        x_re: &[f32],
263        x_im: &[f32],
264        h_re: &[f32],
265        h_im: &[f32],
266        acc_re: &mut [f32],
267        acc_im: &mut [f32],
268    ) {
269        debug_assert_eq!(x_re.len(), x_im.len());
270        debug_assert_eq!(x_re.len(), h_re.len());
271        debug_assert_eq!(x_re.len(), h_im.len());
272        debug_assert_eq!(x_re.len(), acc_re.len());
273        debug_assert_eq!(x_re.len(), acc_im.len());
274        match self {
275            Self::Scalar => {
276                fallback::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
277            }
278            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
279            Self::Sse2 => unsafe {
280                sse2::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
281            },
282            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
283            Self::Avx2 => unsafe {
284                avx2::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
285            },
286            #[cfg(target_arch = "aarch64")]
287            Self::Neon => unsafe {
288                neon::complex_multiply_accumulate(x_re, x_im, h_re, h_im, acc_re, acc_im);
289            },
290        }
291    }
292    /// Standard complex multiply-accumulate:
293    ///   `acc_re[i] += x_re[i]*h_re[i] - x_im[i]*h_im[i]`
294    ///   `acc_im[i] += x_re[i]*h_im[i] + x_im[i]*h_re[i]`
295    ///
296    /// All slices must have the same length.
297    pub fn complex_multiply_accumulate_standard(
298        self,
299        x_re: &[f32],
300        x_im: &[f32],
301        h_re: &[f32],
302        h_im: &[f32],
303        acc_re: &mut [f32],
304        acc_im: &mut [f32],
305    ) {
306        debug_assert_eq!(x_re.len(), x_im.len());
307        debug_assert_eq!(x_re.len(), h_re.len());
308        debug_assert_eq!(x_re.len(), h_im.len());
309        debug_assert_eq!(x_re.len(), acc_re.len());
310        debug_assert_eq!(x_re.len(), acc_im.len());
311        match self {
312            Self::Scalar => {
313                fallback::complex_multiply_accumulate_standard(
314                    x_re, x_im, h_re, h_im, acc_re, acc_im,
315                );
316            }
317            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
318            Self::Sse2 => unsafe {
319                sse2::complex_multiply_accumulate_standard(x_re, x_im, h_re, h_im, acc_re, acc_im);
320            },
321            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
322            Self::Avx2 => unsafe {
323                avx2::complex_multiply_accumulate_standard(x_re, x_im, h_re, h_im, acc_re, acc_im);
324            },
325            #[cfg(target_arch = "aarch64")]
326            Self::Neon => unsafe {
327                neon::complex_multiply_accumulate_standard(x_re, x_im, h_re, h_im, acc_re, acc_im);
328            },
329        }
330    }
331}
332
333// Runtime CPU feature detection via cpufeatures (atomic-cached).
334#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
335cpufeatures::new!(has_avx2_fma, "avx2", "fma");
336#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
337cpufeatures::new!(has_sse2, "sse2");
338
339/// Return all SIMD backends supported by the current CPU.
340///
341/// Always includes [`SimdBackend::Scalar`]. On x86/x86_64, includes SSE2
342/// and/or AVX2+FMA depending on what the CPU reports. On aarch64, includes
343/// NEON. Useful for testing every backend, not just the fastest one.
344pub fn available_backends() -> Vec<SimdBackend> {
345    let mut backends = vec![SimdBackend::Scalar];
346
347    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
348    {
349        if has_sse2::get() {
350            backends.push(SimdBackend::Sse2);
351        }
352        if has_avx2_fma::get() {
353            backends.push(SimdBackend::Avx2);
354        }
355    }
356
357    #[cfg(target_arch = "aarch64")]
358    backends.push(SimdBackend::Neon);
359
360    backends
361}
362
363/// Detect the best available SIMD backend for the current CPU.
364///
365/// Uses runtime feature detection on x86/x86_64 (cached atomically after
366/// first call via `cpufeatures`). On aarch64, NEON is always available.
367/// Falls back to scalar on unknown architectures.
368pub fn detect_backend() -> SimdBackend {
369    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
370    {
371        if has_avx2_fma::get() {
372            return SimdBackend::Avx2;
373        }
374        if has_sse2::get() {
375            return SimdBackend::Sse2;
376        }
377    }
378
379    #[cfg(target_arch = "aarch64")]
380    {
381        return SimdBackend::Neon;
382    }
383
384    #[allow(unreachable_code, reason = "fallback for architectures without SIMD")]
385    SimdBackend::Scalar
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_detect_backend() {
394        let backend = detect_backend();
395        println!("Detected SIMD backend: {}", backend.name());
396        assert!(!backend.name().is_empty());
397    }
398
399    #[test]
400    fn test_backend_is_copy() {
401        let a = detect_backend();
402        let b = a;
403        assert_eq!(a, b);
404    }
405
406    #[test]
407    fn test_dot_product_simple() {
408        let ops = detect_backend();
409        let a = [1.0f32, 2.0, 3.0, 4.0];
410        let b = [5.0f32, 6.0, 7.0, 8.0];
411        let result = ops.dot_product(&a, &b);
412        // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
413        assert!((result - 70.0).abs() < 1e-6);
414    }
415
416    #[test]
417    fn test_dual_dot_product_simple() {
418        let ops = detect_backend();
419        let input = [1.0f32, 2.0, 3.0, 4.0];
420        let k1 = [1.0f32, 0.0, 1.0, 0.0];
421        let k2 = [0.0f32, 1.0, 0.0, 1.0];
422        let (d1, d2) = ops.dual_dot_product(&input, &k1, &k2);
423        assert!((d1 - 4.0).abs() < 1e-6);
424        assert!((d2 - 6.0).abs() < 1e-6);
425    }
426
427    #[test]
428    fn test_multiply_accumulate_simple() {
429        let ops = detect_backend();
430        let mut acc = [10.0f32, 20.0, 30.0, 40.0];
431        let a = [1.0f32, 2.0, 3.0, 4.0];
432        let b = [5.0f32, 6.0, 7.0, 8.0];
433        ops.multiply_accumulate(&mut acc, &a, &b);
434        assert!((acc[0] - 15.0).abs() < 1e-6);
435        assert!((acc[1] - 32.0).abs() < 1e-6);
436        assert!((acc[2] - 51.0).abs() < 1e-6);
437        assert!((acc[3] - 72.0).abs() < 1e-6);
438    }
439
440    #[test]
441    fn test_sum_simple() {
442        let ops = detect_backend();
443        let x = [1.0f32, 2.0, 3.0, 4.0, 5.0];
444        assert!((ops.sum(&x) - 15.0).abs() < 1e-6);
445    }
446
447    #[test]
448    fn test_empty_slices() {
449        let ops = detect_backend();
450        assert_eq!(ops.dot_product(&[], &[]), 0.0);
451        assert_eq!(ops.sum(&[]), 0.0);
452        let (d1, d2) = ops.dual_dot_product(&[], &[], &[]);
453        assert_eq!(d1, 0.0);
454        assert_eq!(d2, 0.0);
455    }
456
457    /// Compare all SIMD backends against scalar fallback with larger inputs.
458    #[test]
459    fn test_dot_product_matches_scalar() {
460        let scalar = SimdBackend::Scalar;
461
462        for &backend in &available_backends() {
463            if backend == SimdBackend::Scalar {
464                continue;
465            }
466            for size in [0, 1, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256] {
467                let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
468                let b: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.005).collect();
469
470                let scalar_result = scalar.dot_product(&a, &b);
471                let simd_result = backend.dot_product(&a, &b);
472
473                assert!(
474                    (scalar_result - simd_result).abs() < 1e-3,
475                    "[{}] Mismatch for size {size}: scalar={scalar_result}, simd={simd_result}",
476                    backend.name()
477                );
478            }
479        }
480    }
481
482    #[test]
483    fn test_dual_dot_product_matches_scalar() {
484        let scalar = SimdBackend::Scalar;
485
486        for &backend in &available_backends() {
487            if backend == SimdBackend::Scalar {
488                continue;
489            }
490            for size in [0, 1, 4, 7, 16, 31, 64, 128, 256] {
491                let input: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
492                let k1: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.003).collect();
493                let k2: Vec<f32> = (0..size).map(|i| 0.5 + (i as f32) * 0.002).collect();
494
495                let (s1, s2) = scalar.dual_dot_product(&input, &k1, &k2);
496                let (d1, d2) = backend.dual_dot_product(&input, &k1, &k2);
497
498                assert!(
499                    (s1 - d1).abs() < 1e-3,
500                    "[{}] k1 mismatch for size {size}: scalar={s1}, simd={d1}",
501                    backend.name()
502                );
503                assert!(
504                    (s2 - d2).abs() < 1e-3,
505                    "[{}] k2 mismatch for size {size}: scalar={s2}, simd={d2}",
506                    backend.name()
507                );
508            }
509        }
510    }
511
512    #[test]
513    fn test_multiply_accumulate_matches_scalar() {
514        let scalar = SimdBackend::Scalar;
515
516        for &backend in &available_backends() {
517            if backend == SimdBackend::Scalar {
518                continue;
519            }
520            for size in [0, 1, 4, 7, 16, 31, 64, 128] {
521                let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
522                let b: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.005).collect();
523
524                let mut acc_scalar: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
525                let mut acc_simd = acc_scalar.clone();
526
527                scalar.multiply_accumulate(&mut acc_scalar, &a, &b);
528                backend.multiply_accumulate(&mut acc_simd, &a, &b);
529
530                for i in 0..size {
531                    assert!(
532                        (acc_scalar[i] - acc_simd[i]).abs() < 1e-4,
533                        "[{}] Mismatch at index {i} for size {size}: scalar={}, simd={}",
534                        backend.name(),
535                        acc_scalar[i],
536                        acc_simd[i]
537                    );
538                }
539            }
540        }
541    }
542
543    #[test]
544    fn test_elementwise_sqrt_simple() {
545        let ops = detect_backend();
546        let mut x = [4.0f32, 9.0, 16.0, 25.0, 36.0];
547        ops.elementwise_sqrt(&mut x);
548        assert!((x[0] - 2.0).abs() < 1e-6);
549        assert!((x[1] - 3.0).abs() < 1e-6);
550        assert!((x[2] - 4.0).abs() < 1e-6);
551        assert!((x[3] - 5.0).abs() < 1e-6);
552        assert!((x[4] - 6.0).abs() < 1e-6);
553    }
554
555    #[test]
556    fn test_elementwise_sqrt_matches_scalar() {
557        let scalar = SimdBackend::Scalar;
558
559        for &backend in &available_backends() {
560            if backend == SimdBackend::Scalar {
561                continue;
562            }
563            for size in [0, 1, 4, 7, 8, 15, 16, 31, 64, 65, 128] {
564                let mut x_scalar: Vec<f32> = (0..size).map(|i| (i as f32) * 0.5 + 0.1).collect();
565                let mut x_simd = x_scalar.clone();
566
567                scalar.elementwise_sqrt(&mut x_scalar);
568                backend.elementwise_sqrt(&mut x_simd);
569
570                for i in 0..size {
571                    assert!(
572                        (x_scalar[i] - x_simd[i]).abs() < 1e-6,
573                        "[{}] sqrt mismatch at index {i} for size {size}: scalar={}, simd={}",
574                        backend.name(),
575                        x_scalar[i],
576                        x_simd[i]
577                    );
578                }
579            }
580        }
581    }
582
583    #[test]
584    fn test_elementwise_multiply_simple() {
585        let ops = detect_backend();
586        let x = [1.0f32, 2.0, 3.0, 4.0, 5.0];
587        let y = [5.0f32, 4.0, 3.0, 2.0, 1.0];
588        let mut z = [0.0f32; 5];
589        ops.elementwise_multiply(&x, &y, &mut z);
590        assert!((z[0] - 5.0).abs() < 1e-6);
591        assert!((z[1] - 8.0).abs() < 1e-6);
592        assert!((z[2] - 9.0).abs() < 1e-6);
593        assert!((z[3] - 8.0).abs() < 1e-6);
594        assert!((z[4] - 5.0).abs() < 1e-6);
595    }
596
597    #[test]
598    fn test_elementwise_multiply_matches_scalar() {
599        let scalar = SimdBackend::Scalar;
600
601        for &backend in &available_backends() {
602            if backend == SimdBackend::Scalar {
603                continue;
604            }
605            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
606                let x: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
607                let y: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.005).collect();
608                let mut z_scalar = vec![0.0f32; size];
609                let mut z_simd = vec![0.0f32; size];
610
611                scalar.elementwise_multiply(&x, &y, &mut z_scalar);
612                backend.elementwise_multiply(&x, &y, &mut z_simd);
613
614                for i in 0..size {
615                    assert!(
616                        (z_scalar[i] - z_simd[i]).abs() < 1e-6,
617                        "[{}] multiply mismatch at index {i} for size {size}: scalar={}, simd={}",
618                        backend.name(),
619                        z_scalar[i],
620                        z_simd[i]
621                    );
622                }
623            }
624        }
625    }
626
627    #[test]
628    fn test_elementwise_accumulate_simple() {
629        let ops = detect_backend();
630        let x = [1.0f32, 2.0, 3.0, 4.0, 5.0];
631        let mut z = [10.0f32, 20.0, 30.0, 40.0, 50.0];
632        ops.elementwise_accumulate(&x, &mut z);
633        assert!((z[0] - 11.0).abs() < 1e-6);
634        assert!((z[1] - 22.0).abs() < 1e-6);
635        assert!((z[2] - 33.0).abs() < 1e-6);
636        assert!((z[3] - 44.0).abs() < 1e-6);
637        assert!((z[4] - 55.0).abs() < 1e-6);
638    }
639
640    #[test]
641    fn test_elementwise_accumulate_matches_scalar() {
642        let scalar = SimdBackend::Scalar;
643
644        for &backend in &available_backends() {
645            if backend == SimdBackend::Scalar {
646                continue;
647            }
648            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
649                let x: Vec<f32> = (0..size).map(|i| (i as f32) * 0.01).collect();
650                let mut z_scalar: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
651                let mut z_simd = z_scalar.clone();
652
653                scalar.elementwise_accumulate(&x, &mut z_scalar);
654                backend.elementwise_accumulate(&x, &mut z_simd);
655
656                for i in 0..size {
657                    assert!(
658                        (z_scalar[i] - z_simd[i]).abs() < 1e-6,
659                        "[{}] accumulate mismatch at index {i} for size {size}: scalar={}, simd={}",
660                        backend.name(),
661                        z_scalar[i],
662                        z_simd[i]
663                    );
664                }
665            }
666        }
667    }
668
669    #[test]
670    fn test_power_spectrum_simple() {
671        let ops = detect_backend();
672        let re = [3.0f32, 0.0, 1.0, 2.0, 5.0];
673        let im = [4.0f32, 1.0, 0.0, 3.0, 12.0];
674        let mut out = [0.0f32; 5];
675        ops.power_spectrum(&re, &im, &mut out);
676        assert!((out[0] - 25.0).abs() < 1e-6); // 9 + 16
677        assert!((out[1] - 1.0).abs() < 1e-6); // 0 + 1
678        assert!((out[2] - 1.0).abs() < 1e-6); // 1 + 0
679        assert!((out[3] - 13.0).abs() < 1e-6); // 4 + 9
680        assert!((out[4] - 169.0).abs() < 1e-6); // 25 + 144
681    }
682
683    #[test]
684    fn test_power_spectrum_matches_scalar() {
685        let scalar = SimdBackend::Scalar;
686
687        for &backend in &available_backends() {
688            if backend == SimdBackend::Scalar {
689                continue;
690            }
691            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
692                let re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
693                let im: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
694                let mut out_scalar = vec![0.0f32; size];
695                let mut out_simd = vec![0.0f32; size];
696
697                scalar.power_spectrum(&re, &im, &mut out_scalar);
698                backend.power_spectrum(&re, &im, &mut out_simd);
699
700                for i in 0..size {
701                    assert!(
702                        (out_scalar[i] - out_simd[i]).abs() < 1e-4,
703                        "[{}] power_spectrum mismatch at index {i} for size {size}: scalar={}, simd={}",
704                        backend.name(),
705                        out_scalar[i],
706                        out_simd[i]
707                    );
708                }
709            }
710        }
711    }
712
713    #[test]
714    fn test_elementwise_min_simple() {
715        let ops = detect_backend();
716        let a = [1.0f32, 5.0, 3.0, 8.0, 2.0];
717        let b = [4.0f32, 2.0, 7.0, 1.0, 9.0];
718        let mut out = [0.0f32; 5];
719        ops.elementwise_min(&a, &b, &mut out);
720        assert_eq!(out, [1.0, 2.0, 3.0, 1.0, 2.0]);
721    }
722
723    #[test]
724    fn test_elementwise_min_matches_scalar() {
725        let scalar = SimdBackend::Scalar;
726
727        for &backend in &available_backends() {
728            if backend == SimdBackend::Scalar {
729                continue;
730            }
731            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
732                let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
733                let b: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
734                let mut out_scalar = vec![0.0f32; size];
735                let mut out_simd = vec![0.0f32; size];
736
737                scalar.elementwise_min(&a, &b, &mut out_scalar);
738                backend.elementwise_min(&a, &b, &mut out_simd);
739
740                for i in 0..size {
741                    assert!(
742                        (out_scalar[i] - out_simd[i]).abs() < 1e-6,
743                        "[{}] min mismatch at index {i} for size {size}: scalar={}, simd={}",
744                        backend.name(),
745                        out_scalar[i],
746                        out_simd[i]
747                    );
748                }
749            }
750        }
751    }
752
753    #[test]
754    fn test_elementwise_max_simple() {
755        let ops = detect_backend();
756        let a = [1.0f32, 5.0, 3.0, 8.0, 2.0];
757        let b = [4.0f32, 2.0, 7.0, 1.0, 9.0];
758        let mut out = [0.0f32; 5];
759        ops.elementwise_max(&a, &b, &mut out);
760        assert_eq!(out, [4.0, 5.0, 7.0, 8.0, 9.0]);
761    }
762
763    #[test]
764    fn test_elementwise_max_matches_scalar() {
765        let scalar = SimdBackend::Scalar;
766
767        for &backend in &available_backends() {
768            if backend == SimdBackend::Scalar {
769                continue;
770            }
771            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
772                let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
773                let b: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
774                let mut out_scalar = vec![0.0f32; size];
775                let mut out_simd = vec![0.0f32; size];
776
777                scalar.elementwise_max(&a, &b, &mut out_scalar);
778                backend.elementwise_max(&a, &b, &mut out_simd);
779
780                for i in 0..size {
781                    assert!(
782                        (out_scalar[i] - out_simd[i]).abs() < 1e-6,
783                        "[{}] max mismatch at index {i} for size {size}: scalar={}, simd={}",
784                        backend.name(),
785                        out_scalar[i],
786                        out_simd[i]
787                    );
788                }
789            }
790        }
791    }
792
793    #[test]
794    fn test_complex_multiply_accumulate_simple() {
795        let ops = detect_backend();
796        // (1+2j) * (3+4j) in AEC3 conjugate convention:
797        //   re = 1*3 + 2*4 = 11
798        //   im = 1*4 - 2*3 = -2
799        let x_re = [1.0f32];
800        let x_im = [2.0f32];
801        let h_re = [3.0f32];
802        let h_im = [4.0f32];
803        let mut acc_re = [0.0f32];
804        let mut acc_im = [0.0f32];
805        ops.complex_multiply_accumulate(&x_re, &x_im, &h_re, &h_im, &mut acc_re, &mut acc_im);
806        assert!((acc_re[0] - 11.0).abs() < 1e-6);
807        assert!((acc_im[0] - (-2.0)).abs() < 1e-6);
808    }
809
810    #[test]
811    fn test_complex_multiply_accumulate_matches_scalar() {
812        let scalar = SimdBackend::Scalar;
813
814        for &backend in &available_backends() {
815            if backend == SimdBackend::Scalar {
816                continue;
817            }
818            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
819                let x_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
820                let x_im: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
821                let h_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.05 + 0.5).collect();
822                let h_im: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.03).collect();
823
824                let mut acc_re_scalar = vec![0.5f32; size];
825                let mut acc_im_scalar = vec![-0.3f32; size];
826                let mut acc_re_simd = acc_re_scalar.clone();
827                let mut acc_im_simd = acc_im_scalar.clone();
828
829                scalar.complex_multiply_accumulate(
830                    &x_re,
831                    &x_im,
832                    &h_re,
833                    &h_im,
834                    &mut acc_re_scalar,
835                    &mut acc_im_scalar,
836                );
837                backend.complex_multiply_accumulate(
838                    &x_re,
839                    &x_im,
840                    &h_re,
841                    &h_im,
842                    &mut acc_re_simd,
843                    &mut acc_im_simd,
844                );
845
846                for i in 0..size {
847                    assert!(
848                        (acc_re_scalar[i] - acc_re_simd[i]).abs() < 1e-4,
849                        "[{}] cma re mismatch at {i} for size {size}: scalar={}, simd={}",
850                        backend.name(),
851                        acc_re_scalar[i],
852                        acc_re_simd[i]
853                    );
854                    assert!(
855                        (acc_im_scalar[i] - acc_im_simd[i]).abs() < 1e-4,
856                        "[{}] cma im mismatch at {i} for size {size}: scalar={}, simd={}",
857                        backend.name(),
858                        acc_im_scalar[i],
859                        acc_im_simd[i]
860                    );
861                }
862            }
863        }
864    }
865
866    #[test]
867    fn test_complex_multiply_accumulate_standard_simple() {
868        let ops = detect_backend();
869        // (1+2j) * (3+4j) in standard convention:
870        //   re = 1*3 - 2*4 = -5
871        //   im = 1*4 + 2*3 = 10
872        let x_re = [1.0f32];
873        let x_im = [2.0f32];
874        let h_re = [3.0f32];
875        let h_im = [4.0f32];
876        let mut acc_re = [0.0f32];
877        let mut acc_im = [0.0f32];
878        ops.complex_multiply_accumulate_standard(
879            &x_re,
880            &x_im,
881            &h_re,
882            &h_im,
883            &mut acc_re,
884            &mut acc_im,
885        );
886        assert!((acc_re[0] - (-5.0)).abs() < 1e-6);
887        assert!((acc_im[0] - 10.0).abs() < 1e-6);
888    }
889
890    #[test]
891    fn test_complex_multiply_accumulate_standard_matches_scalar() {
892        let scalar = SimdBackend::Scalar;
893
894        for &backend in &available_backends() {
895            if backend == SimdBackend::Scalar {
896                continue;
897            }
898            for size in [0, 1, 4, 7, 8, 16, 31, 64, 65, 128] {
899                let x_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1 - 3.0).collect();
900                let x_im: Vec<f32> = (0..size).map(|i| 2.0 - (i as f32) * 0.07).collect();
901                let h_re: Vec<f32> = (0..size).map(|i| (i as f32) * 0.05 + 0.5).collect();
902                let h_im: Vec<f32> = (0..size).map(|i| 1.0 - (i as f32) * 0.03).collect();
903
904                let mut acc_re_scalar = vec![0.5f32; size];
905                let mut acc_im_scalar = vec![-0.3f32; size];
906                let mut acc_re_simd = acc_re_scalar.clone();
907                let mut acc_im_simd = acc_im_scalar.clone();
908
909                scalar.complex_multiply_accumulate_standard(
910                    &x_re,
911                    &x_im,
912                    &h_re,
913                    &h_im,
914                    &mut acc_re_scalar,
915                    &mut acc_im_scalar,
916                );
917                backend.complex_multiply_accumulate_standard(
918                    &x_re,
919                    &x_im,
920                    &h_re,
921                    &h_im,
922                    &mut acc_re_simd,
923                    &mut acc_im_simd,
924                );
925
926                for i in 0..size {
927                    assert!(
928                        (acc_re_scalar[i] - acc_re_simd[i]).abs() < 1e-4,
929                        "[{}] std cma re mismatch at {i} for size {size}: scalar={}, simd={}",
930                        backend.name(),
931                        acc_re_scalar[i],
932                        acc_re_simd[i]
933                    );
934                    assert!(
935                        (acc_im_scalar[i] - acc_im_simd[i]).abs() < 1e-4,
936                        "[{}] std cma im mismatch at {i} for size {size}: scalar={}, simd={}",
937                        backend.name(),
938                        acc_im_scalar[i],
939                        acc_im_simd[i]
940                    );
941                }
942            }
943        }
944    }
945}