Skip to main content

sochdb_vector/segment/
rerank.rs

1//! Rerank builder and int8 quantization with outlier handling.
2//!
3//! Uses percentile-based symmetric quantization with separate outlier storage
4//! to preserve dot product accuracy.
5
6use crate::config::RerankConfig;
7use crate::dispatch::DotI8Dispatcher;
8use crate::types::*;
9use half::f16;
10
11/// Builder for rerank data (int8 embeddings + outliers)
12pub struct RerankBuilder<'a> {
13    config: &'a RerankConfig,
14    vectors: &'a [Vec<f32>],
15}
16
17impl<'a> RerankBuilder<'a> {
18    /// Create a new rerank builder
19    pub fn new(config: &'a RerankConfig, rotated_vectors: &'a [Vec<f32>]) -> Self {
20        Self {
21            config,
22            vectors: rotated_vectors,
23        }
24    }
25
26    /// Build int8 embeddings with per-vector scales
27    /// Returns (i8_data, scales)
28    pub fn build_i8(&self) -> (Vec<i8>, Vec<f32>) {
29        let n_vec = self.vectors.len();
30        if n_vec == 0 {
31            return (Vec::new(), Vec::new());
32        }
33
34        let dim = self.vectors[0].len();
35        let mut i8_data = Vec::with_capacity(n_vec * dim);
36        let mut scales = Vec::with_capacity(n_vec);
37
38        for vec in self.vectors {
39            // Find outlier indices (we'll zero them in i8)
40            let outlier_indices = self.find_outlier_indices(vec);
41
42            // Compute scale using percentile (excluding outliers)
43            let scale = self.compute_scale(vec, &outlier_indices);
44            scales.push(scale);
45
46            // Quantize
47            let inv_scale = if scale > 1e-10 { 1.0 / scale } else { 0.0 };
48            for (i, &v) in vec.iter().enumerate() {
49                if outlier_indices.contains(&(i as u16)) {
50                    // Zero out outlier positions (will be added back during rerank)
51                    i8_data.push(0);
52                } else {
53                    let quantized = (v * inv_scale * 127.0).clamp(-127.0, 127.0) as i8;
54                    i8_data.push(quantized);
55                }
56            }
57        }
58
59        (i8_data, scales)
60    }
61
62    /// Build outlier entries
63    pub fn build_outliers(&self) -> Vec<OutlierEntry> {
64        let n_vec = self.vectors.len();
65        let num_outliers = self.config.num_outliers as usize;
66        let mut outliers = Vec::with_capacity(n_vec * num_outliers);
67
68        for vec in self.vectors {
69            let outlier_entries = self.extract_outliers(vec);
70            for entry in outlier_entries {
71                outliers.push(entry);
72            }
73        }
74
75        outliers
76    }
77
78    /// Find indices of top-o outliers by absolute value
79    fn find_outlier_indices(&self, vec: &[f32]) -> Vec<DimIndex> {
80        let num_outliers = self.config.num_outliers as usize;
81        if num_outliers == 0 {
82            return Vec::new();
83        }
84
85        let mut indexed: Vec<(usize, f32)> =
86            vec.iter().enumerate().map(|(i, &v)| (i, v.abs())).collect();
87
88        if indexed.len() <= num_outliers {
89            return indexed.iter().map(|&(i, _)| i as DimIndex).collect();
90        }
91
92        indexed.select_nth_unstable_by(num_outliers - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
93
94        indexed
95            .iter()
96            .take(num_outliers)
97            .map(|&(i, _)| i as DimIndex)
98            .collect()
99    }
100
101    /// Compute scale using percentile-based approach
102    fn compute_scale(&self, vec: &[f32], outlier_indices: &[DimIndex]) -> f32 {
103        // Collect non-outlier absolute values
104        let mut values: Vec<f32> = vec
105            .iter()
106            .enumerate()
107            .filter(|&(i, _)| !outlier_indices.contains(&(i as DimIndex)))
108            .map(|(_, &v)| v.abs())
109            .collect();
110
111        if values.is_empty() {
112            return 1.0;
113        }
114
115        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
116
117        // Use percentile
118        let idx = ((values.len() as f32) * self.config.scale_percentile) as usize;
119        let idx = idx.min(values.len() - 1);
120
121        values[idx].max(1e-10)
122    }
123
124    /// Extract outliers with their values
125    fn extract_outliers(&self, vec: &[f32]) -> Vec<OutlierEntry> {
126        let num_outliers = self.config.num_outliers as usize;
127        let mut entries = Vec::with_capacity(num_outliers);
128
129        let mut indexed: Vec<(usize, f32)> = vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
130
131        // Sort by absolute value descending
132        indexed.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
133
134        for &(dim_id, value) in indexed.iter().take(num_outliers) {
135            entries.push(OutlierEntry::new(dim_id as DimIndex, f16::from_f32(value)));
136        }
137
138        // Pad with zeros if needed
139        while entries.len() < num_outliers {
140            entries.push(OutlierEntry::new(0, f16::from_f32(0.0)));
141        }
142
143        entries
144    }
145}
146
147/// Reranker for computing int8 dot products with outlier correction
148pub struct Reranker<'a> {
149    i8_data: &'a [i8],
150    scales: &'a [f32],
151    outliers: &'a [OutlierEntry],
152    dim: usize,
153    num_outliers: usize,
154}
155
156impl<'a> Reranker<'a> {
157    /// Create a new reranker
158    pub fn new(
159        i8_data: &'a [i8],
160        scales: &'a [f32],
161        outliers: &'a [OutlierEntry],
162        dim: usize,
163        num_outliers: usize,
164    ) -> Self {
165        Self {
166            i8_data,
167            scales,
168            outliers,
169            dim,
170            num_outliers,
171        }
172    }
173
174    /// Compute dot product score for a single candidate
175    ///
176    /// Uses SIMD-accelerated C++ kernels via FFI when available:
177    /// - AVX2: 32 int8 ops per cycle (8x speedup for dim=768)
178    /// - AVX512: 64 int8 ops per cycle (16x speedup)
179    /// - NEON: 16 int8 ops per cycle (4x speedup)
180    pub fn score(&self, vid: VectorId, query_i8: &[i8], query_scale: f32) -> f32 {
181        // Delegate to score_with_fp32 with None for outlier query values
182        // This maintains backward compatibility while the approximation is used
183        self.score_with_fp32(vid, query_i8, query_scale, None)
184    }
185
186    /// Compute dot product score with optional fp32 query for accurate outlier computation.
187    ///
188    /// When `query_fp32` is provided, outlier contributions use exact fp32 values
189    /// instead of reconstructing from quantized int8, reducing error from O(1/127)
190    /// to floating-point epsilon.
191    ///
192    /// # Arguments
193    /// * `vid` - Vector ID to score
194    /// * `query_i8` - Quantized query vector (for main dot product)
195    /// * `query_scale` - Query quantization scale
196    /// * `query_fp32` - Optional original fp32 query (for accurate outlier scoring)
197    pub fn score_with_fp32(
198        &self,
199        vid: VectorId,
200        query_i8: &[i8],
201        query_scale: f32,
202        query_fp32: Option<&[f32]>,
203    ) -> f32 {
204        let vid = vid as usize;
205        let offset = vid * self.dim;
206
207        if offset + self.dim > self.i8_data.len() {
208            return f32::NEG_INFINITY;
209        }
210
211        let vec_i8 = &self.i8_data[offset..offset + self.dim];
212        let vec_scale = self.scales[vid];
213
214        // SIMD-accelerated int8 dot product via C++ FFI
215        let dot_i8: i32 = DotI8Dispatcher::dot(&query_i8[..self.dim], vec_i8);
216
217        // Dequantize
218        let mut score = (dot_i8 as f32) * query_scale * vec_scale / (127.0 * 127.0);
219
220        // Add outlier contributions
221        if self.num_outliers > 0 {
222            let outlier_offset = vid * self.num_outliers;
223            if outlier_offset + self.num_outliers <= self.outliers.len() {
224                let vec_outliers =
225                    &self.outliers[outlier_offset..outlier_offset + self.num_outliers];
226
227                for outlier in vec_outliers {
228                    let dim_id = outlier.dim_id as usize;
229                    if dim_id < self.dim {
230                        let v_val = outlier.get_value().to_f32();
231
232                        // Use fp32 query if available (accurate), otherwise approximate from int8
233                        let q_val = if let Some(fp32) = query_fp32 {
234                            // Exact fp32 value - no quantization error
235                            fp32[dim_id]
236                        } else {
237                            // Approximate: reconstruct from int8 (introduces ~0.78% error per dim)
238                            (query_i8[dim_id] as f32) * query_scale / 127.0
239                        };
240
241                        score += q_val * v_val;
242                    }
243                }
244            }
245        }
246
247        score
248    }
249
250    /// Score multiple candidates in batch
251    pub fn score_batch(
252        &self,
253        candidates: &[VectorId],
254        query_i8: &[i8],
255        query_scale: f32,
256    ) -> Vec<ScoredCandidate> {
257        candidates
258            .iter()
259            .map(|&vid| ScoredCandidate {
260                id: vid,
261                score: self.score(vid, query_i8, query_scale),
262            })
263            .collect()
264    }
265
266    /// Score multiple candidates with fp32 query for accurate outlier computation
267    pub fn score_batch_with_fp32(
268        &self,
269        candidates: &[VectorId],
270        query_i8: &[i8],
271        query_scale: f32,
272        query_fp32: &[f32],
273    ) -> Vec<ScoredCandidate> {
274        candidates
275            .iter()
276            .map(|&vid| ScoredCandidate {
277                id: vid,
278                score: self.score_with_fp32(vid, query_i8, query_scale, Some(query_fp32)),
279            })
280            .collect()
281    }
282
283    /// Rerank and return top R candidates
284    pub fn rerank(
285        &self,
286        candidates: &[VectorId],
287        query_i8: &[i8],
288        query_scale: f32,
289        r: usize,
290    ) -> Vec<ScoredCandidate> {
291        let mut scored = self.score_batch(candidates, query_i8, query_scale);
292
293        if scored.len() <= r {
294            scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
295            return scored;
296        }
297
298        scored.select_nth_unstable_by(r - 1, |a, b| b.score.partial_cmp(&a.score).unwrap());
299        scored.truncate(r);
300        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
301
302        scored
303    }
304
305    /// Rerank with fp32 query for accurate outlier computation
306    pub fn rerank_with_fp32(
307        &self,
308        candidates: &[VectorId],
309        query_i8: &[i8],
310        query_scale: f32,
311        query_fp32: &[f32],
312        r: usize,
313    ) -> Vec<ScoredCandidate> {
314        let mut scored = self.score_batch_with_fp32(candidates, query_i8, query_scale, query_fp32);
315
316        if scored.len() <= r {
317            scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
318            return scored;
319        }
320
321        scored.select_nth_unstable_by(r - 1, |a, b| b.score.partial_cmp(&a.score).unwrap());
322        scored.truncate(r);
323        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
324
325        scored
326    }
327}
328
329/// Quantize a query vector for reranking
330pub fn quantize_query(query: &[f32], config: &RerankConfig) -> (Vec<i8>, f32) {
331    // Compute scale using percentile
332    let mut abs_values: Vec<f32> = query.iter().map(|&v| v.abs()).collect();
333    abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
334
335    let idx = ((abs_values.len() as f32) * config.scale_percentile) as usize;
336    let idx = idx.min(abs_values.len() - 1);
337    let scale = abs_values[idx].max(1e-10);
338
339    // Quantize
340    let inv_scale = 1.0 / scale;
341    let i8_data: Vec<i8> = query
342        .iter()
343        .map(|&v| (v * inv_scale * 127.0).clamp(-127.0, 127.0) as i8)
344        .collect();
345
346    (i8_data, scale)
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_rerank_build() {
355        let config = RerankConfig {
356            num_outliers: 4,
357            percentile_quantization: true,
358            scale_percentile: 0.99,
359        };
360
361        let vectors: Vec<Vec<f32>> = (0..100)
362            .map(|i| {
363                (0..64)
364                    .map(|j| {
365                        if j < 4 {
366                            (i as f32 + j as f32) * 0.1
367                        } else {
368                            (j as f32 - 32.0) * 0.01
369                        }
370                    })
371                    .collect()
372            })
373            .collect();
374
375        let builder = RerankBuilder::new(&config, &vectors);
376        let (i8_data, scales) = builder.build_i8();
377        let outliers = builder.build_outliers();
378
379        assert_eq!(i8_data.len(), 100 * 64);
380        assert_eq!(scales.len(), 100);
381        assert_eq!(outliers.len(), 100 * 4);
382    }
383
384    #[test]
385    fn test_dot_product() {
386        let config = RerankConfig {
387            num_outliers: 2,
388            percentile_quantization: true,
389            scale_percentile: 0.99,
390        };
391
392        // Create orthogonal-ish vectors
393        let vectors: Vec<Vec<f32>> = vec![
394            vec![1.0, 0.0, 0.0, 0.0],
395            vec![0.0, 1.0, 0.0, 0.0],
396            vec![0.5, 0.5, 0.0, 0.0],
397        ];
398
399        let builder = RerankBuilder::new(&config, &vectors);
400        let (i8_data, scales) = builder.build_i8();
401        let outliers = builder.build_outliers();
402
403        let reranker = Reranker::new(&i8_data, &scales, &outliers, 4, 2);
404
405        // Query similar to first vector
406        let query = vec![1.0f32, 0.0, 0.0, 0.0];
407        let (q_i8, q_scale) = quantize_query(&query, &config);
408
409        let score0 = reranker.score(0, &q_i8, q_scale);
410        let score1 = reranker.score(1, &q_i8, q_scale);
411        let score2 = reranker.score(2, &q_i8, q_scale);
412
413        // Vector 0 should have highest score (most similar to query)
414        assert!(score0 > score1);
415        assert!(score0 > score2);
416    }
417}