Skip to main content

zenjxl_decoder_simd/
lib.rs

1// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6#![allow(clippy::too_many_arguments)]
7#![forbid(unsafe_code)]
8
9use std::{
10    fmt::Debug,
11    ops::{
12        Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
13        DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
14    },
15};
16
17#[cfg(target_arch = "x86_64")]
18mod x86_64;
19
20#[cfg(target_arch = "aarch64")]
21mod aarch64;
22
23#[cfg(target_arch = "wasm32")]
24mod wasm32;
25
26pub mod float16;
27pub mod scalar;
28
29pub use float16::f16;
30
31#[cfg(all(target_arch = "x86_64", feature = "avx"))]
32pub use x86_64::avx::AvxDescriptor;
33#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
34pub use x86_64::avx512::Avx512Descriptor;
35#[cfg(all(target_arch = "x86_64", feature = "sse42"))]
36pub use x86_64::sse42::Sse42Descriptor;
37
38#[cfg(all(target_arch = "aarch64", feature = "neon"))]
39pub use aarch64::neon::NeonDescriptor;
40
41#[cfg(all(target_arch = "wasm32", feature = "wasm128"))]
42pub use wasm32::simd128::Wasm128Descriptor;
43
44pub use scalar::ScalarDescriptor;
45
46// Re-exports for simd_function! macro internals.
47// Concrete token types must keep their original names (no aliases) because
48// the #[arcane] proc macro matches on the last path segment to determine
49// which #[target_feature] attributes to emit.
50#[cfg(target_arch = "aarch64")]
51#[doc(hidden)]
52pub use archmage::NeonToken;
53#[doc(hidden)]
54pub use archmage::ScalarToken;
55#[doc(hidden)]
56pub use archmage::SimdToken as __SimdToken;
57#[cfg(target_arch = "wasm32")]
58#[doc(hidden)]
59pub use archmage::Wasm128Token;
60#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
61#[doc(hidden)]
62pub use archmage::X64V4Token;
63#[doc(hidden)]
64pub use archmage::arcane as __arcane;
65#[cfg(target_arch = "x86_64")]
66#[doc(hidden)]
67pub use archmage::{X64V2Token, X64V3Token};
68
69pub trait SimdDescriptor: Sized + Copy + Debug + Send + Sync {
70    type F32Vec: F32SimdVec<Descriptor = Self>;
71
72    type I32Vec: I32SimdVec<Descriptor = Self>;
73
74    type U32Vec: U32SimdVec<Descriptor = Self>;
75
76    type U16Vec: U16SimdVec<Descriptor = Self>;
77
78    type U8Vec: U8SimdVec<Descriptor = Self>;
79
80    type Mask: SimdMask<Descriptor = Self>;
81
82    /// Prepared 8-entry BF16 lookup table for fast approximate lookups.
83    /// Use `F32SimdVec::prepare_table_bf16_8` to create and
84    /// `F32SimdVec::table_lookup_bf16_8` to use.
85    type Bf16Table8: Copy;
86
87    type Descriptor256: SimdDescriptor<Descriptor256 = Self::Descriptor256>;
88    type Descriptor128: SimdDescriptor<Descriptor128 = Self::Descriptor128>;
89
90    fn new() -> Option<Self>;
91
92    /// Returns a vector descriptor suitable for operations on vectors of length 256 (Self if the
93    /// current vector type is suitable). Note that it might still be beneficial to use `Self` for
94    /// .call(), as the compiler could make use of features from more advanced instruction sets.
95    fn maybe_downgrade_256bit(self) -> Self::Descriptor256;
96
97    /// Same as Self::maybe_downgrade_256bit, but for 128 bits.
98    fn maybe_downgrade_128bit(self) -> Self::Descriptor128;
99
100    /// Calls the given closure within a target feature context.
101    /// This enables establishing an unbroken chain of inline functions from the feature-annotated
102    /// gateway up to the closure, allowing SIMD intrinsics to be used safely.
103    fn call<R>(self, f: impl FnOnce(Self) -> R) -> R;
104}
105
106pub trait F32SimdVec:
107    Sized
108    + Copy
109    + Debug
110    + Send
111    + Sync
112    + Add<Self, Output = Self>
113    + Mul<Self, Output = Self>
114    + Sub<Self, Output = Self>
115    + Div<Self, Output = Self>
116    + AddAssign<Self>
117    + MulAssign<Self>
118    + SubAssign<Self>
119    + DivAssign<Self>
120{
121    type Descriptor: SimdDescriptor;
122
123    const LEN: usize;
124
125    /// An array of f32 of length Self::LEN.
126    type UnderlyingArray: Copy + Default + Debug;
127
128    /// Converts v to an array of v.
129    fn splat(d: Self::Descriptor, v: f32) -> Self;
130
131    fn zero(d: Self::Descriptor) -> Self;
132
133    fn mul_add(self, mul: Self, add: Self) -> Self;
134
135    /// Computes `add - self * mul`, equivalent to `self * (-mul) + add`.
136    /// Uses fused multiply-add with negation when available (FMA3 fnmadd).
137    fn neg_mul_add(self, mul: Self, add: Self) -> Self;
138
139    // Requires `mem.len() >= Self::LEN` or it will panic.
140    fn load(d: Self::Descriptor, mem: &[f32]) -> Self;
141
142    /// Loads Self::LEN f32 values starting at `mem[offset..]`.
143    /// Equivalent to `Self::load(d, &mem[offset..])`.
144    #[inline(always)]
145    fn load_from(d: Self::Descriptor, mem: &[f32], offset: usize) -> Self {
146        Self::load(d, &mem[offset..])
147    }
148
149    fn load_array(d: Self::Descriptor, mem: &Self::UnderlyingArray) -> Self;
150
151    // Requires `mem.len() >= Self::LEN` or it will panic.
152    fn store(&self, mem: &mut [f32]);
153
154    /// Stores Self::LEN f32 values starting at `mem[offset..]`.
155    /// Equivalent to `self.store(&mut mem[offset..])`.
156    #[inline(always)]
157    fn store_at(&self, mem: &mut [f32], offset: usize) {
158        self.store(&mut mem[offset..]);
159    }
160
161    fn store_array(&self, mem: &mut Self::UnderlyingArray);
162
163    /// Stores two vectors interleaved: [a0, b0, a1, b1, a2, b2, ...].
164    /// Requires `dest.len() >= 2 * Self::LEN` or it will panic.
165    fn store_interleaved_2(a: Self, b: Self, dest: &mut [f32]);
166
167    /// Stores three vectors interleaved: [a0, b0, c0, a1, b1, c1, ...].
168    /// Requires `dest.len() >= 3 * Self::LEN` or it will panic.
169    fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [f32]);
170
171    /// Stores four vectors interleaved: [a0, b0, c0, d0, a1, b1, c1, d1, ...].
172    /// Requires `dest.len() >= 4 * Self::LEN` or it will panic.
173    fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [f32]);
174
175    /// Stores eight vectors interleaved: [a0, b0, c0, d0, e0, f0, g0, h0, a1, ...].
176    /// Requires `dest.len() >= 8 * Self::LEN` or it will panic.
177    fn store_interleaved_8(
178        a: Self,
179        b: Self,
180        c: Self,
181        d: Self,
182        e: Self,
183        f: Self,
184        g: Self,
185        h: Self,
186        dest: &mut [f32],
187    );
188
189    /// Loads two vectors from interleaved data: [a0, b0, a1, b1, a2, b2, ...].
190    /// Returns (a, b) where a = [a0, a1, a2, ...] and b = [b0, b1, b2, ...].
191    /// Requires `src.len() >= 2 * Self::LEN` or it will panic.
192    fn load_deinterleaved_2(d: Self::Descriptor, src: &[f32]) -> (Self, Self);
193
194    /// Loads three vectors from interleaved data: [a0, b0, c0, a1, b1, c1, ...].
195    /// Returns (a, b, c) where a = [a0, a1, ...], b = [b0, b1, ...], c = [c0, c1, ...].
196    /// Requires `src.len() >= 3 * Self::LEN` or it will panic.
197    fn load_deinterleaved_3(d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self);
198
199    /// Loads four vectors from interleaved data: [a0, b0, c0, d0, a1, b1, c1, d1, ...].
200    /// Returns (a, b, c, d) where each vector contains the deinterleaved components.
201    /// Requires `src.len() >= 4 * Self::LEN` or it will panic.
202    fn load_deinterleaved_4(d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self, Self);
203
204    /// Rounds to nearest integer and stores as u8.
205    /// Behavior is unspecified if values would overflow u8.
206    /// Requires `dest.len() >= Self::LEN` or it will panic.
207    fn round_store_u8(self, dest: &mut [u8]);
208
209    /// Rounds to nearest integer and stores as u8 at the given offset.
210    /// Equivalent to `self.round_store_u8(&mut dest[offset..])`.
211    #[inline(always)]
212    fn round_store_u8_at(self, dest: &mut [u8], offset: usize) {
213        self.round_store_u8(&mut dest[offset..]);
214    }
215
216    /// Rounds to nearest integer and stores as u16.
217    /// Behavior is unspecified if values would overflow u16.
218    /// Requires `dest.len() >= Self::LEN` or it will panic.
219    fn round_store_u16(self, dest: &mut [u16]);
220
221    fn abs(self) -> Self;
222
223    fn floor(self) -> Self;
224
225    fn sqrt(self) -> Self;
226
227    /// Negates all elements. Currently unused but kept for API completeness.
228    #[allow(dead_code)]
229    fn neg(self) -> Self;
230
231    fn copysign(self, sign: Self) -> Self;
232
233    fn max(self, other: Self) -> Self;
234
235    fn min(self, other: Self) -> Self;
236
237    fn gt(self, other: Self) -> <<Self as F32SimdVec>::Descriptor as SimdDescriptor>::Mask;
238
239    fn as_i32(self) -> <<Self as F32SimdVec>::Descriptor as SimdDescriptor>::I32Vec;
240
241    fn bitcast_to_i32(self) -> <<Self as F32SimdVec>::Descriptor as SimdDescriptor>::I32Vec;
242
243    /// Prepares an 8-entry f32 table for fast approximate lookups.
244    /// Values are converted to BF16 format (loses lower 16 mantissa bits).
245    ///
246    /// Use this when you need to perform multiple lookups with the same table.
247    /// The prepared table can be reused with [`table_lookup_bf16_8`].
248    fn prepare_table_bf16_8(
249        d: Self::Descriptor,
250        table: &[f32; 8],
251    ) -> <<Self as F32SimdVec>::Descriptor as SimdDescriptor>::Bf16Table8;
252
253    /// Performs fast approximate table lookup using a prepared BF16 table.
254    ///
255    /// This is the fastest lookup method when the same table is used multiple times.
256    /// Use [`prepare_table_bf16_8`] to create the prepared table.
257    ///
258    /// # Panics
259    /// May panic or produce undefined results if indices contain values outside 0..8 range.
260    fn table_lookup_bf16_8(
261        d: Self::Descriptor,
262        table: <<Self as F32SimdVec>::Descriptor as SimdDescriptor>::Bf16Table8,
263        indices: <<Self as F32SimdVec>::Descriptor as SimdDescriptor>::I32Vec,
264    ) -> Self;
265
266    /// Converts a slice of f32 into a slice of Self::UnderlyingArray. If slice.len() is not a
267    /// multiple of `Self::LEN` this will panic.
268    fn make_array_slice(slice: &[f32]) -> &[Self::UnderlyingArray];
269
270    /// Converts a mut slice of f32 into a slice of Self::UnderlyingArray. If slice.len() is not a
271    /// multiple of `Self::LEN` this will panic.
272    fn make_array_slice_mut(slice: &mut [f32]) -> &mut [Self::UnderlyingArray];
273
274    /// Transposes the Self::LEN x Self::LEN matrix formed by array elements
275    /// `data[stride * i]` for i = 0..Self::LEN.
276    fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize);
277
278    /// Loads f16 values (stored as u16 bit patterns) and converts them to f32.
279    /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM).
280    /// Requires `mem.len() >= Self::LEN` or it will panic.
281    fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self;
282
283    /// Converts f32 values to f16 and stores as u16 bit patterns.
284    /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM).
285    /// Requires `dest.len() >= Self::LEN` or it will panic.
286    fn store_f16_bits(self, dest: &mut [u16]);
287}
288
289pub trait I32SimdVec:
290    Sized
291    + Copy
292    + Debug
293    + Send
294    + Sync
295    + Add<Self, Output = Self>
296    + Mul<Self, Output = Self>
297    + Sub<Self, Output = Self>
298    + Neg<Output = Self>
299    + BitAnd<Self, Output = Self>
300    + BitOr<Self, Output = Self>
301    + BitXor<Self, Output = Self>
302    + AddAssign<Self>
303    + MulAssign<Self>
304    + SubAssign<Self>
305    + BitAndAssign<Self>
306    + BitOrAssign<Self>
307    + BitXorAssign<Self>
308{
309    type Descriptor: SimdDescriptor;
310
311    #[allow(dead_code)]
312    const LEN: usize;
313
314    /// Converts v to an array of v.
315    fn splat(d: Self::Descriptor, v: i32) -> Self;
316
317    // Requires `mem.len() >= Self::LEN` or it will panic.
318    fn load(d: Self::Descriptor, mem: &[i32]) -> Self;
319
320    /// Loads Self::LEN i32 values starting at `mem[offset..]`.
321    /// Equivalent to `Self::load(d, &mem[offset..])`.
322    #[inline(always)]
323    fn load_from(d: Self::Descriptor, mem: &[i32], offset: usize) -> Self {
324        Self::load(d, &mem[offset..])
325    }
326
327    // Requires `mem.len() >= Self::LEN` or it will panic.
328    fn store(&self, mem: &mut [i32]);
329
330    fn abs(self) -> Self;
331
332    fn as_f32(self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::F32Vec;
333
334    fn bitcast_to_f32(self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::F32Vec;
335
336    fn bitcast_to_u32(self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::U32Vec;
337
338    fn gt(self, other: Self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::Mask;
339
340    fn lt_zero(self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::Mask;
341
342    fn eq(self, other: Self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::Mask;
343
344    fn eq_zero(self) -> <<Self as I32SimdVec>::Descriptor as SimdDescriptor>::Mask;
345
346    fn shl<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self;
347
348    fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self;
349
350    fn mul_wide_take_high(self, rhs: Self) -> Self;
351
352    /// Stores the lower 16 bits of each i32 lane as u16 values.
353    /// Requires `dest.len() >= Self::LEN` or it will panic.
354    fn store_u16(self, dest: &mut [u16]);
355
356    /// Stores the lower 8 bits of each i32 lane as u8 values.
357    /// Requires `dest.len() >= Self::LEN` or it will panic.
358    fn store_u8(self, dest: &mut [u8]);
359}
360
361pub trait U32SimdVec: Sized + Copy + Debug + Send + Sync {
362    type Descriptor: SimdDescriptor;
363
364    #[allow(dead_code)]
365    const LEN: usize;
366
367    fn bitcast_to_i32(self) -> <<Self as U32SimdVec>::Descriptor as SimdDescriptor>::I32Vec;
368
369    fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self;
370}
371
372pub trait U8SimdVec: Sized + Copy + Debug + Send + Sync {
373    type Descriptor: SimdDescriptor;
374
375    const LEN: usize;
376
377    fn load(d: Self::Descriptor, mem: &[u8]) -> Self;
378    fn splat(d: Self::Descriptor, v: u8) -> Self;
379    fn store(&self, mem: &mut [u8]);
380
381    /// Stores two vectors interleaved: [a0, b0, a1, b1, a2, b2, ...].
382    /// Requires `dest.len() >= 2 * Self::LEN` or it will panic.
383    fn store_interleaved_2(a: Self, b: Self, dest: &mut [u8]);
384
385    /// Stores three vectors interleaved: [a0, b0, c0, a1, b1, c1, ...].
386    /// Requires `dest.len() >= 3 * Self::LEN` or it will panic.
387    fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u8]);
388
389    /// Stores four vectors interleaved: [a0, b0, c0, d0, a1, b1, c1, d1, ...].
390    /// Requires `dest.len() >= 4 * Self::LEN` or it will panic.
391    fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u8]);
392}
393
394pub trait U16SimdVec: Sized + Copy + Debug + Send + Sync {
395    type Descriptor: SimdDescriptor;
396
397    const LEN: usize;
398
399    fn load(d: Self::Descriptor, mem: &[u16]) -> Self;
400    fn splat(d: Self::Descriptor, v: u16) -> Self;
401    fn store(&self, mem: &mut [u16]);
402
403    /// Stores two vectors interleaved: [a0, b0, a1, b1, a2, b2, ...].
404    /// Requires `dest.len() >= 2 * Self::LEN` or it will panic.
405    fn store_interleaved_2(a: Self, b: Self, dest: &mut [u16]);
406
407    /// Stores three vectors interleaved: [a0, b0, c0, a1, b1, c1, ...].
408    /// Requires `dest.len() >= 3 * Self::LEN` or it will panic.
409    fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u16]);
410
411    /// Stores four vectors interleaved: [a0, b0, c0, d0, a1, b1, c1, d1, ...].
412    /// Requires `dest.len() >= 4 * Self::LEN` or it will panic.
413    fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u16]);
414}
415
416#[macro_export]
417macro_rules! shl {
418    ($val: expr, $amount: literal) => {
419        $val.shl::<{ $amount as u32 }, { $amount as i32 }>()
420    };
421}
422
423#[macro_export]
424macro_rules! shr {
425    ($val: expr, $amount: literal) => {
426        $val.shr::<{ $amount as u32 }, { $amount as i32 }>()
427    };
428}
429
430pub trait SimdMask:
431    Sized + Copy + Debug + Send + Sync + BitAnd<Self, Output = Self> + BitOr<Self, Output = Self>
432{
433    type Descriptor: SimdDescriptor;
434
435    fn if_then_else_f32(
436        self,
437        if_true: <<Self as SimdMask>::Descriptor as SimdDescriptor>::F32Vec,
438        if_false: <<Self as SimdMask>::Descriptor as SimdDescriptor>::F32Vec,
439    ) -> <<Self as SimdMask>::Descriptor as SimdDescriptor>::F32Vec;
440
441    fn if_then_else_i32(
442        self,
443        if_true: <<Self as SimdMask>::Descriptor as SimdDescriptor>::I32Vec,
444        if_false: <<Self as SimdMask>::Descriptor as SimdDescriptor>::I32Vec,
445    ) -> <<Self as SimdMask>::Descriptor as SimdDescriptor>::I32Vec;
446
447    fn maskz_i32(
448        self,
449        v: <<Self as SimdMask>::Descriptor as SimdDescriptor>::I32Vec,
450    ) -> <<Self as SimdMask>::Descriptor as SimdDescriptor>::I32Vec;
451
452    fn all(self) -> bool;
453
454    // !self & rhs
455    fn andnot(self, rhs: Self) -> Self;
456}
457
458macro_rules! impl_f32_array_interface {
459    () => {
460        type UnderlyingArray = [f32; Self::LEN];
461
462        #[inline(always)]
463        fn make_array_slice(slice: &[f32]) -> &[Self::UnderlyingArray] {
464            let (ret, rem) = slice.as_chunks();
465            assert!(rem.is_empty());
466            ret
467        }
468
469        #[inline(always)]
470        fn make_array_slice_mut(slice: &mut [f32]) -> &mut [Self::UnderlyingArray] {
471            let (ret, rem) = slice.as_chunks_mut();
472            assert!(rem.is_empty());
473            ret
474        }
475
476        #[inline(always)]
477        fn load_array(d: Self::Descriptor, mem: &Self::UnderlyingArray) -> Self {
478            Self::load(d, mem)
479        }
480
481        #[inline(always)]
482        fn store_array(&self, mem: &mut Self::UnderlyingArray) {
483            self.store(mem);
484        }
485    };
486}
487
488pub(crate) use impl_f32_array_interface;
489
490#[cfg(test)]
491mod test {
492    use arbtest::arbitrary::Unstructured;
493
494    use crate::{
495        F32SimdVec, I32SimdVec, ScalarDescriptor, SimdDescriptor, U8SimdVec, U16SimdVec,
496        test_all_instruction_sets,
497    };
498
499    enum Distribution {
500        Floats,
501        NonZeroFloats,
502    }
503
504    fn arb_vec<D: SimdDescriptor>(_: D, u: &mut Unstructured, dist: Distribution) -> Vec<f32> {
505        let mut res = vec![0.0; D::F32Vec::LEN];
506        for v in res.iter_mut() {
507            match dist {
508                Distribution::Floats => {
509                    *v = u.arbitrary::<i32>().unwrap() as f32
510                        / (1.0 + u.arbitrary::<u32>().unwrap() as f32)
511                }
512                Distribution::NonZeroFloats => {
513                    let sign = if u.arbitrary::<bool>().unwrap() {
514                        1.0
515                    } else {
516                        -1.0
517                    };
518                    *v = sign * (1.0 + u.arbitrary::<u32>().unwrap() as f32)
519                        / (1.0 + u.arbitrary::<u32>().unwrap() as f32);
520                }
521            }
522        }
523        res
524    }
525
526    fn compare_scalar_simd(scalar: f32, simd: f32, max_abs: f32, max_rel: f32) {
527        let abs = (simd - scalar).abs();
528        let max = simd.abs().max(scalar.abs());
529        let rel = abs / max;
530        assert!(
531            abs < max_abs || rel < max_rel,
532            "simd {simd}, scalar {scalar}, abs {abs:?} rel {rel:?}",
533        );
534    }
535
536    macro_rules! test_instruction {
537        ($name:ident, |$a:ident: $a_dist:ident| $block:expr) => {
538            fn $name<D: SimdDescriptor>(d: D) {
539                fn compute<D: SimdDescriptor>(d: D, a: &[f32]) -> Vec<f32> {
540                    let closure = |$a: D::F32Vec| $block;
541                    let mut res = vec![0f32; a.len()];
542                    for idx in (0..a.len()).step_by(D::F32Vec::LEN) {
543                        closure(D::F32Vec::load(d, &a[idx..])).store(&mut res[idx..]);
544                    }
545                    res
546                }
547                arbtest::arbtest(|u| {
548                    let a = arb_vec(d, u, Distribution::$a_dist);
549                    let scalar_res = compute(ScalarDescriptor::new().unwrap(), &a);
550                    let simd_res = compute(d, &a);
551                    for (scalar, simd) in scalar_res.iter().zip(simd_res.iter()) {
552                        compare_scalar_simd(*scalar, *simd, 1e-6, 1e-6);
553                    }
554                    Ok(())
555                })
556                .size_min(64);
557            }
558            test_all_instruction_sets!($name);
559        };
560        ($name:ident, |$a:ident: $a_dist:ident, $b:ident: $b_dist:ident| $block:expr) => {
561            fn $name<D: SimdDescriptor>(d: D) {
562                fn compute<D: SimdDescriptor>(d: D, a: &[f32], b: &[f32]) -> Vec<f32> {
563                    let closure = |$a: D::F32Vec, $b: D::F32Vec| $block;
564                    let mut res = vec![0f32; a.len()];
565                    for idx in (0..a.len()).step_by(D::F32Vec::LEN) {
566                        closure(D::F32Vec::load(d, &a[idx..]), D::F32Vec::load(d, &b[idx..]))
567                            .store(&mut res[idx..]);
568                    }
569                    res
570                }
571                arbtest::arbtest(|u| {
572                    let a = arb_vec(d, u, Distribution::$a_dist);
573                    let b = arb_vec(d, u, Distribution::$b_dist);
574                    let scalar_res = compute(ScalarDescriptor::new().unwrap(), &a, &b);
575                    let simd_res = compute(d, &a, &b);
576                    for (scalar, simd) in scalar_res.iter().zip(simd_res.iter()) {
577                        compare_scalar_simd(*scalar, *simd, 1e-6, 1e-6);
578                    }
579                    Ok(())
580                })
581                .size_min(128);
582            }
583            test_all_instruction_sets!($name);
584        };
585        ($name:ident, |$a:ident: $a_dist:ident, $b:ident: $b_dist:ident, $c:ident: $c_dist:ident| $block:expr) => {
586            fn $name<D: SimdDescriptor>(d: D) {
587                fn compute<D: SimdDescriptor>(d: D, a: &[f32], b: &[f32], c: &[f32]) -> Vec<f32> {
588                    let closure = |$a: D::F32Vec, $b: D::F32Vec, $c: D::F32Vec| $block;
589                    let mut res = vec![0f32; a.len()];
590                    for idx in (0..a.len()).step_by(D::F32Vec::LEN) {
591                        closure(
592                            D::F32Vec::load(d, &a[idx..]),
593                            D::F32Vec::load(d, &b[idx..]),
594                            D::F32Vec::load(d, &c[idx..]),
595                        )
596                        .store(&mut res[idx..]);
597                    }
598                    res
599                }
600                arbtest::arbtest(|u| {
601                    let a = arb_vec(d, u, Distribution::$a_dist);
602                    let b = arb_vec(d, u, Distribution::$b_dist);
603                    let c = arb_vec(d, u, Distribution::$c_dist);
604                    let scalar_res = compute(ScalarDescriptor::new().unwrap(), &a, &b, &c);
605                    let simd_res = compute(d, &a, &b, &c);
606                    for (scalar, simd) in scalar_res.iter().zip(simd_res.iter()) {
607                        // Less strict requirements because of fma.
608                        compare_scalar_simd(*scalar, *simd, 2e-5, 2e-5);
609                    }
610                    Ok(())
611                })
612                .size_min(172);
613            }
614            test_all_instruction_sets!($name);
615        };
616    }
617
618    test_instruction!(add, |a: Floats, b: Floats| { a + b });
619    test_instruction!(mul, |a: Floats, b: Floats| { a * b });
620    test_instruction!(sub, |a: Floats, b: Floats| { a - b });
621    test_instruction!(div, |a: Floats, b: NonZeroFloats| { a / b });
622
623    test_instruction!(add_assign, |a: Floats, b: Floats| {
624        let mut res = a;
625        res += b;
626        res
627    });
628    test_instruction!(mul_assign, |a: Floats, b: Floats| {
629        let mut res = a;
630        res *= b;
631        res
632    });
633    test_instruction!(sub_assign, |a: Floats, b: Floats| {
634        let mut res = a;
635        res -= b;
636        res
637    });
638    test_instruction!(div_assign, |a: Floats, b: NonZeroFloats| {
639        let mut res = a;
640        res /= b;
641        res
642    });
643
644    test_instruction!(mul_add, |a: Floats, b: Floats, c: Floats| {
645        a.mul_add(b, c)
646    });
647
648    test_instruction!(neg_mul_add, |a: Floats, b: Floats, c: Floats| {
649        a.neg_mul_add(b, c)
650    });
651
652    // Validate that neg_mul_add computes c - a * b correctly
653    fn test_neg_mul_add_correctness<D: SimdDescriptor>(d: D) {
654        let a_vals = [
655            2.0, 3.0, 4.0, 5.0, 1.5, 2.5, 3.5, 4.5, 2.5, 3.5, 4.5, 5.5, 1.0, 2.0, 3.0, 4.0,
656        ];
657        let b_vals = [
658            1.0, 2.0, 3.0, 4.0, 0.5, 1.5, 2.5, 3.5, 1.5, 2.5, 3.5, 4.5, 0.25, 0.75, 1.25, 1.75,
659        ];
660        let c_vals = [
661            10.0, 20.0, 30.0, 40.0, 5.0, 15.0, 25.0, 35.0, 12.0, 22.0, 32.0, 42.0, 6.0, 16.0, 26.0,
662            36.0,
663        ];
664
665        let a = D::F32Vec::load(d, &a_vals[..D::F32Vec::LEN]);
666        let b = D::F32Vec::load(d, &b_vals[..D::F32Vec::LEN]);
667        let c = D::F32Vec::load(d, &c_vals[..D::F32Vec::LEN]);
668
669        let result = a.neg_mul_add(b, c);
670        let expected = c - a * b;
671
672        let mut result_vals = [0.0; 16];
673        let mut expected_vals = [0.0; 16];
674        result.store(&mut result_vals[..D::F32Vec::LEN]);
675        expected.store(&mut expected_vals[..D::F32Vec::LEN]);
676
677        for i in 0..D::F32Vec::LEN {
678            assert!(
679                (result_vals[i] - expected_vals[i]).abs() < 1e-5,
680                "neg_mul_add correctness failed at index {}: got {}, expected {}",
681                i,
682                result_vals[i],
683                expected_vals[i]
684            );
685        }
686    }
687
688    test_all_instruction_sets!(test_neg_mul_add_correctness);
689
690    test_instruction!(abs, |a: Floats| { a.abs() });
691    test_instruction!(max, |a: Floats, b: Floats| { a.max(b) });
692    test_instruction!(min, |a: Floats, b: Floats| { a.min(b) });
693
694    // Test that the call method works, compiles, and can capture arguments
695    fn test_call<D: SimdDescriptor>(d: D) {
696        // Test basic call functionality
697        let result = d.call(|_d| 42);
698        assert_eq!(result, 42);
699
700        // Test with capturing variables
701        let multiplier = 3.0f32;
702        let addend = 5.0f32;
703
704        // Test SIMD operations inside call with captures
705        let input = vec![1.0f32; D::F32Vec::LEN * 4];
706        let mut output = vec![0.0f32; D::F32Vec::LEN * 4];
707
708        d.call(|d| {
709            let mult_vec = D::F32Vec::splat(d, multiplier);
710            let add_vec = D::F32Vec::splat(d, addend);
711
712            for idx in (0..input.len()).step_by(D::F32Vec::LEN) {
713                let vec = D::F32Vec::load(d, &input[idx..]);
714                let result = vec * mult_vec + add_vec;
715                result.store(&mut output[idx..]);
716            }
717        });
718
719        // Verify results
720        for &val in &output {
721            assert_eq!(val, 1.0 * multiplier + addend);
722        }
723    }
724    test_all_instruction_sets!(test_call);
725
726    fn test_neg<D: SimdDescriptor>(d: D) {
727        // Test negation operation with enough elements for any SIMD size
728        let len = D::F32Vec::LEN * 2; // Ensure we have at least 2 full vectors
729        let input: Vec<f32> = (0..len)
730            .map(|i| if i % 2 == 0 { i as f32 } else { -(i as f32) })
731            .collect();
732        let expected: Vec<f32> = (0..len)
733            .map(|i| if i % 2 == 0 { -(i as f32) } else { i as f32 })
734            .collect();
735        let mut output = vec![0.0f32; input.len()];
736
737        for idx in (0..input.len()).step_by(D::F32Vec::LEN) {
738            let vec = D::F32Vec::load(d, &input[idx..]);
739            let negated = vec.neg();
740            negated.store(&mut output[idx..]);
741        }
742
743        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
744            assert_eq!(
745                out, exp,
746                "Mismatch at index {}: expected {}, got {}",
747                i, exp, out
748            );
749        }
750    }
751    test_all_instruction_sets!(test_neg);
752
753    fn test_transpose_square<D: SimdDescriptor>(d: D) {
754        // Test square matrix transpose
755        let len = D::F32Vec::LEN;
756        // Input: sequential values 0..
757        let mut input = vec![0.0f32; len * len];
758        for (i, val) in input.iter_mut().enumerate() {
759            *val = i as f32;
760        }
761
762        let mut output = input.clone();
763        D::F32Vec::transpose_square(d, D::F32Vec::make_array_slice_mut(&mut output), 1);
764
765        // Verify transpose: output[i*len+j] should equal input[j*len+i]
766        for i in 0..len {
767            for j in 0..len {
768                let expected = input[j * len + i];
769                let actual = output[i * len + j];
770                assert_eq!(
771                    actual, expected,
772                    "Mismatch at position ({}, {}): expected {}, got {}",
773                    i, j, expected, actual
774                );
775            }
776        }
777    }
778    test_all_instruction_sets!(test_transpose_square);
779
780    fn test_store_interleaved_2<D: SimdDescriptor>(d: D) {
781        let len = D::F32Vec::LEN;
782        let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
783        let b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
784        let mut output = vec![0.0f32; 2 * len];
785
786        let a_vec = D::F32Vec::load(d, &a);
787        let b_vec = D::F32Vec::load(d, &b);
788        D::F32Vec::store_interleaved_2(a_vec, b_vec, &mut output);
789
790        // Verify interleaved output: [a0, b0, a1, b1, ...]
791        for i in 0..len {
792            assert_eq!(
793                output[2 * i],
794                a[i],
795                "store_interleaved_2 failed at position {}: expected a[{}]={}, got {}",
796                2 * i,
797                i,
798                a[i],
799                output[2 * i]
800            );
801            assert_eq!(
802                output[2 * i + 1],
803                b[i],
804                "store_interleaved_2 failed at position {}: expected b[{}]={}, got {}",
805                2 * i + 1,
806                i,
807                b[i],
808                output[2 * i + 1]
809            );
810        }
811    }
812    test_all_instruction_sets!(test_store_interleaved_2);
813
814    fn test_store_interleaved_3<D: SimdDescriptor>(d: D) {
815        let len = D::F32Vec::LEN;
816        let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
817        let b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
818        let c: Vec<f32> = (0..len).map(|i| (i + 200) as f32).collect();
819        let mut output = vec![0.0f32; 3 * len];
820
821        let a_vec = D::F32Vec::load(d, &a);
822        let b_vec = D::F32Vec::load(d, &b);
823        let c_vec = D::F32Vec::load(d, &c);
824        D::F32Vec::store_interleaved_3(a_vec, b_vec, c_vec, &mut output);
825
826        // Verify interleaved output: [a0, b0, c0, a1, b1, c1, ...]
827        for i in 0..len {
828            assert_eq!(
829                output[3 * i],
830                a[i],
831                "store_interleaved_3 failed at position {}: expected a[{}]={}, got {}",
832                3 * i,
833                i,
834                a[i],
835                output[3 * i]
836            );
837            assert_eq!(
838                output[3 * i + 1],
839                b[i],
840                "store_interleaved_3 failed at position {}: expected b[{}]={}, got {}",
841                3 * i + 1,
842                i,
843                b[i],
844                output[3 * i + 1]
845            );
846            assert_eq!(
847                output[3 * i + 2],
848                c[i],
849                "store_interleaved_3 failed at position {}: expected c[{}]={}, got {}",
850                3 * i + 2,
851                i,
852                c[i],
853                output[3 * i + 2]
854            );
855        }
856    }
857    test_all_instruction_sets!(test_store_interleaved_3);
858
859    fn test_store_interleaved_4<D: SimdDescriptor>(d: D) {
860        let len = D::F32Vec::LEN;
861        let a: Vec<f32> = (0..len).map(|i| i as f32).collect();
862        let b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
863        let c: Vec<f32> = (0..len).map(|i| (i + 200) as f32).collect();
864        let e: Vec<f32> = (0..len).map(|i| (i + 300) as f32).collect();
865        let mut output = vec![0.0f32; 4 * len];
866
867        let a_vec = D::F32Vec::load(d, &a);
868        let b_vec = D::F32Vec::load(d, &b);
869        let c_vec = D::F32Vec::load(d, &c);
870        let d_vec = D::F32Vec::load(d, &e);
871        D::F32Vec::store_interleaved_4(a_vec, b_vec, c_vec, d_vec, &mut output);
872
873        // Verify interleaved output: [a0, b0, c0, d0, a1, b1, c1, d1, ...]
874        for i in 0..len {
875            assert_eq!(
876                output[4 * i],
877                a[i],
878                "store_interleaved_4 failed at position {}: expected a[{}]={}, got {}",
879                4 * i,
880                i,
881                a[i],
882                output[4 * i]
883            );
884            assert_eq!(
885                output[4 * i + 1],
886                b[i],
887                "store_interleaved_4 failed at position {}: expected b[{}]={}, got {}",
888                4 * i + 1,
889                i,
890                b[i],
891                output[4 * i + 1]
892            );
893            assert_eq!(
894                output[4 * i + 2],
895                c[i],
896                "store_interleaved_4 failed at position {}: expected c[{}]={}, got {}",
897                4 * i + 2,
898                i,
899                c[i],
900                output[4 * i + 2]
901            );
902            assert_eq!(
903                output[4 * i + 3],
904                e[i],
905                "store_interleaved_4 failed at position {}: expected d[{}]={}, got {}",
906                4 * i + 3,
907                i,
908                e[i],
909                output[4 * i + 3]
910            );
911        }
912    }
913    test_all_instruction_sets!(test_store_interleaved_4);
914
915    fn test_store_interleaved_8<D: SimdDescriptor>(d: D) {
916        let len = D::F32Vec::LEN;
917        let arr_a: Vec<f32> = (0..len).map(|i| i as f32).collect();
918        let arr_b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
919        let arr_c: Vec<f32> = (0..len).map(|i| (i + 200) as f32).collect();
920        let arr_d: Vec<f32> = (0..len).map(|i| (i + 300) as f32).collect();
921        let arr_e: Vec<f32> = (0..len).map(|i| (i + 400) as f32).collect();
922        let arr_f: Vec<f32> = (0..len).map(|i| (i + 500) as f32).collect();
923        let arr_g: Vec<f32> = (0..len).map(|i| (i + 600) as f32).collect();
924        let arr_h: Vec<f32> = (0..len).map(|i| (i + 700) as f32).collect();
925        let mut output = vec![0.0f32; 8 * len];
926
927        let a = D::F32Vec::load(d, &arr_a);
928        let b = D::F32Vec::load(d, &arr_b);
929        let c = D::F32Vec::load(d, &arr_c);
930        let dv = D::F32Vec::load(d, &arr_d);
931        let e = D::F32Vec::load(d, &arr_e);
932        let f = D::F32Vec::load(d, &arr_f);
933        let g = D::F32Vec::load(d, &arr_g);
934        let h = D::F32Vec::load(d, &arr_h);
935        D::F32Vec::store_interleaved_8(a, b, c, dv, e, f, g, h, &mut output);
936
937        // Verify interleaved output: [a0, b0, c0, d0, e0, f0, g0, h0, a1, ...]
938        let arrays = [
939            &arr_a, &arr_b, &arr_c, &arr_d, &arr_e, &arr_f, &arr_g, &arr_h,
940        ];
941        for i in 0..len {
942            for (j, arr) in arrays.iter().enumerate() {
943                assert_eq!(
944                    output[8 * i + j],
945                    arr[i],
946                    "store_interleaved_8 failed at position {}: expected {}[{}]={}, got {}",
947                    8 * i + j,
948                    ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'][j],
949                    i,
950                    arr[i],
951                    output[8 * i + j]
952                );
953            }
954        }
955    }
956    test_all_instruction_sets!(test_store_interleaved_8);
957
958    fn test_load_deinterleaved_2<D: SimdDescriptor>(d: D) {
959        let len = D::F32Vec::LEN;
960        // Create interleaved input: [a0, b0, a1, b1, ...]
961        let mut interleaved = vec![0.0f32; 2 * len];
962        let expected_a: Vec<f32> = (0..len).map(|i| i as f32).collect();
963        let expected_b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
964        for i in 0..len {
965            interleaved[2 * i] = expected_a[i];
966            interleaved[2 * i + 1] = expected_b[i];
967        }
968
969        let (a_vec, b_vec) = D::F32Vec::load_deinterleaved_2(d, &interleaved);
970
971        let mut out_a = vec![0.0f32; len];
972        let mut out_b = vec![0.0f32; len];
973        a_vec.store(&mut out_a);
974        b_vec.store(&mut out_b);
975
976        for i in 0..len {
977            assert_eq!(
978                out_a[i], expected_a[i],
979                "load_deinterleaved_2 failed for channel a at {}: expected {}, got {}",
980                i, expected_a[i], out_a[i]
981            );
982            assert_eq!(
983                out_b[i], expected_b[i],
984                "load_deinterleaved_2 failed for channel b at {}: expected {}, got {}",
985                i, expected_b[i], out_b[i]
986            );
987        }
988    }
989    test_all_instruction_sets!(test_load_deinterleaved_2);
990
991    fn test_load_deinterleaved_3<D: SimdDescriptor>(d: D) {
992        let len = D::F32Vec::LEN;
993        // Create interleaved input: [a0, b0, c0, a1, b1, c1, ...]
994        let mut interleaved = vec![0.0f32; 3 * len];
995        let expected_a: Vec<f32> = (0..len).map(|i| i as f32).collect();
996        let expected_b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
997        let expected_c: Vec<f32> = (0..len).map(|i| (i + 200) as f32).collect();
998        for i in 0..len {
999            interleaved[3 * i] = expected_a[i];
1000            interleaved[3 * i + 1] = expected_b[i];
1001            interleaved[3 * i + 2] = expected_c[i];
1002        }
1003
1004        let (a_vec, b_vec, c_vec) = D::F32Vec::load_deinterleaved_3(d, &interleaved);
1005
1006        let mut out_a = vec![0.0f32; len];
1007        let mut out_b = vec![0.0f32; len];
1008        let mut out_c = vec![0.0f32; len];
1009        a_vec.store(&mut out_a);
1010        b_vec.store(&mut out_b);
1011        c_vec.store(&mut out_c);
1012
1013        for i in 0..len {
1014            assert_eq!(
1015                out_a[i], expected_a[i],
1016                "load_deinterleaved_3 failed for channel a at {}: expected {}, got {}",
1017                i, expected_a[i], out_a[i]
1018            );
1019            assert_eq!(
1020                out_b[i], expected_b[i],
1021                "load_deinterleaved_3 failed for channel b at {}: expected {}, got {}",
1022                i, expected_b[i], out_b[i]
1023            );
1024            assert_eq!(
1025                out_c[i], expected_c[i],
1026                "load_deinterleaved_3 failed for channel c at {}: expected {}, got {}",
1027                i, expected_c[i], out_c[i]
1028            );
1029        }
1030    }
1031    test_all_instruction_sets!(test_load_deinterleaved_3);
1032
1033    fn test_load_deinterleaved_4<D: SimdDescriptor>(d: D) {
1034        let len = D::F32Vec::LEN;
1035        // Create interleaved input: [a0, b0, c0, d0, a1, b1, c1, d1, ...]
1036        let mut interleaved = vec![0.0f32; 4 * len];
1037        let expected_a: Vec<f32> = (0..len).map(|i| i as f32).collect();
1038        let expected_b: Vec<f32> = (0..len).map(|i| (i + 100) as f32).collect();
1039        let expected_c: Vec<f32> = (0..len).map(|i| (i + 200) as f32).collect();
1040        let expected_d: Vec<f32> = (0..len).map(|i| (i + 300) as f32).collect();
1041        for i in 0..len {
1042            interleaved[4 * i] = expected_a[i];
1043            interleaved[4 * i + 1] = expected_b[i];
1044            interleaved[4 * i + 2] = expected_c[i];
1045            interleaved[4 * i + 3] = expected_d[i];
1046        }
1047
1048        let (a_vec, b_vec, c_vec, d_vec) = D::F32Vec::load_deinterleaved_4(d, &interleaved);
1049
1050        let mut out_a = vec![0.0f32; len];
1051        let mut out_b = vec![0.0f32; len];
1052        let mut out_c = vec![0.0f32; len];
1053        let mut out_d = vec![0.0f32; len];
1054        a_vec.store(&mut out_a);
1055        b_vec.store(&mut out_b);
1056        c_vec.store(&mut out_c);
1057        d_vec.store(&mut out_d);
1058
1059        for i in 0..len {
1060            assert_eq!(
1061                out_a[i], expected_a[i],
1062                "load_deinterleaved_4 failed for channel a at {}: expected {}, got {}",
1063                i, expected_a[i], out_a[i]
1064            );
1065            assert_eq!(
1066                out_b[i], expected_b[i],
1067                "load_deinterleaved_4 failed for channel b at {}: expected {}, got {}",
1068                i, expected_b[i], out_b[i]
1069            );
1070            assert_eq!(
1071                out_c[i], expected_c[i],
1072                "load_deinterleaved_4 failed for channel c at {}: expected {}, got {}",
1073                i, expected_c[i], out_c[i]
1074            );
1075            assert_eq!(
1076                out_d[i], expected_d[i],
1077                "load_deinterleaved_4 failed for channel d at {}: expected {}, got {}",
1078                i, expected_d[i], out_d[i]
1079            );
1080        }
1081    }
1082    test_all_instruction_sets!(test_load_deinterleaved_4);
1083
1084    // Roundtrip tests: verify store_interleaved + load_deinterleaved returns original data
1085    fn test_interleave_roundtrip_2<D: SimdDescriptor>(d: D) {
1086        let len = D::F32Vec::LEN;
1087        let a: Vec<f32> = (0..len).map(|i| (i * 7 + 3) as f32).collect();
1088        let b: Vec<f32> = (0..len).map(|i| (i * 11 + 5) as f32).collect();
1089
1090        let a_vec = D::F32Vec::load(d, &a);
1091        let b_vec = D::F32Vec::load(d, &b);
1092
1093        let mut interleaved = vec![0.0f32; 2 * len];
1094        D::F32Vec::store_interleaved_2(a_vec, b_vec, &mut interleaved);
1095
1096        let (a_out, b_out) = D::F32Vec::load_deinterleaved_2(d, &interleaved);
1097
1098        let mut out_a = vec![0.0f32; len];
1099        let mut out_b = vec![0.0f32; len];
1100        a_out.store(&mut out_a);
1101        b_out.store(&mut out_b);
1102
1103        assert_eq!(out_a, a, "interleave_roundtrip_2 failed for channel a");
1104        assert_eq!(out_b, b, "interleave_roundtrip_2 failed for channel b");
1105    }
1106    test_all_instruction_sets!(test_interleave_roundtrip_2);
1107
1108    fn test_interleave_roundtrip_3<D: SimdDescriptor>(d: D) {
1109        let len = D::F32Vec::LEN;
1110        let a: Vec<f32> = (0..len).map(|i| (i * 7 + 3) as f32).collect();
1111        let b: Vec<f32> = (0..len).map(|i| (i * 11 + 5) as f32).collect();
1112        let c: Vec<f32> = (0..len).map(|i| (i * 13 + 9) as f32).collect();
1113
1114        let a_vec = D::F32Vec::load(d, &a);
1115        let b_vec = D::F32Vec::load(d, &b);
1116        let c_vec = D::F32Vec::load(d, &c);
1117
1118        let mut interleaved = vec![0.0f32; 3 * len];
1119        D::F32Vec::store_interleaved_3(a_vec, b_vec, c_vec, &mut interleaved);
1120
1121        let (a_out, b_out, c_out) = D::F32Vec::load_deinterleaved_3(d, &interleaved);
1122
1123        let mut out_a = vec![0.0f32; len];
1124        let mut out_b = vec![0.0f32; len];
1125        let mut out_c = vec![0.0f32; len];
1126        a_out.store(&mut out_a);
1127        b_out.store(&mut out_b);
1128        c_out.store(&mut out_c);
1129
1130        assert_eq!(out_a, a, "interleave_roundtrip_3 failed for channel a");
1131        assert_eq!(out_b, b, "interleave_roundtrip_3 failed for channel b");
1132        assert_eq!(out_c, c, "interleave_roundtrip_3 failed for channel c");
1133    }
1134    test_all_instruction_sets!(test_interleave_roundtrip_3);
1135
1136    fn test_interleave_roundtrip_4<D: SimdDescriptor>(d: D) {
1137        let len = D::F32Vec::LEN;
1138        let a: Vec<f32> = (0..len).map(|i| (i * 7 + 3) as f32).collect();
1139        let b: Vec<f32> = (0..len).map(|i| (i * 11 + 5) as f32).collect();
1140        let c: Vec<f32> = (0..len).map(|i| (i * 13 + 9) as f32).collect();
1141        let e: Vec<f32> = (0..len).map(|i| (i * 17 + 1) as f32).collect();
1142
1143        let a_vec = D::F32Vec::load(d, &a);
1144        let b_vec = D::F32Vec::load(d, &b);
1145        let c_vec = D::F32Vec::load(d, &c);
1146        let d_vec = D::F32Vec::load(d, &e);
1147
1148        let mut interleaved = vec![0.0f32; 4 * len];
1149        D::F32Vec::store_interleaved_4(a_vec, b_vec, c_vec, d_vec, &mut interleaved);
1150
1151        let (a_out, b_out, c_out, d_out) = D::F32Vec::load_deinterleaved_4(d, &interleaved);
1152
1153        let mut out_a = vec![0.0f32; len];
1154        let mut out_b = vec![0.0f32; len];
1155        let mut out_c = vec![0.0f32; len];
1156        let mut out_d = vec![0.0f32; len];
1157        a_out.store(&mut out_a);
1158        b_out.store(&mut out_b);
1159        c_out.store(&mut out_c);
1160        d_out.store(&mut out_d);
1161
1162        assert_eq!(out_a, a, "interleave_roundtrip_4 failed for channel a");
1163        assert_eq!(out_b, b, "interleave_roundtrip_4 failed for channel b");
1164        assert_eq!(out_c, c, "interleave_roundtrip_4 failed for channel c");
1165        assert_eq!(out_d, e, "interleave_roundtrip_4 failed for channel d");
1166    }
1167    test_all_instruction_sets!(test_interleave_roundtrip_4);
1168
1169    fn test_prepare_table_bf16_8<D: SimdDescriptor>(d: D) {
1170        // Create an 8-entry lookup table with known values
1171        // Use integer values that are exactly representable in BF16
1172        let lut: [f32; 8] = [0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0];
1173        let len = D::F32Vec::LEN;
1174
1175        // Prepare the table once
1176        let prepared = D::F32Vec::prepare_table_bf16_8(d, &lut);
1177
1178        // Create indices that are valid for the LUT (0..8)
1179        let indices: Vec<i32> = (0..len).map(|i| (i % 8) as i32).collect();
1180        let expected: Vec<f32> = indices.iter().map(|&i| lut[i as usize]).collect();
1181
1182        // Perform table lookup with prepared table
1183        let indices_vec = D::I32Vec::load(d, &indices);
1184        let result = D::F32Vec::table_lookup_bf16_8(d, prepared, indices_vec);
1185
1186        let mut output = vec![0.0f32; len];
1187        result.store(&mut output);
1188
1189        // Verify results - prepared lookup may have BF16 precision loss
1190        // BF16 has ~0.4% relative error for typical values
1191        for i in 0..len {
1192            let tolerance = if expected[i] == 0.0 {
1193                0.01
1194            } else {
1195                expected[i].abs() * 0.01
1196            };
1197            assert!(
1198                (output[i] - expected[i]).abs() < tolerance,
1199                "table_lookup_bf16_8 failed at position {}: expected {}, got {}",
1200                i,
1201                expected[i],
1202                output[i]
1203            );
1204        }
1205    }
1206    test_all_instruction_sets!(test_prepare_table_bf16_8);
1207
1208    /// Test that I32 multiplication operates on all elements, not just alternating lanes.
1209    /// This catches the bug where _mm_mul_epi32 was used instead of _mm_mullo_epi32.
1210    fn test_i32_mul_all_elements<D: SimdDescriptor>(d: D) {
1211        let len = D::I32Vec::LEN;
1212
1213        // Create input vectors where each lane has a unique value
1214        let mut a_data = vec![0i32; len];
1215        let mut b_data = vec![0i32; len];
1216        for i in 0..len {
1217            a_data[i] = (i + 1) as i32; // [1, 2, 3, 4, ...]
1218            b_data[i] = (i + 10) as i32; // [10, 11, 12, 13, ...]
1219        }
1220
1221        let a = D::I32Vec::load(d, &a_data);
1222        let b = D::I32Vec::load(d, &b_data);
1223        let result = a * b;
1224
1225        let mut result_data = vec![0i32; len];
1226        result.store(&mut result_data);
1227
1228        // Verify EVERY element was multiplied correctly
1229        for i in 0..len {
1230            let expected = a_data[i] * b_data[i];
1231            assert_eq!(
1232                result_data[i], expected,
1233                "I32 mul failed at index {}: {} * {} = {}, got {}",
1234                i, a_data[i], b_data[i], expected, result_data[i]
1235            );
1236        }
1237    }
1238    test_all_instruction_sets!(test_i32_mul_all_elements);
1239
1240    fn test_store_u16<D: SimdDescriptor>(d: D) {
1241        let data = [
1242            0xbabau32 as i32,
1243            0x1234u32 as i32,
1244            0xdeadbabau32 as i32,
1245            0xdead1234u32 as i32,
1246            0x1111babau32 as i32,
1247            0x11111234u32 as i32,
1248            0x76543210u32 as i32,
1249            0x01234567u32 as i32,
1250            0x00000000u32 as i32,
1251            0xffffffffu32 as i32,
1252            0x23949289u32 as i32,
1253            0xf9371913u32 as i32,
1254            0xdeadbeefu32 as i32,
1255            0xbeefdeadu32 as i32,
1256            0xaaaaaaaau32 as i32,
1257            0xbbbbbbbbu32 as i32,
1258        ];
1259        let mut output = [0u16; 16];
1260        for i in (0..16).step_by(D::I32Vec::LEN) {
1261            let vec = D::I32Vec::load(d, &data[i..]);
1262            vec.store_u16(&mut output[i..]);
1263        }
1264
1265        for i in 0..16 {
1266            let expected = data[i] as u16;
1267            assert_eq!(
1268                output[i], expected,
1269                "store_u16 failed at index {}: expected {}, got {}",
1270                i, expected, output[i]
1271            );
1272        }
1273    }
1274    test_all_instruction_sets!(test_store_u16);
1275
1276    fn test_store_interleaved_2_u8<D: SimdDescriptor>(d: D) {
1277        let len = D::U8Vec::LEN;
1278        let a: Vec<u8> = (0..len).map(|i| i as u8).collect();
1279        let b: Vec<u8> = (0..len).map(|i| (i + 100) as u8).collect();
1280        let mut output = vec![0u8; 2 * len];
1281
1282        let a_vec = D::U8Vec::load(d, &a);
1283        let b_vec = D::U8Vec::load(d, &b);
1284        D::U8Vec::store_interleaved_2(a_vec, b_vec, &mut output);
1285
1286        for i in 0..len {
1287            assert_eq!(output[2 * i], a[i]);
1288            assert_eq!(output[2 * i + 1], b[i]);
1289        }
1290    }
1291    test_all_instruction_sets!(test_store_interleaved_2_u8);
1292
1293    fn test_store_interleaved_3_u8<D: SimdDescriptor>(d: D) {
1294        let len = D::U8Vec::LEN;
1295        let a: Vec<u8> = (0..len).map(|i| i as u8).collect();
1296        let b: Vec<u8> = (0..len).map(|i| (i + 100) as u8).collect();
1297        let c: Vec<u8> = (0..len).map(|i| (i + 50) as u8).collect();
1298        let mut output = vec![0u8; 3 * len];
1299
1300        let a_vec = D::U8Vec::load(d, &a);
1301        let b_vec = D::U8Vec::load(d, &b);
1302        let c_vec = D::U8Vec::load(d, &c);
1303        D::U8Vec::store_interleaved_3(a_vec, b_vec, c_vec, &mut output);
1304
1305        for i in 0..len {
1306            assert_eq!(output[3 * i], a[i]);
1307            assert_eq!(output[3 * i + 1], b[i]);
1308            assert_eq!(output[3 * i + 2], c[i]);
1309        }
1310    }
1311    test_all_instruction_sets!(test_store_interleaved_3_u8);
1312
1313    fn test_store_interleaved_4_u8<D: SimdDescriptor>(d: D) {
1314        let len = D::U8Vec::LEN;
1315        let a: Vec<u8> = (0..len).map(|i| i as u8).collect();
1316        let b: Vec<u8> = (0..len).map(|i| (i + 100) as u8).collect();
1317        let c: Vec<u8> = (0..len).map(|i| (i + 50) as u8).collect();
1318        let e: Vec<u8> = (0..len).map(|i| (i + 200) as u8).collect();
1319        let mut output = vec![0u8; 4 * len];
1320
1321        let a_vec = D::U8Vec::load(d, &a);
1322        let b_vec = D::U8Vec::load(d, &b);
1323        let c_vec = D::U8Vec::load(d, &c);
1324        let d_vec = D::U8Vec::load(d, &e);
1325        D::U8Vec::store_interleaved_4(a_vec, b_vec, c_vec, d_vec, &mut output);
1326
1327        for i in 0..len {
1328            assert_eq!(output[4 * i], a[i]);
1329            assert_eq!(output[4 * i + 1], b[i]);
1330            assert_eq!(output[4 * i + 2], c[i]);
1331            assert_eq!(output[4 * i + 3], e[i]);
1332        }
1333    }
1334    test_all_instruction_sets!(test_store_interleaved_4_u8);
1335
1336    fn test_store_interleaved_2_u16<D: SimdDescriptor>(d: D) {
1337        let len = D::U16Vec::LEN;
1338        let a: Vec<u16> = (0..len).map(|i| i as u16).collect();
1339        let b: Vec<u16> = (0..len).map(|i| (i + 1000) as u16).collect();
1340        let mut output = vec![0u16; 2 * len];
1341
1342        let a_vec = D::U16Vec::load(d, &a);
1343        let b_vec = D::U16Vec::load(d, &b);
1344        D::U16Vec::store_interleaved_2(a_vec, b_vec, &mut output);
1345
1346        for i in 0..len {
1347            assert_eq!(output[2 * i], a[i]);
1348            assert_eq!(output[2 * i + 1], b[i]);
1349        }
1350    }
1351    test_all_instruction_sets!(test_store_interleaved_2_u16);
1352
1353    fn test_store_interleaved_3_u16<D: SimdDescriptor>(d: D) {
1354        let len = D::U16Vec::LEN;
1355        let a: Vec<u16> = (0..len).map(|i| i as u16).collect();
1356        let b: Vec<u16> = (0..len).map(|i| (i + 1000) as u16).collect();
1357        let c: Vec<u16> = (0..len).map(|i| (i + 2000) as u16).collect();
1358        let mut output = vec![0u16; 3 * len];
1359
1360        let a_vec = D::U16Vec::load(d, &a);
1361        let b_vec = D::U16Vec::load(d, &b);
1362        let c_vec = D::U16Vec::load(d, &c);
1363        D::U16Vec::store_interleaved_3(a_vec, b_vec, c_vec, &mut output);
1364
1365        for i in 0..len {
1366            assert_eq!(output[3 * i], a[i]);
1367            assert_eq!(output[3 * i + 1], b[i]);
1368            assert_eq!(output[3 * i + 2], c[i]);
1369        }
1370    }
1371    test_all_instruction_sets!(test_store_interleaved_3_u16);
1372
1373    fn test_store_interleaved_4_u16<D: SimdDescriptor>(d: D) {
1374        let len = D::U16Vec::LEN;
1375        let a: Vec<u16> = (0..len).map(|i| i as u16).collect();
1376        let b: Vec<u16> = (0..len).map(|i| (i + 1000) as u16).collect();
1377        let c: Vec<u16> = (0..len).map(|i| (i + 2000) as u16).collect();
1378        let e: Vec<u16> = (0..len).map(|i| (i + 3000) as u16).collect();
1379        let mut output = vec![0u16; 4 * len];
1380
1381        let a_vec = D::U16Vec::load(d, &a);
1382        let b_vec = D::U16Vec::load(d, &b);
1383        let c_vec = D::U16Vec::load(d, &c);
1384        let d_vec = D::U16Vec::load(d, &e);
1385        D::U16Vec::store_interleaved_4(a_vec, b_vec, c_vec, d_vec, &mut output);
1386
1387        for i in 0..len {
1388            assert_eq!(output[4 * i], a[i]);
1389            assert_eq!(output[4 * i + 1], b[i]);
1390            assert_eq!(output[4 * i + 2], c[i]);
1391            assert_eq!(output[4 * i + 3], e[i]);
1392        }
1393    }
1394    test_all_instruction_sets!(test_store_interleaved_4_u16);
1395}
1396
1397/// Soundness regression tests for the SIMD API.
1398///
1399/// Every test in this module uses **only safe code** — no `unsafe` blocks.
1400/// All out-of-bounds accesses are caught as panics in **all build profiles**
1401/// (debug and release) thanks to checked slice indexing in trait defaults
1402/// and `assert!` guards in backend load/store implementations.
1403///
1404/// These tests serve as regression tests: if any future change weakens
1405/// the bounds checking (e.g. replacing `assert!` with `debug_assert!`),
1406/// the release-mode tests will fail.
1407#[cfg(test)]
1408mod soundness_tests {
1409    use super::*;
1410    use std::num::Wrapping;
1411
1412    // =========================================================================
1413    // Category 1: Trait default methods with checked slice indexing
1414    //
1415    // These use `&mem[offset..]` / `&mut mem[offset..]` which panics on
1416    // out-of-bounds access in all build profiles. Previously used
1417    // `get_unchecked` behind `debug_assert` — now fully sound.
1418    // =========================================================================
1419
1420    /// `F32SimdVec::load_from` with offset past end of slice.
1421    /// Panics from checked slice indexing.
1422    #[test]
1423    #[should_panic(expected = "range start index")]
1424    fn soundness_f32_load_from_oob_offset() {
1425        let data = [1.0f32, 2.0];
1426        let _ = f32::load_from(ScalarDescriptor, &data, 5);
1427    }
1428
1429    /// Same on the I32 side.
1430    #[test]
1431    #[should_panic(expected = "range start index")]
1432    fn soundness_i32_load_from_oob_offset() {
1433        let data = [1i32, 2];
1434        let _ = Wrapping::<i32>::load_from(ScalarDescriptor, &data, 5);
1435    }
1436
1437    /// `F32SimdVec::store_at` with offset past end of slice.
1438    /// Panics from checked slice indexing.
1439    #[test]
1440    #[should_panic(expected = "range start index")]
1441    fn soundness_f32_store_at_oob_offset() {
1442        let mut data = [0.0f32; 2];
1443        let v = f32::splat(ScalarDescriptor, 42.0);
1444        v.store_at(&mut data, 5);
1445    }
1446
1447    /// `F32SimdVec::round_store_u8_at` with offset past end of slice.
1448    #[test]
1449    #[should_panic(expected = "range start index")]
1450    fn soundness_f32_round_store_u8_at_oob() {
1451        let mut data = [0u8; 2];
1452        let v = f32::splat(ScalarDescriptor, 42.0);
1453        v.round_store_u8_at(&mut data, 5);
1454    }
1455
1456    /// `load_from` with offset exactly at len (offset == len, LEN == 1).
1457    /// Slice `&data[2..]` is empty, then `load` panics on the short slice.
1458    #[test]
1459    #[should_panic]
1460    fn soundness_f32_load_from_at_exact_len() {
1461        let data = [1.0f32, 2.0];
1462        let _ = f32::load_from(ScalarDescriptor, &data, 2);
1463    }
1464
1465    // =========================================================================
1466    // Category 2: SIMD backend load/store with too-short slices
1467    //
1468    // The load() and store() trait methods are declared safe. All backends
1469    // (including SIMD) now use `assert!` before unsafe pointer operations,
1470    // catching violations in all build profiles.
1471    // =========================================================================
1472
1473    /// `F32SimdVec::load` with empty slice.
1474    #[test]
1475    #[should_panic]
1476    fn soundness_f32_load_empty_slice() {
1477        let data: &[f32] = &[];
1478        let _ = f32::load(ScalarDescriptor, data);
1479    }
1480
1481    /// `F32SimdVec::store` with empty destination.
1482    #[test]
1483    #[should_panic]
1484    fn soundness_f32_store_empty_slice() {
1485        let v = f32::splat(ScalarDescriptor, 1.0);
1486        let mut data: Vec<f32> = vec![];
1487        v.store(&mut data);
1488    }
1489
1490    /// `I32SimdVec::load` with empty slice.
1491    #[test]
1492    #[should_panic]
1493    fn soundness_i32_load_empty_slice() {
1494        let data: &[i32] = &[];
1495        let _ = Wrapping::<i32>::load(ScalarDescriptor, data);
1496    }
1497
1498    // =========================================================================
1499    // Category 3: Interleaved store to too-short destination
1500    //
1501    // store_interleaved_2 etc. require dest.len() >= 2 * LEN. All backends
1502    // now use `assert!` guards (not `debug_assert!`), catching violations
1503    // in all build profiles.
1504    // =========================================================================
1505
1506    /// `store_interleaved_2` with destination smaller than 2 * LEN.
1507    #[test]
1508    #[should_panic]
1509    fn soundness_f32_store_interleaved_2_short_dest() {
1510        let a = f32::splat(ScalarDescriptor, 1.0);
1511        let b = f32::splat(ScalarDescriptor, 2.0);
1512        // LEN=1 for scalar, so need 2 elements. Provide only 1.
1513        let mut dest = [0.0f32; 1];
1514        f32::store_interleaved_2(a, b, &mut dest);
1515    }
1516
1517    /// `load_deinterleaved_2` with source smaller than 2 * LEN.
1518    #[test]
1519    #[should_panic]
1520    fn soundness_f32_load_deinterleaved_2_short_src() {
1521        let src = [1.0f32]; // Need 2 for scalar (2 * LEN where LEN=1)
1522        let _ = f32::load_deinterleaved_2(ScalarDescriptor, &src);
1523    }
1524
1525    // =========================================================================
1526    // Category 4: SIMD-backend-specific bounds checking
1527    //
1528    // These verify that SIMD backends correctly panic on too-short slices
1529    // in all build profiles. Now that backends use `assert!` (not
1530    // `debug_assert!`), behavior is consistent across debug/release.
1531    //
1532    // Scalar (LEN=1) is skipped since 1-element operations on 1-element
1533    // slices are valid. We use catch_unwind to assert the panic for SIMD.
1534    // =========================================================================
1535
1536    fn soundness_simd_load_short_slice<D: SimdDescriptor>(d: D) {
1537        if D::F32Vec::LEN <= 1 {
1538            return; // scalar: 1 element from 1-element slice is valid
1539        }
1540        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1541            let data = [1.0f32]; // 1 element, SIMD needs 4/8/16
1542            let _ = D::F32Vec::load(d, &data);
1543        }));
1544        assert!(
1545            result.is_err(),
1546            "load from 1-element slice should panic for LEN={}",
1547            D::F32Vec::LEN
1548        );
1549    }
1550    test_all_instruction_sets!(soundness_simd_load_short_slice);
1551
1552    fn soundness_simd_store_short_slice<D: SimdDescriptor>(d: D) {
1553        if D::F32Vec::LEN <= 1 {
1554            return;
1555        }
1556        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1557            let v = D::F32Vec::splat(d, 42.0);
1558            let mut data = [0.0f32]; // 1 element
1559            v.store(&mut data);
1560        }));
1561        assert!(
1562            result.is_err(),
1563            "store to 1-element slice should panic for LEN={}",
1564            D::F32Vec::LEN
1565        );
1566    }
1567    test_all_instruction_sets!(soundness_simd_store_short_slice);
1568
1569    fn soundness_simd_load_from_partial_oob<D: SimdDescriptor>(d: D) {
1570        if D::F32Vec::LEN <= 1 {
1571            return;
1572        }
1573        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1574            // Provide exactly LEN elements, offset=1 → only LEN-1 remain.
1575            let data: Vec<f32> = (0..D::F32Vec::LEN).map(|i| i as f32).collect();
1576            let _ = D::F32Vec::load_from(d, &data, 1);
1577        }));
1578        assert!(
1579            result.is_err(),
1580            "load_from with partial OOB should panic for LEN={}",
1581            D::F32Vec::LEN
1582        );
1583    }
1584    test_all_instruction_sets!(soundness_simd_load_from_partial_oob);
1585
1586    fn soundness_simd_interleaved_2_short<D: SimdDescriptor>(d: D) {
1587        if D::F32Vec::LEN <= 1 {
1588            return;
1589        }
1590        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1591            let a = D::F32Vec::splat(d, 1.0);
1592            let b = D::F32Vec::splat(d, 2.0);
1593            let mut dest = vec![0.0f32; D::F32Vec::LEN]; // need 2*LEN
1594            D::F32Vec::store_interleaved_2(a, b, &mut dest);
1595        }));
1596        assert!(
1597            result.is_err(),
1598            "interleaved_2 with short dest should panic for LEN={}",
1599            D::F32Vec::LEN
1600        );
1601    }
1602    test_all_instruction_sets!(soundness_simd_interleaved_2_short);
1603
1604    fn soundness_simd_round_store_u8_short<D: SimdDescriptor>(d: D) {
1605        if D::F32Vec::LEN <= 1 {
1606            return;
1607        }
1608        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1609            let v = D::F32Vec::splat(d, 128.0);
1610            let mut dest = [0u8; 1]; // need LEN bytes
1611            v.round_store_u8(&mut dest);
1612        }));
1613        assert!(
1614            result.is_err(),
1615            "round_store_u8 with short dest should panic for LEN={}",
1616            D::F32Vec::LEN
1617        );
1618    }
1619    test_all_instruction_sets!(soundness_simd_round_store_u8_short);
1620}