1use crate::list_bounds::{DistanceMetric, SphericalCapMetadata};
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum CentroidCompression {
62 Fp32,
64 Fp16,
66 Int8,
68 PQ { n_subquantizers: usize, n_bits: u8 },
70 OPQ { n_subquantizers: usize, n_bits: u8 },
72}
73
74impl CentroidCompression {
75 pub fn bytes_per_centroid(&self, dim: usize) -> usize {
77 match self {
78 Self::Fp32 => dim * 4,
79 Self::Fp16 => dim * 2,
80 Self::Int8 => dim,
81 Self::PQ {
82 n_subquantizers,
83 n_bits,
84 } => {
85 (*n_subquantizers * *n_bits as usize + 7) / 8
87 }
88 Self::OPQ {
89 n_subquantizers,
90 n_bits,
91 } => (*n_subquantizers * *n_bits as usize + 7) / 8,
92 }
93 }
94
95 pub fn fits_in_cache(&self, n_centroids: usize, dim: usize, cache_bytes: usize) -> bool {
97 self.bytes_per_centroid(dim) * n_centroids <= cache_bytes
98 }
99
100 pub fn recommend(n_centroids: usize, dim: usize, cache_bytes: usize) -> Self {
102 for compression in [
104 Self::Fp32,
105 Self::Fp16,
106 Self::Int8,
107 Self::PQ {
108 n_subquantizers: dim / 4,
109 n_bits: 8,
110 },
111 ] {
112 if compression.fits_in_cache(n_centroids, dim, cache_bytes) {
113 return compression;
114 }
115 }
116 Self::PQ {
118 n_subquantizers: 16,
119 n_bits: 4,
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
130pub struct RoutingConfig {
131 pub compression: CentroidCompression,
133
134 pub refine_top_k: usize,
136
137 pub full_precision_refine: bool,
139
140 pub target_llc_bytes: usize,
142
143 pub metric: DistanceMetric,
145
146 pub prefetch_depth: usize,
148}
149
150impl Default for RoutingConfig {
151 fn default() -> Self {
152 Self {
153 compression: CentroidCompression::Fp16,
154 refine_top_k: 64,
155 full_precision_refine: true,
156 target_llc_bytes: 32 * 1024 * 1024, metric: DistanceMetric::Cosine,
158 prefetch_depth: 4,
159 }
160 }
161}
162
163impl RoutingConfig {
164 pub fn compression(mut self, compression: CentroidCompression) -> Self {
166 self.compression = compression;
167 self
168 }
169
170 pub fn refine_top_k(mut self, k: usize) -> Self {
172 self.refine_top_k = k;
173 self
174 }
175
176 pub fn target_llc(mut self, bytes: usize) -> Self {
178 self.target_llc_bytes = bytes;
179 self
180 }
181
182 pub fn metric(mut self, metric: DistanceMetric) -> Self {
184 self.metric = metric;
185 self
186 }
187}
188
189#[derive(Debug, Clone)]
195pub struct Fp16Centroids {
196 data: Vec<u16>,
198 n_centroids: usize,
200 dim: usize,
202}
203
204impl Fp16Centroids {
205 pub fn from_fp32(centroids: &[f32], dim: usize) -> Self {
207 let n_centroids = centroids.len() / dim;
208 let data: Vec<u16> = centroids.iter().map(|&x| f32_to_f16(x)).collect();
209
210 Self {
211 data,
212 n_centroids,
213 dim,
214 }
215 }
216
217 pub fn get_fp32(&self, idx: usize) -> Vec<f32> {
219 let start = idx * self.dim;
220 self.data[start..start + self.dim]
221 .iter()
222 .map(|&x| f16_to_f32(x))
223 .collect()
224 }
225
226 pub fn dot_products(&self, query: &[f32]) -> Vec<f32> {
228 let query_f16: Vec<u16> = query.iter().map(|&x| f32_to_f16(x)).collect();
229
230 (0..self.n_centroids)
231 .map(|i| {
232 let start = i * self.dim;
233 let centroid = &self.data[start..start + self.dim];
234 dot_f16(centroid, &query_f16)
235 })
236 .collect()
237 }
238
239 pub fn memory_bytes(&self) -> usize {
241 self.data.len() * 2
242 }
243}
244
245#[derive(Debug, Clone)]
247pub struct Int8Centroids {
248 data: Vec<i8>,
250 scales: Vec<f32>,
252 zero_points: Vec<f32>,
254 n_centroids: usize,
256 dim: usize,
258}
259
260impl Int8Centroids {
261 pub fn from_fp32(centroids: &[f32], dim: usize) -> Self {
263 let n_centroids = centroids.len() / dim;
264
265 let mut mins = vec![f32::MAX; dim];
267 let mut maxs = vec![f32::MIN; dim];
268
269 for i in 0..n_centroids {
270 for j in 0..dim {
271 let val = centroids[i * dim + j];
272 mins[j] = mins[j].min(val);
273 maxs[j] = maxs[j].max(val);
274 }
275 }
276
277 let mut scales = Vec::with_capacity(dim);
279 let mut zero_points = Vec::with_capacity(dim);
280
281 for j in 0..dim {
282 let range = maxs[j] - mins[j];
283 let scale = if range > 1e-10 { range / 255.0 } else { 1.0 };
284 scales.push(scale);
285 zero_points.push(mins[j]);
286 }
287
288 let data: Vec<i8> = centroids
290 .iter()
291 .enumerate()
292 .map(|(idx, &val)| {
293 let j = idx % dim;
294 let q = ((val - zero_points[j]) / scales[j]).round() as i32;
295 q.clamp(-128, 127) as i8
296 })
297 .collect();
298
299 Self {
300 data,
301 scales,
302 zero_points,
303 n_centroids,
304 dim,
305 }
306 }
307
308 pub fn get_fp32(&self, idx: usize) -> Vec<f32> {
310 let start = idx * self.dim;
311 (0..self.dim)
312 .map(|j| self.data[start + j] as f32 * self.scales[j] + self.zero_points[j])
313 .collect()
314 }
315
316 pub fn dot_products(&self, query: &[f32]) -> Vec<f32> {
318 let query_i8: Vec<i8> = query
320 .iter()
321 .enumerate()
322 .map(|(j, &val)| {
323 let q = ((val - self.zero_points[j]) / self.scales[j]).round() as i32;
324 q.clamp(-128, 127) as i8
325 })
326 .collect();
327
328 (0..self.n_centroids)
329 .map(|i| {
330 let start = i * self.dim;
331 let centroid = &self.data[start..start + self.dim];
332
333 let dot_i32: i32 = centroid
335 .iter()
336 .zip(query_i8.iter())
337 .map(|(&a, &b)| a as i32 * b as i32)
338 .sum();
339
340 dot_i32 as f32 * self.scales[0] * self.scales[0]
342 })
343 .collect()
344 }
345
346 pub fn memory_bytes(&self) -> usize {
348 self.data.len() + self.scales.len() * 4 + self.zero_points.len() * 4
349 }
350}
351
352pub struct RoutingLayer {
358 compressed: CompressedCentroids,
360
361 full_precision: Option<Vec<f32>>,
363
364 caps: Vec<SphericalCapMetadata>,
366
367 config: RoutingConfig,
369
370 dim: usize,
372
373 n_lists: usize,
375}
376
377enum CompressedCentroids {
379 Fp32(Vec<f32>),
380 Fp16(Fp16Centroids),
381 Int8(Int8Centroids),
382}
383
384impl RoutingLayer {
385 pub fn build(centroids: &[f32], dim: usize, config: RoutingConfig) -> Self {
387 let n_lists = centroids.len() / dim;
388
389 let compressed = match config.compression {
391 CentroidCompression::Fp32 => CompressedCentroids::Fp32(centroids.to_vec()),
392 CentroidCompression::Fp16 => {
393 CompressedCentroids::Fp16(Fp16Centroids::from_fp32(centroids, dim))
394 }
395 CentroidCompression::Int8 => {
396 CompressedCentroids::Int8(Int8Centroids::from_fp32(centroids, dim))
397 }
398 _ => {
399 CompressedCentroids::Fp16(Fp16Centroids::from_fp32(centroids, dim))
401 }
402 };
403
404 let full_precision = if config.full_precision_refine {
406 Some(centroids.to_vec())
407 } else {
408 None
409 };
410
411 let caps: Vec<SphericalCapMetadata> = (0..n_lists)
413 .map(|i| {
414 let centroid = ¢roids[i * dim..(i + 1) * dim];
415 SphericalCapMetadata {
416 centroid: centroid.to_vec(),
417 theta_max: 0.0, min_dot_to_centroid: 1.0,
419 max_dot_to_centroid: 1.0,
420 vector_count: 0,
421 mean_dot_to_centroid: 1.0,
422 }
423 })
424 .collect();
425
426 Self {
427 compressed,
428 full_precision,
429 caps,
430 config,
431 dim,
432 n_lists,
433 }
434 }
435
436 pub fn route(&self, query: &[f32], n_probes: usize) -> Vec<ListCandidate> {
442 let n_probes = n_probes.min(self.n_lists);
443
444 let coarse_scores = self.coarse_scores(query);
446
447 let refine_k = self.config.refine_top_k.min(self.n_lists);
449 let mut indices: Vec<usize> = (0..self.n_lists).collect();
450
451 if self.config.metric.higher_is_better() {
453 indices.select_nth_unstable_by(refine_k - 1, |&a, &b| {
454 coarse_scores[b].partial_cmp(&coarse_scores[a]).unwrap()
455 });
456 } else {
457 indices.select_nth_unstable_by(refine_k - 1, |&a, &b| {
458 coarse_scores[a].partial_cmp(&coarse_scores[b]).unwrap()
459 });
460 }
461
462 let top_indices = &indices[..refine_k];
463
464 let refined_scores = if let Some(ref full) = self.full_precision {
466 self.refine_scores(query, top_indices, full)
467 } else {
468 top_indices.iter().map(|&i| coarse_scores[i]).collect()
469 };
470
471 let mut candidates: Vec<ListCandidate> = top_indices
473 .iter()
474 .zip(refined_scores.iter())
475 .map(|(&idx, &score)| ListCandidate {
476 list_idx: idx as u32,
477 score,
478 bound: self.compute_bound(idx, query),
479 vector_count: self.caps[idx].vector_count,
480 })
481 .collect();
482
483 if self.config.metric.higher_is_better() {
485 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
486 } else {
487 candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
488 }
489
490 candidates.truncate(n_probes);
491 candidates
492 }
493
494 fn coarse_scores(&self, query: &[f32]) -> Vec<f32> {
496 match &self.compressed {
497 CompressedCentroids::Fp32(data) => self.dot_products_fp32(query, data),
498 CompressedCentroids::Fp16(fp16) => fp16.dot_products(query),
499 CompressedCentroids::Int8(int8) => int8.dot_products(query),
500 }
501 }
502
503 fn dot_products_fp32(&self, query: &[f32], centroids: &[f32]) -> Vec<f32> {
505 (0..self.n_lists)
506 .map(|i| {
507 let centroid = ¢roids[i * self.dim..(i + 1) * self.dim];
508 dot_product_f32(query, centroid)
509 })
510 .collect()
511 }
512
513 fn refine_scores(&self, query: &[f32], indices: &[usize], centroids: &[f32]) -> Vec<f32> {
515 indices
516 .iter()
517 .map(|&i| {
518 let centroid = ¢roids[i * self.dim..(i + 1) * self.dim];
519 dot_product_f32(query, centroid)
520 })
521 .collect()
522 }
523
524 fn compute_bound(&self, idx: usize, query: &[f32]) -> f32 {
526 let cap = &self.caps[idx];
527 let dot = dot_product_f32(query, &cap.centroid);
528 let angle = dot.clamp(-1.0, 1.0).acos();
529 let min_angle = (angle - cap.theta_max).max(0.0);
530 min_angle.cos()
531 }
532
533 pub fn update_cap(&mut self, list_idx: usize, cap: SphericalCapMetadata) {
535 if list_idx < self.caps.len() {
536 self.caps[list_idx] = cap;
537 }
538 }
539
540 pub fn memory_bytes(&self) -> usize {
542 let compressed_bytes = match &self.compressed {
543 CompressedCentroids::Fp32(data) => data.len() * 4,
544 CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
545 CompressedCentroids::Int8(int8) => int8.memory_bytes(),
546 };
547
548 let full_bytes = self
549 .full_precision
550 .as_ref()
551 .map(|v| v.len() * 4)
552 .unwrap_or(0);
553
554 let cap_bytes = self.caps.len() * std::mem::size_of::<SphericalCapMetadata>();
555
556 compressed_bytes + full_bytes + cap_bytes
557 }
558
559 pub fn fits_in_cache(&self) -> bool {
561 let compressed_bytes = match &self.compressed {
562 CompressedCentroids::Fp32(data) => data.len() * 4,
563 CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
564 CompressedCentroids::Int8(int8) => int8.memory_bytes(),
565 };
566
567 compressed_bytes <= self.config.target_llc_bytes
568 }
569
570 pub fn stats(&self) -> RoutingStats {
572 RoutingStats {
573 n_lists: self.n_lists,
574 dim: self.dim,
575 compression: format!("{:?}", self.config.compression),
576 compressed_bytes: match &self.compressed {
577 CompressedCentroids::Fp32(data) => data.len() * 4,
578 CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
579 CompressedCentroids::Int8(int8) => int8.memory_bytes(),
580 },
581 total_bytes: self.memory_bytes(),
582 fits_in_cache: self.fits_in_cache(),
583 target_cache_bytes: self.config.target_llc_bytes,
584 }
585 }
586}
587
588#[derive(Debug, Clone)]
590pub struct ListCandidate {
591 pub list_idx: u32,
593 pub score: f32,
595 pub bound: f32,
597 pub vector_count: u32,
599}
600
601#[derive(Debug, Clone)]
603pub struct RoutingStats {
604 pub n_lists: usize,
605 pub dim: usize,
606 pub compression: String,
607 pub compressed_bytes: usize,
608 pub total_bytes: usize,
609 pub fits_in_cache: bool,
610 pub target_cache_bytes: usize,
611}
612
613#[inline]
619fn f32_to_f16(x: f32) -> u16 {
620 let bits = x.to_bits();
621 let sign = (bits >> 31) & 1;
622 let exp = ((bits >> 23) & 0xff) as i32;
623 let frac = bits & 0x7fffff;
624
625 if exp == 0xff {
627 return ((sign << 15) | 0x7c00 | (frac >> 13)) as u16;
629 }
630 if exp == 0 {
631 return (sign << 15) as u16;
633 }
634
635 let new_exp = exp - 127 + 15;
637
638 if new_exp >= 31 {
639 return ((sign << 15) | 0x7c00) as u16;
641 }
642 if new_exp <= 0 {
643 return (sign << 15) as u16;
645 }
646
647 let new_frac = frac >> 13;
648 ((sign << 15) | ((new_exp as u32) << 10) | new_frac) as u16
649}
650
651#[inline]
653fn f16_to_f32(x: u16) -> f32 {
654 let sign = ((x >> 15) & 1) as u32;
655 let exp = ((x >> 10) & 0x1f) as u32;
656 let frac = (x & 0x3ff) as u32;
657
658 if exp == 0 {
659 if frac == 0 {
660 return f32::from_bits(sign << 31);
661 }
662 let normalized = (frac as f32) / 1024.0 * 2.0f32.powi(-14);
664 return if sign == 1 { -normalized } else { normalized };
665 }
666 if exp == 31 {
667 if frac == 0 {
668 return f32::from_bits((sign << 31) | 0x7f800000);
669 }
670 return f32::NAN;
671 }
672
673 let new_exp = (exp as i32 - 15 + 127) as u32;
674 let new_frac = frac << 13;
675 f32::from_bits((sign << 31) | (new_exp << 23) | new_frac)
676}
677
678#[inline]
680fn dot_f16(a: &[u16], b: &[u16]) -> f32 {
681 a.iter()
683 .zip(b.iter())
684 .map(|(&x, &y)| f16_to_f32(x) * f16_to_f32(y))
685 .sum()
686}
687
688#[inline]
690fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
691 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 #[test]
699 fn test_compression_bytes() {
700 let dim = 768;
701
702 assert_eq!(CentroidCompression::Fp32.bytes_per_centroid(dim), 3072);
703 assert_eq!(CentroidCompression::Fp16.bytes_per_centroid(dim), 1536);
704 assert_eq!(CentroidCompression::Int8.bytes_per_centroid(dim), 768);
705 }
706
707 #[test]
708 fn test_compression_recommendation() {
709 let cache_32mb = 32 * 1024 * 1024;
710 let dim = 768;
711
712 let rec1 = CentroidCompression::recommend(10_000, dim, cache_32mb);
714 assert!(matches!(rec1, CentroidCompression::Fp32));
715
716 let rec2 = CentroidCompression::recommend(20_000, dim, cache_32mb);
718 assert!(matches!(rec2, CentroidCompression::Fp16));
719
720 let rec3 = CentroidCompression::recommend(40_000, dim, cache_32mb);
725 assert!(matches!(rec3, CentroidCompression::Int8));
726 }
727
728 #[test]
729 fn test_fp16_conversion() {
730 let values = [0.0, 1.0, -1.0, 0.5, 0.123, 100.0, -100.0];
731
732 for &x in &values {
733 let f16 = f32_to_f16(x);
734 let back = f16_to_f32(f16);
735 let rel_error = if x.abs() > 1e-10 {
736 (x - back).abs() / x.abs()
737 } else {
738 (x - back).abs()
739 };
740 assert!(
741 rel_error < 0.01,
742 "FP16 roundtrip error too high: {} -> {} -> {}",
743 x,
744 f16,
745 back
746 );
747 }
748 }
749
750 #[test]
751 fn test_routing_layer() {
752 let dim = 4;
753 let n_centroids = 10;
754 let centroids: Vec<f32> = (0..n_centroids * dim)
755 .map(|i| (i as f32 / (n_centroids * dim) as f32))
756 .collect();
757
758 let config = RoutingConfig::default()
759 .compression(CentroidCompression::Fp16)
760 .refine_top_k(5);
761
762 let routing = RoutingLayer::build(¢roids, dim, config);
763
764 let query = vec![0.5, 0.5, 0.5, 0.5];
765 let candidates = routing.route(&query, 3);
766
767 assert_eq!(candidates.len(), 3);
768 assert!(routing.fits_in_cache());
769 }
770
771 #[test]
772 fn test_int8_centroids() {
773 let dim = 4;
774 let centroids = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
775
776 let int8 = Int8Centroids::from_fp32(¢roids, dim);
777
778 let recovered = int8.get_fp32(0);
780 for i in 0..dim {
781 let error = (recovered[i] - centroids[i]).abs();
782 assert!(error < 0.1, "Int8 quantization error too high");
783 }
784 }
785}