Skip to main content

sochdb_vector/
compressed_routing.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! # Cache-Resident Routing Layer (Task 4)
19//!
20//! Ensures routing fits in LLC (Last-Level Cache) to bound latency variance.
21//!
22//! ## Architecture
23//!
24//! Two-stage routing with compressed centroids:
25//! 1. **Coarse stage**: FP16/int8 centroids in compressed space (fits in LLC)
26//! 2. **Fine stage**: Refine top candidates (optionally in full precision)
27//!
28//! ## Math/Algorithm
29//!
30//! Cache complexity constraint: ensure routing working set W ≤ LLC_size
31//!
32//! For C centroids of dimension d:
33//! - FP32: 4·C·d bytes (often exceeds LLC)
34//! - FP16: 2·C·d bytes (50% reduction)
35//! - Int8: C·d bytes (75% reduction)
36//! - PQ: C·(d/m)·1 bytes (even smaller for high-dim)
37//!
38//! Multi-stage ranking: O(C·d_compressed + k·d_full)
39//!
40//! ## Usage
41//!
42//! ```rust,ignore
43//! use sochdb_vector::compressed_routing::{RoutingLayer, RoutingConfig, CentroidCompression};
44//!
45//! let config = RoutingConfig::default()
46//!     .compression(CentroidCompression::Fp16)
47//!     .refine_top_k(32);
48//!
49//! let routing = RoutingLayer::build(&centroids, config);
50//! let top_lists = routing.route(&query, 16);
51//! ```
52
53use crate::list_bounds::{DistanceMetric, SphericalCapMetadata};
54
55// ============================================================================
56// Centroid Compression
57// ============================================================================
58
59/// Compression method for centroids
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum CentroidCompression {
62    /// Full FP32 precision (baseline)
63    Fp32,
64    /// FP16 half precision (2x compression)
65    Fp16,
66    /// Int8 quantization (4x compression)
67    Int8,
68    /// Product Quantization (high compression for high-dim)
69    PQ { n_subquantizers: usize, n_bits: u8 },
70    /// OPQ + PQ (optimized rotation before PQ)
71    OPQ { n_subquantizers: usize, n_bits: u8 },
72}
73
74impl CentroidCompression {
75    /// Compute bytes per centroid for this compression
76    pub fn bytes_per_centroid(&self, dim: usize) -> usize {
77        match self {
78            Self::Fp32 => dim * 4,
79            Self::Fp16 => dim * 2,
80            Self::Int8 => dim,
81            Self::PQ {
82                n_subquantizers,
83                n_bits,
84            } => {
85                // Each subquantizer produces n_bits, pack into bytes
86                (*n_subquantizers * *n_bits as usize + 7) / 8
87            }
88            Self::OPQ {
89                n_subquantizers,
90                n_bits,
91            } => (*n_subquantizers * *n_bits as usize + 7) / 8,
92        }
93    }
94
95    /// Check if this compression will fit in given cache size
96    pub fn fits_in_cache(&self, n_centroids: usize, dim: usize, cache_bytes: usize) -> bool {
97        self.bytes_per_centroid(dim) * n_centroids <= cache_bytes
98    }
99
100    /// Recommend compression level based on constraints
101    pub fn recommend(n_centroids: usize, dim: usize, cache_bytes: usize) -> Self {
102        // Try each compression level until one fits
103        for compression in [
104            Self::Fp32,
105            Self::Fp16,
106            Self::Int8,
107            Self::PQ {
108                n_subquantizers: dim / 4,
109                n_bits: 8,
110            },
111        ] {
112            if compression.fits_in_cache(n_centroids, dim, cache_bytes) {
113                return compression;
114            }
115        }
116        // Fall back to aggressive PQ
117        Self::PQ {
118            n_subquantizers: 16,
119            n_bits: 4,
120        }
121    }
122}
123
124// ============================================================================
125// Routing Configuration
126// ============================================================================
127
128/// Configuration for routing layer
129#[derive(Debug, Clone)]
130pub struct RoutingConfig {
131    /// Compression method for coarse centroids
132    pub compression: CentroidCompression,
133
134    /// Number of top lists to refine in second stage
135    pub refine_top_k: usize,
136
137    /// Use full precision for refinement
138    pub full_precision_refine: bool,
139
140    /// Target LLC size (for cache-awareness)
141    pub target_llc_bytes: usize,
142
143    /// Distance metric
144    pub metric: DistanceMetric,
145
146    /// Prefetch depth for sequential access
147    pub prefetch_depth: usize,
148}
149
150impl Default for RoutingConfig {
151    fn default() -> Self {
152        Self {
153            compression: CentroidCompression::Fp16,
154            refine_top_k: 64,
155            full_precision_refine: true,
156            target_llc_bytes: 32 * 1024 * 1024, // 32 MB LLC
157            metric: DistanceMetric::Cosine,
158            prefetch_depth: 4,
159        }
160    }
161}
162
163impl RoutingConfig {
164    /// Set compression method
165    pub fn compression(mut self, compression: CentroidCompression) -> Self {
166        self.compression = compression;
167        self
168    }
169
170    /// Set refinement count
171    pub fn refine_top_k(mut self, k: usize) -> Self {
172        self.refine_top_k = k;
173        self
174    }
175
176    /// Set target LLC size
177    pub fn target_llc(mut self, bytes: usize) -> Self {
178        self.target_llc_bytes = bytes;
179        self
180    }
181
182    /// Set distance metric
183    pub fn metric(mut self, metric: DistanceMetric) -> Self {
184        self.metric = metric;
185        self
186    }
187}
188
189// ============================================================================
190// Compressed Centroid Storage
191// ============================================================================
192
193/// FP16 encoded centroid storage
194#[derive(Debug, Clone)]
195pub struct Fp16Centroids {
196    /// Packed FP16 data (2 bytes per element)
197    data: Vec<u16>,
198    /// Number of centroids
199    n_centroids: usize,
200    /// Dimension
201    dim: usize,
202}
203
204impl Fp16Centroids {
205    /// Build from FP32 centroids
206    pub fn from_fp32(centroids: &[f32], dim: usize) -> Self {
207        let n_centroids = centroids.len() / dim;
208        let data: Vec<u16> = centroids.iter().map(|&x| f32_to_f16(x)).collect();
209
210        Self {
211            data,
212            n_centroids,
213            dim,
214        }
215    }
216
217    /// Get centroid as FP32 (for refinement)
218    pub fn get_fp32(&self, idx: usize) -> Vec<f32> {
219        let start = idx * self.dim;
220        self.data[start..start + self.dim]
221            .iter()
222            .map(|&x| f16_to_f32(x))
223            .collect()
224    }
225
226    /// Compute dot products with query in FP16
227    pub fn dot_products(&self, query: &[f32]) -> Vec<f32> {
228        let query_f16: Vec<u16> = query.iter().map(|&x| f32_to_f16(x)).collect();
229
230        (0..self.n_centroids)
231            .map(|i| {
232                let start = i * self.dim;
233                let centroid = &self.data[start..start + self.dim];
234                dot_f16(centroid, &query_f16)
235            })
236            .collect()
237    }
238
239    /// Memory footprint in bytes
240    pub fn memory_bytes(&self) -> usize {
241        self.data.len() * 2
242    }
243}
244
245/// Int8 quantized centroid storage
246#[derive(Debug, Clone)]
247pub struct Int8Centroids {
248    /// Quantized data
249    data: Vec<i8>,
250    /// Scale factor per dimension
251    scales: Vec<f32>,
252    /// Zero point per dimension  
253    zero_points: Vec<f32>,
254    /// Number of centroids
255    n_centroids: usize,
256    /// Dimension
257    dim: usize,
258}
259
260impl Int8Centroids {
261    /// Build from FP32 centroids with per-dimension quantization
262    pub fn from_fp32(centroids: &[f32], dim: usize) -> Self {
263        let n_centroids = centroids.len() / dim;
264
265        // Compute min/max per dimension
266        let mut mins = vec![f32::MAX; dim];
267        let mut maxs = vec![f32::MIN; dim];
268
269        for i in 0..n_centroids {
270            for j in 0..dim {
271                let val = centroids[i * dim + j];
272                mins[j] = mins[j].min(val);
273                maxs[j] = maxs[j].max(val);
274            }
275        }
276
277        // Compute scales and zero points
278        let mut scales = Vec::with_capacity(dim);
279        let mut zero_points = Vec::with_capacity(dim);
280
281        for j in 0..dim {
282            let range = maxs[j] - mins[j];
283            let scale = if range > 1e-10 { range / 255.0 } else { 1.0 };
284            scales.push(scale);
285            zero_points.push(mins[j]);
286        }
287
288        // Quantize
289        let data: Vec<i8> = centroids
290            .iter()
291            .enumerate()
292            .map(|(idx, &val)| {
293                let j = idx % dim;
294                let q = ((val - zero_points[j]) / scales[j]).round() as i32;
295                q.clamp(-128, 127) as i8
296            })
297            .collect();
298
299        Self {
300            data,
301            scales,
302            zero_points,
303            n_centroids,
304            dim,
305        }
306    }
307
308    /// Get centroid as FP32 (dequantized)
309    pub fn get_fp32(&self, idx: usize) -> Vec<f32> {
310        let start = idx * self.dim;
311        (0..self.dim)
312            .map(|j| self.data[start + j] as f32 * self.scales[j] + self.zero_points[j])
313            .collect()
314    }
315
316    /// Compute dot products with query using int8 arithmetic
317    pub fn dot_products(&self, query: &[f32]) -> Vec<f32> {
318        // Quantize query
319        let query_i8: Vec<i8> = query
320            .iter()
321            .enumerate()
322            .map(|(j, &val)| {
323                let q = ((val - self.zero_points[j]) / self.scales[j]).round() as i32;
324                q.clamp(-128, 127) as i8
325            })
326            .collect();
327
328        (0..self.n_centroids)
329            .map(|i| {
330                let start = i * self.dim;
331                let centroid = &self.data[start..start + self.dim];
332
333                // Compute dot product in int32 then convert
334                let dot_i32: i32 = centroid
335                    .iter()
336                    .zip(query_i8.iter())
337                    .map(|(&a, &b)| a as i32 * b as i32)
338                    .sum();
339
340                // Approximate dequantization (simplified)
341                dot_i32 as f32 * self.scales[0] * self.scales[0]
342            })
343            .collect()
344    }
345
346    /// Memory footprint in bytes
347    pub fn memory_bytes(&self) -> usize {
348        self.data.len() + self.scales.len() * 4 + self.zero_points.len() * 4
349    }
350}
351
352// ============================================================================
353// Routing Layer
354// ============================================================================
355
356/// Compressed routing layer for cache-resident operations
357pub struct RoutingLayer {
358    /// Compressed centroids for coarse search
359    compressed: CompressedCentroids,
360
361    /// Full-precision centroids for refinement (optional)
362    full_precision: Option<Vec<f32>>,
363
364    /// Spherical cap metadata per list
365    caps: Vec<SphericalCapMetadata>,
366
367    /// Configuration
368    config: RoutingConfig,
369
370    /// Dimension
371    dim: usize,
372
373    /// Number of lists
374    n_lists: usize,
375}
376
377/// Enum for different compression types
378enum CompressedCentroids {
379    Fp32(Vec<f32>),
380    Fp16(Fp16Centroids),
381    Int8(Int8Centroids),
382}
383
384impl RoutingLayer {
385    /// Build routing layer from FP32 centroids
386    pub fn build(centroids: &[f32], dim: usize, config: RoutingConfig) -> Self {
387        let n_lists = centroids.len() / dim;
388
389        // Build compressed centroids
390        let compressed = match config.compression {
391            CentroidCompression::Fp32 => CompressedCentroids::Fp32(centroids.to_vec()),
392            CentroidCompression::Fp16 => {
393                CompressedCentroids::Fp16(Fp16Centroids::from_fp32(centroids, dim))
394            }
395            CentroidCompression::Int8 => {
396                CompressedCentroids::Int8(Int8Centroids::from_fp32(centroids, dim))
397            }
398            _ => {
399                // PQ/OPQ not implemented yet, fall back to FP16
400                CompressedCentroids::Fp16(Fp16Centroids::from_fp32(centroids, dim))
401            }
402        };
403
404        // Store full precision for refinement if configured
405        let full_precision = if config.full_precision_refine {
406            Some(centroids.to_vec())
407        } else {
408            None
409        };
410
411        // Build spherical cap metadata per list
412        let caps: Vec<SphericalCapMetadata> = (0..n_lists)
413            .map(|i| {
414                let centroid = &centroids[i * dim..(i + 1) * dim];
415                SphericalCapMetadata {
416                    centroid: centroid.to_vec(),
417                    theta_max: 0.0, // Will be updated when vectors are added
418                    min_dot_to_centroid: 1.0,
419                    max_dot_to_centroid: 1.0,
420                    vector_count: 0,
421                    mean_dot_to_centroid: 1.0,
422                }
423            })
424            .collect();
425
426        Self {
427            compressed,
428            full_precision,
429            caps,
430            config,
431            dim,
432            n_lists,
433        }
434    }
435
436    /// Route query to top-k lists
437    ///
438    /// Two-stage process:
439    /// 1. Coarse ranking using compressed centroids (cache-resident)
440    /// 2. Refine top candidates using full precision (optional)
441    pub fn route(&self, query: &[f32], n_probes: usize) -> Vec<ListCandidate> {
442        let n_probes = n_probes.min(self.n_lists);
443
444        // Stage 1: Coarse ranking with compressed centroids
445        let coarse_scores = self.coarse_scores(query);
446
447        // Get top-k indices for refinement
448        let refine_k = self.config.refine_top_k.min(self.n_lists);
449        let mut indices: Vec<usize> = (0..self.n_lists).collect();
450
451        // Partial sort for top-k
452        if self.config.metric.higher_is_better() {
453            indices.select_nth_unstable_by(refine_k - 1, |&a, &b| {
454                coarse_scores[b].partial_cmp(&coarse_scores[a]).unwrap()
455            });
456        } else {
457            indices.select_nth_unstable_by(refine_k - 1, |&a, &b| {
458                coarse_scores[a].partial_cmp(&coarse_scores[b]).unwrap()
459            });
460        }
461
462        let top_indices = &indices[..refine_k];
463
464        // Stage 2: Refine with full precision (if available)
465        let refined_scores = if let Some(ref full) = self.full_precision {
466            self.refine_scores(query, top_indices, full)
467        } else {
468            top_indices.iter().map(|&i| coarse_scores[i]).collect()
469        };
470
471        // Build candidates with bounds
472        let mut candidates: Vec<ListCandidate> = top_indices
473            .iter()
474            .zip(refined_scores.iter())
475            .map(|(&idx, &score)| ListCandidate {
476                list_idx: idx as u32,
477                score,
478                bound: self.compute_bound(idx, query),
479                vector_count: self.caps[idx].vector_count,
480            })
481            .collect();
482
483        // Sort by score and take top n_probes
484        if self.config.metric.higher_is_better() {
485            candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
486        } else {
487            candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
488        }
489
490        candidates.truncate(n_probes);
491        candidates
492    }
493
494    /// Compute coarse scores using compressed centroids
495    fn coarse_scores(&self, query: &[f32]) -> Vec<f32> {
496        match &self.compressed {
497            CompressedCentroids::Fp32(data) => self.dot_products_fp32(query, data),
498            CompressedCentroids::Fp16(fp16) => fp16.dot_products(query),
499            CompressedCentroids::Int8(int8) => int8.dot_products(query),
500        }
501    }
502
503    /// Compute full-precision dot products
504    fn dot_products_fp32(&self, query: &[f32], centroids: &[f32]) -> Vec<f32> {
505        (0..self.n_lists)
506            .map(|i| {
507                let centroid = &centroids[i * self.dim..(i + 1) * self.dim];
508                dot_product_f32(query, centroid)
509            })
510            .collect()
511    }
512
513    /// Refine scores for selected indices
514    fn refine_scores(&self, query: &[f32], indices: &[usize], centroids: &[f32]) -> Vec<f32> {
515        indices
516            .iter()
517            .map(|&i| {
518                let centroid = &centroids[i * self.dim..(i + 1) * self.dim];
519                dot_product_f32(query, centroid)
520            })
521            .collect()
522    }
523
524    /// Compute bound for a list
525    fn compute_bound(&self, idx: usize, query: &[f32]) -> f32 {
526        let cap = &self.caps[idx];
527        let dot = dot_product_f32(query, &cap.centroid);
528        let angle = dot.clamp(-1.0, 1.0).acos();
529        let min_angle = (angle - cap.theta_max).max(0.0);
530        min_angle.cos()
531    }
532
533    /// Update spherical cap metadata for a list
534    pub fn update_cap(&mut self, list_idx: usize, cap: SphericalCapMetadata) {
535        if list_idx < self.caps.len() {
536            self.caps[list_idx] = cap;
537        }
538    }
539
540    /// Get memory footprint
541    pub fn memory_bytes(&self) -> usize {
542        let compressed_bytes = match &self.compressed {
543            CompressedCentroids::Fp32(data) => data.len() * 4,
544            CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
545            CompressedCentroids::Int8(int8) => int8.memory_bytes(),
546        };
547
548        let full_bytes = self
549            .full_precision
550            .as_ref()
551            .map(|v| v.len() * 4)
552            .unwrap_or(0);
553
554        let cap_bytes = self.caps.len() * std::mem::size_of::<SphericalCapMetadata>();
555
556        compressed_bytes + full_bytes + cap_bytes
557    }
558
559    /// Check if routing layer fits in target cache
560    pub fn fits_in_cache(&self) -> bool {
561        let compressed_bytes = match &self.compressed {
562            CompressedCentroids::Fp32(data) => data.len() * 4,
563            CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
564            CompressedCentroids::Int8(int8) => int8.memory_bytes(),
565        };
566
567        compressed_bytes <= self.config.target_llc_bytes
568    }
569
570    /// Get routing statistics
571    pub fn stats(&self) -> RoutingStats {
572        RoutingStats {
573            n_lists: self.n_lists,
574            dim: self.dim,
575            compression: format!("{:?}", self.config.compression),
576            compressed_bytes: match &self.compressed {
577                CompressedCentroids::Fp32(data) => data.len() * 4,
578                CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
579                CompressedCentroids::Int8(int8) => int8.memory_bytes(),
580            },
581            total_bytes: self.memory_bytes(),
582            fits_in_cache: self.fits_in_cache(),
583            target_cache_bytes: self.config.target_llc_bytes,
584        }
585    }
586}
587
588/// Candidate list from routing
589#[derive(Debug, Clone)]
590pub struct ListCandidate {
591    /// List index
592    pub list_idx: u32,
593    /// Similarity score to centroid
594    pub score: f32,
595    /// Upper bound on best score in this list
596    pub bound: f32,
597    /// Number of vectors in this list
598    pub vector_count: u32,
599}
600
601/// Routing statistics
602#[derive(Debug, Clone)]
603pub struct RoutingStats {
604    pub n_lists: usize,
605    pub dim: usize,
606    pub compression: String,
607    pub compressed_bytes: usize,
608    pub total_bytes: usize,
609    pub fits_in_cache: bool,
610    pub target_cache_bytes: usize,
611}
612
613// ============================================================================
614// Helper Functions
615// ============================================================================
616
617/// FP32 to FP16 conversion (IEEE 754 half-precision)
618#[inline]
619fn f32_to_f16(x: f32) -> u16 {
620    let bits = x.to_bits();
621    let sign = (bits >> 31) & 1;
622    let exp = ((bits >> 23) & 0xff) as i32;
623    let frac = bits & 0x7fffff;
624
625    // Handle special cases
626    if exp == 0xff {
627        // Inf or NaN
628        return ((sign << 15) | 0x7c00 | (frac >> 13)) as u16;
629    }
630    if exp == 0 {
631        // Zero or subnormal
632        return (sign << 15) as u16;
633    }
634
635    // Adjust exponent for FP16 bias
636    let new_exp = exp - 127 + 15;
637
638    if new_exp >= 31 {
639        // Overflow to infinity
640        return ((sign << 15) | 0x7c00) as u16;
641    }
642    if new_exp <= 0 {
643        // Underflow to zero
644        return (sign << 15) as u16;
645    }
646
647    let new_frac = frac >> 13;
648    ((sign << 15) | ((new_exp as u32) << 10) | new_frac) as u16
649}
650
651/// FP16 to FP32 conversion
652#[inline]
653fn f16_to_f32(x: u16) -> f32 {
654    let sign = ((x >> 15) & 1) as u32;
655    let exp = ((x >> 10) & 0x1f) as u32;
656    let frac = (x & 0x3ff) as u32;
657
658    if exp == 0 {
659        if frac == 0 {
660            return f32::from_bits(sign << 31);
661        }
662        // Subnormal
663        let normalized = (frac as f32) / 1024.0 * 2.0f32.powi(-14);
664        return if sign == 1 { -normalized } else { normalized };
665    }
666    if exp == 31 {
667        if frac == 0 {
668            return f32::from_bits((sign << 31) | 0x7f800000);
669        }
670        return f32::NAN;
671    }
672
673    let new_exp = (exp as i32 - 15 + 127) as u32;
674    let new_frac = frac << 13;
675    f32::from_bits((sign << 31) | (new_exp << 23) | new_frac)
676}
677
678/// FP16 dot product
679#[inline]
680fn dot_f16(a: &[u16], b: &[u16]) -> f32 {
681    // Compute in FP32 for accuracy
682    a.iter()
683        .zip(b.iter())
684        .map(|(&x, &y)| f16_to_f32(x) * f16_to_f32(y))
685        .sum()
686}
687
688/// FP32 dot product
689#[inline]
690fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
691    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    #[test]
699    fn test_compression_bytes() {
700        let dim = 768;
701
702        assert_eq!(CentroidCompression::Fp32.bytes_per_centroid(dim), 3072);
703        assert_eq!(CentroidCompression::Fp16.bytes_per_centroid(dim), 1536);
704        assert_eq!(CentroidCompression::Int8.bytes_per_centroid(dim), 768);
705    }
706
707    #[test]
708    fn test_compression_recommendation() {
709        let cache_32mb = 32 * 1024 * 1024;
710        let dim = 768;
711
712        // 10k centroids at FP32 = 30MB, should fit
713        let rec1 = CentroidCompression::recommend(10_000, dim, cache_32mb);
714        assert!(matches!(rec1, CentroidCompression::Fp32));
715
716        // 20k centroids at FP32 = 60MB, need FP16
717        let rec2 = CentroidCompression::recommend(20_000, dim, cache_32mb);
718        assert!(matches!(rec2, CentroidCompression::Fp16));
719
720        // 40k centroids: FP32 (123MB) and FP16 (61MB) exceed the 32MB cache,
721        // but Int8 (40k*768B = 30.7MB) fits — so Int8 is recommended. (The old
722        // 50k case asserted Int8 but 50k*768B = 36.6MB does NOT fit 32MB, so the
723        // code correctly fell through to PQ — the expectation, not the code, was wrong.)
724        let rec3 = CentroidCompression::recommend(40_000, dim, cache_32mb);
725        assert!(matches!(rec3, CentroidCompression::Int8));
726    }
727
728    #[test]
729    fn test_fp16_conversion() {
730        let values = [0.0, 1.0, -1.0, 0.5, 0.123, 100.0, -100.0];
731
732        for &x in &values {
733            let f16 = f32_to_f16(x);
734            let back = f16_to_f32(f16);
735            let rel_error = if x.abs() > 1e-10 {
736                (x - back).abs() / x.abs()
737            } else {
738                (x - back).abs()
739            };
740            assert!(
741                rel_error < 0.01,
742                "FP16 roundtrip error too high: {} -> {} -> {}",
743                x,
744                f16,
745                back
746            );
747        }
748    }
749
750    #[test]
751    fn test_routing_layer() {
752        let dim = 4;
753        let n_centroids = 10;
754        let centroids: Vec<f32> = (0..n_centroids * dim)
755            .map(|i| (i as f32 / (n_centroids * dim) as f32))
756            .collect();
757
758        let config = RoutingConfig::default()
759            .compression(CentroidCompression::Fp16)
760            .refine_top_k(5);
761
762        let routing = RoutingLayer::build(&centroids, dim, config);
763
764        let query = vec![0.5, 0.5, 0.5, 0.5];
765        let candidates = routing.route(&query, 3);
766
767        assert_eq!(candidates.len(), 3);
768        assert!(routing.fits_in_cache());
769    }
770
771    #[test]
772    fn test_int8_centroids() {
773        let dim = 4;
774        let centroids = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
775
776        let int8 = Int8Centroids::from_fp32(&centroids, dim);
777
778        // Check dequantization is approximate
779        let recovered = int8.get_fp32(0);
780        for i in 0..dim {
781            let error = (recovered[i] - centroids[i]).abs();
782            assert!(error < 0.1, "Int8 quantization error too high");
783        }
784    }
785}