1use super::dispatch::cpu_features;
32
33#[inline]
45pub fn bps_scan(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
46 assert!(query.len() >= n_blocks, "query too short");
47 assert!(out.len() >= n_vec, "output buffer too small");
48
49 let features = cpu_features();
50
51 #[cfg(target_arch = "x86_64")]
52 {
53 if features.has_avx2 {
54 unsafe { bps_scan_avx2(bps, n_vec, n_blocks, query, out) };
56 return;
57 }
58 }
59
60 #[cfg(target_arch = "aarch64")]
61 {
62 if features.has_neon {
63 unsafe { bps_scan_neon(bps, n_vec, n_blocks, query, out) };
65 return;
66 }
67 }
68
69 bps_scan_scalar(bps, n_vec, n_blocks, query, out);
71}
72
73#[inline]
77pub fn bps_scan_u32(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u32]) {
78 assert!(query.len() >= n_blocks, "query too short");
79 assert!(out.len() >= n_vec, "output buffer too small");
80
81 let features = cpu_features();
82
83 #[cfg(target_arch = "x86_64")]
84 {
85 if features.has_avx2 {
86 unsafe { bps_scan_avx2_u32(bps, n_vec, n_blocks, query, out) };
87 return;
88 }
89 }
90
91 #[cfg(target_arch = "aarch64")]
92 {
93 if features.has_neon {
94 unsafe { bps_scan_neon_u32(bps, n_vec, n_blocks, query, out) };
95 return;
96 }
97 }
98
99 bps_scan_scalar_u32(bps, n_vec, n_blocks, query, out);
100}
101
102#[cfg(target_arch = "x86_64")]
107#[target_feature(enable = "avx2")]
108unsafe fn bps_scan_avx2(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
109 use std::arch::x86_64::*;
110 unsafe {
111 let vec_aligned = (n_vec / 32) * 32;
113
114 out.iter_mut().take(n_vec).for_each(|d| *d = 0);
116
117 for chunk_start in (0..vec_aligned).step_by(32) {
119 let mut acc_lo = _mm256_setzero_si256(); let mut acc_hi = _mm256_setzero_si256(); for slot in 0..n_blocks {
124 let base = slot * n_vec + chunk_start;
125
126 let v = _mm256_loadu_si256(bps.as_ptr().add(base) as *const __m256i);
128
129 let qv = _mm256_set1_epi8(query[slot] as i8);
131
132 let d1 = _mm256_subs_epu8(v, qv);
134 let d2 = _mm256_subs_epu8(qv, v);
135 let diff = _mm256_or_si256(d1, d2);
136
137 let diff_lo128 = _mm256_castsi256_si128(diff);
140 let diff_hi128 = _mm256_extracti128_si256(diff, 1);
141
142 let lo16 = _mm256_cvtepu8_epi16(diff_lo128);
144 let hi16 = _mm256_cvtepu8_epi16(diff_hi128);
145
146 acc_lo = _mm256_add_epi16(acc_lo, lo16);
148 acc_hi = _mm256_add_epi16(acc_hi, hi16);
149 }
150
151 _mm256_storeu_si256(out.as_mut_ptr().add(chunk_start) as *mut __m256i, acc_lo);
153 _mm256_storeu_si256(
154 out.as_mut_ptr().add(chunk_start + 16) as *mut __m256i,
155 acc_hi,
156 );
157 }
158
159 for i in vec_aligned..n_vec {
161 let mut sum: u16 = 0;
162 for slot in 0..n_blocks {
163 let v = bps[slot * n_vec + i];
164 let qv = query[slot];
165 let diff = if v > qv { v - qv } else { qv - v };
166 sum = sum.saturating_add(diff as u16);
167 }
168 out[i] = sum;
169 }
170 }
171}
172
173#[cfg(target_arch = "x86_64")]
174#[target_feature(enable = "avx2")]
175unsafe fn bps_scan_avx2_u32(
176 bps: &[u8],
177 n_vec: usize,
178 n_blocks: usize,
179 query: &[u8],
180 out: &mut [u32],
181) {
182 use std::arch::x86_64::*;
183 unsafe {
184 let vec_aligned = (n_vec / 32) * 32;
186
187 out.iter_mut().take(n_vec).for_each(|d| *d = 0);
189
190 for chunk_start in (0..vec_aligned).step_by(32) {
192 let mut acc_lo = _mm256_setzero_si256(); let mut acc_hi = _mm256_setzero_si256(); for slot in 0..n_blocks {
198 let base = slot * n_vec + chunk_start;
199 let v = _mm256_loadu_si256(bps.as_ptr().add(base) as *const __m256i);
200 let qv = _mm256_set1_epi8(query[slot] as i8);
201
202 let d1 = _mm256_subs_epu8(v, qv);
203 let d2 = _mm256_subs_epu8(qv, v);
204 let diff = _mm256_or_si256(d1, d2);
205
206 let diff_lo128 = _mm256_castsi256_si128(diff);
207 let diff_hi128 = _mm256_extracti128_si256(diff, 1);
208
209 let lo16 = _mm256_cvtepu8_epi16(diff_lo128);
210 let hi16 = _mm256_cvtepu8_epi16(diff_hi128);
211
212 acc_lo = _mm256_add_epi16(acc_lo, lo16);
213 acc_hi = _mm256_add_epi16(acc_hi, hi16);
214 }
215
216 let acc_lo_128_0 = _mm256_castsi256_si128(acc_lo);
222 let acc_lo_128_1 = _mm256_extracti128_si256(acc_lo, 1);
223 let out_0 = _mm256_cvtepu16_epi32(acc_lo_128_0); let out_1 = _mm256_cvtepu16_epi32(acc_lo_128_1); _mm256_storeu_si256(out.as_mut_ptr().add(chunk_start) as *mut __m256i, out_0);
227 _mm256_storeu_si256(out.as_mut_ptr().add(chunk_start + 8) as *mut __m256i, out_1);
228
229 let acc_hi_128_0 = _mm256_castsi256_si128(acc_hi);
231 let acc_hi_128_1 = _mm256_extracti128_si256(acc_hi, 1);
232 let out_2 = _mm256_cvtepu16_epi32(acc_hi_128_0);
233 let out_3 = _mm256_cvtepu16_epi32(acc_hi_128_1);
234
235 _mm256_storeu_si256(
236 out.as_mut_ptr().add(chunk_start + 16) as *mut __m256i,
237 out_2,
238 );
239 _mm256_storeu_si256(
240 out.as_mut_ptr().add(chunk_start + 24) as *mut __m256i,
241 out_3,
242 );
243 }
244
245 for i in vec_aligned..n_vec {
247 let mut sum: u32 = 0;
248 for slot in 0..n_blocks {
249 let v = bps[slot * n_vec + i];
250 let qv = query[slot];
251 let diff = if v > qv { v - qv } else { qv - v };
252 sum += diff as u32;
253 }
254 out[i] = sum;
255 }
256 }
257}
258
259#[cfg(target_arch = "aarch64")]
264#[target_feature(enable = "neon")]
265unsafe fn bps_scan_neon(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
266 use std::arch::aarch64::*;
267
268 unsafe {
269 let vec_aligned = (n_vec / 16) * 16;
271
272 out.iter_mut().take(n_vec).for_each(|d| *d = 0);
274
275 for chunk_start in (0..vec_aligned).step_by(16) {
276 let mut acc_lo = vdupq_n_u16(0);
278 let mut acc_hi = vdupq_n_u16(0);
279
280 for slot in 0..n_blocks {
281 let base = slot * n_vec + chunk_start;
282
283 let q = vdupq_n_u8(query[slot]);
285
286 let db = vld1q_u8(bps.as_ptr().add(base));
288
289 let diff = vabdq_u8(q, db);
291
292 acc_lo = vaddw_u8(acc_lo, vget_low_u8(diff));
294 acc_hi = vaddw_u8(acc_hi, vget_high_u8(diff));
295 }
296
297 vst1q_u16(out.as_mut_ptr().add(chunk_start), acc_lo);
299 vst1q_u16(out.as_mut_ptr().add(chunk_start + 8), acc_hi);
300 }
301
302 for i in vec_aligned..n_vec {
304 let mut sum: u16 = 0;
305 for slot in 0..n_blocks {
306 let v = bps[slot * n_vec + i];
307 let qv = query[slot];
308 let diff = if v > qv { v - qv } else { qv - v };
309 sum = sum.saturating_add(diff as u16);
310 }
311 out[i] = sum;
312 }
313 }
314}
315
316#[cfg(target_arch = "aarch64")]
317#[target_feature(enable = "neon")]
318unsafe fn bps_scan_neon_u32(
319 bps: &[u8],
320 n_vec: usize,
321 n_blocks: usize,
322 query: &[u8],
323 out: &mut [u32],
324) {
325 use std::arch::aarch64::*;
326
327 unsafe {
328 let vec_aligned = (n_vec / 16) * 16;
329
330 out.iter_mut().take(n_vec).for_each(|d| *d = 0);
331
332 for chunk_start in (0..vec_aligned).step_by(16) {
333 let mut acc_lo = vdupq_n_u16(0);
334 let mut acc_hi = vdupq_n_u16(0);
335
336 for slot in 0..n_blocks {
337 let base = slot * n_vec + chunk_start;
338 let q = vdupq_n_u8(query[slot]);
339 let db = vld1q_u8(bps.as_ptr().add(base));
340 let diff = vabdq_u8(q, db);
341
342 acc_lo = vaddw_u8(acc_lo, vget_low_u8(diff));
343 acc_hi = vaddw_u8(acc_hi, vget_high_u8(diff));
344 }
345
346 let d0 = vmovl_u16(vget_low_u16(acc_lo));
348 let d1 = vmovl_u16(vget_high_u16(acc_lo));
349 let d2 = vmovl_u16(vget_low_u16(acc_hi));
350 let d3 = vmovl_u16(vget_high_u16(acc_hi));
351
352 vst1q_u32(out.as_mut_ptr().add(chunk_start), d0);
353 vst1q_u32(out.as_mut_ptr().add(chunk_start + 4), d1);
354 vst1q_u32(out.as_mut_ptr().add(chunk_start + 8), d2);
355 vst1q_u32(out.as_mut_ptr().add(chunk_start + 12), d3);
356 }
357
358 for i in vec_aligned..n_vec {
359 let mut sum: u32 = 0;
360 for slot in 0..n_blocks {
361 let v = bps[slot * n_vec + i];
362 let qv = query[slot];
363 let diff = if v > qv { v - qv } else { qv - v };
364 sum += diff as u32;
365 }
366 out[i] = sum;
367 }
368 }
369}
370
371#[inline]
377fn bps_scan_scalar(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
378 out.iter_mut().take(n_vec).for_each(|d| *d = 0);
380
381 for slot in 0..n_blocks {
382 let q = query[slot];
383 let base = slot * n_vec;
384
385 for vec_id in 0..n_vec {
386 let v = bps[base + vec_id];
387 let diff = if v > q { v - q } else { q - v };
388 out[vec_id] = out[vec_id].saturating_add(diff as u16);
389 }
390 }
391}
392
393#[inline]
395fn bps_scan_scalar_u32(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u32]) {
396 out.iter_mut().take(n_vec).for_each(|d| *d = 0);
397
398 for slot in 0..n_blocks {
399 let q = query[slot];
400 let base = slot * n_vec;
401
402 for vec_id in 0..n_vec {
403 let v = bps[base + vec_id];
404 let diff = if v > q { v - q } else { q - v };
405 out[vec_id] += diff as u32;
406 }
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_bps_scan_basic() {
416 let n_vec = 100;
417 let n_blocks = 8;
418 let bps: Vec<u8> = (0..n_vec * n_blocks).map(|i| (i % 256) as u8).collect();
419 let query: Vec<u8> = (0..n_blocks).map(|i| (i * 10) as u8).collect();
420 let mut out = vec![0u16; n_vec];
421
422 bps_scan(&bps, n_vec, n_blocks, &query, &mut out);
423
424 let mut expected = vec![0u16; n_vec];
426 bps_scan_scalar(&bps, n_vec, n_blocks, &query, &mut expected);
427
428 assert_eq!(out, expected);
429 }
430
431 #[test]
432 fn test_bps_scan_u32_basic() {
433 let n_vec = 100;
434 let n_blocks = 8;
435 let bps: Vec<u8> = (0..n_vec * n_blocks).map(|i| (i % 256) as u8).collect();
436 let query: Vec<u8> = (0..n_blocks).map(|i| (i * 10) as u8).collect();
437 let mut out = vec![0u32; n_vec];
438
439 bps_scan_u32(&bps, n_vec, n_blocks, &query, &mut out);
440
441 let mut expected = vec![0u32; n_vec];
442 bps_scan_scalar_u32(&bps, n_vec, n_blocks, &query, &mut expected);
443
444 assert_eq!(out, expected);
445 }
446
447 #[test]
448 fn test_bps_scan_alignment() {
449 for n_vec in [1, 15, 17, 31, 33, 63, 65, 127] {
451 let n_blocks = 4;
452 let bps: Vec<u8> = (0..n_vec * n_blocks).map(|i| (i % 256) as u8).collect();
453 let query: Vec<u8> = vec![128; n_blocks];
454 let mut out = vec![0u16; n_vec];
455
456 bps_scan(&bps, n_vec, n_blocks, &query, &mut out);
457
458 let mut expected = vec![0u16; n_vec];
459 bps_scan_scalar(&bps, n_vec, n_blocks, &query, &mut expected);
460
461 assert_eq!(out, expected, "Mismatch for n_vec={}", n_vec);
462 }
463 }
464}