Skip to main content

ruvector_router_core/
quantization.rs

1//! Quantization techniques for memory compression
2
3use crate::error::{Result, VectorDbError};
4use crate::types::QuantizationType;
5use serde::{Deserialize, Serialize};
6
7/// Quantized vector representation
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum QuantizedVector {
10    /// No quantization - full precision float32
11    None(Vec<f32>),
12    /// Scalar quantization to int8
13    Scalar {
14        /// Quantized values
15        data: Vec<u8>,
16        /// Minimum value for dequantization
17        min: f32,
18        /// Scale factor for dequantization
19        scale: f32,
20    },
21    /// Product quantization
22    Product {
23        /// Codebook indices
24        codes: Vec<u8>,
25        /// Number of subspaces
26        subspaces: usize,
27    },
28    /// Binary quantization (1 bit per dimension)
29    Binary {
30        /// Packed binary data
31        data: Vec<u8>,
32        /// Threshold value
33        threshold: f32,
34        /// Number of original dimensions
35        dimensions: usize,
36    },
37}
38
39/// Quantize a vector using specified quantization type
40pub fn quantize(vector: &[f32], qtype: QuantizationType) -> Result<QuantizedVector> {
41    match qtype {
42        QuantizationType::None => Ok(QuantizedVector::None(vector.to_vec())),
43        QuantizationType::Scalar => Ok(scalar_quantize(vector)),
44        QuantizationType::Product { subspaces, k } => product_quantize(vector, subspaces, k),
45        QuantizationType::Binary => Ok(binary_quantize(vector)),
46    }
47}
48
49/// Dequantize a quantized vector back to float32
50pub fn dequantize(quantized: &QuantizedVector) -> Vec<f32> {
51    match quantized {
52        QuantizedVector::None(v) => v.clone(),
53        QuantizedVector::Scalar { data, min, scale } => scalar_dequantize(data, *min, *scale),
54        QuantizedVector::Product { codes, subspaces } => {
55            // Placeholder - would need codebooks stored separately
56            vec![0.0; codes.len() * (codes.len() / subspaces)]
57        }
58        QuantizedVector::Binary {
59            data,
60            threshold,
61            dimensions,
62        } => binary_dequantize(data, *threshold, *dimensions),
63    }
64}
65
66/// Scalar quantization to int8
67fn scalar_quantize(vector: &[f32]) -> QuantizedVector {
68    let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
69    let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
70
71    let scale = if max > min { 255.0 / (max - min) } else { 1.0 };
72
73    let data: Vec<u8> = vector
74        .iter()
75        .map(|&v| ((v - min) * scale).clamp(0.0, 255.0) as u8)
76        .collect();
77
78    QuantizedVector::Scalar { data, min, scale }
79}
80
81/// Dequantize scalar quantized vector
82fn scalar_dequantize(data: &[u8], min: f32, scale: f32) -> Vec<f32> {
83    // CRITICAL FIX: During quantization, we compute: quantized = (value - min) * scale
84    // where scale = 255.0 / (max - min)
85    // Therefore, dequantization must be: value = quantized / scale + min
86    // which simplifies to: value = min + quantized * (max - min) / 255.0
87    // Since scale = 255.0 / (max - min), then 1/scale = (max - min) / 255.0
88    // So the correct formula is: value = min + quantized / scale
89    data.iter().map(|&v| min + (v as f32) / scale).collect()
90}
91
92/// Product quantization (simplified version)
93fn product_quantize(vector: &[f32], subspaces: usize, _k: usize) -> Result<QuantizedVector> {
94    if !vector.len().is_multiple_of(subspaces) {
95        return Err(VectorDbError::Quantization(
96            "Vector length must be divisible by number of subspaces".to_string(),
97        ));
98    }
99
100    // Simplified: just store subspace indices
101    // In production, this would involve k-means clustering per subspace
102    let subspace_dim = vector.len() / subspaces;
103    let codes: Vec<u8> = (0..subspaces)
104        .map(|i| {
105            let start = i * subspace_dim;
106            let subvec = &vector[start..start + subspace_dim];
107            // Placeholder: hash to a code (0-255)
108            (subvec.iter().sum::<f32>() as u32 % 256) as u8
109        })
110        .collect();
111
112    Ok(QuantizedVector::Product { codes, subspaces })
113}
114
115/// Binary quantization (1 bit per dimension)
116fn binary_quantize(vector: &[f32]) -> QuantizedVector {
117    let threshold = vector.iter().sum::<f32>() / vector.len() as f32;
118    let dimensions = vector.len();
119
120    let num_bytes = dimensions.div_ceil(8);
121    let mut data = vec![0u8; num_bytes];
122
123    for (i, &val) in vector.iter().enumerate() {
124        if val > threshold {
125            let byte_idx = i / 8;
126            let bit_idx = i % 8;
127            data[byte_idx] |= 1 << bit_idx;
128        }
129    }
130
131    QuantizedVector::Binary {
132        data,
133        threshold,
134        dimensions,
135    }
136}
137
138/// Dequantize binary quantized vector
139fn binary_dequantize(data: &[u8], threshold: f32, dimensions: usize) -> Vec<f32> {
140    let mut result = Vec::with_capacity(dimensions);
141
142    for (i, &byte) in data.iter().enumerate() {
143        for bit_idx in 0..8 {
144            if result.len() >= dimensions {
145                break;
146            }
147            let bit = (byte >> bit_idx) & 1;
148            result.push(if bit == 1 {
149                threshold + 1.0
150            } else {
151                threshold - 1.0
152            });
153        }
154        if result.len() >= dimensions {
155            break;
156        }
157    }
158
159    result
160}
161
162/// Calculate memory savings from quantization
163pub fn calculate_compression_ratio(original_dims: usize, qtype: QuantizationType) -> f32 {
164    let original_bytes = original_dims * 4; // float32 = 4 bytes
165    let quantized_bytes = match qtype {
166        QuantizationType::None => original_bytes,
167        QuantizationType::Scalar => original_dims + 8, // u8 per dim + min + scale
168        QuantizationType::Product { subspaces, .. } => subspaces + 4, // u8 per subspace + overhead
169        QuantizationType::Binary => original_dims.div_ceil(8) + 4, // 1 bit per dim + threshold
170    };
171
172    original_bytes as f32 / quantized_bytes as f32
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_scalar_quantization() {
181        let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
182        let quantized = scalar_quantize(&vector);
183        let dequantized = dequantize(&quantized);
184
185        // Check approximate equality (quantization loses precision)
186        for (orig, deq) in vector.iter().zip(dequantized.iter()) {
187            assert!((orig - deq).abs() < 0.1);
188        }
189    }
190
191    #[test]
192    fn test_binary_quantization() {
193        let vector = vec![1.0, 5.0, 2.0, 8.0, 3.0];
194        let quantized = binary_quantize(&vector);
195
196        match quantized {
197            QuantizedVector::Binary {
198                data, dimensions, ..
199            } => {
200                assert!(!data.is_empty());
201                assert_eq!(dimensions, 5);
202            }
203            _ => panic!("Expected binary quantization"),
204        }
205    }
206
207    #[test]
208    fn test_compression_ratio() {
209        let ratio = calculate_compression_ratio(384, QuantizationType::Scalar);
210        assert!(ratio > 3.0); // Should be close to 4x
211
212        let ratio = calculate_compression_ratio(384, QuantizationType::Binary);
213        assert!(ratio > 20.0); // Should be close to 32x
214    }
215
216    #[test]
217    fn test_scalar_quantization_roundtrip() {
218        // Test that quantize -> dequantize produces values close to original
219        let test_vectors = vec![
220            vec![1.0, 2.0, 3.0, 4.0, 5.0],
221            vec![-10.0, -5.0, 0.0, 5.0, 10.0],
222            vec![0.1, 0.2, 0.3, 0.4, 0.5],
223            vec![100.0, 200.0, 300.0, 400.0, 500.0],
224        ];
225
226        for vector in test_vectors {
227            let quantized = scalar_quantize(&vector);
228            let dequantized = dequantize(&quantized);
229
230            assert_eq!(vector.len(), dequantized.len());
231
232            for (orig, deq) in vector.iter().zip(dequantized.iter()) {
233                // With 8-bit quantization, max error is roughly (max-min)/255
234                let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
235                let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
236                let max_error = (max - min) / 255.0 * 2.0; // Allow 2x for rounding
237
238                assert!(
239                    (orig - deq).abs() < max_error,
240                    "Roundtrip error too large: orig={}, deq={}, error={}",
241                    orig,
242                    deq,
243                    (orig - deq).abs()
244                );
245            }
246        }
247    }
248
249    #[test]
250    fn test_scalar_quantization_edge_cases() {
251        // Test with all same values
252        let same_values = vec![5.0, 5.0, 5.0, 5.0];
253        let quantized = scalar_quantize(&same_values);
254        let dequantized = dequantize(&quantized);
255
256        for (orig, deq) in same_values.iter().zip(dequantized.iter()) {
257            assert!((orig - deq).abs() < 0.01);
258        }
259
260        // Test with extreme ranges
261        let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
262        let quantized = scalar_quantize(&extreme);
263        let dequantized = dequantize(&quantized);
264
265        assert_eq!(extreme.len(), dequantized.len());
266    }
267
268    #[test]
269    fn test_binary_quantization_roundtrip() {
270        let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
271        let quantized = binary_quantize(&vector);
272        let dequantized = dequantize(&quantized);
273
274        // Binary quantization doesn't preserve exact values,
275        // but should preserve the sign relative to threshold
276        assert_eq!(
277            vector.len(),
278            dequantized.len(),
279            "Dequantized vector should have same length as original"
280        );
281
282        match quantized {
283            QuantizedVector::Binary {
284                threshold,
285                dimensions,
286                ..
287            } => {
288                assert_eq!(dimensions, vector.len());
289                for (orig, deq) in vector.iter().zip(dequantized.iter()) {
290                    // Check that both have same relationship to threshold
291                    let orig_above = orig > &threshold;
292                    let deq_above = deq > &threshold;
293                    assert_eq!(orig_above, deq_above);
294                }
295            }
296            _ => panic!("Expected binary quantization"),
297        }
298    }
299}