Skip to main content

vector_ta/utilities/
helpers.rs

1use crate::utilities::enums::Kernel;
2use aligned_vec::AVec;
3use std::arch::is_x86_feature_detected;
4use std::sync::OnceLock;
5use std::{mem::MaybeUninit, ptr, slice};
6
7static BEST_SINGLE: OnceLock<Kernel> = OnceLock::new();
8static BEST_BATCH: OnceLock<Kernel> = OnceLock::new();
9
10#[inline(always)]
11pub fn detect_best_kernel() -> Kernel {
12    *BEST_SINGLE.get_or_init(|| {
13        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14        {
15            if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
16                return Kernel::Avx512;
17            }
18            if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
19                return Kernel::Avx2;
20            }
21        }
22
23        Kernel::Scalar
24    })
25}
26
27#[inline(always)]
28pub fn detect_best_batch_kernel() -> Kernel {
29    *BEST_BATCH.get_or_init(|| match detect_best_kernel() {
30        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
31        Kernel::Avx512 => Kernel::Avx512Batch,
32        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
33        Kernel::Avx2 => Kernel::Avx2Batch,
34        _ => Kernel::ScalarBatch,
35    })
36}
37
38#[cfg(target_arch = "wasm32")]
39static BEST_WASM: OnceLock<Kernel> = OnceLock::new();
40
41#[cfg(target_arch = "wasm32")]
42#[inline(always)]
43pub fn detect_wasm_kernel() -> Kernel {
44    *BEST_WASM.get_or_init(|| {
45        #[cfg(target_feature = "simd128")]
46        {
47            return Kernel::Scalar;
48        }
49
50        Kernel::Scalar
51    })
52}
53
54#[cfg(not(target_arch = "wasm32"))]
55#[inline(always)]
56pub fn detect_wasm_kernel() -> Kernel {
57    Kernel::Scalar
58}
59
60#[macro_export]
61macro_rules! skip_if_unsupported {
62    ($kernel:expr, $test_name:expr) => {{
63        use std::arch::is_x86_feature_detected;
64        use $crate::utilities::enums::Kernel;
65
66        #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
67        {
68            if matches!(
69                $kernel,
70                Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch
71            ) {
72                eprintln!(
73                    "[{}] skipped {:?} – compiled without `nightly-avx`",
74                    $test_name, $kernel
75                );
76                return Ok(());
77            }
78        }
79
80        #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
81        {
82            let need: (&'static str, fn() -> bool) = match $kernel {
83                Kernel::Avx512 | Kernel::Avx512Batch => ("AVX-512F + FMA", || {
84                    is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma")
85                }),
86                Kernel::Avx2 | Kernel::Avx2Batch => ("AVX2 + FMA", || {
87                    is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
88                }),
89                _ => ("", || true),
90            };
91
92            if !(need.1)() {
93                eprintln!(
94                    "[{}] skipped {:?} - CPU lacks {}",
95                    $test_name, $kernel, need.0
96                );
97                return Ok(());
98            }
99        }
100    }};
101}
102#[inline(always)]
103pub fn alloc_with_nan_prefix(len: usize, warm: usize) -> Vec<f64> {
104    use std::mem::{self, MaybeUninit};
105
106    let warm = warm.min(len);
107
108    let mut buf: Vec<MaybeUninit<f64>> = Vec::with_capacity(len);
109
110    #[cfg(not(debug_assertions))]
111    {
112        unsafe {
113            buf.set_len(len);
114        }
115        for i in 0..warm {
116            buf[i].write(f64::from_bits(0x7ff8_0000_0000_0000));
117        }
118    }
119
120    #[cfg(debug_assertions)]
121    {
122        for _ in 0..warm {
123            buf.push(MaybeUninit::new(f64::from_bits(0x7ff8_0000_0000_0000)));
124        }
125        for _ in warm..len {
126            buf.push(MaybeUninit::new(f64::from_bits(0x11111111_11111111)));
127        }
128    }
129
130    let ptr = buf.as_mut_ptr() as *mut f64;
131    let cap = buf.capacity();
132    mem::forget(buf);
133    unsafe { Vec::from_raw_parts(ptr, len, cap) }
134}
135
136#[inline]
137pub fn init_matrix_prefixes(buf: &mut [MaybeUninit<f64>], cols: usize, warm_prefixes: &[usize]) {
138    assert!(
139        cols != 0 && buf.len() % cols == 0,
140        "`buf` length must be a multiple of `cols`"
141    );
142    let rows = buf.len() / cols;
143    assert_eq!(
144        rows,
145        warm_prefixes.len(),
146        "`warm_prefixes` length must equal number of rows"
147    );
148
149    #[cfg(debug_assertions)]
150    {
151        for cell in buf.iter_mut() {
152            cell.write(f64::from_bits(0x22222222_22222222));
153        }
154    }
155
156    buf.chunks_exact_mut(cols)
157        .zip(warm_prefixes)
158        .for_each(|(row, &warm)| {
159            assert!(warm <= cols, "warm prefix exceeds row width");
160            for cell in &mut row[..warm] {
161                cell.write(f64::from_bits(0x7ff8_0000_0000_0000));
162            }
163        });
164}
165
166#[inline]
167pub fn make_uninit_matrix(rows: usize, cols: usize) -> Vec<MaybeUninit<f64>> {
168    let total = rows
169        .checked_mul(cols)
170        .expect("rows * cols overflowed usize");
171
172    let mut v: Vec<MaybeUninit<f64>> = Vec::new();
173    v.try_reserve_exact(total)
174        .expect("OOM in make_uninit_matrix");
175
176    #[cfg(not(debug_assertions))]
177    {
178        unsafe {
179            v.set_len(total);
180        }
181    }
182
183    #[cfg(debug_assertions)]
184    {
185        for _ in 0..total {
186            v.push(MaybeUninit::new(f64::from_bits(0x33333333_33333333)));
187        }
188    }
189    v
190}