Skip to main content

sochdb_vector/segment/
bps.rs

1//! BPS (Block Projection Sketch) builder and utilities.
2//!
3//! BPS divides vectors into blocks and computes short projections per block,
4//! stored in SoA layout for vertical SIMD scanning.
5
6use crate::config::BpsConfig;
7use crate::dispatch::BpsScanDispatcher;
8use bytemuck::{Pod, Zeroable};
9use rand::Rng;
10use rand::SeedableRng;
11use rand_xoshiro::Xoshiro256PlusPlus;
12
13/// BPS quantization parameters per slot (min, inv_range)
14#[repr(C)]
15#[derive(Debug, Clone, Copy, Pod, Zeroable)]
16pub struct BpsQParam {
17    pub min: f32,
18    pub inv_range: f32,
19}
20
21/// Builder for BPS sketches
22pub struct BpsBuilder<'a> {
23    config: &'a BpsConfig,
24    vectors: &'a [Vec<f32>],
25    projection_vectors: Vec<Vec<f32>>,
26}
27
28/// Seed for deterministic BPS projections
29const BPS_SEED: u64 = 0xBEEF_CAFE_1234_5678;
30
31impl<'a> BpsBuilder<'a> {
32    /// Create a new BPS builder
33    pub fn new(config: &'a BpsConfig, rotated_vectors: &'a [Vec<f32>]) -> Self {
34        // Generate random projection vectors (one per block per projection)
35        let mut rng = Xoshiro256PlusPlus::seed_from_u64(BPS_SEED);
36        let num_blocks = config.num_blocks as usize;
37        let block_size = config.block_size as usize;
38        let num_proj = config.num_projections as usize;
39
40        let mut projection_vectors = Vec::with_capacity(num_blocks * num_proj);
41        for _ in 0..(num_blocks * num_proj) {
42            let proj: Vec<f32> = (0..block_size)
43                .map(|_| {
44                    // Random unit vector component
45                    let v: f32 = rng.gen_range(-1.0..1.0);
46                    v
47                })
48                .collect();
49            // Normalize
50            let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
51            let normalized: Vec<f32> = proj.iter().map(|x| x / norm.max(1e-10)).collect();
52            projection_vectors.push(normalized);
53        }
54
55        Self {
56            config,
57            vectors: rotated_vectors,
58            projection_vectors,
59        }
60    }
61
62    /// Build BPS data in SoA layout with quantization parameters
63    /// Layout: [(block * num_proj + proj) * n_vec + vec]
64    /// Returns: (bps_data, qparams per slot)
65    pub fn build(&self) -> (Vec<u8>, Vec<BpsQParam>) {
66        let n_vec = self.vectors.len();
67        let num_blocks = self.config.num_blocks as usize;
68        let num_proj = self.config.num_projections as usize;
69        let block_size = self.config.block_size as usize;
70        let num_slots = num_blocks * num_proj;
71
72        // Compute projections for all vectors
73        let mut projections: Vec<Vec<f32>> = Vec::with_capacity(n_vec);
74        for vec in self.vectors {
75            let mut vec_proj = Vec::with_capacity(num_slots);
76            for block_idx in 0..num_blocks {
77                let block_start = block_idx * block_size;
78                let block_end = (block_start + block_size).min(vec.len());
79
80                for proj_idx in 0..num_proj {
81                    let proj_vec = &self.projection_vectors[block_idx * num_proj + proj_idx];
82                    let mut dot = 0.0f32;
83                    for (i, &v) in vec[block_start..block_end].iter().enumerate() {
84                        if i < proj_vec.len() {
85                            dot += v * proj_vec[i];
86                        }
87                    }
88                    vec_proj.push(dot);
89                }
90            }
91            projections.push(vec_proj);
92        }
93
94        // Find min/max per slot for quantization
95        let mut mins = vec![f32::MAX; num_slots];
96        let mut maxs = vec![f32::MIN; num_slots];
97        for proj in &projections {
98            for (i, &v) in proj.iter().enumerate() {
99                mins[i] = mins[i].min(v);
100                maxs[i] = maxs[i].max(v);
101            }
102        }
103
104        // Build qparams
105        let mut qparams = Vec::with_capacity(num_slots);
106        for slot in 0..num_slots {
107            let range = (maxs[slot] - mins[slot]).max(1e-10);
108            qparams.push(BpsQParam {
109                min: mins[slot],
110                inv_range: 255.0 / range,
111            });
112        }
113
114        // Quantize to u8 and store in SoA layout
115        let mut bps_data = vec![0u8; num_slots * n_vec];
116        for (vec_id, proj) in projections.iter().enumerate() {
117            for (slot_idx, &value) in proj.iter().enumerate() {
118                let normalized = ((value - qparams[slot_idx].min) * qparams[slot_idx].inv_range)
119                    .clamp(0.0, 255.0);
120
121                // SoA index: slot_idx * n_vec + vec_id
122                let idx = slot_idx * n_vec + vec_id;
123                bps_data[idx] = normalized as u8;
124            }
125        }
126
127        (bps_data, qparams)
128    }
129
130    /// Compute query sketch using stored quantization parameters
131    pub fn compute_query_sketch_with_params(
132        config: &BpsConfig,
133        rotated_query: &[f32],
134        qparams: &[BpsQParam],
135    ) -> Vec<u8> {
136        let mut rng = Xoshiro256PlusPlus::seed_from_u64(BPS_SEED);
137        let num_blocks = config.num_blocks as usize;
138        let block_size = config.block_size as usize;
139        let num_proj = config.num_projections as usize;
140
141        let mut sketch = Vec::with_capacity(num_blocks * num_proj);
142        let mut slot_idx = 0;
143
144        for block_idx in 0..num_blocks {
145            let block_start = block_idx * block_size;
146            let block_end = (block_start + block_size).min(rotated_query.len());
147
148            for _ in 0..num_proj {
149                // Generate same random projection (must match builder)
150                let proj: Vec<f32> = (0..block_size)
151                    .map(|_| rng.gen_range(-1.0f32..1.0))
152                    .collect();
153                let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
154
155                let mut dot = 0.0f32;
156                for (i, &v) in rotated_query[block_start..block_end].iter().enumerate() {
157                    if i < proj.len() {
158                        dot += v * (proj[i] / norm.max(1e-10));
159                    }
160                }
161
162                // Use stored qparams for correct quantization
163                if slot_idx < qparams.len() {
164                    let qp = &qparams[slot_idx];
165                    let quantized = ((dot - qp.min) * qp.inv_range).clamp(0.0, 255.0) as u8;
166                    sketch.push(quantized);
167                } else {
168                    // Fallback: symmetric quantization
169                    let quantized = ((dot + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
170                    sketch.push(quantized);
171                }
172                slot_idx += 1;
173            }
174        }
175
176        sketch
177    }
178
179    /// Legacy: Compute query sketch without stored params (for backwards compat)
180    ///
181    /// **DEPRECATED**: This function uses symmetric quantization which does NOT
182    /// match the asymmetric quantization used when building the index.  The
183    /// mismatch inflates L1 distances and degrades recall.
184    /// Use `compute_query_sketch_with_params()` instead, passing the `BpsQParam`s
185    /// stored during index build.
186    #[deprecated(
187        since = "0.5.0",
188        note = "use compute_query_sketch_with_params() — symmetric quantization mismatches index qparams"
189    )]
190    pub fn compute_query_sketch(config: &BpsConfig, rotated_query: &[f32]) -> Vec<u8> {
191        let mut rng = Xoshiro256PlusPlus::seed_from_u64(BPS_SEED);
192        let num_blocks = config.num_blocks as usize;
193        let block_size = config.block_size as usize;
194        let num_proj = config.num_projections as usize;
195
196        let mut sketch = Vec::with_capacity(num_blocks * num_proj);
197
198        for block_idx in 0..num_blocks {
199            let block_start = block_idx * block_size;
200            let block_end = (block_start + block_size).min(rotated_query.len());
201
202            for _ in 0..num_proj {
203                // Generate same random projection (must match builder)
204                let proj: Vec<f32> = (0..block_size)
205                    .map(|_| rng.gen_range(-1.0f32..1.0))
206                    .collect();
207                let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
208
209                let mut dot = 0.0f32;
210                for (i, &v) in rotated_query[block_start..block_end].iter().enumerate() {
211                    if i < proj.len() {
212                        dot += v * (proj[i] / norm.max(1e-10));
213                    }
214                }
215
216                // Symmetric quantization (less accurate without qparams)
217                let quantized = ((dot + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
218                sketch.push(quantized);
219            }
220        }
221
222        sketch
223    }
224}
225
226/// BPS scanner for streaming candidate generation
227pub struct BpsScanner<'a> {
228    bps_data: &'a [u8],
229    n_vec: usize,
230    num_blocks: usize,
231    num_proj: usize,
232}
233
234impl<'a> BpsScanner<'a> {
235    /// Create a new BPS scanner
236    pub fn new(bps_data: &'a [u8], n_vec: usize, num_blocks: usize, num_proj: usize) -> Self {
237        Self {
238            bps_data,
239            n_vec,
240            num_blocks,
241            num_proj,
242        }
243    }
244
245    /// Scan and compute L1 distances to query sketch
246    /// Returns distances for all vectors (lower = more similar)
247    ///
248    /// Uses SIMD-accelerated C++ kernels via FFI when available:
249    /// - AVX2: 32 vectors per cycle (~32x speedup)
250    /// - AVX512: 64 vectors per cycle (~64x speedup)
251    /// - NEON: 16 vectors per cycle (~16x speedup)
252    pub fn scan(&self, query_sketch: &[u8]) -> Vec<u16> {
253        let mut distances = vec![0u16; self.n_vec];
254        let n_slots = self.num_blocks * self.num_proj;
255
256        // Dispatch to C++ SIMD kernel (AVX2/AVX512/NEON) via FFI
257        BpsScanDispatcher::scan(
258            self.bps_data,
259            self.n_vec,
260            n_slots, // n_blocks for dispatcher = total slots
261            1,       // proj = 1 (legacy param)
262            query_sketch,
263            &mut distances,
264        );
265
266        distances
267    }
268
269    /// Fallback Rust implementation (kept for testing/verification)
270    /// Uses saturating_add to prevent overflow for safety.
271    #[allow(dead_code)]
272    fn scan_fallback(&self, query_sketch: &[u8], distances: &mut [u16]) {
273        let slots = self.num_blocks * self.num_proj;
274
275        for slot_idx in 0..slots {
276            let q = query_sketch[slot_idx] as i16;
277            let base = slot_idx * self.n_vec;
278
279            for vec_id in 0..self.n_vec {
280                let v = self.bps_data[base + vec_id] as i16;
281                let diff = (q - v).abs() as u16;
282                // Use saturating_add to prevent overflow (safety measure)
283                distances[vec_id] = distances[vec_id].saturating_add(diff);
284            }
285        }
286    }
287
288    /// Get top-k candidates by distance (lower is better)
289    pub fn top_k(&self, query_sketch: &[u8], k: usize) -> Vec<(u32, u16)> {
290        let distances = self.scan(query_sketch);
291
292        // Use partial selection for efficiency
293        let mut candidates: Vec<(u32, u16)> = distances
294            .into_iter()
295            .enumerate()
296            .map(|(i, d)| (i as u32, d))
297            .collect();
298
299        if candidates.len() <= k {
300            candidates.sort_by_key(|&(_, d)| d);
301            return candidates;
302        }
303
304        // Partial sort for top k
305        candidates.select_nth_unstable_by_key(k - 1, |&(_, d)| d);
306        candidates.truncate(k);
307        candidates.sort_by_key(|&(_, d)| d);
308
309        candidates
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_bps_build() {
319        let config = BpsConfig {
320            block_size: 16,
321            num_blocks: 4,
322            num_projections: 1,
323        };
324
325        let vectors: Vec<Vec<f32>> = (0..100)
326            .map(|i| (0..64).map(|j| (i * 64 + j) as f32 / 1000.0).collect())
327            .collect();
328
329        let builder = BpsBuilder::new(&config, &vectors);
330        let (bps_data, qparams) = builder.build();
331
332        // Should have num_blocks * num_proj * n_vec bytes
333        assert_eq!(bps_data.len(), 4 * 1 * 100);
334        // Should have qparams for each slot
335        assert_eq!(qparams.len(), 4 * 1);
336    }
337
338    #[test]
339    fn test_bps_scan() {
340        let config = BpsConfig {
341            block_size: 16,
342            num_blocks: 4,
343            num_projections: 1,
344        };
345
346        let vectors: Vec<Vec<f32>> = (0..100)
347            .map(|i| (0..64).map(|j| (i * 64 + j) as f32 / 1000.0).collect())
348            .collect();
349
350        let builder = BpsBuilder::new(&config, &vectors);
351        let (bps_data, _qparams) = builder.build();
352
353        let scanner = BpsScanner::new(&bps_data, 100, 4, 1);
354
355        // Query with first vector's sketch (should have distance 0 or close)
356        let query_sketch = vec![128u8; 4]; // Neutral sketch
357        let candidates = scanner.top_k(&query_sketch, 10);
358
359        assert_eq!(candidates.len(), 10);
360    }
361}