Skip to main content

velesdb_core/quantization/
pq.rs

1//! Product Quantization (PQ) for aggressive lossy vector compression.
2//!
3//! PQ splits vectors into multiple subspaces and quantizes each subspace
4//! independently with its own codebook (k-means centroids).
5
6use serde::{Deserialize, Serialize};
7
8/// Per-subspace centroid tables learned with k-means.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct PQCodebook {
11    /// Flattened centroids, indexed as `[subspace][centroid][subspace_dim]`.
12    pub centroids: Vec<Vec<Vec<f32>>>,
13    /// Full vector dimension.
14    pub dimension: usize,
15    /// Number of subspaces `m`.
16    pub num_subspaces: usize,
17    /// Number of centroids `k` per subspace.
18    pub num_centroids: usize,
19    /// Dimension of each subspace.
20    pub subspace_dim: usize,
21}
22
23/// Compressed representation of a vector: one centroid id per subspace.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PQVector {
26    /// Selected centroid ids for each subspace.
27    pub codes: Vec<u16>,
28}
29
30/// Product quantizer model and helpers for train/encode/decode.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ProductQuantizer {
33    /// Trained codebook.
34    pub codebook: PQCodebook,
35}
36
37impl ProductQuantizer {
38    /// Train a PQ codebook using simplified k-means for each subspace.
39    #[must_use]
40    pub fn train(vectors: &[Vec<f32>], num_subspaces: usize, num_centroids: usize) -> Self {
41        assert!(!vectors.is_empty(), "Cannot train PQ with empty dataset");
42        assert!(num_subspaces > 0, "num_subspaces must be > 0");
43        assert!(num_centroids > 0, "num_centroids must be > 0");
44        assert!(
45            num_centroids <= usize::from(u16::MAX),
46            "num_centroids must fit in u16 (max 65535)"
47        );
48
49        let dimension = vectors[0].len();
50        assert!(
51            vectors.iter().all(|v| v.len() == dimension),
52            "All vectors must share the same dimension"
53        );
54        assert!(
55            dimension.is_multiple_of(num_subspaces),
56            "Dimension must be divisible by num_subspaces"
57        );
58
59        let subspace_dim = dimension / num_subspaces;
60        let mut centroids = Vec::with_capacity(num_subspaces);
61
62        for subspace in 0..num_subspaces {
63            let start = subspace * subspace_dim;
64            let end = start + subspace_dim;
65            let sub_vectors: Vec<Vec<f32>> =
66                vectors.iter().map(|v| v[start..end].to_vec()).collect();
67            centroids.push(kmeans_train(&sub_vectors, num_centroids, 25));
68        }
69
70        Self {
71            codebook: PQCodebook {
72                centroids,
73                dimension,
74                num_subspaces,
75                num_centroids,
76                subspace_dim,
77            },
78        }
79    }
80
81    /// Quantize a full-precision vector into PQ codes.
82    #[must_use]
83    pub fn quantize(&self, vector: &[f32]) -> PQVector {
84        assert_eq!(vector.len(), self.codebook.dimension);
85
86        let mut codes = Vec::with_capacity(self.codebook.num_subspaces);
87        for subspace in 0..self.codebook.num_subspaces {
88            let start = subspace * self.codebook.subspace_dim;
89            let end = start + self.codebook.subspace_dim;
90            let code = nearest_centroid(&vector[start..end], &self.codebook.centroids[subspace]);
91            // SAFETY: `num_centroids` is validated to fit in u16 during `train()`.
92            // `nearest_centroid` returns an index < num_centroids, so it always fits.
93            #[allow(clippy::cast_possible_truncation)]
94            codes.push(code as u16);
95        }
96
97        PQVector { codes }
98    }
99
100    /// Reconstruct an approximate vector from PQ codes.
101    #[must_use]
102    pub fn reconstruct(&self, pq_vector: &PQVector) -> Vec<f32> {
103        assert_eq!(pq_vector.codes.len(), self.codebook.num_subspaces);
104
105        let mut reconstructed = Vec::with_capacity(self.codebook.dimension);
106        for (subspace, &code) in pq_vector.codes.iter().enumerate() {
107            let centroid = &self.codebook.centroids[subspace][usize::from(code)];
108            reconstructed.extend_from_slice(centroid);
109        }
110
111        reconstructed
112    }
113}
114
115/// Asymmetric distance computation (ADC): query is f32, candidate is PQ-coded.
116#[must_use]
117pub fn distance_pq(query_vector: &[f32], pq_vector: &PQVector, codebook: &PQCodebook) -> f32 {
118    assert_eq!(query_vector.len(), codebook.dimension);
119    assert_eq!(pq_vector.codes.len(), codebook.num_subspaces);
120
121    let mut lookup_tables = Vec::with_capacity(codebook.num_subspaces);
122    for subspace in 0..codebook.num_subspaces {
123        let start = subspace * codebook.subspace_dim;
124        let end = start + codebook.subspace_dim;
125        let q_sub = &query_vector[start..end];
126
127        let table: Vec<f32> = codebook.centroids[subspace]
128            .iter()
129            .map(|centroid| l2_squared(q_sub, centroid))
130            .collect();
131        lookup_tables.push(table);
132    }
133
134    pq_vector
135        .codes
136        .iter()
137        .enumerate()
138        .map(|(subspace, &code)| lookup_tables[subspace][usize::from(code)])
139        .sum::<f32>()
140        .sqrt()
141}
142
143fn nearest_centroid(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
144    let mut best_idx = 0;
145    let mut best_dist = f32::MAX;
146
147    for (idx, centroid) in centroids.iter().enumerate() {
148        let dist = l2_squared(vector, centroid);
149        if dist < best_dist {
150            best_dist = dist;
151            best_idx = idx;
152        }
153    }
154
155    best_idx
156}
157
158fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
159    a.iter()
160        .zip(b.iter())
161        .map(|(x, y)| {
162            let d = x - y;
163            d * d
164        })
165        .sum()
166}
167
168fn kmeans_train(samples: &[Vec<f32>], k: usize, max_iters: usize) -> Vec<Vec<f32>> {
169    assert!(!samples.is_empty());
170    let dim = samples[0].len();
171
172    // Deterministic init: first k (cycled if needed).
173    let mut centroids: Vec<Vec<f32>> = (0..k).map(|i| samples[i % samples.len()].clone()).collect();
174
175    let mut assignments = vec![0usize; samples.len()];
176
177    for _ in 0..max_iters {
178        let mut changed = false;
179
180        // Assignment step
181        for (i, sample) in samples.iter().enumerate() {
182            let new_assignment = nearest_centroid(sample, &centroids);
183            if assignments[i] != new_assignment {
184                assignments[i] = new_assignment;
185                changed = true;
186            }
187        }
188
189        // Update step
190        let mut new_centroids = vec![vec![0.0; dim]; k];
191        let mut counts = vec![0usize; k];
192
193        for (sample, &cluster) in samples.iter().zip(assignments.iter()) {
194            counts[cluster] += 1;
195            for (d, &val) in sample.iter().enumerate() {
196                new_centroids[cluster][d] += val;
197            }
198        }
199
200        for cluster in 0..k {
201            if counts[cluster] == 0 {
202                // Re-seed empty cluster deterministically.
203                new_centroids[cluster] = samples[cluster % samples.len()].clone();
204            } else {
205                let inv = 1.0 / counts[cluster] as f32;
206                for d in 0..dim {
207                    new_centroids[cluster][d] *= inv;
208                }
209            }
210        }
211
212        centroids = new_centroids;
213
214        if !changed {
215            break;
216        }
217    }
218
219    centroids
220}
221
222#[cfg(test)]
223mod tests {
224    use super::{distance_pq, ProductQuantizer};
225
226    #[test]
227    fn train_builds_expected_codebook_shape() {
228        let vectors = vec![
229            vec![0.0, 0.0, 10.0, 10.0],
230            vec![0.1, 0.0, 9.9, 10.1],
231            vec![8.0, 8.0, 1.0, 1.0],
232            vec![8.1, 7.9, 1.2, 0.8],
233        ];
234
235        let pq = ProductQuantizer::train(&vectors, 2, 2);
236        assert_eq!(pq.codebook.num_subspaces, 2);
237        assert_eq!(pq.codebook.num_centroids, 2);
238        assert_eq!(pq.codebook.subspace_dim, 2);
239        assert_eq!(pq.codebook.centroids.len(), 2);
240        assert_eq!(pq.codebook.centroids[0].len(), 2);
241    }
242
243    #[test]
244    fn quantize_and_reconstruct_roundtrip_is_reasonable() {
245        let vectors = vec![
246            vec![0.0, 0.0, 10.0, 10.0],
247            vec![0.1, -0.1, 10.1, 9.9],
248            vec![8.0, 8.0, 1.0, 1.0],
249            vec![8.1, 7.9, 1.2, 0.8],
250        ];
251        let pq = ProductQuantizer::train(&vectors, 2, 4);
252
253        let input = vec![8.05, 8.0, 1.1, 1.0];
254        let code = pq.quantize(&input);
255        let reconstructed = pq.reconstruct(&code);
256
257        assert_eq!(code.codes.len(), 2);
258        assert_eq!(reconstructed.len(), input.len());
259
260        let mse: f32 = input
261            .iter()
262            .zip(reconstructed.iter())
263            .map(|(a, b)| {
264                let d = a - b;
265                d * d
266            })
267            .sum::<f32>()
268            / input.len() as f32;
269        assert!(mse < 0.2, "reconstruction MSE too high: {mse}");
270    }
271
272    #[test]
273    fn adc_distance_prefers_closer_codes() {
274        let vectors = vec![
275            vec![0.0, 0.0, 0.0, 0.0],
276            vec![0.1, 0.1, 0.1, 0.1],
277            vec![5.0, 5.0, 5.0, 5.0],
278            vec![5.1, 4.9, 5.0, 5.2],
279        ];
280        let pq = ProductQuantizer::train(&vectors, 2, 2);
281
282        let near = pq.quantize(&[0.05, 0.05, 0.0, 0.1]);
283        let far = pq.quantize(&[5.0, 5.0, 5.0, 5.0]);
284        let query = [0.0, 0.0, 0.0, 0.0];
285
286        let d_near = distance_pq(&query, &near, &pq.codebook);
287        let d_far = distance_pq(&query, &far, &pq.codebook);
288
289        assert!(d_near < d_far, "ADC did not preserve proximity ordering");
290    }
291}