sochdb_vector/segment/
bps.rs1use crate::config::BpsConfig;
7use crate::dispatch::BpsScanDispatcher;
8use bytemuck::{Pod, Zeroable};
9use rand::Rng;
10use rand::SeedableRng;
11use rand_xoshiro::Xoshiro256PlusPlus;
12
13#[repr(C)]
15#[derive(Debug, Clone, Copy, Pod, Zeroable)]
16pub struct BpsQParam {
17 pub min: f32,
18 pub inv_range: f32,
19}
20
21pub struct BpsBuilder<'a> {
23 config: &'a BpsConfig,
24 vectors: &'a [Vec<f32>],
25 projection_vectors: Vec<Vec<f32>>,
26}
27
28const BPS_SEED: u64 = 0xBEEF_CAFE_1234_5678;
30
31impl<'a> BpsBuilder<'a> {
32 pub fn new(config: &'a BpsConfig, rotated_vectors: &'a [Vec<f32>]) -> Self {
34 let mut rng = Xoshiro256PlusPlus::seed_from_u64(BPS_SEED);
36 let num_blocks = config.num_blocks as usize;
37 let block_size = config.block_size as usize;
38 let num_proj = config.num_projections as usize;
39
40 let mut projection_vectors = Vec::with_capacity(num_blocks * num_proj);
41 for _ in 0..(num_blocks * num_proj) {
42 let proj: Vec<f32> = (0..block_size)
43 .map(|_| {
44 let v: f32 = rng.gen_range(-1.0..1.0);
46 v
47 })
48 .collect();
49 let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
51 let normalized: Vec<f32> = proj.iter().map(|x| x / norm.max(1e-10)).collect();
52 projection_vectors.push(normalized);
53 }
54
55 Self {
56 config,
57 vectors: rotated_vectors,
58 projection_vectors,
59 }
60 }
61
62 pub fn build(&self) -> (Vec<u8>, Vec<BpsQParam>) {
66 let n_vec = self.vectors.len();
67 let num_blocks = self.config.num_blocks as usize;
68 let num_proj = self.config.num_projections as usize;
69 let block_size = self.config.block_size as usize;
70 let num_slots = num_blocks * num_proj;
71
72 let mut projections: Vec<Vec<f32>> = Vec::with_capacity(n_vec);
74 for vec in self.vectors {
75 let mut vec_proj = Vec::with_capacity(num_slots);
76 for block_idx in 0..num_blocks {
77 let block_start = block_idx * block_size;
78 let block_end = (block_start + block_size).min(vec.len());
79
80 for proj_idx in 0..num_proj {
81 let proj_vec = &self.projection_vectors[block_idx * num_proj + proj_idx];
82 let mut dot = 0.0f32;
83 for (i, &v) in vec[block_start..block_end].iter().enumerate() {
84 if i < proj_vec.len() {
85 dot += v * proj_vec[i];
86 }
87 }
88 vec_proj.push(dot);
89 }
90 }
91 projections.push(vec_proj);
92 }
93
94 let mut mins = vec![f32::MAX; num_slots];
96 let mut maxs = vec![f32::MIN; num_slots];
97 for proj in &projections {
98 for (i, &v) in proj.iter().enumerate() {
99 mins[i] = mins[i].min(v);
100 maxs[i] = maxs[i].max(v);
101 }
102 }
103
104 let mut qparams = Vec::with_capacity(num_slots);
106 for slot in 0..num_slots {
107 let range = (maxs[slot] - mins[slot]).max(1e-10);
108 qparams.push(BpsQParam {
109 min: mins[slot],
110 inv_range: 255.0 / range,
111 });
112 }
113
114 let mut bps_data = vec![0u8; num_slots * n_vec];
116 for (vec_id, proj) in projections.iter().enumerate() {
117 for (slot_idx, &value) in proj.iter().enumerate() {
118 let normalized = ((value - qparams[slot_idx].min) * qparams[slot_idx].inv_range)
119 .clamp(0.0, 255.0);
120
121 let idx = slot_idx * n_vec + vec_id;
123 bps_data[idx] = normalized as u8;
124 }
125 }
126
127 (bps_data, qparams)
128 }
129
130 pub fn compute_query_sketch_with_params(
132 config: &BpsConfig,
133 rotated_query: &[f32],
134 qparams: &[BpsQParam],
135 ) -> Vec<u8> {
136 let mut rng = Xoshiro256PlusPlus::seed_from_u64(BPS_SEED);
137 let num_blocks = config.num_blocks as usize;
138 let block_size = config.block_size as usize;
139 let num_proj = config.num_projections as usize;
140
141 let mut sketch = Vec::with_capacity(num_blocks * num_proj);
142 let mut slot_idx = 0;
143
144 for block_idx in 0..num_blocks {
145 let block_start = block_idx * block_size;
146 let block_end = (block_start + block_size).min(rotated_query.len());
147
148 for _ in 0..num_proj {
149 let proj: Vec<f32> = (0..block_size)
151 .map(|_| rng.gen_range(-1.0f32..1.0))
152 .collect();
153 let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
154
155 let mut dot = 0.0f32;
156 for (i, &v) in rotated_query[block_start..block_end].iter().enumerate() {
157 if i < proj.len() {
158 dot += v * (proj[i] / norm.max(1e-10));
159 }
160 }
161
162 if slot_idx < qparams.len() {
164 let qp = &qparams[slot_idx];
165 let quantized = ((dot - qp.min) * qp.inv_range).clamp(0.0, 255.0) as u8;
166 sketch.push(quantized);
167 } else {
168 let quantized = ((dot + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
170 sketch.push(quantized);
171 }
172 slot_idx += 1;
173 }
174 }
175
176 sketch
177 }
178
179 #[deprecated(
187 since = "0.5.0",
188 note = "use compute_query_sketch_with_params() — symmetric quantization mismatches index qparams"
189 )]
190 pub fn compute_query_sketch(config: &BpsConfig, rotated_query: &[f32]) -> Vec<u8> {
191 let mut rng = Xoshiro256PlusPlus::seed_from_u64(BPS_SEED);
192 let num_blocks = config.num_blocks as usize;
193 let block_size = config.block_size as usize;
194 let num_proj = config.num_projections as usize;
195
196 let mut sketch = Vec::with_capacity(num_blocks * num_proj);
197
198 for block_idx in 0..num_blocks {
199 let block_start = block_idx * block_size;
200 let block_end = (block_start + block_size).min(rotated_query.len());
201
202 for _ in 0..num_proj {
203 let proj: Vec<f32> = (0..block_size)
205 .map(|_| rng.gen_range(-1.0f32..1.0))
206 .collect();
207 let norm: f32 = proj.iter().map(|x| x * x).sum::<f32>().sqrt();
208
209 let mut dot = 0.0f32;
210 for (i, &v) in rotated_query[block_start..block_end].iter().enumerate() {
211 if i < proj.len() {
212 dot += v * (proj[i] / norm.max(1e-10));
213 }
214 }
215
216 let quantized = ((dot + 1.0) * 127.5).clamp(0.0, 255.0) as u8;
218 sketch.push(quantized);
219 }
220 }
221
222 sketch
223 }
224}
225
226pub struct BpsScanner<'a> {
228 bps_data: &'a [u8],
229 n_vec: usize,
230 num_blocks: usize,
231 num_proj: usize,
232}
233
234impl<'a> BpsScanner<'a> {
235 pub fn new(bps_data: &'a [u8], n_vec: usize, num_blocks: usize, num_proj: usize) -> Self {
237 Self {
238 bps_data,
239 n_vec,
240 num_blocks,
241 num_proj,
242 }
243 }
244
245 pub fn scan(&self, query_sketch: &[u8]) -> Vec<u16> {
253 let mut distances = vec![0u16; self.n_vec];
254 let n_slots = self.num_blocks * self.num_proj;
255
256 BpsScanDispatcher::scan(
258 self.bps_data,
259 self.n_vec,
260 n_slots, 1, query_sketch,
263 &mut distances,
264 );
265
266 distances
267 }
268
269 #[allow(dead_code)]
272 fn scan_fallback(&self, query_sketch: &[u8], distances: &mut [u16]) {
273 let slots = self.num_blocks * self.num_proj;
274
275 for slot_idx in 0..slots {
276 let q = query_sketch[slot_idx] as i16;
277 let base = slot_idx * self.n_vec;
278
279 for vec_id in 0..self.n_vec {
280 let v = self.bps_data[base + vec_id] as i16;
281 let diff = (q - v).abs() as u16;
282 distances[vec_id] = distances[vec_id].saturating_add(diff);
284 }
285 }
286 }
287
288 pub fn top_k(&self, query_sketch: &[u8], k: usize) -> Vec<(u32, u16)> {
290 let distances = self.scan(query_sketch);
291
292 let mut candidates: Vec<(u32, u16)> = distances
294 .into_iter()
295 .enumerate()
296 .map(|(i, d)| (i as u32, d))
297 .collect();
298
299 if candidates.len() <= k {
300 candidates.sort_by_key(|&(_, d)| d);
301 return candidates;
302 }
303
304 candidates.select_nth_unstable_by_key(k - 1, |&(_, d)| d);
306 candidates.truncate(k);
307 candidates.sort_by_key(|&(_, d)| d);
308
309 candidates
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_bps_build() {
319 let config = BpsConfig {
320 block_size: 16,
321 num_blocks: 4,
322 num_projections: 1,
323 };
324
325 let vectors: Vec<Vec<f32>> = (0..100)
326 .map(|i| (0..64).map(|j| (i * 64 + j) as f32 / 1000.0).collect())
327 .collect();
328
329 let builder = BpsBuilder::new(&config, &vectors);
330 let (bps_data, qparams) = builder.build();
331
332 assert_eq!(bps_data.len(), 4 * 1 * 100);
334 assert_eq!(qparams.len(), 4 * 1);
336 }
337
338 #[test]
339 fn test_bps_scan() {
340 let config = BpsConfig {
341 block_size: 16,
342 num_blocks: 4,
343 num_projections: 1,
344 };
345
346 let vectors: Vec<Vec<f32>> = (0..100)
347 .map(|i| (0..64).map(|j| (i * 64 + j) as f32 / 1000.0).collect())
348 .collect();
349
350 let builder = BpsBuilder::new(&config, &vectors);
351 let (bps_data, _qparams) = builder.build();
352
353 let scanner = BpsScanner::new(&bps_data, 100, 4, 1);
354
355 let query_sketch = vec![128u8; 4]; let candidates = scanner.top_k(&query_sketch, 10);
358
359 assert_eq!(candidates.len(), 10);
360 }
361}