sqlite_knowledge_graph/vector/
turboquant.rs1use crate::error::{Error, Result};
12use rand::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TurboQuantConfig {
19 pub dimension: usize,
21 pub bit_width: usize,
23 pub seed: u64,
25}
26
27impl Default for TurboQuantConfig {
28 fn default() -> Self {
29 Self {
30 dimension: 384,
31 bit_width: 3,
32 seed: 42,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct TurboQuantIndex {
40 config: TurboQuantConfig,
41 rotation_matrix: Vec<Vec<f32>>,
43 codebook: Vec<f32>,
45 quantized_vectors: HashMap<i64, Vec<u8>>,
47 vector_norms: HashMap<i64, f32>,
49}
50
51pub struct LinearScanIndex {
53 config: TurboQuantConfig,
54 vectors: HashMap<i64, Vec<f32>>,
55}
56
57impl LinearScanIndex {
58 pub fn new(config: TurboQuantConfig) -> Result<Self> {
60 Ok(Self {
61 config,
62 vectors: HashMap::new(),
63 })
64 }
65
66 pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
68 if vector.len() != self.config.dimension {
69 return Err(Error::InvalidVectorDimension {
70 expected: self.config.dimension,
71 actual: vector.len(),
72 });
73 }
74 self.vectors.insert(entity_id, vector.to_vec());
75 Ok(())
76 }
77
78 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
80 if query.len() != self.config.dimension {
81 return Err(Error::InvalidVectorDimension {
82 expected: self.config.dimension,
83 actual: query.len(),
84 });
85 }
86
87 let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
88
89 let mut results: Vec<(i64, f32)> = self
90 .vectors
91 .iter()
92 .map(|(&entity_id, vector)| {
93 let dot_product: f32 = query.iter().zip(vector.iter()).map(|(a, b)| a * b).sum();
94 let target_norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
95 let similarity = if query_norm > 0.0 && target_norm > 0.0 {
96 dot_product / (query_norm * target_norm)
97 } else {
98 0.0
99 };
100 (entity_id, similarity)
101 })
102 .collect();
103
104 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
105 results.truncate(k);
106
107 Ok(results)
108 }
109
110 pub fn stats(&self) -> LinearScanStats {
112 LinearScanStats {
113 num_vectors: self.vectors.len(),
114 dimension: self.config.dimension,
115 bytes_per_vector: self.config.dimension * 4, }
117 }
118
119 pub fn clear(&mut self) {
121 self.vectors.clear();
122 }
123
124 pub fn len(&self) -> usize {
126 self.vectors.len()
127 }
128
129 pub fn is_empty(&self) -> bool {
131 self.vectors.is_empty()
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct LinearScanStats {
138 pub num_vectors: usize,
139 pub dimension: usize,
140 pub bytes_per_vector: usize,
141}
142
143impl TurboQuantIndex {
144 pub fn new(config: TurboQuantConfig) -> Result<Self> {
146 if config.bit_width < 1 || config.bit_width > 8 {
147 return Err(Error::InvalidInput(
148 "bit_width must be between 1 and 8".to_string(),
149 ));
150 }
151
152 let mut rng = StdRng::seed_from_u64(config.seed);
153
154 let rotation_matrix = Self::generate_rotation_matrix(config.dimension, &mut rng);
156
157 let codebook = Self::compute_codebook(config.bit_width);
159
160 Ok(Self {
161 config,
162 rotation_matrix,
163 codebook,
164 quantized_vectors: HashMap::new(),
165 vector_norms: HashMap::new(),
166 })
167 }
168
169 fn generate_rotation_matrix(d: usize, rng: &mut StdRng) -> Vec<Vec<f32>> {
171 let mut matrix = vec![vec![0.0f32; d]; d];
174
175 for row in &mut matrix {
176 for val in row.iter_mut() {
177 *val = rng.gen::<f32>() * 2.0 - 1.0;
178 }
179 }
180
181 matrix
184 }
185
186 fn compute_codebook(bit_width: usize) -> Vec<f32> {
189 let num_levels = 1 << bit_width; let mut codebook = Vec::with_capacity(num_levels);
196
197 match bit_width {
198 1 => {
199 codebook = vec![-0.5, 0.5];
201 }
202 2 => {
203 codebook = vec![-0.75, -0.25, 0.25, 0.75];
205 }
206 3 => {
207 codebook = vec![-0.9, -0.6, -0.35, -0.1, 0.1, 0.35, 0.6, 0.9];
209 }
210 4 => {
211 for i in 0..num_levels {
213 let val = (i as f32 / (num_levels - 1) as f32) * 2.0 - 1.0;
214 codebook.push(val * 0.95); }
216 }
217 _ => {
218 for i in 0..num_levels {
220 let val = (i as f32 / (num_levels - 1) as f32) * 2.0 - 1.0;
221 codebook.push(val * 0.95);
222 }
223 }
224 }
225
226 codebook
227 }
228
229 pub fn add_vector(&mut self, entity_id: i64, vector: &[f32]) -> Result<()> {
231 if vector.len() != self.config.dimension {
232 return Err(Error::InvalidVectorDimension {
233 expected: self.config.dimension,
234 actual: vector.len(),
235 });
236 }
237
238 let norm: f32 = vector.iter().map(|x| x * x).sum();
240 let norm = norm.sqrt();
241 self.vector_norms.insert(entity_id, norm);
242
243 let rotated = self.apply_rotation(vector);
245
246 let quantized = self.quantize_vector(&rotated);
248
249 self.quantized_vectors.insert(entity_id, quantized);
250
251 Ok(())
252 }
253
254 fn apply_rotation(&self, vector: &[f32]) -> Vec<f32> {
256 let d = self.config.dimension;
257 let mut rotated = vec![0.0f32; d];
258
259 for (i, rot_row) in self.rotation_matrix.iter().enumerate().take(d) {
260 for (j, &val) in vector.iter().enumerate().take(d) {
261 rotated[i] += rot_row[j] * val;
262 }
263 }
264
265 rotated
266 }
267
268 fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
270 vector
271 .iter()
272 .map(|&val| {
273 let mut best_idx = 0;
275 let mut best_dist = f32::MAX;
276
277 for (idx, ¢roid) in self.codebook.iter().enumerate() {
278 let dist = (val - centroid).abs();
279 if dist < best_dist {
280 best_dist = dist;
281 best_idx = idx;
282 }
283 }
284
285 best_idx as u8
286 })
287 .collect()
288 }
289
290 #[allow(dead_code)]
292 fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
293 quantized
294 .iter()
295 .map(|&idx| self.codebook[idx as usize])
296 .collect()
297 }
298
299 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(i64, f32)>> {
301 if query.len() != self.config.dimension {
302 return Err(Error::InvalidVectorDimension {
303 expected: self.config.dimension,
304 actual: query.len(),
305 });
306 }
307
308 let rotated_query = self.apply_rotation(query);
310 let quantized_query = self.quantize_vector(&rotated_query);
311
312 let query_norm: f32 = query.iter().map(|x| x * x).sum();
314 let query_norm = query_norm.sqrt();
315
316 let mut results: Vec<(i64, f32)> = self
318 .quantized_vectors
319 .iter()
320 .map(|(&entity_id, quantized_vec)| {
321 let similarity = self.compute_similarity(
322 &quantized_query,
323 quantized_vec,
324 query_norm,
325 self.vector_norms.get(&entity_id).copied().unwrap_or(1.0),
326 );
327 (entity_id, similarity)
328 })
329 .collect();
330
331 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
333 results.truncate(k);
334
335 Ok(results)
336 }
337
338 fn compute_similarity(
340 &self,
341 query: &[u8],
342 target: &[u8],
343 query_norm: f32,
344 target_norm: f32,
345 ) -> f32 {
346 if query.len() != target.len() {
347 return 0.0;
348 }
349
350 let mut dot_product = 0.0f32;
352 for i in 0..query.len() {
353 let q_val = self.codebook[query[i] as usize];
354 let t_val = self.codebook[target[i] as usize];
355 dot_product += q_val * t_val;
356 }
357
358 if query_norm > 0.0 && target_norm > 0.0 {
360 dot_product / (query_norm * target_norm)
361 } else {
362 0.0
363 }
364 }
365
366 pub fn add_vectors_batch(&mut self, vectors: &[(i64, Vec<f32>)]) -> Result<()> {
368 for (entity_id, vector) in vectors {
369 self.add_vector(*entity_id, vector)?;
370 }
371 Ok(())
372 }
373
374 pub fn stats(&self) -> TurboQuantStats {
376 TurboQuantStats {
377 num_vectors: self.quantized_vectors.len(),
378 dimension: self.config.dimension,
379 bit_width: self.config.bit_width,
380 bytes_per_vector: self.config.dimension, compression_ratio: 32.0 / self.config.bit_width as f32, }
383 }
384
385 pub fn remove_vector(&mut self, entity_id: i64) -> Result<()> {
387 self.quantized_vectors.remove(&entity_id);
388 self.vector_norms.remove(&entity_id);
389 Ok(())
390 }
391
392 pub fn clear(&mut self) {
394 self.quantized_vectors.clear();
395 self.vector_norms.clear();
396 }
397
398 pub fn len(&self) -> usize {
400 self.quantized_vectors.len()
401 }
402
403 pub fn is_empty(&self) -> bool {
405 self.quantized_vectors.is_empty()
406 }
407
408 pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
410 let serialized = serde_json::to_string(self)?;
411 std::fs::write(path, serialized)?;
412 Ok(())
413 }
414
415 pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
417 let contents = std::fs::read_to_string(path)?;
418 let index: Self = serde_json::from_str(&contents)?;
419 Ok(index)
420 }
421
422 pub fn config(&self) -> &TurboQuantConfig {
424 &self.config
425 }
426
427 pub fn search_batch(&self, queries: &[Vec<f32>], k: usize) -> Result<Vec<Vec<(i64, f32)>>> {
429 queries.iter().map(|query| self.search(query, k)).collect()
430 }
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct TurboQuantStats {
436 pub num_vectors: usize,
437 pub dimension: usize,
438 pub bit_width: usize,
439 pub bytes_per_vector: usize,
440 pub compression_ratio: f32,
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_create_index() {
449 let config = TurboQuantConfig {
450 dimension: 128,
451 bit_width: 3,
452 seed: 42,
453 };
454
455 let index = TurboQuantIndex::new(config).unwrap();
456 assert_eq!(index.config.dimension, 128);
457 assert_eq!(index.config.bit_width, 3);
458 }
459
460 #[test]
461 fn test_add_and_search() {
462 let config = TurboQuantConfig {
463 dimension: 128,
464 bit_width: 3,
465 seed: 42,
466 };
467
468 let mut index = TurboQuantIndex::new(config).unwrap();
469
470 let vec1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
472 let vec2: Vec<f32> = (0..128).map(|i| ((i + 64) % 128) as f32 / 128.0).collect();
473 let vec3: Vec<f32> = (0..128).map(|i| 1.0 - (i as f32) / 128.0).collect();
474
475 index.add_vector(1, &vec1).unwrap();
476 index.add_vector(2, &vec2).unwrap();
477 index.add_vector(3, &vec3).unwrap();
478
479 let results = index.search(&vec1, 2).unwrap();
481 assert_eq!(results.len(), 2);
482 assert_eq!(results[0].0, 1); }
484
485 #[test]
486 fn test_compression_ratio() {
487 let config = TurboQuantConfig {
488 dimension: 384,
489 bit_width: 3,
490 seed: 42,
491 };
492
493 let index = TurboQuantIndex::new(config).unwrap();
494 let stats = index.stats();
495
496 assert!(stats.compression_ratio > 10.0);
498 }
499
500 #[test]
501 fn test_stats() {
502 let config = TurboQuantConfig {
503 dimension: 384,
504 bit_width: 3,
505 seed: 42,
506 };
507
508 let mut index = TurboQuantIndex::new(config).unwrap();
509
510 let vec: Vec<f32> = vec![0.1; 384];
511 index.add_vector(1, &vec).unwrap();
512 index.add_vector(2, &vec).unwrap();
513
514 let stats = index.stats();
515 assert_eq!(stats.num_vectors, 2);
516 assert_eq!(stats.dimension, 384);
517 }
518}