Skip to main content

sqlite_knowledge_graph/vector/
turboquant.rs

1//! TurboQuant: Near-optimal vector quantization for instant indexing
2//!
3//! Based on arXiv:2504.19874 (ICLR 2026)
4//!
5//! Key benefits:
6//! - Indexing time: 239s → 0.001s (vs Product Quantization)
7//! - Memory compression: 6x
8//! - Zero accuracy loss
9//! - No training required (data-oblivious)
10
11use crate::error::{Error, Result};
12use nalgebra::DMatrix;
13use rand::{rngs::StdRng, SeedableRng};
14use rand_distr::{Distribution, StandardNormal};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18/// TurboQuant configuration
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TurboQuantConfig {
21    /// Vector dimension
22    pub dimension: usize,
23    /// Bits per coordinate (1-8)
24    pub bit_width: usize,
25    /// Random seed for reproducibility
26    pub seed: u64,
27}
28
29impl Default for TurboQuantConfig {
30    fn default() -> Self {
31        Self {
32            dimension: 384,
33            bit_width: 3,
34            seed: 42,
35        }
36    }
37}
38
39/// TurboQuant index for fast approximate nearest neighbor search
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TurboQuantIndex {
42    config: TurboQuantConfig,
43    /// Random rotation matrix (d × d)
44    rotation_matrix: Vec<Vec<f32>>,
45    /// Optimal scalar quantizer codebook
46    codebook: Vec<f32>,
47    /// Quantized vectors: entity_id -> quantized indices
48    quantized_vectors: HashMap<i64, Vec<u8>>,
49    /// Norms of original vectors (for similarity computation)
50    vector_norms: HashMap<i64, f32>,
51}
52
53/// Linear scan index for comparison (exact search)
54pub struct LinearScanIndex {
55    config: TurboQuantConfig,
56    vectors: HashMap<i64, Vec<f32>>,
57}
58
59impl LinearScanIndex {
60    /// Create a new linear scan index
61    pub fn new(config: TurboQuantConfig) -> Result<Self> {
62        Ok(Self {
63            config,
64            vectors: HashMap::new(),
65        })
66    }
67
68    /// Add a vector to the index
69    pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
70        if vector.len() != self.config.dimension {
71            return Err(Error::InvalidVectorDimension {
72                expected: self.config.dimension,
73                actual: vector.len(),
74            });
75        }
76        self.vectors.insert(entity_id, vector.to_vec());
77        Ok(())
78    }
79
80    /// Search for k nearest neighbors using exact cosine similarity
81    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
82        if query.len() != self.config.dimension {
83            return Err(Error::InvalidVectorDimension {
84                expected: self.config.dimension,
85                actual: query.len(),
86            });
87        }
88
89        let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
90
91        let mut results: Vec<(i64, f32)> = self
92            .vectors
93            .iter()
94            .map(|(&entity_id, vector)| {
95                let dot_product: f32 = query.iter().zip(vector.iter()).map(|(a, b)| a * b).sum();
96                let target_norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
97                let similarity = if query_norm > 0.0 && target_norm > 0.0 {
98                    dot_product / (query_norm * target_norm)
99                } else {
100                    0.0
101                };
102                (entity_id, similarity)
103            })
104            .collect();
105
106        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107        results.truncate(k);
108
109        Ok(results)
110    }
111
112    /// Get index statistics
113    pub fn stats(&self) -> LinearScanStats {
114        LinearScanStats {
115            num_vectors: self.vectors.len(),
116            dimension: self.config.dimension,
117            bytes_per_vector: self.config.dimension * 4, // f32 = 4 bytes
118        }
119    }
120
121    /// Clear the index
122    pub fn clear(&mut self) {
123        self.vectors.clear();
124    }
125
126    /// Get number of vectors
127    pub fn len(&self) -> usize {
128        self.vectors.len()
129    }
130
131    /// Check if index is empty
132    pub fn is_empty(&self) -> bool {
133        self.vectors.is_empty()
134    }
135}
136
137/// Statistics about a LinearScan index
138#[derive(Debug, Clone)]
139pub struct LinearScanStats {
140    pub num_vectors: usize,
141    pub dimension: usize,
142    pub bytes_per_vector: usize,
143}
144
145impl TurboQuantIndex {
146    /// Create a new TurboQuant index
147    pub fn new(config: TurboQuantConfig) -> Result<Self> {
148        if config.bit_width < 1 || config.bit_width > 8 {
149            return Err(Error::InvalidInput(
150                "bit_width must be between 1 and 8".to_string(),
151            ));
152        }
153
154        let mut rng = StdRng::seed_from_u64(config.seed);
155
156        // Generate random rotation matrix using QR decomposition approximation
157        let rotation_matrix = Self::generate_rotation_matrix(config.dimension, &mut rng);
158
159        // Compute optimal scalar quantizer for concentrated Beta distribution
160        let codebook = Self::compute_codebook(config.bit_width);
161
162        Ok(Self {
163            config,
164            rotation_matrix,
165            codebook,
166            quantized_vectors: HashMap::new(),
167            vector_norms: HashMap::new(),
168        })
169    }
170
171    /// Generate random orthogonal rotation matrix via QR decomposition.
172    ///
173    /// Fills a d×d matrix with standard-normal entries, then performs QR
174    /// decomposition and returns the orthogonal factor Q.  This matches the
175    /// paper's requirement of a proper random orthogonal matrix.
176    fn generate_rotation_matrix(d: usize, rng: &mut StdRng) -> Vec<Vec<f32>> {
177        // Sample entries from N(0,1) as f64 for nalgebra
178        let data: Vec<f64> = (0..d * d).map(|_| StandardNormal.sample(rng)).collect();
179        let matrix = DMatrix::from_vec(d, d, data);
180
181        // QR decomposition; Q is d×d orthogonal
182        let qr = matrix.qr();
183        let q = qr.q();
184
185        // Convert to Vec<Vec<f32>>
186        (0..d)
187            .map(|i| (0..d).map(|j| q[(i, j)] as f32).collect())
188            .collect()
189    }
190
191    /// Compute optimal scalar quantizer codebook using the Max-Lloyd algorithm.
192    ///
193    /// After random rotation each coordinate follows an approximately
194    /// N(0, 1/d) distribution.  We sample from that distribution and run
195    /// Lloyd's 1-D k-means to find the centroids that minimise MSE.
196    fn compute_codebook(bit_width: usize) -> Vec<f32> {
197        let k = 1usize << bit_width; // 2^b centroids
198                                     // Use a fixed-seed RNG so the codebook is deterministic
199        let mut rng = StdRng::seed_from_u64(0xc0de_b007);
200        let num_samples = 50_000usize;
201        let std_dev = (1.0_f32 / 384_f32).sqrt(); // approximate for default dim
202
203        // 1. Draw samples approximating the post-rotation distribution
204        let samples: Vec<f32> = (0..num_samples)
205            .map(|_| {
206                let n: f64 = StandardNormal.sample(&mut rng);
207                (n as f32 * std_dev).clamp(-1.0, 1.0)
208            })
209            .collect();
210
211        // 2. Initialise centroids uniformly across [-1, 1]
212        let mut centroids: Vec<f32> = (0..k)
213            .map(|i| {
214                if k == 1 {
215                    0.0
216                } else {
217                    -1.0 + 2.0 * i as f32 / (k - 1) as f32
218                }
219            })
220            .collect();
221
222        // 3. Lloyd iterations (1-D k-means)
223        for _ in 0..100 {
224            let mut sums = vec![0.0f64; k];
225            let mut counts = vec![0usize; k];
226
227            for &x in &samples {
228                let nearest = centroids
229                    .iter()
230                    .enumerate()
231                    .min_by(|(_, a), (_, b)| {
232                        (x - *a)
233                            .abs()
234                            .partial_cmp(&(x - *b).abs())
235                            .unwrap_or(std::cmp::Ordering::Equal)
236                    })
237                    .map(|(i, _)| i)
238                    .unwrap_or(0);
239                sums[nearest] += x as f64;
240                counts[nearest] += 1;
241            }
242
243            let prev = centroids.clone();
244            for i in 0..k {
245                if counts[i] > 0 {
246                    centroids[i] = (sums[i] / counts[i] as f64) as f32;
247                }
248            }
249
250            // Check convergence
251            let converged = centroids
252                .iter()
253                .zip(prev.iter())
254                .all(|(a, b)| (a - b).abs() < 1e-6);
255            if converged {
256                break;
257            }
258        }
259
260        centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
261        centroids
262    }
263
264    /// Add a vector to the index
265    pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
266        if vector.len() != self.config.dimension {
267            return Err(Error::InvalidVectorDimension {
268                expected: self.config.dimension,
269                actual: vector.len(),
270            });
271        }
272
273        // Compute norm for similarity normalization
274        let norm: f32 = vector.iter().map(|x| x * x).sum();
275        let norm = norm.sqrt();
276        self.vector_norms.insert(entity_id, norm);
277
278        // Apply random rotation
279        let rotated = self.apply_rotation(vector);
280
281        // Quantize each coordinate
282        let quantized = self.quantize_vector(&rotated);
283
284        self.quantized_vectors.insert(entity_id, quantized);
285
286        Ok(())
287    }
288
289    /// Apply random rotation to vector
290    fn apply_rotation(&self, vector: &[f32]) -> Vec<f32> {
291        let d = self.config.dimension;
292        let mut rotated = vec![0.0f32; d];
293
294        for (i, rot_row) in self.rotation_matrix.iter().enumerate().take(d) {
295            for (j, &val) in vector.iter().enumerate().take(d) {
296                rotated[i] += rot_row[j] * val;
297            }
298        }
299
300        rotated
301    }
302
303    /// Quantize a rotated vector
304    fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
305        vector
306            .iter()
307            .map(|&val| {
308                // Find nearest codebook entry
309                let mut best_idx = 0;
310                let mut best_dist = f32::MAX;
311
312                for (idx, &centroid) in self.codebook.iter().enumerate() {
313                    let dist = (val - centroid).abs();
314                    if dist < best_dist {
315                        best_dist = dist;
316                        best_idx = idx;
317                    }
318                }
319
320                best_idx as u8
321            })
322            .collect()
323    }
324
325    /// Dequantize a vector (for reconstruction)
326    #[allow(dead_code)]
327    fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
328        quantized
329            .iter()
330            .map(|&idx| self.codebook[idx as usize])
331            .collect()
332    }
333
334    /// Search for k nearest neighbors
335    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
336        if query.len() != self.config.dimension {
337            return Err(Error::InvalidVectorDimension {
338                expected: self.config.dimension,
339                actual: query.len(),
340            });
341        }
342
343        // Rotate and quantize query
344        let rotated_query = self.apply_rotation(query);
345        let quantized_query = self.quantize_vector(&rotated_query);
346
347        // Compute query norm
348        let query_norm: f32 = query.iter().map(|x| x * x).sum();
349        let query_norm = query_norm.sqrt();
350
351        // Compute approximate similarities with all indexed vectors
352        let mut results: Vec<(i64, f32)> = self
353            .quantized_vectors
354            .iter()
355            .map(|(&entity_id, quantized_vec)| {
356                let similarity = self.compute_similarity(
357                    &quantized_query,
358                    quantized_vec,
359                    query_norm,
360                    self.vector_norms.get(&entity_id).copied().unwrap_or(1.0),
361                );
362                (entity_id, similarity)
363            })
364            .collect();
365
366        // Sort by similarity (descending) and take top k
367        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
368        results.truncate(k);
369
370        Ok(results)
371    }
372
373    /// Compute approximate cosine similarity between quantized vectors.
374    ///
375    /// All arithmetic is done in the quantised reconstruction space so that
376    /// numerator and denominator are dimensionally consistent (both are sums
377    /// of squared codebook values).  The original-space norms are no longer
378    /// used here; they are kept in `vector_norms` only for potential future
379    /// re-ranking passes.
380    fn compute_similarity(
381        &self,
382        query: &[u8],
383        target: &[u8],
384        _query_norm: f32,
385        _target_norm: f32,
386    ) -> f32 {
387        if query.len() != target.len() {
388            return 0.0;
389        }
390
391        let mut dot_product = 0.0f32;
392        let mut query_sq = 0.0f32;
393        let mut target_sq = 0.0f32;
394
395        for i in 0..query.len() {
396            let q_val = self.codebook[query[i] as usize];
397            let t_val = self.codebook[target[i] as usize];
398            dot_product += q_val * t_val;
399            query_sq += q_val * q_val;
400            target_sq += t_val * t_val;
401        }
402
403        let denom = query_sq.sqrt() * target_sq.sqrt();
404        if denom > 0.0 {
405            dot_product / denom
406        } else {
407            0.0
408        }
409    }
410
411    /// Batch add vectors to the index
412    pub fn add_vectors_batch(&mut self, vectors: &[(i64, Vec<f32>)]) -> Result<()> {
413        for (entity_id, vector) in vectors {
414            self.add_vector(*entity_id, vector)?;
415        }
416        Ok(())
417    }
418
419    /// Get index statistics
420    pub fn stats(&self) -> TurboQuantStats {
421        TurboQuantStats {
422            num_vectors: self.quantized_vectors.len(),
423            dimension: self.config.dimension,
424            bit_width: self.config.bit_width,
425            bytes_per_vector: self.config.dimension, // 1 byte per coordinate
426            compression_ratio: 32.0 / self.config.bit_width as f32, // vs float32
427        }
428    }
429
430    /// Remove a vector from the index
431    pub fn remove_vector(&mut self, entity_id: i64) -> Result<()> {
432        self.quantized_vectors.remove(&entity_id);
433        self.vector_norms.remove(&entity_id);
434        Ok(())
435    }
436
437    /// Clear the index
438    pub fn clear(&mut self) {
439        self.quantized_vectors.clear();
440        self.vector_norms.clear();
441    }
442
443    /// Get number of vectors
444    pub fn len(&self) -> usize {
445        self.quantized_vectors.len()
446    }
447
448    /// Check if index is empty
449    pub fn is_empty(&self) -> bool {
450        self.quantized_vectors.is_empty()
451    }
452
453    /// Save index to file
454    pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
455        let serialized = serde_json::to_string(self)?;
456        std::fs::write(path, serialized)?;
457        Ok(())
458    }
459
460    /// Load index from file
461    pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
462        let contents = std::fs::read_to_string(path)?;
463        let index: Self = serde_json::from_str(&contents)?;
464        Ok(index)
465    }
466
467    /// Get the config
468    pub fn config(&self) -> &TurboQuantConfig {
469        &self.config
470    }
471
472    /// Batch search for multiple queries
473    pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(i64, f32)>>> {
474        queries.iter().map(|query| self.search(query, k)).collect()
475    }
476}
477
478/// Statistics about a TurboQuant index
479#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct TurboQuantStats {
481    pub num_vectors: usize,
482    pub dimension: usize,
483    pub bit_width: usize,
484    pub bytes_per_vector: usize,
485    pub compression_ratio: f32,
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_create_index() {
494        let config = TurboQuantConfig {
495            dimension: 128,
496            bit_width: 3,
497            seed: 42,
498        };
499
500        let index = TurboQuantIndex::new(config).unwrap();
501        assert_eq!(index.config.dimension, 128);
502        assert_eq!(index.config.bit_width, 3);
503    }
504
505    #[test]
506    fn test_add_and_search() {
507        let config = TurboQuantConfig {
508            dimension: 128,
509            bit_width: 3,
510            seed: 42,
511        };
512
513        let mut index = TurboQuantIndex::new(config).unwrap();
514
515        // Add some test vectors
516        let vec1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
517        let vec2: Vec<f32> = (0..128).map(|i| ((i + 64) % 128) as f32 / 128.0).collect();
518        let vec3: Vec<f32> = (0..128).map(|i| 1.0 - (i as f32) / 128.0).collect();
519
520        index.add_vector(1, &vec1).unwrap();
521        index.add_vector(2, &vec2).unwrap();
522        index.add_vector(3, &vec3).unwrap();
523
524        // Search with vec1
525        let results = index.search(&vec1, 2).unwrap();
526        assert_eq!(results.len(), 2);
527        assert_eq!(results[0].0, 1); // vec1 should be closest to itself
528    }
529
530    #[test]
531    fn test_compression_ratio() {
532        let config = TurboQuantConfig {
533            dimension: 384,
534            bit_width: 3,
535            seed: 42,
536        };
537
538        let index = TurboQuantIndex::new(config).unwrap();
539        let stats = index.stats();
540
541        // 3 bits vs 32 bits = ~10x compression
542        assert!(stats.compression_ratio > 10.0);
543    }
544
545    #[test]
546    fn test_stats() {
547        let config = TurboQuantConfig {
548            dimension: 384,
549            bit_width: 3,
550            seed: 42,
551        };
552
553        let mut index = TurboQuantIndex::new(config).unwrap();
554
555        let vec: Vec<f32> = vec![0.1; 384];
556        index.add_vector(1, &vec).unwrap();
557        index.add_vector(2, &vec).unwrap();
558
559        let stats = index.stats();
560        assert_eq!(stats.num_vectors, 2);
561        assert_eq!(stats.dimension, 384);
562    }
563}