1use anyhow::{anyhow, Result};
12use rand::seq::SliceRandom;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16pub const NUM_CENTROIDS: usize = 256;
18
19pub const DEFAULT_SUBVEC_DIM: usize = 8;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PQConfig {
25 pub dimension: usize,
27 pub num_subvectors: usize,
29 pub subvec_dim: usize,
31 pub num_centroids: usize,
33 pub kmeans_iterations: usize,
35}
36
37impl PQConfig {
38 pub fn for_dimension(dimension: usize) -> Self {
40 let subvec_dim = DEFAULT_SUBVEC_DIM;
41 let num_subvectors = dimension / subvec_dim;
42
43 assert!(
44 dimension % subvec_dim == 0,
45 "Dimension {} must be divisible by subvec_dim {}",
46 dimension,
47 subvec_dim
48 );
49
50 Self {
51 dimension,
52 num_subvectors,
53 subvec_dim,
54 num_centroids: NUM_CENTROIDS,
55 kmeans_iterations: 20,
56 }
57 }
58
59 pub fn minilm() -> Self {
61 Self::for_dimension(384)
62 }
63
64 pub fn clip() -> Self {
66 Self::for_dimension(768)
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ProductQuantizer {
75 pub config: PQConfig,
77 pub centroids: Vec<Vec<Vec<f32>>>,
79 pub trained: bool,
81}
82
83impl ProductQuantizer {
84 pub fn new(config: PQConfig) -> Self {
86 Self {
87 config,
88 centroids: Vec::new(),
89 trained: false,
90 }
91 }
92
93 pub fn train(config: PQConfig, training_vectors: &[Vec<f32>]) -> Result<Self> {
95 if training_vectors.is_empty() {
96 return Err(anyhow!("No training vectors provided"));
97 }
98
99 let first_dim = training_vectors[0].len();
100 if first_dim != config.dimension {
101 return Err(anyhow!(
102 "Training vector dimension {} doesn't match config {}",
103 first_dim,
104 config.dimension
105 ));
106 }
107
108 let mut pq = Self::new(config);
109 pq.fit(training_vectors)?;
110 Ok(pq)
111 }
112
113 fn fit(&mut self, vectors: &[Vec<f32>]) -> Result<()> {
115 let n_vectors = vectors.len();
116 let n_subvectors = self.config.num_subvectors;
117 let subvec_dim = self.config.subvec_dim;
118 let n_centroids = self.config.num_centroids.min(n_vectors);
119 let iterations = self.config.kmeans_iterations;
120
121 tracing::info!(
122 "Training PQ: {} vectors, {} subvectors, {} centroids, {} iterations",
123 n_vectors,
124 n_subvectors,
125 n_centroids,
126 iterations
127 );
128
129 self.centroids = Vec::with_capacity(n_subvectors);
131
132 for subvec_idx in 0..n_subvectors {
134 let start = subvec_idx * subvec_dim;
135 let end = start + subvec_dim;
136
137 let subvectors: Vec<Vec<f32>> =
139 vectors.iter().map(|v| v[start..end].to_vec()).collect();
140
141 let centroids = self.kmeans(&subvectors, n_centroids, iterations)?;
143 self.centroids.push(centroids);
144 }
145
146 self.trained = true;
147 tracing::info!("PQ training complete");
148 Ok(())
149 }
150
151 fn kmeans(&self, vectors: &[Vec<f32>], k: usize, iterations: usize) -> Result<Vec<Vec<f32>>> {
153 let dim = vectors[0].len();
154 let n = vectors.len();
155
156 let mut rng = rand::thread_rng();
158 let mut indices: Vec<usize> = (0..n).collect();
159 indices.shuffle(&mut rng);
160
161 let mut centroids: Vec<Vec<f32>> = indices
162 .iter()
163 .take(k)
164 .map(|&i| vectors[i].clone())
165 .collect();
166
167 while centroids.len() < k {
169 let idx = indices[centroids.len() % n];
170 centroids.push(vectors[idx].clone());
171 }
172
173 let mut assignments = vec![0usize; n];
174
175 for _ in 0..iterations {
177 for (i, vec) in vectors.iter().enumerate() {
179 let mut best_centroid = 0;
180 let mut best_dist = f32::MAX;
181
182 for (c, centroid) in centroids.iter().enumerate() {
183 let dist = squared_l2_distance(vec, centroid);
184 if dist < best_dist {
185 best_dist = dist;
186 best_centroid = c;
187 }
188 }
189 assignments[i] = best_centroid;
190 }
191
192 let mut new_centroids: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
194 let mut counts = vec![0usize; k];
195
196 for (i, vec) in vectors.iter().enumerate() {
197 let c = assignments[i];
198 counts[c] += 1;
199 for (j, &v) in vec.iter().enumerate() {
200 new_centroids[c][j] += v;
201 }
202 }
203
204 for c in 0..k {
206 if counts[c] > 0 {
207 for j in 0..dim {
208 new_centroids[c][j] /= counts[c] as f32;
209 }
210 centroids[c] = new_centroids[c].clone();
211 }
212 }
214 }
215
216 Ok(centroids)
217 }
218
219 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
221 if !self.trained {
222 return Err(anyhow!("ProductQuantizer not trained"));
223 }
224
225 if vector.len() != self.config.dimension {
226 return Err(anyhow!(
227 "Vector dimension {} doesn't match config {}",
228 vector.len(),
229 self.config.dimension
230 ));
231 }
232
233 let mut codes = Vec::with_capacity(self.config.num_subvectors);
234 let subvec_dim = self.config.subvec_dim;
235
236 for (subvec_idx, subspace_centroids) in self.centroids.iter().enumerate() {
237 let start = subvec_idx * subvec_dim;
238 let end = start + subvec_dim;
239 let subvector = &vector[start..end];
240
241 let mut best_centroid = 0u8;
243 let mut best_dist = f32::MAX;
244
245 for (c, centroid) in subspace_centroids.iter().enumerate() {
246 let dist = squared_l2_distance_slice(subvector, centroid);
247 if dist < best_dist {
248 best_dist = dist;
249 best_centroid = c as u8;
250 }
251 }
252
253 codes.push(best_centroid);
254 }
255
256 Ok(codes)
257 }
258
259 pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
261 if !self.trained {
262 return Err(anyhow!("ProductQuantizer not trained"));
263 }
264
265 if codes.len() != self.config.num_subvectors {
266 return Err(anyhow!(
267 "Code length {} doesn't match num_subvectors {}",
268 codes.len(),
269 self.config.num_subvectors
270 ));
271 }
272
273 let mut vector = Vec::with_capacity(self.config.dimension);
274
275 for (subvec_idx, &code) in codes.iter().enumerate() {
276 let subspace = &self.centroids[subvec_idx];
277 let code_idx = code as usize;
278 if code_idx >= subspace.len() {
279 return Err(anyhow!(
280 "PQ code {} out of bounds for subspace {} with {} centroids (data corruption?)",
281 code_idx,
282 subvec_idx,
283 subspace.len()
284 ));
285 }
286 vector.extend_from_slice(&subspace[code_idx]);
287 }
288
289 Ok(vector)
290 }
291
292 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> Result<f32> {
297 if !self.trained {
298 return Err(anyhow!("ProductQuantizer not trained"));
299 }
300
301 let subvec_dim = self.config.subvec_dim;
302 let mut total_dist = 0.0f32;
303
304 for (subvec_idx, &code) in codes.iter().enumerate() {
305 let start = subvec_idx * subvec_dim;
306 let end = start + subvec_dim;
307 let query_subvec = &query[start..end];
308 let subspace = &self.centroids[subvec_idx];
309 let code_idx = code as usize;
310 if code_idx >= subspace.len() {
311 return Err(anyhow!(
312 "PQ code {} out of bounds for subspace {} with {} centroids (data corruption?)",
313 code_idx,
314 subvec_idx,
315 subspace.len()
316 ));
317 }
318
319 total_dist += squared_l2_distance_slice(query_subvec, &subspace[code_idx]);
320 }
321
322 Ok(total_dist)
323 }
324
325 pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>> {
330 if !self.trained {
331 return Err(anyhow!("ProductQuantizer not trained"));
332 }
333
334 let subvec_dim = self.config.subvec_dim;
335 let n_centroids = self.config.num_centroids;
336 let mut table = Vec::with_capacity(self.config.num_subvectors);
337
338 for (subvec_idx, subspace_centroids) in self.centroids.iter().enumerate() {
339 let start = subvec_idx * subvec_dim;
340 let end = start + subvec_dim;
341 let query_subvec = &query[start..end];
342
343 let mut distances = Vec::with_capacity(n_centroids);
344 for centroid in subspace_centroids {
345 distances.push(squared_l2_distance_slice(query_subvec, centroid));
346 }
347 table.push(distances);
348 }
349
350 Ok(table)
351 }
352
353 #[inline]
358 pub fn distance_with_table(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
359 let mut total = 0.0f32;
360 for (subvec_idx, &code) in codes.iter().enumerate() {
361 let code_idx = code as usize;
362 if subvec_idx >= table.len() || code_idx >= table[subvec_idx].len() {
363 return f32::MAX; }
365 total += table[subvec_idx][code_idx];
366 }
367 total
368 }
369
370 pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Result<Vec<Vec<u8>>> {
372 vectors.iter().map(|v| self.encode(v)).collect()
373 }
374
375 pub fn compressed_size(&self) -> usize {
377 self.config.num_subvectors }
379
380 pub fn original_size(&self) -> usize {
382 self.config.dimension * std::mem::size_of::<f32>()
383 }
384
385 pub fn compression_ratio(&self) -> f32 {
387 self.original_size() as f32 / self.compressed_size() as f32
388 }
389}
390
391#[inline]
393fn squared_l2_distance(a: &[f32], b: &[f32]) -> f32 {
394 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
395}
396
397#[inline]
399fn squared_l2_distance_slice(a: &[f32], b: &[f32]) -> f32 {
400 squared_l2_distance(a, b)
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct CompressedVectorStore {
406 pub quantizer: ProductQuantizer,
408 pub codes: HashMap<u32, Vec<u8>>,
410}
411
412impl CompressedVectorStore {
413 pub fn new(quantizer: ProductQuantizer) -> Self {
415 Self {
416 quantizer,
417 codes: HashMap::new(),
418 }
419 }
420
421 pub fn train_and_create(config: PQConfig, training_vectors: &[Vec<f32>]) -> Result<Self> {
423 let quantizer = ProductQuantizer::train(config, training_vectors)?;
424 Ok(Self::new(quantizer))
425 }
426
427 pub fn add(&mut self, vector_id: u32, vector: &[f32]) -> Result<()> {
429 let codes = self.quantizer.encode(vector)?;
430 self.codes.insert(vector_id, codes);
431 Ok(())
432 }
433
434 pub fn get_codes(&self, vector_id: u32) -> Option<&Vec<u8>> {
436 self.codes.get(&vector_id)
437 }
438
439 pub fn decode(&self, vector_id: u32) -> Result<Vec<f32>> {
441 let codes = self
442 .codes
443 .get(&vector_id)
444 .ok_or_else(|| anyhow!("Vector {} not found", vector_id))?;
445 self.quantizer.decode(codes)
446 }
447
448 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
450 let table = self.quantizer.build_distance_table(query)?;
452
453 let mut distances: Vec<(u32, f32)> = self
455 .codes
456 .iter()
457 .map(|(&id, codes)| (id, self.quantizer.distance_with_table(&table, codes)))
458 .collect();
459
460 distances.sort_by(|a, b| a.1.total_cmp(&b.1));
462 distances.truncate(k);
463
464 Ok(distances)
465 }
466
467 pub fn len(&self) -> usize {
469 self.codes.len()
470 }
471
472 pub fn is_empty(&self) -> bool {
474 self.codes.is_empty()
475 }
476
477 pub fn storage_bytes(&self) -> usize {
479 self.codes.len() * self.quantizer.compressed_size()
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 fn generate_random_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
488 use rand::Rng;
489 let mut rng = rand::thread_rng();
490 (0..n)
491 .map(|_| (0..dim).map(|_| rng.gen::<f32>()).collect())
492 .collect()
493 }
494
495 #[test]
496 fn test_pq_encode_decode() {
497 let vectors = generate_random_vectors(1000, 384);
498 let config = PQConfig::minilm();
499 let pq = ProductQuantizer::train(config, &vectors).unwrap();
500
501 let original = &vectors[0];
503 let codes = pq.encode(original).unwrap();
504 let decoded = pq.decode(&codes).unwrap();
505
506 assert_eq!(codes.len(), 48); assert_eq!(decoded.len(), 384);
509
510 let mse: f32 = original
512 .iter()
513 .zip(decoded.iter())
514 .map(|(a, b)| (a - b).powi(2))
515 .sum::<f32>()
516 / 384.0;
517
518 assert!(mse < 0.1, "MSE too high: {}", mse);
520 }
521
522 #[test]
523 fn test_compression_ratio() {
524 let config = PQConfig::minilm();
525 let pq = ProductQuantizer::new(config);
526
527 assert_eq!(pq.original_size(), 384 * 4); assert_eq!(pq.compressed_size(), 48); assert!((pq.compression_ratio() - 32.0).abs() < 0.01);
530 }
531
532 #[test]
533 fn test_distance_table() {
534 let vectors = generate_random_vectors(100, 384);
535 let config = PQConfig::minilm();
536 let pq = ProductQuantizer::train(config, &vectors).unwrap();
537
538 let query = &vectors[0];
539 let codes = pq.encode(&vectors[1]).unwrap();
540
541 let direct_dist = pq.asymmetric_distance(query, &codes).unwrap();
543
544 let table = pq.build_distance_table(query).unwrap();
546 let table_dist = pq.distance_with_table(&table, &codes);
547
548 assert!((direct_dist - table_dist).abs() < 1e-6);
550 }
551
552 #[test]
553 fn test_compressed_store_search() {
554 let vectors = generate_random_vectors(1000, 384);
555 let config = PQConfig::minilm();
556
557 let mut store = CompressedVectorStore::train_and_create(config, &vectors).unwrap();
558
559 for (i, v) in vectors.iter().enumerate() {
561 store.add(i as u32, v).unwrap();
562 }
563
564 let results = store.search(&vectors[0], 10).unwrap();
566
567 assert_eq!(results.len(), 10);
570 let query_in_top_results = results.iter().take(5).any(|(id, _)| *id == 0);
572 assert!(
573 query_in_top_results,
574 "Query vector not found in top 5 results"
575 );
576 }
577}