vector_ta/utilities/
helpers.rs1use crate::utilities::enums::Kernel;
2use aligned_vec::AVec;
3use std::arch::is_x86_feature_detected;
4use std::sync::OnceLock;
5use std::{mem::MaybeUninit, ptr, slice};
6
7static BEST_SINGLE: OnceLock<Kernel> = OnceLock::new();
8static BEST_BATCH: OnceLock<Kernel> = OnceLock::new();
9
10#[inline(always)]
11pub fn detect_best_kernel() -> Kernel {
12 *BEST_SINGLE.get_or_init(|| {
13 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
14 {
15 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
16 return Kernel::Avx512;
17 }
18 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
19 return Kernel::Avx2;
20 }
21 }
22
23 Kernel::Scalar
24 })
25}
26
27#[inline(always)]
28pub fn detect_best_batch_kernel() -> Kernel {
29 *BEST_BATCH.get_or_init(|| match detect_best_kernel() {
30 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
31 Kernel::Avx512 => Kernel::Avx512Batch,
32 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
33 Kernel::Avx2 => Kernel::Avx2Batch,
34 _ => Kernel::ScalarBatch,
35 })
36}
37
38#[cfg(target_arch = "wasm32")]
39static BEST_WASM: OnceLock<Kernel> = OnceLock::new();
40
41#[cfg(target_arch = "wasm32")]
42#[inline(always)]
43pub fn detect_wasm_kernel() -> Kernel {
44 *BEST_WASM.get_or_init(|| {
45 #[cfg(target_feature = "simd128")]
46 {
47 return Kernel::Scalar;
48 }
49
50 Kernel::Scalar
51 })
52}
53
54#[cfg(not(target_arch = "wasm32"))]
55#[inline(always)]
56pub fn detect_wasm_kernel() -> Kernel {
57 Kernel::Scalar
58}
59
60#[macro_export]
61macro_rules! skip_if_unsupported {
62 ($kernel:expr, $test_name:expr) => {{
63 use std::arch::is_x86_feature_detected;
64 use $crate::utilities::enums::Kernel;
65
66 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
67 {
68 if matches!(
69 $kernel,
70 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch
71 ) {
72 eprintln!(
73 "[{}] skipped {:?} – compiled without `nightly-avx`",
74 $test_name, $kernel
75 );
76 return Ok(());
77 }
78 }
79
80 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
81 {
82 let need: (&'static str, fn() -> bool) = match $kernel {
83 Kernel::Avx512 | Kernel::Avx512Batch => ("AVX-512F + FMA", || {
84 is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma")
85 }),
86 Kernel::Avx2 | Kernel::Avx2Batch => ("AVX2 + FMA", || {
87 is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
88 }),
89 _ => ("", || true),
90 };
91
92 if !(need.1)() {
93 eprintln!(
94 "[{}] skipped {:?} - CPU lacks {}",
95 $test_name, $kernel, need.0
96 );
97 return Ok(());
98 }
99 }
100 }};
101}
102#[inline(always)]
103pub fn alloc_with_nan_prefix(len: usize, warm: usize) -> Vec<f64> {
104 use std::mem::{self, MaybeUninit};
105
106 let warm = warm.min(len);
107
108 let mut buf: Vec<MaybeUninit<f64>> = Vec::with_capacity(len);
109
110 #[cfg(not(debug_assertions))]
111 {
112 unsafe {
113 buf.set_len(len);
114 }
115 for i in 0..warm {
116 buf[i].write(f64::from_bits(0x7ff8_0000_0000_0000));
117 }
118 }
119
120 #[cfg(debug_assertions)]
121 {
122 for _ in 0..warm {
123 buf.push(MaybeUninit::new(f64::from_bits(0x7ff8_0000_0000_0000)));
124 }
125 for _ in warm..len {
126 buf.push(MaybeUninit::new(f64::from_bits(0x11111111_11111111)));
127 }
128 }
129
130 let ptr = buf.as_mut_ptr() as *mut f64;
131 let cap = buf.capacity();
132 mem::forget(buf);
133 unsafe { Vec::from_raw_parts(ptr, len, cap) }
134}
135
136#[inline]
137pub fn init_matrix_prefixes(buf: &mut [MaybeUninit<f64>], cols: usize, warm_prefixes: &[usize]) {
138 assert!(
139 cols != 0 && buf.len() % cols == 0,
140 "`buf` length must be a multiple of `cols`"
141 );
142 let rows = buf.len() / cols;
143 assert_eq!(
144 rows,
145 warm_prefixes.len(),
146 "`warm_prefixes` length must equal number of rows"
147 );
148
149 #[cfg(debug_assertions)]
150 {
151 for cell in buf.iter_mut() {
152 cell.write(f64::from_bits(0x22222222_22222222));
153 }
154 }
155
156 buf.chunks_exact_mut(cols)
157 .zip(warm_prefixes)
158 .for_each(|(row, &warm)| {
159 assert!(warm <= cols, "warm prefix exceeds row width");
160 for cell in &mut row[..warm] {
161 cell.write(f64::from_bits(0x7ff8_0000_0000_0000));
162 }
163 });
164}
165
166#[inline]
167pub fn make_uninit_matrix(rows: usize, cols: usize) -> Vec<MaybeUninit<f64>> {
168 let total = rows
169 .checked_mul(cols)
170 .expect("rows * cols overflowed usize");
171
172 let mut v: Vec<MaybeUninit<f64>> = Vec::new();
173 v.try_reserve_exact(total)
174 .expect("OOM in make_uninit_matrix");
175
176 #[cfg(not(debug_assertions))]
177 {
178 unsafe {
179 v.set_len(total);
180 }
181 }
182
183 #[cfg(debug_assertions)]
184 {
185 for _ in 0..total {
186 v.push(MaybeUninit::new(f64::from_bits(0x33333333_33333333)));
187 }
188 }
189 v
190}
191
192#[inline(always)]
193pub fn alloc_uninit_f64(len: usize) -> Vec<f64> {
194 #[cfg(not(debug_assertions))]
195 {
196 let mut v = Vec::<f64>::with_capacity(len);
197 unsafe {
198 v.set_len(len);
199 }
200 v
201 }
202
203 #[cfg(debug_assertions)]
204 {
205 vec![f64::from_bits(0x11111111_11111111); len]
206 }
207}