velesdb_core/quantization/
pq.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct PQCodebook {
11 pub centroids: Vec<Vec<Vec<f32>>>,
13 pub dimension: usize,
15 pub num_subspaces: usize,
17 pub num_centroids: usize,
19 pub subspace_dim: usize,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct PQVector {
26 pub codes: Vec<u16>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ProductQuantizer {
33 pub codebook: PQCodebook,
35}
36
37impl ProductQuantizer {
38 #[must_use]
40 pub fn train(vectors: &[Vec<f32>], num_subspaces: usize, num_centroids: usize) -> Self {
41 assert!(!vectors.is_empty(), "Cannot train PQ with empty dataset");
42 assert!(num_subspaces > 0, "num_subspaces must be > 0");
43 assert!(num_centroids > 0, "num_centroids must be > 0");
44 assert!(
45 num_centroids <= usize::from(u16::MAX),
46 "num_centroids must fit in u16 (max 65535)"
47 );
48
49 let dimension = vectors[0].len();
50 assert!(
51 vectors.iter().all(|v| v.len() == dimension),
52 "All vectors must share the same dimension"
53 );
54 assert!(
55 dimension.is_multiple_of(num_subspaces),
56 "Dimension must be divisible by num_subspaces"
57 );
58
59 let subspace_dim = dimension / num_subspaces;
60 let mut centroids = Vec::with_capacity(num_subspaces);
61
62 for subspace in 0..num_subspaces {
63 let start = subspace * subspace_dim;
64 let end = start + subspace_dim;
65 let sub_vectors: Vec<Vec<f32>> =
66 vectors.iter().map(|v| v[start..end].to_vec()).collect();
67 centroids.push(kmeans_train(&sub_vectors, num_centroids, 25));
68 }
69
70 Self {
71 codebook: PQCodebook {
72 centroids,
73 dimension,
74 num_subspaces,
75 num_centroids,
76 subspace_dim,
77 },
78 }
79 }
80
81 #[must_use]
83 pub fn quantize(&self, vector: &[f32]) -> PQVector {
84 assert_eq!(vector.len(), self.codebook.dimension);
85
86 let mut codes = Vec::with_capacity(self.codebook.num_subspaces);
87 for subspace in 0..self.codebook.num_subspaces {
88 let start = subspace * self.codebook.subspace_dim;
89 let end = start + self.codebook.subspace_dim;
90 let code = nearest_centroid(&vector[start..end], &self.codebook.centroids[subspace]);
91 #[allow(clippy::cast_possible_truncation)]
94 codes.push(code as u16);
95 }
96
97 PQVector { codes }
98 }
99
100 #[must_use]
102 pub fn reconstruct(&self, pq_vector: &PQVector) -> Vec<f32> {
103 assert_eq!(pq_vector.codes.len(), self.codebook.num_subspaces);
104
105 let mut reconstructed = Vec::with_capacity(self.codebook.dimension);
106 for (subspace, &code) in pq_vector.codes.iter().enumerate() {
107 let centroid = &self.codebook.centroids[subspace][usize::from(code)];
108 reconstructed.extend_from_slice(centroid);
109 }
110
111 reconstructed
112 }
113}
114
115#[must_use]
117pub fn distance_pq(query_vector: &[f32], pq_vector: &PQVector, codebook: &PQCodebook) -> f32 {
118 assert_eq!(query_vector.len(), codebook.dimension);
119 assert_eq!(pq_vector.codes.len(), codebook.num_subspaces);
120
121 let mut lookup_tables = Vec::with_capacity(codebook.num_subspaces);
122 for subspace in 0..codebook.num_subspaces {
123 let start = subspace * codebook.subspace_dim;
124 let end = start + codebook.subspace_dim;
125 let q_sub = &query_vector[start..end];
126
127 let table: Vec<f32> = codebook.centroids[subspace]
128 .iter()
129 .map(|centroid| l2_squared(q_sub, centroid))
130 .collect();
131 lookup_tables.push(table);
132 }
133
134 pq_vector
135 .codes
136 .iter()
137 .enumerate()
138 .map(|(subspace, &code)| lookup_tables[subspace][usize::from(code)])
139 .sum::<f32>()
140 .sqrt()
141}
142
143fn nearest_centroid(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
144 let mut best_idx = 0;
145 let mut best_dist = f32::MAX;
146
147 for (idx, centroid) in centroids.iter().enumerate() {
148 let dist = l2_squared(vector, centroid);
149 if dist < best_dist {
150 best_dist = dist;
151 best_idx = idx;
152 }
153 }
154
155 best_idx
156}
157
158fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
159 a.iter()
160 .zip(b.iter())
161 .map(|(x, y)| {
162 let d = x - y;
163 d * d
164 })
165 .sum()
166}
167
168fn kmeans_train(samples: &[Vec<f32>], k: usize, max_iters: usize) -> Vec<Vec<f32>> {
169 assert!(!samples.is_empty());
170 let dim = samples[0].len();
171
172 let mut centroids: Vec<Vec<f32>> = (0..k).map(|i| samples[i % samples.len()].clone()).collect();
174
175 let mut assignments = vec![0usize; samples.len()];
176
177 for _ in 0..max_iters {
178 let mut changed = false;
179
180 for (i, sample) in samples.iter().enumerate() {
182 let new_assignment = nearest_centroid(sample, ¢roids);
183 if assignments[i] != new_assignment {
184 assignments[i] = new_assignment;
185 changed = true;
186 }
187 }
188
189 let mut new_centroids = vec![vec![0.0; dim]; k];
191 let mut counts = vec![0usize; k];
192
193 for (sample, &cluster) in samples.iter().zip(assignments.iter()) {
194 counts[cluster] += 1;
195 for (d, &val) in sample.iter().enumerate() {
196 new_centroids[cluster][d] += val;
197 }
198 }
199
200 for cluster in 0..k {
201 if counts[cluster] == 0 {
202 new_centroids[cluster] = samples[cluster % samples.len()].clone();
204 } else {
205 let inv = 1.0 / counts[cluster] as f32;
206 for d in 0..dim {
207 new_centroids[cluster][d] *= inv;
208 }
209 }
210 }
211
212 centroids = new_centroids;
213
214 if !changed {
215 break;
216 }
217 }
218
219 centroids
220}
221
222#[cfg(test)]
223mod tests {
224 use super::{distance_pq, ProductQuantizer};
225
226 #[test]
227 fn train_builds_expected_codebook_shape() {
228 let vectors = vec![
229 vec![0.0, 0.0, 10.0, 10.0],
230 vec![0.1, 0.0, 9.9, 10.1],
231 vec![8.0, 8.0, 1.0, 1.0],
232 vec![8.1, 7.9, 1.2, 0.8],
233 ];
234
235 let pq = ProductQuantizer::train(&vectors, 2, 2);
236 assert_eq!(pq.codebook.num_subspaces, 2);
237 assert_eq!(pq.codebook.num_centroids, 2);
238 assert_eq!(pq.codebook.subspace_dim, 2);
239 assert_eq!(pq.codebook.centroids.len(), 2);
240 assert_eq!(pq.codebook.centroids[0].len(), 2);
241 }
242
243 #[test]
244 fn quantize_and_reconstruct_roundtrip_is_reasonable() {
245 let vectors = vec![
246 vec![0.0, 0.0, 10.0, 10.0],
247 vec![0.1, -0.1, 10.1, 9.9],
248 vec![8.0, 8.0, 1.0, 1.0],
249 vec![8.1, 7.9, 1.2, 0.8],
250 ];
251 let pq = ProductQuantizer::train(&vectors, 2, 4);
252
253 let input = vec![8.05, 8.0, 1.1, 1.0];
254 let code = pq.quantize(&input);
255 let reconstructed = pq.reconstruct(&code);
256
257 assert_eq!(code.codes.len(), 2);
258 assert_eq!(reconstructed.len(), input.len());
259
260 let mse: f32 = input
261 .iter()
262 .zip(reconstructed.iter())
263 .map(|(a, b)| {
264 let d = a - b;
265 d * d
266 })
267 .sum::<f32>()
268 / input.len() as f32;
269 assert!(mse < 0.2, "reconstruction MSE too high: {mse}");
270 }
271
272 #[test]
273 fn adc_distance_prefers_closer_codes() {
274 let vectors = vec![
275 vec![0.0, 0.0, 0.0, 0.0],
276 vec![0.1, 0.1, 0.1, 0.1],
277 vec![5.0, 5.0, 5.0, 5.0],
278 vec![5.1, 4.9, 5.0, 5.2],
279 ];
280 let pq = ProductQuantizer::train(&vectors, 2, 2);
281
282 let near = pq.quantize(&[0.05, 0.05, 0.0, 0.1]);
283 let far = pq.quantize(&[5.0, 5.0, 5.0, 5.0]);
284 let query = [0.0, 0.0, 0.0, 0.0];
285
286 let d_near = distance_pq(&query, &near, &pq.codebook);
287 let d_far = distance_pq(&query, &far, &pq.codebook);
288
289 assert!(d_near < d_far, "ADC did not preserve proximity ordering");
290 }
291}