ruvector_core/
quantization.rs

1//! Quantization techniques for memory compression
2
3use crate::error::Result;
4use serde::{Deserialize, Serialize};
5
6/// Trait for quantized vector representations
7pub trait QuantizedVector: Send + Sync {
8    /// Quantize a full-precision vector
9    fn quantize(vector: &[f32]) -> Self;
10
11    /// Calculate distance to another quantized vector
12    fn distance(&self, other: &Self) -> f32;
13
14    /// Reconstruct approximate full-precision vector
15    fn reconstruct(&self) -> Vec<f32>;
16}
17
18/// Scalar quantization to int8 (4x compression)
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScalarQuantized {
21    /// Quantized values (int8)
22    pub data: Vec<u8>,
23    /// Minimum value for dequantization
24    pub min: f32,
25    /// Scale factor for dequantization
26    pub scale: f32,
27}
28
29impl QuantizedVector for ScalarQuantized {
30    fn quantize(vector: &[f32]) -> Self {
31        let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
32        let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
33
34        // Handle edge case where all values are the same (scale = 0)
35        let scale = if (max - min).abs() < f32::EPSILON {
36            1.0 // Arbitrary non-zero scale when all values are identical
37        } else {
38            (max - min) / 255.0
39        };
40
41        let data = vector
42            .iter()
43            .map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
44            .collect();
45
46        Self { data, min, scale }
47    }
48
49    fn distance(&self, other: &Self) -> f32 {
50        // Fast int8 distance calculation
51        // Use i32 to avoid overflow: max diff is 255, and 255*255=65025 fits in i32
52
53        // Scale handling: We use the average of both scales for balanced comparison.
54        // Using max(scale) would bias toward the vector with larger range,
55        // while average provides a more symmetric distance metric.
56        // This ensures distance(a, b) ≈ distance(b, a) in the reconstructed space.
57        let avg_scale = (self.scale + other.scale) / 2.0;
58
59        self.data
60            .iter()
61            .zip(&other.data)
62            .map(|(&a, &b)| {
63                let diff = a as i32 - b as i32;
64                (diff * diff) as f32
65            })
66            .sum::<f32>()
67            .sqrt()
68            * avg_scale
69    }
70
71    fn reconstruct(&self) -> Vec<f32> {
72        self.data
73            .iter()
74            .map(|&v| self.min + (v as f32) * self.scale)
75            .collect()
76    }
77}
78
79/// Product quantization (8-16x compression)
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ProductQuantized {
82    /// Quantized codes (one per subspace)
83    pub codes: Vec<u8>,
84    /// Codebooks for each subspace
85    pub codebooks: Vec<Vec<Vec<f32>>>,
86}
87
88impl ProductQuantized {
89    /// Train product quantization on a set of vectors
90    pub fn train(
91        vectors: &[Vec<f32>],
92        num_subspaces: usize,
93        codebook_size: usize,
94        iterations: usize,
95    ) -> Result<Self> {
96        if vectors.is_empty() {
97            return Err(crate::error::RuvectorError::InvalidInput(
98                "Cannot train on empty vector set".into(),
99            ));
100        }
101        if vectors[0].is_empty() {
102            return Err(crate::error::RuvectorError::InvalidInput(
103                "Cannot train on vectors with zero dimensions".into(),
104            ));
105        }
106        if codebook_size > 256 {
107            return Err(crate::error::RuvectorError::InvalidParameter(format!(
108                "Codebook size {} exceeds u8 maximum of 256",
109                codebook_size
110            )));
111        }
112        let dimensions = vectors[0].len();
113        let subspace_dim = dimensions / num_subspaces;
114
115        let mut codebooks = Vec::with_capacity(num_subspaces);
116
117        // Train codebook for each subspace using k-means
118        for subspace_idx in 0..num_subspaces {
119            let start = subspace_idx * subspace_dim;
120            let end = start + subspace_dim;
121
122            // Extract subspace vectors
123            let subspace_vectors: Vec<Vec<f32>> =
124                vectors.iter().map(|v| v[start..end].to_vec()).collect();
125
126            // Run k-means
127            let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
128            codebooks.push(codebook);
129        }
130
131        Ok(Self {
132            codes: vec![],
133            codebooks,
134        })
135    }
136
137    /// Quantize a vector using trained codebooks
138    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
139        let num_subspaces = self.codebooks.len();
140        let subspace_dim = vector.len() / num_subspaces;
141
142        let mut codes = Vec::with_capacity(num_subspaces);
143
144        for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
145            let start = subspace_idx * subspace_dim;
146            let end = start + subspace_dim;
147            let subvector = &vector[start..end];
148
149            // Find nearest centroid
150            let code = codebook
151                .iter()
152                .enumerate()
153                .min_by(|(_, a), (_, b)| {
154                    let dist_a = euclidean_squared(subvector, a);
155                    let dist_b = euclidean_squared(subvector, b);
156                    dist_a.partial_cmp(&dist_b).unwrap()
157                })
158                .map(|(idx, _)| idx as u8)
159                .unwrap_or(0);
160
161            codes.push(code);
162        }
163
164        codes
165    }
166}
167
168/// Binary quantization (32x compression)
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct BinaryQuantized {
171    /// Binary representation (1 bit per dimension, packed into bytes)
172    pub bits: Vec<u8>,
173    /// Number of dimensions
174    pub dimensions: usize,
175}
176
177impl QuantizedVector for BinaryQuantized {
178    fn quantize(vector: &[f32]) -> Self {
179        let dimensions = vector.len();
180        let num_bytes = (dimensions + 7) / 8;
181        let mut bits = vec![0u8; num_bytes];
182
183        for (i, &v) in vector.iter().enumerate() {
184            if v > 0.0 {
185                let byte_idx = i / 8;
186                let bit_idx = i % 8;
187                bits[byte_idx] |= 1 << bit_idx;
188            }
189        }
190
191        Self { bits, dimensions }
192    }
193
194    fn distance(&self, other: &Self) -> f32 {
195        // Hamming distance
196        let mut distance = 0u32;
197
198        for (&a, &b) in self.bits.iter().zip(&other.bits) {
199            distance += (a ^ b).count_ones();
200        }
201
202        distance as f32
203    }
204
205    fn reconstruct(&self) -> Vec<f32> {
206        let mut result = Vec::with_capacity(self.dimensions);
207
208        for i in 0..self.dimensions {
209            let byte_idx = i / 8;
210            let bit_idx = i % 8;
211            let bit = (self.bits[byte_idx] >> bit_idx) & 1;
212            result.push(if bit == 1 { 1.0 } else { -1.0 });
213        }
214
215        result
216    }
217}
218
219// Helper functions
220
221fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
222    a.iter()
223        .zip(b)
224        .map(|(&x, &y)| {
225            let diff = x - y;
226            diff * diff
227        })
228        .sum()
229}
230
231fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
232    use rand::seq::SliceRandom;
233    use rand::thread_rng;
234
235    let mut rng = thread_rng();
236
237    // Initialize centroids randomly
238    let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
239
240    for _ in 0..iterations {
241        // Assign vectors to nearest centroid
242        let mut assignments = vec![Vec::new(); k];
243
244        for vector in vectors {
245            let nearest = centroids
246                .iter()
247                .enumerate()
248                .min_by(|(_, a), (_, b)| {
249                    let dist_a = euclidean_squared(vector, a);
250                    let dist_b = euclidean_squared(vector, b);
251                    dist_a.partial_cmp(&dist_b).unwrap()
252                })
253                .map(|(idx, _)| idx)
254                .unwrap_or(0);
255
256            assignments[nearest].push(vector.clone());
257        }
258
259        // Update centroids
260        for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
261            if !assigned.is_empty() {
262                let dim = centroid.len();
263                *centroid = vec![0.0; dim];
264
265                for vector in assigned {
266                    for (i, &v) in vector.iter().enumerate() {
267                        centroid[i] += v;
268                    }
269                }
270
271                let count = assigned.len() as f32;
272                for v in centroid.iter_mut() {
273                    *v /= count;
274                }
275            }
276        }
277    }
278
279    centroids
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_scalar_quantization() {
288        let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
289        let quantized = ScalarQuantized::quantize(&vector);
290        let reconstructed = quantized.reconstruct();
291
292        // Check approximate reconstruction
293        for (orig, recon) in vector.iter().zip(&reconstructed) {
294            assert!((orig - recon).abs() < 0.1);
295        }
296    }
297
298    #[test]
299    fn test_binary_quantization() {
300        let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
301        let quantized = BinaryQuantized::quantize(&vector);
302
303        assert_eq!(quantized.dimensions, 5);
304        assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
305    }
306
307    #[test]
308    fn test_binary_distance() {
309        let v1 = vec![1.0, 1.0, 1.0, 1.0];
310        let v2 = vec![1.0, 1.0, -1.0, -1.0];
311
312        let q1 = BinaryQuantized::quantize(&v1);
313        let q2 = BinaryQuantized::quantize(&v2);
314
315        let dist = q1.distance(&q2);
316        assert_eq!(dist, 2.0); // 2 bits differ
317    }
318
319    #[test]
320    fn test_scalar_quantization_roundtrip() {
321        // Test that quantize -> reconstruct produces values close to original
322        let test_vectors = vec![
323            vec![1.0, 2.0, 3.0, 4.0, 5.0],
324            vec![-10.0, -5.0, 0.0, 5.0, 10.0],
325            vec![0.1, 0.2, 0.3, 0.4, 0.5],
326            vec![100.0, 200.0, 300.0, 400.0, 500.0],
327        ];
328
329        for vector in test_vectors {
330            let quantized = ScalarQuantized::quantize(&vector);
331            let reconstructed = quantized.reconstruct();
332
333            assert_eq!(vector.len(), reconstructed.len());
334
335            for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
336                // With 8-bit quantization, max error is roughly (max-min)/255
337                let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
338                let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
339                let max_error = (max - min) / 255.0 * 2.0; // Allow 2x for rounding
340
341                assert!(
342                    (orig - recon).abs() < max_error,
343                    "Roundtrip error too large: orig={}, recon={}, error={}",
344                    orig,
345                    recon,
346                    (orig - recon).abs()
347                );
348            }
349        }
350    }
351
352    #[test]
353    fn test_scalar_distance_symmetry() {
354        // Test that distance(a, b) == distance(b, a)
355        let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
356        let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
357
358        let q1 = ScalarQuantized::quantize(&v1);
359        let q2 = ScalarQuantized::quantize(&v2);
360
361        let dist_ab = q1.distance(&q2);
362        let dist_ba = q2.distance(&q1);
363
364        // Distance should be symmetric (within floating point precision)
365        assert!(
366            (dist_ab - dist_ba).abs() < 0.01,
367            "Distance is not symmetric: d(a,b)={}, d(b,a)={}",
368            dist_ab,
369            dist_ba
370        );
371    }
372
373    #[test]
374    fn test_scalar_distance_different_scales() {
375        // Test distance calculation with vectors that have different scales
376        let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // range: 4.0
377        let v2 = vec![10.0, 20.0, 30.0, 40.0, 50.0]; // range: 40.0
378
379        let q1 = ScalarQuantized::quantize(&v1);
380        let q2 = ScalarQuantized::quantize(&v2);
381
382        let dist_ab = q1.distance(&q2);
383        let dist_ba = q2.distance(&q1);
384
385        // With average scaling, symmetry should be maintained
386        assert!(
387            (dist_ab - dist_ba).abs() < 0.01,
388            "Distance with different scales not symmetric: d(a,b)={}, d(b,a)={}",
389            dist_ab,
390            dist_ba
391        );
392    }
393
394    #[test]
395    fn test_scalar_quantization_edge_cases() {
396        // Test with all same values
397        let same_values = vec![5.0, 5.0, 5.0, 5.0];
398        let quantized = ScalarQuantized::quantize(&same_values);
399        let reconstructed = quantized.reconstruct();
400
401        for (orig, recon) in same_values.iter().zip(reconstructed.iter()) {
402            assert!((orig - recon).abs() < 0.01);
403        }
404
405        // Test with extreme ranges
406        let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
407        let quantized = ScalarQuantized::quantize(&extreme);
408        let reconstructed = quantized.reconstruct();
409
410        assert_eq!(extreme.len(), reconstructed.len());
411    }
412
413    #[test]
414    fn test_binary_distance_symmetry() {
415        // Test that binary distance is symmetric
416        let v1 = vec![1.0, -1.0, 1.0, -1.0];
417        let v2 = vec![1.0, 1.0, -1.0, -1.0];
418
419        let q1 = BinaryQuantized::quantize(&v1);
420        let q2 = BinaryQuantized::quantize(&v2);
421
422        let dist_ab = q1.distance(&q2);
423        let dist_ba = q2.distance(&q1);
424
425        assert_eq!(
426            dist_ab, dist_ba,
427            "Binary distance not symmetric: d(a,b)={}, d(b,a)={}",
428            dist_ab, dist_ba
429        );
430    }
431}