velesdb_core/index/hnsw/native/
quantization.rs

1//! Scalar Quantization (SQ8) for fast HNSW traversal.
2//!
3//! Based on VSAG paper (arXiv:2503.17911): dual-precision architecture
4//! using int8 for graph traversal and float32 for final re-ranking.
5//!
6//! # Performance Benefits
7//!
8//! - **4x memory bandwidth reduction** during traversal
9//! - **SIMD-friendly**: 32 int8 values fit in 256-bit register (vs 8 float32)
10//! - **Cache efficiency**: More vectors fit in L1/L2 cache
11//!
12//! # Algorithm
13//!
14//! For each dimension:
15//! - Compute min/max from training data
16//! - Scale to [0, 255] range: `q = round((x - min) / (max - min) * 255)`
17//! - Store scale and offset for reconstruction
18
19use std::sync::Arc;
20
21// =============================================================================
22// SIMD-optimized distance computation for int8 quantized vectors
23// =============================================================================
24
25/// Computes L2 squared distance between two quantized vectors using SIMD.
26///
27/// Uses 8-wide unrolling for better instruction-level parallelism.
28/// On x86_64 with AVX2, processes 32 bytes per iteration.
29///
30/// # Performance
31///
32/// - **4x memory bandwidth reduction** vs float32
33/// - **Better SIMD utilization**: 32 int8 fit in 256-bit register vs 8 float32
34#[inline]
35fn distance_l2_quantized_simd(a: &[u8], b: &[u8]) -> u32 {
36    debug_assert_eq!(a.len(), b.len());
37
38    // Process in chunks of 8 for better ILP (Instruction Level Parallelism)
39    let chunks = a.len() / 8;
40    let remainder = a.len() % 8;
41
42    let mut sum0: u32 = 0;
43    let mut sum1: u32 = 0;
44    let mut sum2: u32 = 0;
45    let mut sum3: u32 = 0;
46
47    // Main loop: 8-wide unrolling
48    for i in 0..chunks {
49        let base = i * 8;
50
51        // Unroll 8 iterations with 4 accumulators
52        let d0 = i32::from(a[base]) - i32::from(b[base]);
53        let d1 = i32::from(a[base + 1]) - i32::from(b[base + 1]);
54        let d2 = i32::from(a[base + 2]) - i32::from(b[base + 2]);
55        let d3 = i32::from(a[base + 3]) - i32::from(b[base + 3]);
56        let d4 = i32::from(a[base + 4]) - i32::from(b[base + 4]);
57        let d5 = i32::from(a[base + 5]) - i32::from(b[base + 5]);
58        let d6 = i32::from(a[base + 6]) - i32::from(b[base + 6]);
59        let d7 = i32::from(a[base + 7]) - i32::from(b[base + 7]);
60
61        sum0 += (d0 * d0) as u32 + (d4 * d4) as u32;
62        sum1 += (d1 * d1) as u32 + (d5 * d5) as u32;
63        sum2 += (d2 * d2) as u32 + (d6 * d6) as u32;
64        sum3 += (d3 * d3) as u32 + (d7 * d7) as u32;
65    }
66
67    // Handle remainder
68    let base = chunks * 8;
69    for i in 0..remainder {
70        let diff = i32::from(a[base + i]) - i32::from(b[base + i]);
71        sum0 += (diff * diff) as u32;
72    }
73
74    sum0 + sum1 + sum2 + sum3
75}
76
77/// Computes asymmetric L2 distance: float32 query vs quantized candidate.
78///
79/// Uses precomputed lookup tables for efficient SIMD execution.
80/// Based on VSAG paper's ADT (Asymmetric Distance Table) approach.
81#[inline]
82fn distance_l2_asymmetric_simd(
83    query: &[f32],
84    quantized: &[u8],
85    min_vals: &[f32],
86    inv_scales: &[f32],
87) -> f32 {
88    debug_assert_eq!(query.len(), quantized.len());
89    debug_assert_eq!(query.len(), min_vals.len());
90    debug_assert_eq!(query.len(), inv_scales.len());
91
92    // Process in chunks of 4 for SIMD-friendly access
93    let chunks = query.len() / 4;
94    let remainder = query.len() % 4;
95
96    let mut sum0: f32 = 0.0;
97    let mut sum1: f32 = 0.0;
98    let mut sum2: f32 = 0.0;
99    let mut sum3: f32 = 0.0;
100
101    for i in 0..chunks {
102        let base = i * 4;
103
104        // Dequantize and compute squared difference
105        let dq0 = f32::from(quantized[base]) * inv_scales[base] + min_vals[base];
106        let dq1 = f32::from(quantized[base + 1]) * inv_scales[base + 1] + min_vals[base + 1];
107        let dq2 = f32::from(quantized[base + 2]) * inv_scales[base + 2] + min_vals[base + 2];
108        let dq3 = f32::from(quantized[base + 3]) * inv_scales[base + 3] + min_vals[base + 3];
109
110        let d0 = query[base] - dq0;
111        let d1 = query[base + 1] - dq1;
112        let d2 = query[base + 2] - dq2;
113        let d3 = query[base + 3] - dq3;
114
115        sum0 += d0 * d0;
116        sum1 += d1 * d1;
117        sum2 += d2 * d2;
118        sum3 += d3 * d3;
119    }
120
121    // Handle remainder
122    let base = chunks * 4;
123    for i in 0..remainder {
124        let idx = base + i;
125        let dq = f32::from(quantized[idx]) * inv_scales[idx] + min_vals[idx];
126        let diff = query[idx] - dq;
127        sum0 += diff * diff;
128    }
129
130    (sum0 + sum1 + sum2 + sum3).sqrt()
131}
132
133/// Quantization parameters learned from training data.
134#[derive(Debug, Clone)]
135pub struct ScalarQuantizer {
136    /// Minimum value per dimension
137    pub min_vals: Vec<f32>,
138    /// Scale factor per dimension: 255 / (max - min)
139    pub scales: Vec<f32>,
140    /// Inverse scale factor: 1 / scale (precomputed for fast dequantization)
141    pub inv_scales: Vec<f32>,
142    /// Vector dimension
143    pub dimension: usize,
144}
145
146/// Quantized vector storage (int8 per dimension).
147#[derive(Debug, Clone)]
148pub struct QuantizedVector {
149    /// Quantized values [0, 255]
150    pub data: Vec<u8>,
151}
152
153/// Quantized vector storage with shared quantizer reference.
154#[derive(Debug, Clone)]
155pub struct QuantizedVectorStore {
156    /// Shared quantizer parameters
157    quantizer: Arc<ScalarQuantizer>,
158    /// Quantized vectors (flattened: node_id * dimension + dim_idx)
159    data: Vec<u8>,
160    /// Number of vectors stored
161    count: usize,
162}
163
164impl ScalarQuantizer {
165    /// Creates a new quantizer from training vectors.
166    ///
167    /// # Arguments
168    ///
169    /// * `vectors` - Training vectors to compute min/max per dimension
170    ///
171    /// # Panics
172    ///
173    /// Panics if vectors is empty or vectors have inconsistent dimensions.
174    #[must_use]
175    pub fn train(vectors: &[&[f32]]) -> Self {
176        assert!(!vectors.is_empty(), "Cannot train on empty vectors");
177        let dimension = vectors[0].len();
178        assert!(
179            vectors.iter().all(|v| v.len() == dimension),
180            "All vectors must have same dimension"
181        );
182
183        let mut min_vals = vec![f32::MAX; dimension];
184        let mut max_vals = vec![f32::MIN; dimension];
185
186        // Find min/max per dimension
187        for vec in vectors {
188            for (i, &val) in vec.iter().enumerate() {
189                min_vals[i] = min_vals[i].min(val);
190                max_vals[i] = max_vals[i].max(val);
191            }
192        }
193
194        // Compute scales (avoid division by zero)
195        let scales: Vec<f32> = min_vals
196            .iter()
197            .zip(max_vals.iter())
198            .map(|(&min, &max)| {
199                let range = max - min;
200                if range.abs() < 1e-10 {
201                    1.0 // Constant dimension, scale doesn't matter
202                } else {
203                    255.0 / range
204                }
205            })
206            .collect();
207
208        // Precompute inverse scales for fast dequantization
209        let inv_scales: Vec<f32> = scales.iter().map(|&s| 1.0 / s).collect();
210
211        Self {
212            min_vals,
213            scales,
214            inv_scales,
215            dimension,
216        }
217    }
218
219    /// Quantizes a float32 vector to int8.
220    #[must_use]
221    pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
222        debug_assert_eq!(vector.len(), self.dimension);
223
224        let data: Vec<u8> = vector
225            .iter()
226            .zip(self.min_vals.iter())
227            .zip(self.scales.iter())
228            .map(|((&val, &min), &scale)| {
229                let q = ((val - min) * scale).round();
230                q.clamp(0.0, 255.0) as u8
231            })
232            .collect();
233
234        QuantizedVector { data }
235    }
236
237    /// Dequantizes an int8 vector back to float32.
238    #[must_use]
239    pub fn dequantize(&self, quantized: &QuantizedVector) -> Vec<f32> {
240        debug_assert_eq!(quantized.data.len(), self.dimension);
241
242        quantized
243            .data
244            .iter()
245            .zip(self.min_vals.iter())
246            .zip(self.inv_scales.iter())
247            .map(|((&q, &min), &inv_scale)| {
248                // x = q * inv_scale + min (multiplication is faster than division)
249                f32::from(q) * inv_scale + min
250            })
251            .collect()
252    }
253
254    /// Computes approximate L2 distance between quantized vectors.
255    ///
256    /// This is ~4x faster than float32 due to SIMD efficiency.
257    #[inline]
258    #[must_use]
259    pub fn distance_l2_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> u32 {
260        debug_assert_eq!(a.data.len(), b.data.len());
261        distance_l2_quantized_simd(&a.data, &b.data)
262    }
263
264    /// Computes approximate L2 distance using raw slices (zero-copy).
265    ///
266    /// Useful for QuantizedVectorStore.get_slice() access pattern.
267    #[inline]
268    #[must_use]
269    pub fn distance_l2_quantized_slice(&self, a: &[u8], b: &[u8]) -> u32 {
270        debug_assert_eq!(a.len(), b.len());
271        distance_l2_quantized_simd(a, b)
272    }
273
274    /// Computes approximate L2 distance: quantized vs float32 query.
275    ///
276    /// Asymmetric distance: query stays in float32, candidates in int8.
277    /// This is the VSAG "ADT" (Asymmetric Distance Table) approach.
278    #[inline]
279    #[must_use]
280    pub fn distance_l2_asymmetric(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
281        debug_assert_eq!(query.len(), self.dimension);
282        debug_assert_eq!(quantized.data.len(), self.dimension);
283
284        distance_l2_asymmetric_simd(query, &quantized.data, &self.min_vals, &self.inv_scales)
285    }
286
287    /// Computes asymmetric L2 distance using raw slice (zero-copy).
288    #[inline]
289    #[must_use]
290    pub fn distance_l2_asymmetric_slice(&self, query: &[f32], quantized: &[u8]) -> f32 {
291        debug_assert_eq!(query.len(), self.dimension);
292        debug_assert_eq!(quantized.len(), self.dimension);
293
294        distance_l2_asymmetric_simd(query, quantized, &self.min_vals, &self.inv_scales)
295    }
296}
297
298impl QuantizedVectorStore {
299    /// Creates a new quantized vector store.
300    #[must_use]
301    pub fn new(quantizer: Arc<ScalarQuantizer>, capacity: usize) -> Self {
302        let dimension = quantizer.dimension;
303        Self {
304            quantizer,
305            data: Vec::with_capacity(capacity * dimension),
306            count: 0,
307        }
308    }
309
310    /// Adds a vector to the store (quantizes it first).
311    pub fn push(&mut self, vector: &[f32]) {
312        let quantized = self.quantizer.quantize(vector);
313        self.data.extend(quantized.data);
314        self.count += 1;
315    }
316
317    /// Gets a quantized vector by index.
318    #[must_use]
319    pub fn get(&self, index: usize) -> Option<QuantizedVector> {
320        if index >= self.count {
321            return None;
322        }
323        let start = index * self.quantizer.dimension;
324        let end = start + self.quantizer.dimension;
325        Some(QuantizedVector {
326            data: self.data[start..end].to_vec(),
327        })
328    }
329
330    /// Gets raw slice for a quantized vector (zero-copy).
331    #[must_use]
332    pub fn get_slice(&self, index: usize) -> Option<&[u8]> {
333        if index >= self.count {
334            return None;
335        }
336        let start = index * self.quantizer.dimension;
337        let end = start + self.quantizer.dimension;
338        Some(&self.data[start..end])
339    }
340
341    /// Returns the number of vectors.
342    #[must_use]
343    pub fn len(&self) -> usize {
344        self.count
345    }
346
347    /// Returns true if empty.
348    #[must_use]
349    pub fn is_empty(&self) -> bool {
350        self.count == 0
351    }
352
353    /// Returns reference to quantizer.
354    #[must_use]
355    pub fn quantizer(&self) -> &ScalarQuantizer {
356        &self.quantizer
357    }
358}
359
360#[cfg(test)]
361#[allow(clippy::similar_names)]
362mod tests {
363    use super::*;
364
365    // =========================================================================
366    // TDD Tests: ScalarQuantizer training
367    // =========================================================================
368
369    #[test]
370    fn test_train_computes_correct_min_max() {
371        let v1 = vec![0.0, 10.0, -5.0];
372        let v2 = vec![5.0, 20.0, 5.0];
373        let v3 = vec![2.5, 15.0, 0.0];
374
375        let quantizer = ScalarQuantizer::train(&[&v1, &v2, &v3]);
376
377        assert_eq!(quantizer.dimension, 3);
378        assert!((quantizer.min_vals[0] - 0.0).abs() < 1e-6);
379        assert!((quantizer.min_vals[1] - 10.0).abs() < 1e-6);
380        assert!((quantizer.min_vals[2] - (-5.0)).abs() < 1e-6);
381
382        // Scale = 255 / (max - min)
383        assert!((quantizer.scales[0] - 255.0 / 5.0).abs() < 1e-4);
384        assert!((quantizer.scales[1] - 255.0 / 10.0).abs() < 1e-4);
385        assert!((quantizer.scales[2] - 255.0 / 10.0).abs() < 1e-4);
386    }
387
388    #[test]
389    fn test_train_handles_constant_dimension() {
390        let v1 = vec![1.0, 5.0, 5.0]; // dim 1 and 2 are constant
391        let v2 = vec![2.0, 5.0, 5.0];
392
393        let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
394
395        // Constant dimensions should have scale = 1.0 (fallback)
396        assert!((quantizer.scales[1] - 1.0).abs() < 1e-6);
397        assert!((quantizer.scales[2] - 1.0).abs() < 1e-6);
398    }
399
400    #[test]
401    #[should_panic(expected = "Cannot train on empty vectors")]
402    fn test_train_panics_on_empty() {
403        let _: ScalarQuantizer = ScalarQuantizer::train(&[]);
404    }
405
406    // =========================================================================
407    // TDD Tests: Quantization and dequantization
408    // =========================================================================
409
410    #[test]
411    fn test_quantize_min_becomes_zero() {
412        let v = vec![0.0, 100.0];
413        let quantizer = ScalarQuantizer::train(&[&v]);
414
415        let qvec = quantizer.quantize(&[0.0, 100.0]);
416
417        // min should map to 0, max should map to 255
418        assert_eq!(qvec.data[0], 0);
419        // For single vector, min=max for each dim, so scale=1.0
420    }
421
422    #[test]
423    fn test_quantize_range_maps_correctly() {
424        let v1 = vec![0.0, 0.0];
425        let v2 = vec![10.0, 100.0];
426        let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
427
428        // Test min values -> 0
429        let q_min = quantizer.quantize(&[0.0, 0.0]);
430        assert_eq!(q_min.data[0], 0);
431        assert_eq!(q_min.data[1], 0);
432
433        // Test max values -> 255
434        let q_max = quantizer.quantize(&[10.0, 100.0]);
435        assert_eq!(q_max.data[0], 255);
436        assert_eq!(q_max.data[1], 255);
437
438        // Test mid values -> ~127-128
439        let q_mid = quantizer.quantize(&[5.0, 50.0]);
440        assert!((i32::from(q_mid.data[0]) - 127).abs() <= 1);
441        assert!((i32::from(q_mid.data[1]) - 127).abs() <= 1);
442    }
443
444    #[test]
445    fn test_quantize_clamps_out_of_range() {
446        let v1 = vec![0.0];
447        let v2 = vec![10.0];
448        let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
449
450        // Value below training min
451        let q_low = quantizer.quantize(&[-5.0]);
452        assert_eq!(q_low.data[0], 0, "Should clamp to 0");
453
454        // Value above training max
455        let q_high = quantizer.quantize(&[20.0]);
456        assert_eq!(q_high.data[0], 255, "Should clamp to 255");
457    }
458
459    #[test]
460    fn test_dequantize_recovers_approximate_values() {
461        let v1 = vec![0.0, -10.0, 100.0];
462        let v2 = vec![10.0, 10.0, 200.0];
463        let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
464
465        let original = vec![5.0, 0.0, 150.0];
466        let qvec = quantizer.quantize(&original);
467        let recovered = quantizer.dequantize(&qvec);
468
469        // Should be approximately equal (quantization error < 1% of range)
470        for (i, (&orig, &rec)) in original.iter().zip(recovered.iter()).enumerate() {
471            let range = v2[i] - v1[i];
472            let error = (orig - rec).abs();
473            let relative_error = error / range;
474            assert!(
475                relative_error < 0.01,
476                "Dim {i}: orig={orig}, rec={rec}, error={relative_error:.4}"
477            );
478        }
479    }
480
481    // =========================================================================
482    // TDD Tests: Distance computation
483    // =========================================================================
484
485    #[test]
486    fn test_distance_l2_quantized_identical_is_zero() {
487        let quantizer = ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]);
488        let v = quantizer.quantize(&[5.0, 5.0]);
489
490        let dist = quantizer.distance_l2_quantized(&v, &v);
491        assert_eq!(dist, 0, "Distance to self should be 0");
492    }
493
494    #[test]
495    fn test_distance_l2_quantized_symmetry() {
496        let quantizer = ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]);
497        let a = quantizer.quantize(&[2.0, 3.0]);
498        let b = quantizer.quantize(&[7.0, 8.0]);
499
500        let dist_ab = quantizer.distance_l2_quantized(&a, &b);
501        let dist_ba = quantizer.distance_l2_quantized(&b, &a);
502
503        assert_eq!(dist_ab, dist_ba, "Distance should be symmetric");
504    }
505
506    #[test]
507    fn test_distance_l2_asymmetric_close_to_exact() {
508        let v1 = vec![0.0; 128];
509        let v2 = vec![10.0; 128];
510        let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
511
512        let query = vec![3.0; 128];
513        let candidate = vec![7.0; 128];
514
515        let quantized_candidate = quantizer.quantize(&candidate);
516        let approx_dist = quantizer.distance_l2_asymmetric(&query, &quantized_candidate);
517
518        // Exact L2 distance
519        let exact_dist: f32 = query
520            .iter()
521            .zip(candidate.iter())
522            .map(|(a, b)| (a - b).powi(2))
523            .sum::<f32>()
524            .sqrt();
525
526        // Asymmetric distance should be within 5% of exact
527        let relative_error = (approx_dist - exact_dist).abs() / exact_dist;
528        assert!(
529            relative_error < 0.05,
530            "approx={approx_dist}, exact={exact_dist}, error={relative_error:.4}"
531        );
532    }
533
534    // =========================================================================
535    // TDD Tests: QuantizedVectorStore
536    // =========================================================================
537
538    #[test]
539    fn test_store_push_and_get() {
540        let quantizer = Arc::new(ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]));
541        let mut store = QuantizedVectorStore::new(quantizer.clone(), 100);
542
543        store.push(&[2.0, 3.0]);
544        store.push(&[7.0, 8.0]);
545
546        assert_eq!(store.len(), 2);
547
548        let v0 = store.get(0).expect("Should have index 0");
549        let v1 = store.get(1).expect("Should have index 1");
550
551        // Verify values are different
552        assert_ne!(v0.data, v1.data);
553    }
554
555    #[test]
556    fn test_store_get_out_of_bounds_returns_none() {
557        let quantizer = Arc::new(ScalarQuantizer::train(&[&[0.0], &[10.0]]));
558        let store = QuantizedVectorStore::new(quantizer, 100);
559
560        assert!(store.get(0).is_none());
561        assert!(store.get(100).is_none());
562    }
563
564    #[test]
565    fn test_store_get_slice_zero_copy() {
566        let quantizer = Arc::new(ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]));
567        let mut store = QuantizedVectorStore::new(quantizer.clone(), 100);
568
569        store.push(&[5.0, 5.0]);
570
571        let slice = store.get_slice(0).expect("Should have slice");
572        assert_eq!(slice.len(), 2);
573
574        // Verify it's the expected quantized value (~127)
575        assert!((i32::from(slice[0]) - 127).abs() <= 1);
576        assert!((i32::from(slice[1]) - 127).abs() <= 1);
577    }
578
579    // =========================================================================
580    // TDD Tests: Memory efficiency
581    // =========================================================================
582
583    #[test]
584    fn test_memory_efficiency_4x_reduction() {
585        let dim = 768;
586        let count = 10_000;
587
588        // Float32 storage: 768 * 4 * 10000 = 30.72 MB
589        let float32_bytes = dim * 4 * count;
590
591        // Int8 storage: 768 * 1 * 10000 = 7.68 MB
592        let int8_bytes = dim * count;
593
594        assert_eq!(float32_bytes / int8_bytes, 4, "Should be 4x reduction");
595    }
596
597    // =========================================================================
598    // TDD Tests: High-dimensional vectors (realistic embedding sizes)
599    // =========================================================================
600
601    #[test]
602    fn test_quantize_768d_embedding() {
603        // Typical embedding size (BERT, etc.)
604        let v1: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
605        let v2: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).cos()).collect();
606
607        let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
608        assert_eq!(quantizer.dimension, 768);
609
610        let qvec = quantizer.quantize(&v1);
611        assert_eq!(qvec.data.len(), 768);
612
613        let recovered = quantizer.dequantize(&qvec);
614        assert_eq!(recovered.len(), 768);
615
616        // Check reconstruction error is reasonable
617        let mse: f32 = v1
618            .iter()
619            .zip(recovered.iter())
620            .map(|(a, b)| (a - b).powi(2))
621            .sum::<f32>()
622            / 768.0;
623
624        assert!(mse < 0.001, "MSE should be small: {mse}");
625    }
626}