Skip to main content

shodh_memory/vector_db/
pq.rs

1//! Product Quantization (PQ) for vector compression
2//!
3//! Compresses high-dimensional vectors by splitting them into subvectors
4//! and quantizing each subvector to its nearest centroid.
5//!
6//! For 384-dim MiniLM: 1536 bytes → 48 bytes (32x compression)
7//! For 768-dim CLIP: 3072 bytes → 96 bytes (32x compression)
8//!
9//! Trade-off: ~95% recall accuracy for 32x storage reduction
10
11use anyhow::{anyhow, Result};
12use rand::seq::SliceRandom;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Number of centroids per subspace (2^8 = 256, fits in u8)
17pub const NUM_CENTROIDS: usize = 256;
18
19/// Default subvector dimension (8 floats per subvector)
20pub const DEFAULT_SUBVEC_DIM: usize = 8;
21
22/// Product Quantization configuration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PQConfig {
25    /// Total vector dimension (e.g., 384 for MiniLM, 768 for CLIP)
26    pub dimension: usize,
27    /// Number of subvectors (dimension / subvec_dim)
28    pub num_subvectors: usize,
29    /// Dimension of each subvector
30    pub subvec_dim: usize,
31    /// Number of centroids per subspace (default 256)
32    pub num_centroids: usize,
33    /// Number of k-means iterations for training
34    pub kmeans_iterations: usize,
35}
36
37impl PQConfig {
38    /// Create PQ config for a given vector dimension
39    pub fn for_dimension(dimension: usize) -> Self {
40        let subvec_dim = DEFAULT_SUBVEC_DIM;
41        let num_subvectors = dimension / subvec_dim;
42
43        assert!(
44            dimension % subvec_dim == 0,
45            "Dimension {} must be divisible by subvec_dim {}",
46            dimension,
47            subvec_dim
48        );
49
50        Self {
51            dimension,
52            num_subvectors,
53            subvec_dim,
54            num_centroids: NUM_CENTROIDS,
55            kmeans_iterations: 20,
56        }
57    }
58
59    /// Create config for MiniLM embeddings (384 dims)
60    pub fn minilm() -> Self {
61        Self::for_dimension(384)
62    }
63
64    /// Create config for CLIP embeddings (768 dims)
65    pub fn clip() -> Self {
66        Self::for_dimension(768)
67    }
68}
69
70/// Trained Product Quantizer
71///
72/// Contains centroids for each subspace, learned from training data.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ProductQuantizer {
75    /// Configuration
76    pub config: PQConfig,
77    /// Centroids for each subspace: [num_subvectors][num_centroids][subvec_dim]
78    pub centroids: Vec<Vec<Vec<f32>>>,
79    /// Whether the quantizer has been trained
80    pub trained: bool,
81}
82
83impl ProductQuantizer {
84    /// Create a new untrained product quantizer
85    pub fn new(config: PQConfig) -> Self {
86        Self {
87            config,
88            centroids: Vec::new(),
89            trained: false,
90        }
91    }
92
93    /// Create and train a product quantizer on given vectors
94    pub fn train(config: PQConfig, training_vectors: &[Vec<f32>]) -> Result<Self> {
95        if training_vectors.is_empty() {
96            return Err(anyhow!("No training vectors provided"));
97        }
98
99        let first_dim = training_vectors[0].len();
100        if first_dim != config.dimension {
101            return Err(anyhow!(
102                "Training vector dimension {} doesn't match config {}",
103                first_dim,
104                config.dimension
105            ));
106        }
107
108        let mut pq = Self::new(config);
109        pq.fit(training_vectors)?;
110        Ok(pq)
111    }
112
113    /// Train centroids on a set of vectors using k-means
114    fn fit(&mut self, vectors: &[Vec<f32>]) -> Result<()> {
115        let n_vectors = vectors.len();
116        let n_subvectors = self.config.num_subvectors;
117        let subvec_dim = self.config.subvec_dim;
118        let n_centroids = self.config.num_centroids.min(n_vectors);
119        let iterations = self.config.kmeans_iterations;
120
121        tracing::info!(
122            "Training PQ: {} vectors, {} subvectors, {} centroids, {} iterations",
123            n_vectors,
124            n_subvectors,
125            n_centroids,
126            iterations
127        );
128
129        // Initialize centroids storage
130        self.centroids = Vec::with_capacity(n_subvectors);
131
132        // Train k-means for each subspace independently
133        for subvec_idx in 0..n_subvectors {
134            let start = subvec_idx * subvec_dim;
135            let end = start + subvec_dim;
136
137            // Extract subvectors for this subspace
138            let subvectors: Vec<Vec<f32>> =
139                vectors.iter().map(|v| v[start..end].to_vec()).collect();
140
141            // Run k-means clustering
142            let centroids = self.kmeans(&subvectors, n_centroids, iterations)?;
143            self.centroids.push(centroids);
144        }
145
146        self.trained = true;
147        tracing::info!("PQ training complete");
148        Ok(())
149    }
150
151    /// Simple k-means clustering
152    fn kmeans(&self, vectors: &[Vec<f32>], k: usize, iterations: usize) -> Result<Vec<Vec<f32>>> {
153        let dim = vectors[0].len();
154        let n = vectors.len();
155
156        // Initialize centroids by random sampling
157        let mut rng = rand::thread_rng();
158        let mut indices: Vec<usize> = (0..n).collect();
159        indices.shuffle(&mut rng);
160
161        let mut centroids: Vec<Vec<f32>> = indices
162            .iter()
163            .take(k)
164            .map(|&i| vectors[i].clone())
165            .collect();
166
167        // Pad with random vectors if needed (when n < k)
168        while centroids.len() < k {
169            let idx = indices[centroids.len() % n];
170            centroids.push(vectors[idx].clone());
171        }
172
173        let mut assignments = vec![0usize; n];
174
175        // K-means iterations
176        for _ in 0..iterations {
177            // Assign each vector to nearest centroid
178            for (i, vec) in vectors.iter().enumerate() {
179                let mut best_centroid = 0;
180                let mut best_dist = f32::MAX;
181
182                for (c, centroid) in centroids.iter().enumerate() {
183                    let dist = squared_l2_distance(vec, centroid);
184                    if dist < best_dist {
185                        best_dist = dist;
186                        best_centroid = c;
187                    }
188                }
189                assignments[i] = best_centroid;
190            }
191
192            // Update centroids
193            let mut new_centroids: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
194            let mut counts = vec![0usize; k];
195
196            for (i, vec) in vectors.iter().enumerate() {
197                let c = assignments[i];
198                counts[c] += 1;
199                for (j, &v) in vec.iter().enumerate() {
200                    new_centroids[c][j] += v;
201                }
202            }
203
204            // Average and handle empty clusters
205            for c in 0..k {
206                if counts[c] > 0 {
207                    for j in 0..dim {
208                        new_centroids[c][j] /= counts[c] as f32;
209                    }
210                    centroids[c] = new_centroids[c].clone();
211                }
212                // Keep old centroid if cluster is empty
213            }
214        }
215
216        Ok(centroids)
217    }
218
219    /// Encode a vector to PQ codes (one u8 per subvector)
220    pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
221        if !self.trained {
222            return Err(anyhow!("ProductQuantizer not trained"));
223        }
224
225        if vector.len() != self.config.dimension {
226            return Err(anyhow!(
227                "Vector dimension {} doesn't match config {}",
228                vector.len(),
229                self.config.dimension
230            ));
231        }
232
233        let mut codes = Vec::with_capacity(self.config.num_subvectors);
234        let subvec_dim = self.config.subvec_dim;
235
236        for (subvec_idx, subspace_centroids) in self.centroids.iter().enumerate() {
237            let start = subvec_idx * subvec_dim;
238            let end = start + subvec_dim;
239            let subvector = &vector[start..end];
240
241            // Find nearest centroid
242            let mut best_centroid = 0u8;
243            let mut best_dist = f32::MAX;
244
245            for (c, centroid) in subspace_centroids.iter().enumerate() {
246                let dist = squared_l2_distance_slice(subvector, centroid);
247                if dist < best_dist {
248                    best_dist = dist;
249                    best_centroid = c as u8;
250                }
251            }
252
253            codes.push(best_centroid);
254        }
255
256        Ok(codes)
257    }
258
259    /// Decode PQ codes back to approximate vector
260    pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
261        if !self.trained {
262            return Err(anyhow!("ProductQuantizer not trained"));
263        }
264
265        if codes.len() != self.config.num_subvectors {
266            return Err(anyhow!(
267                "Code length {} doesn't match num_subvectors {}",
268                codes.len(),
269                self.config.num_subvectors
270            ));
271        }
272
273        let mut vector = Vec::with_capacity(self.config.dimension);
274
275        for (subvec_idx, &code) in codes.iter().enumerate() {
276            let subspace = &self.centroids[subvec_idx];
277            let code_idx = code as usize;
278            if code_idx >= subspace.len() {
279                return Err(anyhow!(
280                    "PQ code {} out of bounds for subspace {} with {} centroids (data corruption?)",
281                    code_idx,
282                    subvec_idx,
283                    subspace.len()
284                ));
285            }
286            vector.extend_from_slice(&subspace[code_idx]);
287        }
288
289        Ok(vector)
290    }
291
292    /// Compute asymmetric distance between query vector and encoded vector
293    ///
294    /// ADC (Asymmetric Distance Computation) is more accurate than SDC
295    /// because the query is not quantized.
296    pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> Result<f32> {
297        if !self.trained {
298            return Err(anyhow!("ProductQuantizer not trained"));
299        }
300
301        let subvec_dim = self.config.subvec_dim;
302        let mut total_dist = 0.0f32;
303
304        for (subvec_idx, &code) in codes.iter().enumerate() {
305            let start = subvec_idx * subvec_dim;
306            let end = start + subvec_dim;
307            let query_subvec = &query[start..end];
308            let subspace = &self.centroids[subvec_idx];
309            let code_idx = code as usize;
310            if code_idx >= subspace.len() {
311                return Err(anyhow!(
312                    "PQ code {} out of bounds for subspace {} with {} centroids (data corruption?)",
313                    code_idx,
314                    subvec_idx,
315                    subspace.len()
316                ));
317            }
318
319            total_dist += squared_l2_distance_slice(query_subvec, &subspace[code_idx]);
320        }
321
322        Ok(total_dist)
323    }
324
325    /// Build distance lookup table for a query (ADC optimization)
326    ///
327    /// Pre-computes distances from each query subvector to all centroids.
328    /// This allows O(M) distance computation per encoded vector instead of O(D).
329    pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>> {
330        if !self.trained {
331            return Err(anyhow!("ProductQuantizer not trained"));
332        }
333
334        let subvec_dim = self.config.subvec_dim;
335        let n_centroids = self.config.num_centroids;
336        let mut table = Vec::with_capacity(self.config.num_subvectors);
337
338        for (subvec_idx, subspace_centroids) in self.centroids.iter().enumerate() {
339            let start = subvec_idx * subvec_dim;
340            let end = start + subvec_dim;
341            let query_subvec = &query[start..end];
342
343            let mut distances = Vec::with_capacity(n_centroids);
344            for centroid in subspace_centroids {
345                distances.push(squared_l2_distance_slice(query_subvec, centroid));
346            }
347            table.push(distances);
348        }
349
350        Ok(table)
351    }
352
353    /// Fast distance computation using pre-built lookup table
354    ///
355    /// Returns `f32::MAX` if a PQ code is out of bounds (corrupted data)
356    /// rather than panicking, since this is called in the hot search path.
357    #[inline]
358    pub fn distance_with_table(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
359        let mut total = 0.0f32;
360        for (subvec_idx, &code) in codes.iter().enumerate() {
361            let code_idx = code as usize;
362            if subvec_idx >= table.len() || code_idx >= table[subvec_idx].len() {
363                return f32::MAX; // Corrupted code — push to bottom of results
364            }
365            total += table[subvec_idx][code_idx];
366        }
367        total
368    }
369
370    /// Batch encode multiple vectors
371    pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Result<Vec<Vec<u8>>> {
372        vectors.iter().map(|v| self.encode(v)).collect()
373    }
374
375    /// Compressed size in bytes for one vector
376    pub fn compressed_size(&self) -> usize {
377        self.config.num_subvectors // One byte per subvector
378    }
379
380    /// Original size in bytes for one vector
381    pub fn original_size(&self) -> usize {
382        self.config.dimension * std::mem::size_of::<f32>()
383    }
384
385    /// Compression ratio
386    pub fn compression_ratio(&self) -> f32 {
387        self.original_size() as f32 / self.compressed_size() as f32
388    }
389}
390
391/// Squared L2 distance between two vectors
392#[inline]
393fn squared_l2_distance(a: &[f32], b: &[f32]) -> f32 {
394    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
395}
396
397/// Squared L2 distance for slices (same as above but clearer intent)
398#[inline]
399fn squared_l2_distance_slice(a: &[f32], b: &[f32]) -> f32 {
400    squared_l2_distance(a, b)
401}
402
403/// Compressed vector storage for PQ-encoded vectors
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct CompressedVectorStore {
406    /// The trained quantizer
407    pub quantizer: ProductQuantizer,
408    /// Encoded vectors: vector_id -> PQ codes
409    pub codes: HashMap<u32, Vec<u8>>,
410}
411
412impl CompressedVectorStore {
413    /// Create a new compressed vector store
414    pub fn new(quantizer: ProductQuantizer) -> Self {
415        Self {
416            quantizer,
417            codes: HashMap::new(),
418        }
419    }
420
421    /// Train quantizer and create store from training vectors
422    pub fn train_and_create(config: PQConfig, training_vectors: &[Vec<f32>]) -> Result<Self> {
423        let quantizer = ProductQuantizer::train(config, training_vectors)?;
424        Ok(Self::new(quantizer))
425    }
426
427    /// Add a vector to the store
428    pub fn add(&mut self, vector_id: u32, vector: &[f32]) -> Result<()> {
429        let codes = self.quantizer.encode(vector)?;
430        self.codes.insert(vector_id, codes);
431        Ok(())
432    }
433
434    /// Get PQ codes for a vector
435    pub fn get_codes(&self, vector_id: u32) -> Option<&Vec<u8>> {
436        self.codes.get(&vector_id)
437    }
438
439    /// Decode a vector back to approximate floats
440    pub fn decode(&self, vector_id: u32) -> Result<Vec<f32>> {
441        let codes = self
442            .codes
443            .get(&vector_id)
444            .ok_or_else(|| anyhow!("Vector {} not found", vector_id))?;
445        self.quantizer.decode(codes)
446    }
447
448    /// Search for k nearest neighbors using PQ distance
449    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
450        // Build distance table for fast lookup
451        let table = self.quantizer.build_distance_table(query)?;
452
453        // Compute distances to all vectors
454        let mut distances: Vec<(u32, f32)> = self
455            .codes
456            .iter()
457            .map(|(&id, codes)| (id, self.quantizer.distance_with_table(&table, codes)))
458            .collect();
459
460        // Sort by distance and take top k
461        distances.sort_by(|a, b| a.1.total_cmp(&b.1));
462        distances.truncate(k);
463
464        Ok(distances)
465    }
466
467    /// Number of vectors in store
468    pub fn len(&self) -> usize {
469        self.codes.len()
470    }
471
472    /// Check if store is empty
473    pub fn is_empty(&self) -> bool {
474        self.codes.is_empty()
475    }
476
477    /// Total compressed storage size in bytes
478    pub fn storage_bytes(&self) -> usize {
479        self.codes.len() * self.quantizer.compressed_size()
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
488        use rand::Rng;
489        let mut rng = rand::thread_rng();
490        (0..n)
491            .map(|_| (0..dim).map(|_| rng.gen::<f32>()).collect())
492            .collect()
493    }
494
495    #[test]
496    fn test_pq_encode_decode() {
497        let vectors = generate_random_vectors(1000, 384);
498        let config = PQConfig::minilm();
499        let pq = ProductQuantizer::train(config, &vectors).unwrap();
500
501        // Test encode/decode
502        let original = &vectors[0];
503        let codes = pq.encode(original).unwrap();
504        let decoded = pq.decode(&codes).unwrap();
505
506        // Check dimensions
507        assert_eq!(codes.len(), 48); // 384 / 8 = 48 subvectors
508        assert_eq!(decoded.len(), 384);
509
510        // Decoded should be close to original (not exact due to quantization)
511        let mse: f32 = original
512            .iter()
513            .zip(decoded.iter())
514            .map(|(a, b)| (a - b).powi(2))
515            .sum::<f32>()
516            / 384.0;
517
518        // MSE should be reasonably low
519        assert!(mse < 0.1, "MSE too high: {}", mse);
520    }
521
522    #[test]
523    fn test_compression_ratio() {
524        let config = PQConfig::minilm();
525        let pq = ProductQuantizer::new(config);
526
527        assert_eq!(pq.original_size(), 384 * 4); // 1536 bytes
528        assert_eq!(pq.compressed_size(), 48); // 48 bytes
529        assert!((pq.compression_ratio() - 32.0).abs() < 0.01);
530    }
531
532    #[test]
533    fn test_distance_table() {
534        let vectors = generate_random_vectors(100, 384);
535        let config = PQConfig::minilm();
536        let pq = ProductQuantizer::train(config, &vectors).unwrap();
537
538        let query = &vectors[0];
539        let codes = pq.encode(&vectors[1]).unwrap();
540
541        // Direct distance
542        let direct_dist = pq.asymmetric_distance(query, &codes).unwrap();
543
544        // Table-based distance
545        let table = pq.build_distance_table(query).unwrap();
546        let table_dist = pq.distance_with_table(&table, &codes);
547
548        // Should be identical
549        assert!((direct_dist - table_dist).abs() < 1e-6);
550    }
551
552    #[test]
553    fn test_compressed_store_search() {
554        let vectors = generate_random_vectors(1000, 384);
555        let config = PQConfig::minilm();
556
557        let mut store = CompressedVectorStore::train_and_create(config, &vectors).unwrap();
558
559        // Add all vectors
560        for (i, v) in vectors.iter().enumerate() {
561            store.add(i as u32, v).unwrap();
562        }
563
564        // Search
565        let results = store.search(&vectors[0], 10).unwrap();
566
567        // First result should be the query itself or very close vector
568        // Note: PQ has quantization error, so distance won't be exactly 0
569        assert_eq!(results.len(), 10);
570        // The query should be among the top results (within first few)
571        let query_in_top_results = results.iter().take(5).any(|(id, _)| *id == 0);
572        assert!(
573            query_in_top_results,
574            "Query vector not found in top 5 results"
575        );
576    }
577}