ruvector_core/
quantization.rs

1//! Quantization techniques for memory compression
2
3use crate::error::Result;
4use serde::{Deserialize, Serialize};
5
6/// Trait for quantized vector representations
7pub trait QuantizedVector: Send + Sync {
8    /// Quantize a full-precision vector
9    fn quantize(vector: &[f32]) -> Self;
10
11    /// Calculate distance to another quantized vector
12    fn distance(&self, other: &Self) -> f32;
13
14    /// Reconstruct approximate full-precision vector
15    fn reconstruct(&self) -> Vec<f32>;
16}
17
18/// Scalar quantization to int8 (4x compression)
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScalarQuantized {
21    /// Quantized values (int8)
22    pub data: Vec<u8>,
23    /// Minimum value for dequantization
24    pub min: f32,
25    /// Scale factor for dequantization
26    pub scale: f32,
27}
28
29impl QuantizedVector for ScalarQuantized {
30    fn quantize(vector: &[f32]) -> Self {
31        let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
32        let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
33        let scale = (max - min) / 255.0;
34
35        let data = vector
36            .iter()
37            .map(|&v| ((v - min) / scale).round() as u8)
38            .collect();
39
40        Self { data, min, scale }
41    }
42
43    fn distance(&self, other: &Self) -> f32 {
44        // Fast int8 distance calculation
45        self.data
46            .iter()
47            .zip(&other.data)
48            .map(|(&a, &b)| {
49                let diff = a as i16 - b as i16;
50                (diff * diff) as f32
51            })
52            .sum::<f32>()
53            .sqrt()
54            * self.scale.max(other.scale)
55    }
56
57    fn reconstruct(&self) -> Vec<f32> {
58        self.data
59            .iter()
60            .map(|&v| self.min + (v as f32) * self.scale)
61            .collect()
62    }
63}
64
65/// Product quantization (8-16x compression)
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ProductQuantized {
68    /// Quantized codes (one per subspace)
69    pub codes: Vec<u8>,
70    /// Codebooks for each subspace
71    pub codebooks: Vec<Vec<Vec<f32>>>,
72}
73
74impl ProductQuantized {
75    /// Train product quantization on a set of vectors
76    pub fn train(
77        vectors: &[Vec<f32>],
78        num_subspaces: usize,
79        codebook_size: usize,
80        iterations: usize,
81    ) -> Result<Self> {
82        let dimensions = vectors[0].len();
83        let subspace_dim = dimensions / num_subspaces;
84
85        let mut codebooks = Vec::with_capacity(num_subspaces);
86
87        // Train codebook for each subspace using k-means
88        for subspace_idx in 0..num_subspaces {
89            let start = subspace_idx * subspace_dim;
90            let end = start + subspace_dim;
91
92            // Extract subspace vectors
93            let subspace_vectors: Vec<Vec<f32>> =
94                vectors.iter().map(|v| v[start..end].to_vec()).collect();
95
96            // Run k-means
97            let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
98            codebooks.push(codebook);
99        }
100
101        Ok(Self {
102            codes: vec![],
103            codebooks,
104        })
105    }
106
107    /// Quantize a vector using trained codebooks
108    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
109        let num_subspaces = self.codebooks.len();
110        let subspace_dim = vector.len() / num_subspaces;
111
112        let mut codes = Vec::with_capacity(num_subspaces);
113
114        for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
115            let start = subspace_idx * subspace_dim;
116            let end = start + subspace_dim;
117            let subvector = &vector[start..end];
118
119            // Find nearest centroid
120            let code = codebook
121                .iter()
122                .enumerate()
123                .min_by(|(_, a), (_, b)| {
124                    let dist_a = euclidean_squared(subvector, a);
125                    let dist_b = euclidean_squared(subvector, b);
126                    dist_a.partial_cmp(&dist_b).unwrap()
127                })
128                .map(|(idx, _)| idx as u8)
129                .unwrap_or(0);
130
131            codes.push(code);
132        }
133
134        codes
135    }
136}
137
138/// Binary quantization (32x compression)
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct BinaryQuantized {
141    /// Binary representation (1 bit per dimension, packed into bytes)
142    pub bits: Vec<u8>,
143    /// Number of dimensions
144    pub dimensions: usize,
145}
146
147impl QuantizedVector for BinaryQuantized {
148    fn quantize(vector: &[f32]) -> Self {
149        let dimensions = vector.len();
150        let num_bytes = (dimensions + 7) / 8;
151        let mut bits = vec![0u8; num_bytes];
152
153        for (i, &v) in vector.iter().enumerate() {
154            if v > 0.0 {
155                let byte_idx = i / 8;
156                let bit_idx = i % 8;
157                bits[byte_idx] |= 1 << bit_idx;
158            }
159        }
160
161        Self { bits, dimensions }
162    }
163
164    fn distance(&self, other: &Self) -> f32 {
165        // Hamming distance
166        let mut distance = 0u32;
167
168        for (&a, &b) in self.bits.iter().zip(&other.bits) {
169            distance += (a ^ b).count_ones();
170        }
171
172        distance as f32
173    }
174
175    fn reconstruct(&self) -> Vec<f32> {
176        let mut result = Vec::with_capacity(self.dimensions);
177
178        for i in 0..self.dimensions {
179            let byte_idx = i / 8;
180            let bit_idx = i % 8;
181            let bit = (self.bits[byte_idx] >> bit_idx) & 1;
182            result.push(if bit == 1 { 1.0 } else { -1.0 });
183        }
184
185        result
186    }
187}
188
189// Helper functions
190
191fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
192    a.iter()
193        .zip(b)
194        .map(|(&x, &y)| {
195            let diff = x - y;
196            diff * diff
197        })
198        .sum()
199}
200
201fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
202    use rand::seq::SliceRandom;
203    use rand::thread_rng;
204
205    let mut rng = thread_rng();
206
207    // Initialize centroids randomly
208    let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
209
210    for _ in 0..iterations {
211        // Assign vectors to nearest centroid
212        let mut assignments = vec![Vec::new(); k];
213
214        for vector in vectors {
215            let nearest = centroids
216                .iter()
217                .enumerate()
218                .min_by(|(_, a), (_, b)| {
219                    let dist_a = euclidean_squared(vector, a);
220                    let dist_b = euclidean_squared(vector, b);
221                    dist_a.partial_cmp(&dist_b).unwrap()
222                })
223                .map(|(idx, _)| idx)
224                .unwrap_or(0);
225
226            assignments[nearest].push(vector.clone());
227        }
228
229        // Update centroids
230        for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
231            if !assigned.is_empty() {
232                let dim = centroid.len();
233                *centroid = vec![0.0; dim];
234
235                for vector in assigned {
236                    for (i, &v) in vector.iter().enumerate() {
237                        centroid[i] += v;
238                    }
239                }
240
241                let count = assigned.len() as f32;
242                for v in centroid.iter_mut() {
243                    *v /= count;
244                }
245            }
246        }
247    }
248
249    centroids
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_scalar_quantization() {
258        let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
259        let quantized = ScalarQuantized::quantize(&vector);
260        let reconstructed = quantized.reconstruct();
261
262        // Check approximate reconstruction
263        for (orig, recon) in vector.iter().zip(&reconstructed) {
264            assert!((orig - recon).abs() < 0.1);
265        }
266    }
267
268    #[test]
269    fn test_binary_quantization() {
270        let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
271        let quantized = BinaryQuantized::quantize(&vector);
272
273        assert_eq!(quantized.dimensions, 5);
274        assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
275    }
276
277    #[test]
278    fn test_binary_distance() {
279        let v1 = vec![1.0, 1.0, 1.0, 1.0];
280        let v2 = vec![1.0, 1.0, -1.0, -1.0];
281
282        let q1 = BinaryQuantized::quantize(&v1);
283        let q2 = BinaryQuantized::quantize(&v2);
284
285        let dist = q1.distance(&q2);
286        assert_eq!(dist, 2.0); // 2 bits differ
287    }
288}