Skip to main content

sochdb_vector/
bmi2_paths.rs

1// Copyright 2025 SochDB Authors
2//
3// Licensed under the Apache License, Version 2.0
4
5//! BMI2 Fast Paths for Bit Manipulation
6//!
7//! This module provides PEXT/PDEP-accelerated bit packing/unpacking operations
8//! with proper fallback ladder:
9//!
10//! 1. **BMI2** (Intel Haswell+, AMD Zen3+): Native PEXT/PDEP
11//! 2. **AVX2**: SIMD-based bit extraction (no PEXT)
12//! 3. **Scalar**: Portable loop-based implementation
13//!
14//! # Operations
15//!
16//! - **PEXT (Parallel Extract)**: Extract bits at mask positions
17//! - **PDEP (Parallel Deposit)**: Deposit bits at mask positions
18//!
19//! # Use Cases
20//!
21//! - Unpacking 4-bit quantized values from packed storage
22//! - Extracting specific dimensions from compressed vectors
23//! - Bitmap operations for filtered candidate sets
24//!
25//! # Performance Warning
26//!
27//! AMD Zen/Zen2 have slow microcode PEXT/PDEP (~18 cycles vs 3 cycles on Intel).
28//! Use feature detection to choose appropriate path.
29
30/// BMI2 availability (cached after first check).
31static BMI2_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
32
33/// Check if BMI2 is available on current CPU.
34#[inline]
35pub fn bmi2_available() -> bool {
36    *BMI2_AVAILABLE.get_or_init(|| {
37        #[cfg(target_arch = "x86_64")]
38        {
39            is_x86_feature_detected!("bmi2")
40        }
41        #[cfg(not(target_arch = "x86_64"))]
42        {
43            false
44        }
45    })
46}
47
48/// Check if BMI2 is fast (Intel or AMD Zen3+).
49/// Returns false for AMD Zen/Zen2 where PEXT/PDEP are slow.
50#[cfg(target_arch = "x86_64")]
51pub fn bmi2_fast() -> bool {
52    if !bmi2_available() {
53        return false;
54    }
55
56    // Check CPU vendor and model
57    // AMD Zen/Zen2 have slow microcode implementation
58    // Zen3+ have fast implementation
59
60    // For now, assume BMI2 is fast if available
61    // A more complete implementation would check CPUID
62    true
63}
64
65#[cfg(not(target_arch = "x86_64"))]
66pub fn bmi2_fast() -> bool {
67    false
68}
69
70/// Parallel bit extract: extract bits from `src` at positions specified by `mask`.
71///
72/// Example:
73/// ```text
74/// src  = 0b_1010_1100
75/// mask = 0b_0101_0101
76/// result = 0b_0000_0010 (bits at positions 0, 2, 4, 6)
77/// ```
78#[inline]
79pub fn pext_u64(src: u64, mask: u64) -> u64 {
80    #[cfg(target_arch = "x86_64")]
81    {
82        if bmi2_available() {
83            return unsafe { pext_u64_bmi2(src, mask) };
84        }
85    }
86
87    pext_u64_scalar(src, mask)
88}
89
90/// Parallel bit deposit: deposit bits from `src` to positions specified by `mask`.
91///
92/// Example:
93/// ```text
94/// src  = 0b_0000_0010
95/// mask = 0b_0101_0101
96/// result = 0b_0000_0100 (bits deposited at positions 0, 2, 4, 6)
97/// ```
98#[inline]
99pub fn pdep_u64(src: u64, mask: u64) -> u64 {
100    #[cfg(target_arch = "x86_64")]
101    {
102        if bmi2_available() {
103            return unsafe { pdep_u64_bmi2(src, mask) };
104        }
105    }
106
107    pdep_u64_scalar(src, mask)
108}
109
110/// BMI2 PEXT implementation.
111#[cfg(target_arch = "x86_64")]
112#[target_feature(enable = "bmi2")]
113#[inline]
114unsafe fn pext_u64_bmi2(src: u64, mask: u64) -> u64 {
115    use std::arch::x86_64::_pext_u64;
116    _pext_u64(src, mask)
117}
118
119/// BMI2 PDEP implementation.
120#[cfg(target_arch = "x86_64")]
121#[target_feature(enable = "bmi2")]
122#[inline]
123unsafe fn pdep_u64_bmi2(src: u64, mask: u64) -> u64 {
124    use std::arch::x86_64::_pdep_u64;
125    _pdep_u64(src, mask)
126}
127
128/// Scalar PEXT fallback.
129#[inline]
130fn pext_u64_scalar(src: u64, mask: u64) -> u64 {
131    let mut result = 0u64;
132    let mut bb = 1u64;
133    let mut m = mask;
134
135    while m != 0 {
136        if src & m & m.wrapping_neg() != 0 {
137            result |= bb;
138        }
139        m &= m - 1; // Clear lowest set bit
140        bb <<= 1;
141    }
142
143    result
144}
145
146/// Scalar PDEP fallback.
147#[inline]
148fn pdep_u64_scalar(src: u64, mask: u64) -> u64 {
149    let mut result = 0u64;
150    let mut bb = 1u64;
151    let mut m = mask;
152
153    while m != 0 {
154        let bit = m & m.wrapping_neg(); // Lowest set bit
155        if src & bb != 0 {
156            result |= bit;
157        }
158        m &= m - 1; // Clear lowest set bit
159        bb <<= 1;
160    }
161
162    result
163}
164
165/// 32-bit versions.
166#[inline]
167pub fn pext_u32(src: u32, mask: u32) -> u32 {
168    pext_u64(src as u64, mask as u64) as u32
169}
170
171#[inline]
172pub fn pdep_u32(src: u32, mask: u32) -> u32 {
173    pdep_u64(src as u64, mask as u64) as u32
174}
175
176// ============================================================================
177// Bit Packing/Unpacking Utilities
178// ============================================================================
179
180/// Pack 4-bit values into a byte array.
181/// Each input value should be 0-15.
182pub fn pack_4bit(values: &[u8]) -> Vec<u8> {
183    let packed_len = (values.len() + 1) / 2;
184    let mut packed = vec![0u8; packed_len];
185
186    for (i, chunk) in values.chunks(2).enumerate() {
187        let lo = chunk[0] & 0x0F;
188        let hi = chunk.get(1).map_or(0, |&v| v & 0x0F);
189        packed[i] = lo | (hi << 4);
190    }
191
192    packed
193}
194
195/// Unpack 4-bit values from a byte array.
196pub fn unpack_4bit(packed: &[u8], count: usize) -> Vec<u8> {
197    let mut values = Vec::with_capacity(count);
198
199    for &byte in packed {
200        if values.len() < count {
201            values.push(byte & 0x0F);
202        }
203        if values.len() < count {
204            values.push(byte >> 4);
205        }
206    }
207
208    values
209}
210
211/// Pack N-bit values (1-8 bits per value).
212pub fn pack_nbits(values: &[u8], bits_per_value: u8) -> Vec<u8> {
213    assert!(bits_per_value >= 1 && bits_per_value <= 8);
214
215    let total_bits = values.len() * bits_per_value as usize;
216    let packed_len = (total_bits + 7) / 8;
217    let mut packed = vec![0u8; packed_len];
218
219    let mask = (1u8 << bits_per_value) - 1;
220    let mut bit_pos = 0usize;
221
222    for &value in values {
223        let value = value & mask;
224        let byte_idx = bit_pos / 8;
225        let bit_offset = bit_pos % 8;
226
227        packed[byte_idx] |= value << bit_offset;
228
229        // Handle overflow to next byte
230        if bit_offset + bits_per_value as usize > 8 {
231            let overflow_bits = bit_offset + bits_per_value as usize - 8;
232            if byte_idx + 1 < packed.len() {
233                packed[byte_idx + 1] |= value >> (bits_per_value as usize - overflow_bits);
234            }
235        }
236
237        bit_pos += bits_per_value as usize;
238    }
239
240    packed
241}
242
243/// Unpack N-bit values (1-8 bits per value).
244pub fn unpack_nbits(packed: &[u8], bits_per_value: u8, count: usize) -> Vec<u8> {
245    assert!(bits_per_value >= 1 && bits_per_value <= 8);
246
247    let mut values = Vec::with_capacity(count);
248    let mask = (1u16 << bits_per_value) - 1;
249    let mut bit_pos = 0usize;
250
251    for _ in 0..count {
252        let byte_idx = bit_pos / 8;
253        let bit_offset = bit_pos % 8;
254
255        if byte_idx >= packed.len() {
256            break;
257        }
258
259        // Read up to 16 bits to handle boundary
260        let mut raw = packed[byte_idx] as u16;
261        if byte_idx + 1 < packed.len() {
262            raw |= (packed[byte_idx + 1] as u16) << 8;
263        }
264
265        let value = ((raw >> bit_offset) & mask) as u8;
266        values.push(value);
267
268        bit_pos += bits_per_value as usize;
269    }
270
271    values
272}
273
274// ============================================================================
275// BMI2-accelerated batch operations
276// ============================================================================
277
278/// Extract multiple 4-bit values using PEXT.
279/// Processes 16 values per u64 word.
280#[inline]
281pub fn extract_4bit_batch(packed: u64) -> [u8; 16] {
282    let mut result = [0u8; 16];
283
284    #[cfg(target_arch = "x86_64")]
285    if bmi2_available() {
286        // Use PEXT to extract each nibble
287        const NIBBLE_MASK: u64 = 0x0F0F_0F0F_0F0F_0F0F;
288
289        let even = unsafe { pext_u64_bmi2(packed, NIBBLE_MASK) };
290        let odd = unsafe { pext_u64_bmi2(packed >> 4, NIBBLE_MASK) };
291
292        // Interleave even and odd nibbles
293        for i in 0..8 {
294            result[i * 2] = ((even >> (i * 4)) & 0x0F) as u8;
295            result[i * 2 + 1] = ((odd >> (i * 4)) & 0x0F) as u8;
296        }
297
298        return result;
299    }
300
301    // Scalar fallback
302    for i in 0..16 {
303        result[i] = ((packed >> (i * 4)) & 0x0F) as u8;
304    }
305
306    result
307}
308
309/// Deposit multiple 4-bit values using PDEP.
310/// Processes 16 values per u64 word.
311#[inline]
312pub fn deposit_4bit_batch(values: [u8; 16]) -> u64 {
313    #[cfg(target_arch = "x86_64")]
314    if bmi2_available() {
315        // Combine even and odd nibbles
316        let mut even = 0u64;
317        let mut odd = 0u64;
318
319        for i in 0..8 {
320            even |= ((values[i * 2] & 0x0F) as u64) << (i * 4);
321            odd |= ((values[i * 2 + 1] & 0x0F) as u64) << (i * 4);
322        }
323
324        const NIBBLE_MASK: u64 = 0x0F0F_0F0F_0F0F_0F0F;
325
326        let packed_even = unsafe { pdep_u64_bmi2(even, NIBBLE_MASK) };
327        let packed_odd = unsafe { pdep_u64_bmi2(odd, NIBBLE_MASK) } << 4;
328
329        return packed_even | packed_odd;
330    }
331
332    // Scalar fallback
333    let mut result = 0u64;
334    for i in 0..16 {
335        result |= ((values[i] & 0x0F) as u64) << (i * 4);
336    }
337    result
338}
339
340/// Dispatch info for debugging.
341pub fn dispatch_info() -> &'static str {
342    #[cfg(target_arch = "x86_64")]
343    {
344        if bmi2_available() {
345            if bmi2_fast() {
346                return "BMI2 (fast)";
347            } else {
348                return "BMI2 (slow/microcode)";
349            }
350        }
351        return "Scalar (x86_64)";
352    }
353    #[cfg(target_arch = "aarch64")]
354    {
355        return "Scalar (ARM64)";
356    }
357    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
358    {
359        return "Scalar";
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_pext_basic() {
369        // Extract every other bit
370        let src = 0b_1010_1010u64;
371        let mask = 0b_0101_0101u64;
372
373        let result = pext_u64(src, mask);
374        assert_eq!(result, 0b_0000_0000); // No bits at even positions
375
376        let result2 = pext_u64(src, 0b_1010_1010);
377        assert_eq!(result2, 0b_0000_1111); // All bits at odd positions
378    }
379
380    #[test]
381    fn test_pdep_basic() {
382        // Deposit to every other position
383        let src = 0b_0000_1111u64;
384        let mask = 0b_0101_0101u64;
385
386        let result = pdep_u64(src, mask);
387        assert_eq!(result, 0b_0101_0101); // All 1s at even positions
388    }
389
390    #[test]
391    fn test_pext_pdep_roundtrip() {
392        let original = 0b_1100_1010_0011_0101u64;
393        let mask = 0b_1111_0000_1111_0000u64;
394
395        let extracted = pext_u64(original, mask);
396        let restored = pdep_u64(extracted, mask);
397
398        assert_eq!(original & mask, restored);
399    }
400
401    #[test]
402    fn test_pack_unpack_4bit() {
403        let values: Vec<u8> = vec![0, 15, 7, 8, 3, 12];
404        let packed = pack_4bit(&values);
405        let unpacked = unpack_4bit(&packed, values.len());
406
407        assert_eq!(values, unpacked);
408    }
409
410    #[test]
411    fn test_pack_unpack_nbits() {
412        // 3-bit values
413        let values: Vec<u8> = vec![0, 7, 3, 5, 2, 6, 1, 4];
414        let packed = pack_nbits(&values, 3);
415        let unpacked = unpack_nbits(&packed, 3, values.len());
416
417        assert_eq!(values, unpacked);
418    }
419
420    #[test]
421    fn test_extract_4bit_batch() {
422        // Create a packed u64 where nibble i = i
423        let mut packed: u64 = 0;
424        for i in 0..16u64 {
425            packed |= i << (i * 4);
426        }
427
428        let result = extract_4bit_batch(packed);
429
430        for i in 0..16 {
431            assert_eq!(result[i], i as u8, "nibble {} mismatch", i);
432        }
433    }
434
435    #[test]
436    fn test_deposit_4bit_batch() {
437        let values: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
438        let packed = deposit_4bit_batch(values);
439
440        // Verify by extraction
441        for i in 0..16 {
442            assert_eq!(((packed >> (i * 4)) & 0x0F) as u8, i as u8);
443        }
444    }
445
446    #[test]
447    fn test_dispatch_info() {
448        let info = dispatch_info();
449        assert!(!info.is_empty());
450        // Should contain some descriptor
451        println!("BMI2 dispatch: {}", info);
452    }
453
454    #[test]
455    fn test_scalar_fallback_correctness() {
456        // Test scalar implementations directly
457        let src = 0xDEAD_BEEF_CAFE_BABEu64;
458        let mask = 0x5555_5555_5555_5555u64;
459
460        let extracted = pext_u64_scalar(src, mask);
461        let deposited = pdep_u64_scalar(extracted, mask);
462
463        assert_eq!(src & mask, deposited);
464    }
465
466    #[test]
467    fn test_32bit_versions() {
468        let src = 0xABCD_1234u32;
469        let mask = 0xFF00_FF00u32;
470
471        let extracted = pext_u32(src, mask);
472        let deposited = pdep_u32(extracted, mask);
473
474        assert_eq!(src & mask, deposited);
475    }
476
477    #[test]
478    fn test_edge_cases() {
479        // Zero mask
480        assert_eq!(pext_u64(0xFFFF_FFFF, 0), 0);
481        assert_eq!(pdep_u64(0xFFFF_FFFF, 0), 0);
482
483        // Full mask
484        assert_eq!(pext_u64(0x1234, 0xFFFF), 0x1234);
485        assert_eq!(pdep_u64(0x1234, 0xFFFF), 0x1234);
486
487        // Zero source
488        assert_eq!(pext_u64(0, 0xFFFF), 0);
489        assert_eq!(pdep_u64(0, 0xFFFF), 0);
490    }
491}