sqlite_knowledge_graph/vector/
turboquant.rs1use crate::error::{Error, Result};
12use nalgebra::DMatrix;
13use rand::{rngs::StdRng, SeedableRng};
14use rand_distr::{Distribution, StandardNormal};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TurboQuantConfig {
21 pub dimension: usize,
23 pub bit_width: usize,
25 pub seed: u64,
27}
28
29impl Default for TurboQuantConfig {
30 fn default() -> Self {
31 Self {
32 dimension: 384,
33 bit_width: 3,
34 seed: 42,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct TurboQuantIndex {
42 config: TurboQuantConfig,
43 rotation_matrix: Vec<Vec<f32>>,
45 codebook: Vec<f32>,
47 quantized_vectors: HashMap<i64, Vec<u8>>,
49 vector_norms: HashMap<i64, f32>,
51}
52
53pub struct LinearScanIndex {
55 config: TurboQuantConfig,
56 vectors: HashMap<i64, Vec<f32>>,
57}
58
59impl LinearScanIndex {
60 pub fn new(config: TurboQuantConfig) -> Result<Self> {
62 Ok(Self {
63 config,
64 vectors: HashMap::new(),
65 })
66 }
67
68 pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
70 if vector.len() != self.config.dimension {
71 return Err(Error::InvalidVectorDimension {
72 expected: self.config.dimension,
73 actual: vector.len(),
74 });
75 }
76 self.vectors.insert(entity_id, vector.to_vec());
77 Ok(())
78 }
79
80 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
82 if query.len() != self.config.dimension {
83 return Err(Error::InvalidVectorDimension {
84 expected: self.config.dimension,
85 actual: query.len(),
86 });
87 }
88
89 let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
90
91 let mut results: Vec<(i64, f32)> = self
92 .vectors
93 .iter()
94 .map(|(&entity_id, vector)| {
95 let dot_product: f32 = query.iter().zip(vector.iter()).map(|(a, b)| a * b).sum();
96 let target_norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
97 let similarity = if query_norm > 0.0 && target_norm > 0.0 {
98 dot_product / (query_norm * target_norm)
99 } else {
100 0.0
101 };
102 (entity_id, similarity)
103 })
104 .collect();
105
106 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107 results.truncate(k);
108
109 Ok(results)
110 }
111
112 pub fn stats(&self) -> LinearScanStats {
114 LinearScanStats {
115 num_vectors: self.vectors.len(),
116 dimension: self.config.dimension,
117 bytes_per_vector: self.config.dimension * 4, }
119 }
120
121 pub fn clear(&mut self) {
123 self.vectors.clear();
124 }
125
126 pub fn len(&self) -> usize {
128 self.vectors.len()
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.vectors.is_empty()
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct LinearScanStats {
140 pub num_vectors: usize,
141 pub dimension: usize,
142 pub bytes_per_vector: usize,
143}
144
145impl TurboQuantIndex {
146 pub fn new(config: TurboQuantConfig) -> Result<Self> {
148 if config.bit_width < 1 || config.bit_width > 8 {
149 return Err(Error::InvalidInput(
150 "bit_width must be between 1 and 8".to_string(),
151 ));
152 }
153
154 let mut rng = StdRng::seed_from_u64(config.seed);
155
156 let rotation_matrix = Self::generate_rotation_matrix(config.dimension, &mut rng);
158
159 let codebook = Self::compute_codebook(config.bit_width);
161
162 Ok(Self {
163 config,
164 rotation_matrix,
165 codebook,
166 quantized_vectors: HashMap::new(),
167 vector_norms: HashMap::new(),
168 })
169 }
170
171 fn generate_rotation_matrix(d: usize, rng: &mut StdRng) -> Vec<Vec<f32>> {
177 let data: Vec<f64> = (0..d * d).map(|_| StandardNormal.sample(rng)).collect();
179 let matrix = DMatrix::from_vec(d, d, data);
180
181 let qr = matrix.qr();
183 let q = qr.q();
184
185 (0..d)
187 .map(|i| (0..d).map(|j| q[(i, j)] as f32).collect())
188 .collect()
189 }
190
191 fn compute_codebook(bit_width: usize) -> Vec<f32> {
197 let k = 1usize << bit_width; let mut rng = StdRng::seed_from_u64(0xc0de_b007);
200 let num_samples = 50_000usize;
201 let std_dev = (1.0_f32 / 384_f32).sqrt(); let samples: Vec<f32> = (0..num_samples)
205 .map(|_| {
206 let n: f64 = StandardNormal.sample(&mut rng);
207 (n as f32 * std_dev).clamp(-1.0, 1.0)
208 })
209 .collect();
210
211 let mut centroids: Vec<f32> = (0..k)
213 .map(|i| {
214 if k == 1 {
215 0.0
216 } else {
217 -1.0 + 2.0 * i as f32 / (k - 1) as f32
218 }
219 })
220 .collect();
221
222 for _ in 0..100 {
224 let mut sums = vec![0.0f64; k];
225 let mut counts = vec![0usize; k];
226
227 for &x in &samples {
228 let nearest = centroids
229 .iter()
230 .enumerate()
231 .min_by(|(_, a), (_, b)| {
232 (x - *a)
233 .abs()
234 .partial_cmp(&(x - *b).abs())
235 .unwrap_or(std::cmp::Ordering::Equal)
236 })
237 .map(|(i, _)| i)
238 .unwrap_or(0);
239 sums[nearest] += x as f64;
240 counts[nearest] += 1;
241 }
242
243 let prev = centroids.clone();
244 for i in 0..k {
245 if counts[i] > 0 {
246 centroids[i] = (sums[i] / counts[i] as f64) as f32;
247 }
248 }
249
250 let converged = centroids
252 .iter()
253 .zip(prev.iter())
254 .all(|(a, b)| (a - b).abs() < 1e-6);
255 if converged {
256 break;
257 }
258 }
259
260 centroids.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
261 centroids
262 }
263
264 pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
266 if vector.len() != self.config.dimension {
267 return Err(Error::InvalidVectorDimension {
268 expected: self.config.dimension,
269 actual: vector.len(),
270 });
271 }
272
273 let norm: f32 = vector.iter().map(|x| x * x).sum();
275 let norm = norm.sqrt();
276 self.vector_norms.insert(entity_id, norm);
277
278 let rotated = self.apply_rotation(vector);
280
281 let quantized = self.quantize_vector(&rotated);
283
284 self.quantized_vectors.insert(entity_id, quantized);
285
286 Ok(())
287 }
288
289 fn apply_rotation(&self, vector: &[f32]) -> Vec<f32> {
291 let d = self.config.dimension;
292 let mut rotated = vec![0.0f32; d];
293
294 for (i, rot_row) in self.rotation_matrix.iter().enumerate().take(d) {
295 for (j, &val) in vector.iter().enumerate().take(d) {
296 rotated[i] += rot_row[j] * val;
297 }
298 }
299
300 rotated
301 }
302
303 fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
305 vector
306 .iter()
307 .map(|&val| {
308 let mut best_idx = 0;
310 let mut best_dist = f32::MAX;
311
312 for (idx, ¢roid) in self.codebook.iter().enumerate() {
313 let dist = (val - centroid).abs();
314 if dist < best_dist {
315 best_dist = dist;
316 best_idx = idx;
317 }
318 }
319
320 best_idx as u8
321 })
322 .collect()
323 }
324
325 #[allow(dead_code)]
327 fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
328 quantized
329 .iter()
330 .map(|&idx| self.codebook[idx as usize])
331 .collect()
332 }
333
334 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
336 if query.len() != self.config.dimension {
337 return Err(Error::InvalidVectorDimension {
338 expected: self.config.dimension,
339 actual: query.len(),
340 });
341 }
342
343 let rotated_query = self.apply_rotation(query);
345 let quantized_query = self.quantize_vector(&rotated_query);
346
347 let query_norm: f32 = query.iter().map(|x| x * x).sum();
349 let query_norm = query_norm.sqrt();
350
351 let mut results: Vec<(i64, f32)> = self
353 .quantized_vectors
354 .iter()
355 .map(|(&entity_id, quantized_vec)| {
356 let similarity = self.compute_similarity(
357 &quantized_query,
358 quantized_vec,
359 query_norm,
360 self.vector_norms.get(&entity_id).copied().unwrap_or(1.0),
361 );
362 (entity_id, similarity)
363 })
364 .collect();
365
366 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
368 results.truncate(k);
369
370 Ok(results)
371 }
372
373 fn compute_similarity(
381 &self,
382 query: &[u8],
383 target: &[u8],
384 _query_norm: f32,
385 _target_norm: f32,
386 ) -> f32 {
387 if query.len() != target.len() {
388 return 0.0;
389 }
390
391 let mut dot_product = 0.0f32;
392 let mut query_sq = 0.0f32;
393 let mut target_sq = 0.0f32;
394
395 for i in 0..query.len() {
396 let q_val = self.codebook[query[i] as usize];
397 let t_val = self.codebook[target[i] as usize];
398 dot_product += q_val * t_val;
399 query_sq += q_val * q_val;
400 target_sq += t_val * t_val;
401 }
402
403 let denom = query_sq.sqrt() * target_sq.sqrt();
404 if denom > 0.0 {
405 dot_product / denom
406 } else {
407 0.0
408 }
409 }
410
411 pub fn add_vectors_batch(&mut self, vectors: &[(i64, Vec<f32>)]) -> Result<()> {
413 for (entity_id, vector) in vectors {
414 self.add_vector(*entity_id, vector)?;
415 }
416 Ok(())
417 }
418
419 pub fn stats(&self) -> TurboQuantStats {
421 TurboQuantStats {
422 num_vectors: self.quantized_vectors.len(),
423 dimension: self.config.dimension,
424 bit_width: self.config.bit_width,
425 bytes_per_vector: self.config.dimension, compression_ratio: 32.0 / self.config.bit_width as f32, }
428 }
429
430 pub fn remove_vector(&mut self, entity_id: i64) -> Result<()> {
432 self.quantized_vectors.remove(&entity_id);
433 self.vector_norms.remove(&entity_id);
434 Ok(())
435 }
436
437 pub fn clear(&mut self) {
439 self.quantized_vectors.clear();
440 self.vector_norms.clear();
441 }
442
443 pub fn len(&self) -> usize {
445 self.quantized_vectors.len()
446 }
447
448 pub fn is_empty(&self) -> bool {
450 self.quantized_vectors.is_empty()
451 }
452
453 pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
455 let serialized = serde_json::to_string(self)?;
456 std::fs::write(path, serialized)?;
457 Ok(())
458 }
459
460 pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
462 let contents = std::fs::read_to_string(path)?;
463 let index: Self = serde_json::from_str(&contents)?;
464 Ok(index)
465 }
466
467 pub fn config(&self) -> &TurboQuantConfig {
469 &self.config
470 }
471
472 pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(i64, f32)>>> {
474 queries.iter().map(|query| self.search(query, k)).collect()
475 }
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct TurboQuantStats {
481 pub num_vectors: usize,
482 pub dimension: usize,
483 pub bit_width: usize,
484 pub bytes_per_vector: usize,
485 pub compression_ratio: f32,
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_create_index() {
494 let config = TurboQuantConfig {
495 dimension: 128,
496 bit_width: 3,
497 seed: 42,
498 };
499
500 let index = TurboQuantIndex::new(config).unwrap();
501 assert_eq!(index.config.dimension, 128);
502 assert_eq!(index.config.bit_width, 3);
503 }
504
505 #[test]
506 fn test_add_and_search() {
507 let config = TurboQuantConfig {
508 dimension: 128,
509 bit_width: 3,
510 seed: 42,
511 };
512
513 let mut index = TurboQuantIndex::new(config).unwrap();
514
515 let vec1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
517 let vec2: Vec<f32> = (0..128).map(|i| ((i + 64) % 128) as f32 / 128.0).collect();
518 let vec3: Vec<f32> = (0..128).map(|i| 1.0 - (i as f32) / 128.0).collect();
519
520 index.add_vector(1, &vec1).unwrap();
521 index.add_vector(2, &vec2).unwrap();
522 index.add_vector(3, &vec3).unwrap();
523
524 let results = index.search(&vec1, 2).unwrap();
526 assert_eq!(results.len(), 2);
527 assert_eq!(results[0].0, 1); }
529
530 #[test]
531 fn test_compression_ratio() {
532 let config = TurboQuantConfig {
533 dimension: 384,
534 bit_width: 3,
535 seed: 42,
536 };
537
538 let index = TurboQuantIndex::new(config).unwrap();
539 let stats = index.stats();
540
541 assert!(stats.compression_ratio > 10.0);
543 }
544
545 #[test]
546 fn test_stats() {
547 let config = TurboQuantConfig {
548 dimension: 384,
549 bit_width: 3,
550 seed: 42,
551 };
552
553 let mut index = TurboQuantIndex::new(config).unwrap();
554
555 let vec: Vec<f32> = vec![0.1; 384];
556 index.add_vector(1, &vec).unwrap();
557 index.add_vector(2, &vec).unwrap();
558
559 let stats = index.stats();
560 assert_eq!(stats.num_vectors, 2);
561 assert_eq!(stats.dimension, 384);
562 }
563}