Skip to main content

sochdb_vector/simd/
bps_scan.rs

1//! BPS (Block Projection Sketch) L1 Distance Kernel
2//!
3//! This implements the vertical SIMD approach for computing L1 distances
4//! between a query sketch and many vector sketches stored in SoA layout.
5//!
6//! # Algorithm
7//!
8//! For each query sketch Q[0..n_blocks]:
9//!     For each vector V[i] in SoA layout:
10//!         distance[i] = Σ |Q[slot] - V[slot * n_vec + i]|
11//!
12//! # Memory Layout
13//!
14//! The BPS data uses Structure-of-Arrays (SoA) layout:
15//! - `bps[slot * n_vec + vec_id]` gives the sketch value for vector `vec_id` at `slot`
16//!
17//! # SIMD Strategy
18//!
19//! - **AVX2**: Process 32 vectors per iteration using 256-bit registers
20//! - **NEON**: Process 16 vectors per iteration using 128-bit registers
21//! - **Scalar**: Fallback for unsupported platforms
22//!
23//! # Math
24//!
25//! The L1 distance uses the identity:
26//! ```text
27//! |a - b| = max(a - b, 0) + max(b - a, 0) = (a ⊖ b) ∨ (b ⊖ a)
28//! ```
29//! where `⊖` is saturating subtraction and `∨` is bitwise OR.
30
31use super::dispatch::cpu_features;
32
33/// Compute BPS L1 distances between query and database vectors.
34///
35/// # Arguments
36/// * `bps` - BPS data in SoA layout: `bps[slot * n_vec + vec_id]`
37/// * `n_vec` - Number of vectors in the database
38/// * `n_blocks` - Number of blocks in each sketch
39/// * `query` - Query sketch values
40/// * `out` - Output distances (u16)
41///
42/// # Panics
43/// Panics if `query.len() < n_blocks` or `out.len() < n_vec`
44#[inline]
45pub fn bps_scan(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
46    assert!(query.len() >= n_blocks, "query too short");
47    assert!(out.len() >= n_vec, "output buffer too small");
48
49    let features = cpu_features();
50
51    #[cfg(target_arch = "x86_64")]
52    {
53        if features.has_avx2 {
54            // Safety: AVX2 feature is verified
55            unsafe { bps_scan_avx2(bps, n_vec, n_blocks, query, out) };
56            return;
57        }
58    }
59
60    #[cfg(target_arch = "aarch64")]
61    {
62        if features.has_neon {
63            // Safety: NEON is mandatory on aarch64
64            unsafe { bps_scan_neon(bps, n_vec, n_blocks, query, out) };
65            return;
66        }
67    }
68
69    // Scalar fallback
70    bps_scan_scalar(bps, n_vec, n_blocks, query, out);
71}
72
73/// Compute BPS L1 distances with u32 output.
74///
75/// Same as `bps_scan` but outputs u32 distances for larger accumulations.
76#[inline]
77pub fn bps_scan_u32(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u32]) {
78    assert!(query.len() >= n_blocks, "query too short");
79    assert!(out.len() >= n_vec, "output buffer too small");
80
81    let features = cpu_features();
82
83    #[cfg(target_arch = "x86_64")]
84    {
85        if features.has_avx2 {
86            unsafe { bps_scan_avx2_u32(bps, n_vec, n_blocks, query, out) };
87            return;
88        }
89    }
90
91    #[cfg(target_arch = "aarch64")]
92    {
93        if features.has_neon {
94            unsafe { bps_scan_neon_u32(bps, n_vec, n_blocks, query, out) };
95            return;
96        }
97    }
98
99    bps_scan_scalar_u32(bps, n_vec, n_blocks, query, out);
100}
101
102// ============================================================================
103// x86_64 AVX2 Implementation
104// ============================================================================
105
106#[cfg(target_arch = "x86_64")]
107#[target_feature(enable = "avx2")]
108unsafe fn bps_scan_avx2(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
109    use std::arch::x86_64::*;
110    unsafe {
111        // Process 32 vectors at a time (256 bits / 8 bits = 32)
112        let vec_aligned = (n_vec / 32) * 32;
113
114        // Zero output
115        out.iter_mut().take(n_vec).for_each(|d| *d = 0);
116
117        // Main loop: process 32 vectors at a time
118        for chunk_start in (0..vec_aligned).step_by(32) {
119            // Accumulators for 32 vectors (split into 2x16 u16)
120            let mut acc_lo = _mm256_setzero_si256(); // Vectors 0-15
121            let mut acc_hi = _mm256_setzero_si256(); // Vectors 16-31
122
123            for slot in 0..n_blocks {
124                let base = slot * n_vec + chunk_start;
125
126                // Load 32 vector values
127                let v = _mm256_loadu_si256(bps.as_ptr().add(base) as *const __m256i);
128
129                // Broadcast query value
130                let qv = _mm256_set1_epi8(query[slot] as i8);
131
132                // Compute absolute difference: |a - b| = (a ⊖ b) ∨ (b ⊖ a)
133                let d1 = _mm256_subs_epu8(v, qv);
134                let d2 = _mm256_subs_epu8(qv, v);
135                let diff = _mm256_or_si256(d1, d2);
136
137                // Widen u8 → u16 and accumulate
138                // Extract low and high 128-bit lanes
139                let diff_lo128 = _mm256_castsi256_si128(diff);
140                let diff_hi128 = _mm256_extracti128_si256(diff, 1);
141
142                // Zero-extend u8 to u16
143                let lo16 = _mm256_cvtepu8_epi16(diff_lo128);
144                let hi16 = _mm256_cvtepu8_epi16(diff_hi128);
145
146                // Accumulate
147                acc_lo = _mm256_add_epi16(acc_lo, lo16);
148                acc_hi = _mm256_add_epi16(acc_hi, hi16);
149            }
150
151            // Store results
152            _mm256_storeu_si256(out.as_mut_ptr().add(chunk_start) as *mut __m256i, acc_lo);
153            _mm256_storeu_si256(
154                out.as_mut_ptr().add(chunk_start + 16) as *mut __m256i,
155                acc_hi,
156            );
157        }
158
159        // Handle remaining vectors with scalar code
160        for i in vec_aligned..n_vec {
161            let mut sum: u16 = 0;
162            for slot in 0..n_blocks {
163                let v = bps[slot * n_vec + i];
164                let qv = query[slot];
165                let diff = if v > qv { v - qv } else { qv - v };
166                sum = sum.saturating_add(diff as u16);
167            }
168            out[i] = sum;
169        }
170    }
171}
172
173#[cfg(target_arch = "x86_64")]
174#[target_feature(enable = "avx2")]
175unsafe fn bps_scan_avx2_u32(
176    bps: &[u8],
177    n_vec: usize,
178    n_blocks: usize,
179    query: &[u8],
180    out: &mut [u32],
181) {
182    use std::arch::x86_64::*;
183    unsafe {
184        // Process 32 vectors at a time
185        let vec_aligned = (n_vec / 32) * 32;
186
187        // Zero output
188        out.iter_mut().take(n_vec).for_each(|d| *d = 0);
189
190        // Main loop: process 32 vectors at a time
191        for chunk_start in (0..vec_aligned).step_by(32) {
192            // Accumulators - need 8 x 4 = 32 u32 values
193            // We'll use intermediate u16 accumulators and widen at the end
194            let mut acc_lo = _mm256_setzero_si256(); // Vectors 0-15 as u16
195            let mut acc_hi = _mm256_setzero_si256(); // Vectors 16-31 as u16
196
197            for slot in 0..n_blocks {
198                let base = slot * n_vec + chunk_start;
199                let v = _mm256_loadu_si256(bps.as_ptr().add(base) as *const __m256i);
200                let qv = _mm256_set1_epi8(query[slot] as i8);
201
202                let d1 = _mm256_subs_epu8(v, qv);
203                let d2 = _mm256_subs_epu8(qv, v);
204                let diff = _mm256_or_si256(d1, d2);
205
206                let diff_lo128 = _mm256_castsi256_si128(diff);
207                let diff_hi128 = _mm256_extracti128_si256(diff, 1);
208
209                let lo16 = _mm256_cvtepu8_epi16(diff_lo128);
210                let hi16 = _mm256_cvtepu8_epi16(diff_hi128);
211
212                acc_lo = _mm256_add_epi16(acc_lo, lo16);
213                acc_hi = _mm256_add_epi16(acc_hi, hi16);
214            }
215
216            // Widen u16 to u32 and store
217            // acc_lo contains 16 u16 values for vectors 0-15
218            // acc_hi contains 16 u16 values for vectors 16-31
219
220            // Extract and widen acc_lo
221            let acc_lo_128_0 = _mm256_castsi256_si128(acc_lo);
222            let acc_lo_128_1 = _mm256_extracti128_si256(acc_lo, 1);
223            let out_0 = _mm256_cvtepu16_epi32(acc_lo_128_0); // 8 u32
224            let out_1 = _mm256_cvtepu16_epi32(acc_lo_128_1); // 8 u32
225
226            _mm256_storeu_si256(out.as_mut_ptr().add(chunk_start) as *mut __m256i, out_0);
227            _mm256_storeu_si256(out.as_mut_ptr().add(chunk_start + 8) as *mut __m256i, out_1);
228
229            // Extract and widen acc_hi
230            let acc_hi_128_0 = _mm256_castsi256_si128(acc_hi);
231            let acc_hi_128_1 = _mm256_extracti128_si256(acc_hi, 1);
232            let out_2 = _mm256_cvtepu16_epi32(acc_hi_128_0);
233            let out_3 = _mm256_cvtepu16_epi32(acc_hi_128_1);
234
235            _mm256_storeu_si256(
236                out.as_mut_ptr().add(chunk_start + 16) as *mut __m256i,
237                out_2,
238            );
239            _mm256_storeu_si256(
240                out.as_mut_ptr().add(chunk_start + 24) as *mut __m256i,
241                out_3,
242            );
243        }
244
245        // Handle remaining vectors
246        for i in vec_aligned..n_vec {
247            let mut sum: u32 = 0;
248            for slot in 0..n_blocks {
249                let v = bps[slot * n_vec + i];
250                let qv = query[slot];
251                let diff = if v > qv { v - qv } else { qv - v };
252                sum += diff as u32;
253            }
254            out[i] = sum;
255        }
256    }
257}
258
259// ============================================================================
260// aarch64 NEON Implementation
261// ============================================================================
262
263#[cfg(target_arch = "aarch64")]
264#[target_feature(enable = "neon")]
265unsafe fn bps_scan_neon(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
266    use std::arch::aarch64::*;
267
268    unsafe {
269        // Process 16 vectors at a time (128 bits / 8 bits = 16)
270        let vec_aligned = (n_vec / 16) * 16;
271
272        // Zero output
273        out.iter_mut().take(n_vec).for_each(|d| *d = 0);
274
275        for chunk_start in (0..vec_aligned).step_by(16) {
276            // Accumulators for 16 vectors as u16 (split into 2x8)
277            let mut acc_lo = vdupq_n_u16(0);
278            let mut acc_hi = vdupq_n_u16(0);
279
280            for slot in 0..n_blocks {
281                let base = slot * n_vec + chunk_start;
282
283                // Broadcast query byte
284                let q = vdupq_n_u8(query[slot]);
285
286                // Load 16 database bytes
287                let db = vld1q_u8(bps.as_ptr().add(base));
288
289                // Compute |q - db| using vabdq_u8 (single instruction on NEON!)
290                let diff = vabdq_u8(q, db);
291
292                // Widen to u16 and accumulate
293                acc_lo = vaddw_u8(acc_lo, vget_low_u8(diff));
294                acc_hi = vaddw_u8(acc_hi, vget_high_u8(diff));
295            }
296
297            // Store 16 distances
298            vst1q_u16(out.as_mut_ptr().add(chunk_start), acc_lo);
299            vst1q_u16(out.as_mut_ptr().add(chunk_start + 8), acc_hi);
300        }
301
302        // Handle remainder
303        for i in vec_aligned..n_vec {
304            let mut sum: u16 = 0;
305            for slot in 0..n_blocks {
306                let v = bps[slot * n_vec + i];
307                let qv = query[slot];
308                let diff = if v > qv { v - qv } else { qv - v };
309                sum = sum.saturating_add(diff as u16);
310            }
311            out[i] = sum;
312        }
313    }
314}
315
316#[cfg(target_arch = "aarch64")]
317#[target_feature(enable = "neon")]
318unsafe fn bps_scan_neon_u32(
319    bps: &[u8],
320    n_vec: usize,
321    n_blocks: usize,
322    query: &[u8],
323    out: &mut [u32],
324) {
325    use std::arch::aarch64::*;
326
327    unsafe {
328        let vec_aligned = (n_vec / 16) * 16;
329
330        out.iter_mut().take(n_vec).for_each(|d| *d = 0);
331
332        for chunk_start in (0..vec_aligned).step_by(16) {
333            let mut acc_lo = vdupq_n_u16(0);
334            let mut acc_hi = vdupq_n_u16(0);
335
336            for slot in 0..n_blocks {
337                let base = slot * n_vec + chunk_start;
338                let q = vdupq_n_u8(query[slot]);
339                let db = vld1q_u8(bps.as_ptr().add(base));
340                let diff = vabdq_u8(q, db);
341
342                acc_lo = vaddw_u8(acc_lo, vget_low_u8(diff));
343                acc_hi = vaddw_u8(acc_hi, vget_high_u8(diff));
344            }
345
346            // Widen u16 to u32 and store
347            let d0 = vmovl_u16(vget_low_u16(acc_lo));
348            let d1 = vmovl_u16(vget_high_u16(acc_lo));
349            let d2 = vmovl_u16(vget_low_u16(acc_hi));
350            let d3 = vmovl_u16(vget_high_u16(acc_hi));
351
352            vst1q_u32(out.as_mut_ptr().add(chunk_start), d0);
353            vst1q_u32(out.as_mut_ptr().add(chunk_start + 4), d1);
354            vst1q_u32(out.as_mut_ptr().add(chunk_start + 8), d2);
355            vst1q_u32(out.as_mut_ptr().add(chunk_start + 12), d3);
356        }
357
358        for i in vec_aligned..n_vec {
359            let mut sum: u32 = 0;
360            for slot in 0..n_blocks {
361                let v = bps[slot * n_vec + i];
362                let qv = query[slot];
363                let diff = if v > qv { v - qv } else { qv - v };
364                sum += diff as u32;
365            }
366            out[i] = sum;
367        }
368    }
369}
370
371// ============================================================================
372// Scalar Fallback
373// ============================================================================
374
375/// Scalar fallback for BPS scan (u16 output)
376#[inline]
377fn bps_scan_scalar(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u16]) {
378    // Zero output
379    out.iter_mut().take(n_vec).for_each(|d| *d = 0);
380
381    for slot in 0..n_blocks {
382        let q = query[slot];
383        let base = slot * n_vec;
384
385        for vec_id in 0..n_vec {
386            let v = bps[base + vec_id];
387            let diff = if v > q { v - q } else { q - v };
388            out[vec_id] = out[vec_id].saturating_add(diff as u16);
389        }
390    }
391}
392
393/// Scalar fallback for BPS scan (u32 output)
394#[inline]
395fn bps_scan_scalar_u32(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u32]) {
396    out.iter_mut().take(n_vec).for_each(|d| *d = 0);
397
398    for slot in 0..n_blocks {
399        let q = query[slot];
400        let base = slot * n_vec;
401
402        for vec_id in 0..n_vec {
403            let v = bps[base + vec_id];
404            let diff = if v > q { v - q } else { q - v };
405            out[vec_id] += diff as u32;
406        }
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_bps_scan_basic() {
416        let n_vec = 100;
417        let n_blocks = 8;
418        let bps: Vec<u8> = (0..n_vec * n_blocks).map(|i| (i % 256) as u8).collect();
419        let query: Vec<u8> = (0..n_blocks).map(|i| (i * 10) as u8).collect();
420        let mut out = vec![0u16; n_vec];
421
422        bps_scan(&bps, n_vec, n_blocks, &query, &mut out);
423
424        // Verify against scalar
425        let mut expected = vec![0u16; n_vec];
426        bps_scan_scalar(&bps, n_vec, n_blocks, &query, &mut expected);
427
428        assert_eq!(out, expected);
429    }
430
431    #[test]
432    fn test_bps_scan_u32_basic() {
433        let n_vec = 100;
434        let n_blocks = 8;
435        let bps: Vec<u8> = (0..n_vec * n_blocks).map(|i| (i % 256) as u8).collect();
436        let query: Vec<u8> = (0..n_blocks).map(|i| (i * 10) as u8).collect();
437        let mut out = vec![0u32; n_vec];
438
439        bps_scan_u32(&bps, n_vec, n_blocks, &query, &mut out);
440
441        let mut expected = vec![0u32; n_vec];
442        bps_scan_scalar_u32(&bps, n_vec, n_blocks, &query, &mut expected);
443
444        assert_eq!(out, expected);
445    }
446
447    #[test]
448    fn test_bps_scan_alignment() {
449        // Test with sizes that don't align to SIMD width
450        for n_vec in [1, 15, 17, 31, 33, 63, 65, 127] {
451            let n_blocks = 4;
452            let bps: Vec<u8> = (0..n_vec * n_blocks).map(|i| (i % 256) as u8).collect();
453            let query: Vec<u8> = vec![128; n_blocks];
454            let mut out = vec![0u16; n_vec];
455
456            bps_scan(&bps, n_vec, n_blocks, &query, &mut out);
457
458            let mut expected = vec![0u16; n_vec];
459            bps_scan_scalar(&bps, n_vec, n_blocks, &query, &mut expected);
460
461            assert_eq!(out, expected, "Mismatch for n_vec={}", n_vec);
462        }
463    }
464}