ruvector_core/
quantization.rs

1//! Quantization techniques for memory compression
2
3use crate::error::Result;
4use serde::{Deserialize, Serialize};
5
6/// Trait for quantized vector representations
7pub trait QuantizedVector: Send + Sync {
8    /// Quantize a full-precision vector
9    fn quantize(vector: &[f32]) -> Self;
10
11    /// Calculate distance to another quantized vector
12    fn distance(&self, other: &Self) -> f32;
13
14    /// Reconstruct approximate full-precision vector
15    fn reconstruct(&self) -> Vec<f32>;
16}
17
18/// Scalar quantization to int8 (4x compression)
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScalarQuantized {
21    /// Quantized values (int8)
22    pub data: Vec<u8>,
23    /// Minimum value for dequantization
24    pub min: f32,
25    /// Scale factor for dequantization
26    pub scale: f32,
27}
28
29impl QuantizedVector for ScalarQuantized {
30    fn quantize(vector: &[f32]) -> Self {
31        let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
32        let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
33        let scale = (max - min) / 255.0;
34
35        let data = vector
36            .iter()
37            .map(|&v| ((v - min) / scale).round() as u8)
38            .collect();
39
40        Self { data, min, scale }
41    }
42
43    fn distance(&self, other: &Self) -> f32 {
44        // Fast int8 distance calculation
45        self.data
46            .iter()
47            .zip(&other.data)
48            .map(|(&a, &b)| {
49                let diff = a as i16 - b as i16;
50                (diff * diff) as f32
51            })
52            .sum::<f32>()
53            .sqrt()
54            * self.scale.max(other.scale)
55    }
56
57    fn reconstruct(&self) -> Vec<f32> {
58        self.data
59            .iter()
60            .map(|&v| self.min + (v as f32) * self.scale)
61            .collect()
62    }
63}
64
65/// Product quantization (8-16x compression)
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ProductQuantized {
68    /// Quantized codes (one per subspace)
69    pub codes: Vec<u8>,
70    /// Codebooks for each subspace
71    pub codebooks: Vec<Vec<Vec<f32>>>,
72}
73
74impl ProductQuantized {
75    /// Train product quantization on a set of vectors
76    pub fn train(
77        vectors: &[Vec<f32>],
78        num_subspaces: usize,
79        codebook_size: usize,
80        iterations: usize,
81    ) -> Result<Self> {
82        if vectors.is_empty() {
83            return Err(crate::error::RuvectorError::InvalidInput(
84                "Cannot train on empty vector set".into(),
85            ));
86        }
87        if vectors[0].is_empty() {
88            return Err(crate::error::RuvectorError::InvalidInput(
89                "Cannot train on vectors with zero dimensions".into(),
90            ));
91        }
92        if codebook_size > 256 {
93            return Err(crate::error::RuvectorError::InvalidParameter(
94                format!("Codebook size {} exceeds u8 maximum of 256", codebook_size),
95            ));
96        }
97        let dimensions = vectors[0].len();
98        let subspace_dim = dimensions / num_subspaces;
99
100        let mut codebooks = Vec::with_capacity(num_subspaces);
101
102        // Train codebook for each subspace using k-means
103        for subspace_idx in 0..num_subspaces {
104            let start = subspace_idx * subspace_dim;
105            let end = start + subspace_dim;
106
107            // Extract subspace vectors
108            let subspace_vectors: Vec<Vec<f32>> =
109                vectors.iter().map(|v| v[start..end].to_vec()).collect();
110
111            // Run k-means
112            let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
113            codebooks.push(codebook);
114        }
115
116        Ok(Self {
117            codes: vec![],
118            codebooks,
119        })
120    }
121
122    /// Quantize a vector using trained codebooks
123    pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
124        let num_subspaces = self.codebooks.len();
125        let subspace_dim = vector.len() / num_subspaces;
126
127        let mut codes = Vec::with_capacity(num_subspaces);
128
129        for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
130            let start = subspace_idx * subspace_dim;
131            let end = start + subspace_dim;
132            let subvector = &vector[start..end];
133
134            // Find nearest centroid
135            let code = codebook
136                .iter()
137                .enumerate()
138                .min_by(|(_, a), (_, b)| {
139                    let dist_a = euclidean_squared(subvector, a);
140                    let dist_b = euclidean_squared(subvector, b);
141                    dist_a.partial_cmp(&dist_b).unwrap()
142                })
143                .map(|(idx, _)| idx as u8)
144                .unwrap_or(0);
145
146            codes.push(code);
147        }
148
149        codes
150    }
151}
152
153/// Binary quantization (32x compression)
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct BinaryQuantized {
156    /// Binary representation (1 bit per dimension, packed into bytes)
157    pub bits: Vec<u8>,
158    /// Number of dimensions
159    pub dimensions: usize,
160}
161
162impl QuantizedVector for BinaryQuantized {
163    fn quantize(vector: &[f32]) -> Self {
164        let dimensions = vector.len();
165        let num_bytes = (dimensions + 7) / 8;
166        let mut bits = vec![0u8; num_bytes];
167
168        for (i, &v) in vector.iter().enumerate() {
169            if v > 0.0 {
170                let byte_idx = i / 8;
171                let bit_idx = i % 8;
172                bits[byte_idx] |= 1 << bit_idx;
173            }
174        }
175
176        Self { bits, dimensions }
177    }
178
179    fn distance(&self, other: &Self) -> f32 {
180        // Hamming distance
181        let mut distance = 0u32;
182
183        for (&a, &b) in self.bits.iter().zip(&other.bits) {
184            distance += (a ^ b).count_ones();
185        }
186
187        distance as f32
188    }
189
190    fn reconstruct(&self) -> Vec<f32> {
191        let mut result = Vec::with_capacity(self.dimensions);
192
193        for i in 0..self.dimensions {
194            let byte_idx = i / 8;
195            let bit_idx = i % 8;
196            let bit = (self.bits[byte_idx] >> bit_idx) & 1;
197            result.push(if bit == 1 { 1.0 } else { -1.0 });
198        }
199
200        result
201    }
202}
203
204// Helper functions
205
206fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
207    a.iter()
208        .zip(b)
209        .map(|(&x, &y)| {
210            let diff = x - y;
211            diff * diff
212        })
213        .sum()
214}
215
216fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
217    use rand::seq::SliceRandom;
218    use rand::thread_rng;
219
220    let mut rng = thread_rng();
221
222    // Initialize centroids randomly
223    let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
224
225    for _ in 0..iterations {
226        // Assign vectors to nearest centroid
227        let mut assignments = vec![Vec::new(); k];
228
229        for vector in vectors {
230            let nearest = centroids
231                .iter()
232                .enumerate()
233                .min_by(|(_, a), (_, b)| {
234                    let dist_a = euclidean_squared(vector, a);
235                    let dist_b = euclidean_squared(vector, b);
236                    dist_a.partial_cmp(&dist_b).unwrap()
237                })
238                .map(|(idx, _)| idx)
239                .unwrap_or(0);
240
241            assignments[nearest].push(vector.clone());
242        }
243
244        // Update centroids
245        for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
246            if !assigned.is_empty() {
247                let dim = centroid.len();
248                *centroid = vec![0.0; dim];
249
250                for vector in assigned {
251                    for (i, &v) in vector.iter().enumerate() {
252                        centroid[i] += v;
253                    }
254                }
255
256                let count = assigned.len() as f32;
257                for v in centroid.iter_mut() {
258                    *v /= count;
259                }
260            }
261        }
262    }
263
264    centroids
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_scalar_quantization() {
273        let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
274        let quantized = ScalarQuantized::quantize(&vector);
275        let reconstructed = quantized.reconstruct();
276
277        // Check approximate reconstruction
278        for (orig, recon) in vector.iter().zip(&reconstructed) {
279            assert!((orig - recon).abs() < 0.1);
280        }
281    }
282
283    #[test]
284    fn test_binary_quantization() {
285        let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
286        let quantized = BinaryQuantized::quantize(&vector);
287
288        assert_eq!(quantized.dimensions, 5);
289        assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
290    }
291
292    #[test]
293    fn test_binary_distance() {
294        let v1 = vec![1.0, 1.0, 1.0, 1.0];
295        let v2 = vec![1.0, 1.0, -1.0, -1.0];
296
297        let q1 = BinaryQuantized::quantize(&v1);
298        let q2 = BinaryQuantized::quantize(&v2);
299
300        let dist = q1.distance(&q2);
301        assert_eq!(dist, 2.0); // 2 bits differ
302    }
303}