1use crate::config::RerankConfig;
7use crate::dispatch::DotI8Dispatcher;
8use crate::types::*;
9use half::f16;
10
11pub struct RerankBuilder<'a> {
13 config: &'a RerankConfig,
14 vectors: &'a [Vec<f32>],
15}
16
17impl<'a> RerankBuilder<'a> {
18 pub fn new(config: &'a RerankConfig, rotated_vectors: &'a [Vec<f32>]) -> Self {
20 Self {
21 config,
22 vectors: rotated_vectors,
23 }
24 }
25
26 pub fn build_i8(&self) -> (Vec<i8>, Vec<f32>) {
29 let n_vec = self.vectors.len();
30 if n_vec == 0 {
31 return (Vec::new(), Vec::new());
32 }
33
34 let dim = self.vectors[0].len();
35 let mut i8_data = Vec::with_capacity(n_vec * dim);
36 let mut scales = Vec::with_capacity(n_vec);
37
38 for vec in self.vectors {
39 let outlier_indices = self.find_outlier_indices(vec);
41
42 let scale = self.compute_scale(vec, &outlier_indices);
44 scales.push(scale);
45
46 let inv_scale = if scale > 1e-10 { 1.0 / scale } else { 0.0 };
48 for (i, &v) in vec.iter().enumerate() {
49 if outlier_indices.contains(&(i as u16)) {
50 i8_data.push(0);
52 } else {
53 let quantized = (v * inv_scale * 127.0).clamp(-127.0, 127.0) as i8;
54 i8_data.push(quantized);
55 }
56 }
57 }
58
59 (i8_data, scales)
60 }
61
62 pub fn build_outliers(&self) -> Vec<OutlierEntry> {
64 let n_vec = self.vectors.len();
65 let num_outliers = self.config.num_outliers as usize;
66 let mut outliers = Vec::with_capacity(n_vec * num_outliers);
67
68 for vec in self.vectors {
69 let outlier_entries = self.extract_outliers(vec);
70 for entry in outlier_entries {
71 outliers.push(entry);
72 }
73 }
74
75 outliers
76 }
77
78 fn find_outlier_indices(&self, vec: &[f32]) -> Vec<DimIndex> {
80 let num_outliers = self.config.num_outliers as usize;
81 if num_outliers == 0 {
82 return Vec::new();
83 }
84
85 let mut indexed: Vec<(usize, f32)> =
86 vec.iter().enumerate().map(|(i, &v)| (i, v.abs())).collect();
87
88 if indexed.len() <= num_outliers {
89 return indexed.iter().map(|&(i, _)| i as DimIndex).collect();
90 }
91
92 indexed.select_nth_unstable_by(num_outliers - 1, |a, b| b.1.partial_cmp(&a.1).unwrap());
93
94 indexed
95 .iter()
96 .take(num_outliers)
97 .map(|&(i, _)| i as DimIndex)
98 .collect()
99 }
100
101 fn compute_scale(&self, vec: &[f32], outlier_indices: &[DimIndex]) -> f32 {
103 let mut values: Vec<f32> = vec
105 .iter()
106 .enumerate()
107 .filter(|&(i, _)| !outlier_indices.contains(&(i as DimIndex)))
108 .map(|(_, &v)| v.abs())
109 .collect();
110
111 if values.is_empty() {
112 return 1.0;
113 }
114
115 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
116
117 let idx = ((values.len() as f32) * self.config.scale_percentile) as usize;
119 let idx = idx.min(values.len() - 1);
120
121 values[idx].max(1e-10)
122 }
123
124 fn extract_outliers(&self, vec: &[f32]) -> Vec<OutlierEntry> {
126 let num_outliers = self.config.num_outliers as usize;
127 let mut entries = Vec::with_capacity(num_outliers);
128
129 let mut indexed: Vec<(usize, f32)> = vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
130
131 indexed.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
133
134 for &(dim_id, value) in indexed.iter().take(num_outliers) {
135 entries.push(OutlierEntry::new(dim_id as DimIndex, f16::from_f32(value)));
136 }
137
138 while entries.len() < num_outliers {
140 entries.push(OutlierEntry::new(0, f16::from_f32(0.0)));
141 }
142
143 entries
144 }
145}
146
147pub struct Reranker<'a> {
149 i8_data: &'a [i8],
150 scales: &'a [f32],
151 outliers: &'a [OutlierEntry],
152 dim: usize,
153 num_outliers: usize,
154}
155
156impl<'a> Reranker<'a> {
157 pub fn new(
159 i8_data: &'a [i8],
160 scales: &'a [f32],
161 outliers: &'a [OutlierEntry],
162 dim: usize,
163 num_outliers: usize,
164 ) -> Self {
165 Self {
166 i8_data,
167 scales,
168 outliers,
169 dim,
170 num_outliers,
171 }
172 }
173
174 pub fn score(&self, vid: VectorId, query_i8: &[i8], query_scale: f32) -> f32 {
181 self.score_with_fp32(vid, query_i8, query_scale, None)
184 }
185
186 pub fn score_with_fp32(
198 &self,
199 vid: VectorId,
200 query_i8: &[i8],
201 query_scale: f32,
202 query_fp32: Option<&[f32]>,
203 ) -> f32 {
204 let vid = vid as usize;
205 let offset = vid * self.dim;
206
207 if offset + self.dim > self.i8_data.len() {
208 return f32::NEG_INFINITY;
209 }
210
211 let vec_i8 = &self.i8_data[offset..offset + self.dim];
212 let vec_scale = self.scales[vid];
213
214 let dot_i8: i32 = DotI8Dispatcher::dot(&query_i8[..self.dim], vec_i8);
216
217 let mut score = (dot_i8 as f32) * query_scale * vec_scale / (127.0 * 127.0);
219
220 if self.num_outliers > 0 {
222 let outlier_offset = vid * self.num_outliers;
223 if outlier_offset + self.num_outliers <= self.outliers.len() {
224 let vec_outliers =
225 &self.outliers[outlier_offset..outlier_offset + self.num_outliers];
226
227 for outlier in vec_outliers {
228 let dim_id = outlier.dim_id as usize;
229 if dim_id < self.dim {
230 let v_val = outlier.get_value().to_f32();
231
232 let q_val = if let Some(fp32) = query_fp32 {
234 fp32[dim_id]
236 } else {
237 (query_i8[dim_id] as f32) * query_scale / 127.0
239 };
240
241 score += q_val * v_val;
242 }
243 }
244 }
245 }
246
247 score
248 }
249
250 pub fn score_batch(
252 &self,
253 candidates: &[VectorId],
254 query_i8: &[i8],
255 query_scale: f32,
256 ) -> Vec<ScoredCandidate> {
257 candidates
258 .iter()
259 .map(|&vid| ScoredCandidate {
260 id: vid,
261 score: self.score(vid, query_i8, query_scale),
262 })
263 .collect()
264 }
265
266 pub fn score_batch_with_fp32(
268 &self,
269 candidates: &[VectorId],
270 query_i8: &[i8],
271 query_scale: f32,
272 query_fp32: &[f32],
273 ) -> Vec<ScoredCandidate> {
274 candidates
275 .iter()
276 .map(|&vid| ScoredCandidate {
277 id: vid,
278 score: self.score_with_fp32(vid, query_i8, query_scale, Some(query_fp32)),
279 })
280 .collect()
281 }
282
283 pub fn rerank(
285 &self,
286 candidates: &[VectorId],
287 query_i8: &[i8],
288 query_scale: f32,
289 r: usize,
290 ) -> Vec<ScoredCandidate> {
291 let mut scored = self.score_batch(candidates, query_i8, query_scale);
292
293 if scored.len() <= r {
294 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
295 return scored;
296 }
297
298 scored.select_nth_unstable_by(r - 1, |a, b| b.score.partial_cmp(&a.score).unwrap());
299 scored.truncate(r);
300 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
301
302 scored
303 }
304
305 pub fn rerank_with_fp32(
307 &self,
308 candidates: &[VectorId],
309 query_i8: &[i8],
310 query_scale: f32,
311 query_fp32: &[f32],
312 r: usize,
313 ) -> Vec<ScoredCandidate> {
314 let mut scored = self.score_batch_with_fp32(candidates, query_i8, query_scale, query_fp32);
315
316 if scored.len() <= r {
317 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
318 return scored;
319 }
320
321 scored.select_nth_unstable_by(r - 1, |a, b| b.score.partial_cmp(&a.score).unwrap());
322 scored.truncate(r);
323 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
324
325 scored
326 }
327}
328
329pub fn quantize_query(query: &[f32], config: &RerankConfig) -> (Vec<i8>, f32) {
331 let mut abs_values: Vec<f32> = query.iter().map(|&v| v.abs()).collect();
333 abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
334
335 let idx = ((abs_values.len() as f32) * config.scale_percentile) as usize;
336 let idx = idx.min(abs_values.len() - 1);
337 let scale = abs_values[idx].max(1e-10);
338
339 let inv_scale = 1.0 / scale;
341 let i8_data: Vec<i8> = query
342 .iter()
343 .map(|&v| (v * inv_scale * 127.0).clamp(-127.0, 127.0) as i8)
344 .collect();
345
346 (i8_data, scale)
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_rerank_build() {
355 let config = RerankConfig {
356 num_outliers: 4,
357 percentile_quantization: true,
358 scale_percentile: 0.99,
359 };
360
361 let vectors: Vec<Vec<f32>> = (0..100)
362 .map(|i| {
363 (0..64)
364 .map(|j| {
365 if j < 4 {
366 (i as f32 + j as f32) * 0.1
367 } else {
368 (j as f32 - 32.0) * 0.01
369 }
370 })
371 .collect()
372 })
373 .collect();
374
375 let builder = RerankBuilder::new(&config, &vectors);
376 let (i8_data, scales) = builder.build_i8();
377 let outliers = builder.build_outliers();
378
379 assert_eq!(i8_data.len(), 100 * 64);
380 assert_eq!(scales.len(), 100);
381 assert_eq!(outliers.len(), 100 * 4);
382 }
383
384 #[test]
385 fn test_dot_product() {
386 let config = RerankConfig {
387 num_outliers: 2,
388 percentile_quantization: true,
389 scale_percentile: 0.99,
390 };
391
392 let vectors: Vec<Vec<f32>> = vec![
394 vec![1.0, 0.0, 0.0, 0.0],
395 vec![0.0, 1.0, 0.0, 0.0],
396 vec![0.5, 0.5, 0.0, 0.0],
397 ];
398
399 let builder = RerankBuilder::new(&config, &vectors);
400 let (i8_data, scales) = builder.build_i8();
401 let outliers = builder.build_outliers();
402
403 let reranker = Reranker::new(&i8_data, &scales, &outliers, 4, 2);
404
405 let query = vec![1.0f32, 0.0, 0.0, 0.0];
407 let (q_i8, q_scale) = quantize_query(&query, &config);
408
409 let score0 = reranker.score(0, &q_i8, q_scale);
410 let score1 = reranker.score(1, &q_i8, q_scale);
411 let score2 = reranker.score(2, &q_i8, q_scale);
412
413 assert!(score0 > score1);
415 assert!(score0 > score2);
416 }
417}