velesdb_core/index/hnsw/native/
quantization.rs1use std::sync::Arc;
20
21#[inline]
35fn distance_l2_quantized_simd(a: &[u8], b: &[u8]) -> u32 {
36 debug_assert_eq!(a.len(), b.len());
37
38 let chunks = a.len() / 8;
40 let remainder = a.len() % 8;
41
42 let mut sum0: u32 = 0;
43 let mut sum1: u32 = 0;
44 let mut sum2: u32 = 0;
45 let mut sum3: u32 = 0;
46
47 for i in 0..chunks {
49 let base = i * 8;
50
51 let d0 = i32::from(a[base]) - i32::from(b[base]);
53 let d1 = i32::from(a[base + 1]) - i32::from(b[base + 1]);
54 let d2 = i32::from(a[base + 2]) - i32::from(b[base + 2]);
55 let d3 = i32::from(a[base + 3]) - i32::from(b[base + 3]);
56 let d4 = i32::from(a[base + 4]) - i32::from(b[base + 4]);
57 let d5 = i32::from(a[base + 5]) - i32::from(b[base + 5]);
58 let d6 = i32::from(a[base + 6]) - i32::from(b[base + 6]);
59 let d7 = i32::from(a[base + 7]) - i32::from(b[base + 7]);
60
61 sum0 += (d0 * d0) as u32 + (d4 * d4) as u32;
62 sum1 += (d1 * d1) as u32 + (d5 * d5) as u32;
63 sum2 += (d2 * d2) as u32 + (d6 * d6) as u32;
64 sum3 += (d3 * d3) as u32 + (d7 * d7) as u32;
65 }
66
67 let base = chunks * 8;
69 for i in 0..remainder {
70 let diff = i32::from(a[base + i]) - i32::from(b[base + i]);
71 sum0 += (diff * diff) as u32;
72 }
73
74 sum0 + sum1 + sum2 + sum3
75}
76
77#[inline]
82fn distance_l2_asymmetric_simd(
83 query: &[f32],
84 quantized: &[u8],
85 min_vals: &[f32],
86 inv_scales: &[f32],
87) -> f32 {
88 debug_assert_eq!(query.len(), quantized.len());
89 debug_assert_eq!(query.len(), min_vals.len());
90 debug_assert_eq!(query.len(), inv_scales.len());
91
92 let chunks = query.len() / 4;
94 let remainder = query.len() % 4;
95
96 let mut sum0: f32 = 0.0;
97 let mut sum1: f32 = 0.0;
98 let mut sum2: f32 = 0.0;
99 let mut sum3: f32 = 0.0;
100
101 for i in 0..chunks {
102 let base = i * 4;
103
104 let dq0 = f32::from(quantized[base]) * inv_scales[base] + min_vals[base];
106 let dq1 = f32::from(quantized[base + 1]) * inv_scales[base + 1] + min_vals[base + 1];
107 let dq2 = f32::from(quantized[base + 2]) * inv_scales[base + 2] + min_vals[base + 2];
108 let dq3 = f32::from(quantized[base + 3]) * inv_scales[base + 3] + min_vals[base + 3];
109
110 let d0 = query[base] - dq0;
111 let d1 = query[base + 1] - dq1;
112 let d2 = query[base + 2] - dq2;
113 let d3 = query[base + 3] - dq3;
114
115 sum0 += d0 * d0;
116 sum1 += d1 * d1;
117 sum2 += d2 * d2;
118 sum3 += d3 * d3;
119 }
120
121 let base = chunks * 4;
123 for i in 0..remainder {
124 let idx = base + i;
125 let dq = f32::from(quantized[idx]) * inv_scales[idx] + min_vals[idx];
126 let diff = query[idx] - dq;
127 sum0 += diff * diff;
128 }
129
130 (sum0 + sum1 + sum2 + sum3).sqrt()
131}
132
133#[derive(Debug, Clone)]
135pub struct ScalarQuantizer {
136 pub min_vals: Vec<f32>,
138 pub scales: Vec<f32>,
140 pub inv_scales: Vec<f32>,
142 pub dimension: usize,
144}
145
146#[derive(Debug, Clone)]
148pub struct QuantizedVector {
149 pub data: Vec<u8>,
151}
152
153#[derive(Debug, Clone)]
155pub struct QuantizedVectorStore {
156 quantizer: Arc<ScalarQuantizer>,
158 data: Vec<u8>,
160 count: usize,
162}
163
164impl ScalarQuantizer {
165 #[must_use]
175 pub fn train(vectors: &[&[f32]]) -> Self {
176 assert!(!vectors.is_empty(), "Cannot train on empty vectors");
177 let dimension = vectors[0].len();
178 assert!(
179 vectors.iter().all(|v| v.len() == dimension),
180 "All vectors must have same dimension"
181 );
182
183 let mut min_vals = vec![f32::MAX; dimension];
184 let mut max_vals = vec![f32::MIN; dimension];
185
186 for vec in vectors {
188 for (i, &val) in vec.iter().enumerate() {
189 min_vals[i] = min_vals[i].min(val);
190 max_vals[i] = max_vals[i].max(val);
191 }
192 }
193
194 let scales: Vec<f32> = min_vals
196 .iter()
197 .zip(max_vals.iter())
198 .map(|(&min, &max)| {
199 let range = max - min;
200 if range.abs() < 1e-10 {
201 1.0 } else {
203 255.0 / range
204 }
205 })
206 .collect();
207
208 let inv_scales: Vec<f32> = scales.iter().map(|&s| 1.0 / s).collect();
210
211 Self {
212 min_vals,
213 scales,
214 inv_scales,
215 dimension,
216 }
217 }
218
219 #[must_use]
221 pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
222 debug_assert_eq!(vector.len(), self.dimension);
223
224 let data: Vec<u8> = vector
225 .iter()
226 .zip(self.min_vals.iter())
227 .zip(self.scales.iter())
228 .map(|((&val, &min), &scale)| {
229 let q = ((val - min) * scale).round();
230 q.clamp(0.0, 255.0) as u8
231 })
232 .collect();
233
234 QuantizedVector { data }
235 }
236
237 #[must_use]
239 pub fn dequantize(&self, quantized: &QuantizedVector) -> Vec<f32> {
240 debug_assert_eq!(quantized.data.len(), self.dimension);
241
242 quantized
243 .data
244 .iter()
245 .zip(self.min_vals.iter())
246 .zip(self.inv_scales.iter())
247 .map(|((&q, &min), &inv_scale)| {
248 f32::from(q) * inv_scale + min
250 })
251 .collect()
252 }
253
254 #[inline]
258 #[must_use]
259 pub fn distance_l2_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> u32 {
260 debug_assert_eq!(a.data.len(), b.data.len());
261 distance_l2_quantized_simd(&a.data, &b.data)
262 }
263
264 #[inline]
268 #[must_use]
269 pub fn distance_l2_quantized_slice(&self, a: &[u8], b: &[u8]) -> u32 {
270 debug_assert_eq!(a.len(), b.len());
271 distance_l2_quantized_simd(a, b)
272 }
273
274 #[inline]
279 #[must_use]
280 pub fn distance_l2_asymmetric(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
281 debug_assert_eq!(query.len(), self.dimension);
282 debug_assert_eq!(quantized.data.len(), self.dimension);
283
284 distance_l2_asymmetric_simd(query, &quantized.data, &self.min_vals, &self.inv_scales)
285 }
286
287 #[inline]
289 #[must_use]
290 pub fn distance_l2_asymmetric_slice(&self, query: &[f32], quantized: &[u8]) -> f32 {
291 debug_assert_eq!(query.len(), self.dimension);
292 debug_assert_eq!(quantized.len(), self.dimension);
293
294 distance_l2_asymmetric_simd(query, quantized, &self.min_vals, &self.inv_scales)
295 }
296}
297
298impl QuantizedVectorStore {
299 #[must_use]
301 pub fn new(quantizer: Arc<ScalarQuantizer>, capacity: usize) -> Self {
302 let dimension = quantizer.dimension;
303 Self {
304 quantizer,
305 data: Vec::with_capacity(capacity * dimension),
306 count: 0,
307 }
308 }
309
310 pub fn push(&mut self, vector: &[f32]) {
312 let quantized = self.quantizer.quantize(vector);
313 self.data.extend(quantized.data);
314 self.count += 1;
315 }
316
317 #[must_use]
319 pub fn get(&self, index: usize) -> Option<QuantizedVector> {
320 if index >= self.count {
321 return None;
322 }
323 let start = index * self.quantizer.dimension;
324 let end = start + self.quantizer.dimension;
325 Some(QuantizedVector {
326 data: self.data[start..end].to_vec(),
327 })
328 }
329
330 #[must_use]
332 pub fn get_slice(&self, index: usize) -> Option<&[u8]> {
333 if index >= self.count {
334 return None;
335 }
336 let start = index * self.quantizer.dimension;
337 let end = start + self.quantizer.dimension;
338 Some(&self.data[start..end])
339 }
340
341 #[must_use]
343 pub fn len(&self) -> usize {
344 self.count
345 }
346
347 #[must_use]
349 pub fn is_empty(&self) -> bool {
350 self.count == 0
351 }
352
353 #[must_use]
355 pub fn quantizer(&self) -> &ScalarQuantizer {
356 &self.quantizer
357 }
358}
359
360#[cfg(test)]
361#[allow(clippy::similar_names)]
362mod tests {
363 use super::*;
364
365 #[test]
370 fn test_train_computes_correct_min_max() {
371 let v1 = vec![0.0, 10.0, -5.0];
372 let v2 = vec![5.0, 20.0, 5.0];
373 let v3 = vec![2.5, 15.0, 0.0];
374
375 let quantizer = ScalarQuantizer::train(&[&v1, &v2, &v3]);
376
377 assert_eq!(quantizer.dimension, 3);
378 assert!((quantizer.min_vals[0] - 0.0).abs() < 1e-6);
379 assert!((quantizer.min_vals[1] - 10.0).abs() < 1e-6);
380 assert!((quantizer.min_vals[2] - (-5.0)).abs() < 1e-6);
381
382 assert!((quantizer.scales[0] - 255.0 / 5.0).abs() < 1e-4);
384 assert!((quantizer.scales[1] - 255.0 / 10.0).abs() < 1e-4);
385 assert!((quantizer.scales[2] - 255.0 / 10.0).abs() < 1e-4);
386 }
387
388 #[test]
389 fn test_train_handles_constant_dimension() {
390 let v1 = vec![1.0, 5.0, 5.0]; let v2 = vec![2.0, 5.0, 5.0];
392
393 let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
394
395 assert!((quantizer.scales[1] - 1.0).abs() < 1e-6);
397 assert!((quantizer.scales[2] - 1.0).abs() < 1e-6);
398 }
399
400 #[test]
401 #[should_panic(expected = "Cannot train on empty vectors")]
402 fn test_train_panics_on_empty() {
403 let _: ScalarQuantizer = ScalarQuantizer::train(&[]);
404 }
405
406 #[test]
411 fn test_quantize_min_becomes_zero() {
412 let v = vec![0.0, 100.0];
413 let quantizer = ScalarQuantizer::train(&[&v]);
414
415 let qvec = quantizer.quantize(&[0.0, 100.0]);
416
417 assert_eq!(qvec.data[0], 0);
419 }
421
422 #[test]
423 fn test_quantize_range_maps_correctly() {
424 let v1 = vec![0.0, 0.0];
425 let v2 = vec![10.0, 100.0];
426 let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
427
428 let q_min = quantizer.quantize(&[0.0, 0.0]);
430 assert_eq!(q_min.data[0], 0);
431 assert_eq!(q_min.data[1], 0);
432
433 let q_max = quantizer.quantize(&[10.0, 100.0]);
435 assert_eq!(q_max.data[0], 255);
436 assert_eq!(q_max.data[1], 255);
437
438 let q_mid = quantizer.quantize(&[5.0, 50.0]);
440 assert!((i32::from(q_mid.data[0]) - 127).abs() <= 1);
441 assert!((i32::from(q_mid.data[1]) - 127).abs() <= 1);
442 }
443
444 #[test]
445 fn test_quantize_clamps_out_of_range() {
446 let v1 = vec![0.0];
447 let v2 = vec![10.0];
448 let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
449
450 let q_low = quantizer.quantize(&[-5.0]);
452 assert_eq!(q_low.data[0], 0, "Should clamp to 0");
453
454 let q_high = quantizer.quantize(&[20.0]);
456 assert_eq!(q_high.data[0], 255, "Should clamp to 255");
457 }
458
459 #[test]
460 fn test_dequantize_recovers_approximate_values() {
461 let v1 = vec![0.0, -10.0, 100.0];
462 let v2 = vec![10.0, 10.0, 200.0];
463 let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
464
465 let original = vec![5.0, 0.0, 150.0];
466 let qvec = quantizer.quantize(&original);
467 let recovered = quantizer.dequantize(&qvec);
468
469 for (i, (&orig, &rec)) in original.iter().zip(recovered.iter()).enumerate() {
471 let range = v2[i] - v1[i];
472 let error = (orig - rec).abs();
473 let relative_error = error / range;
474 assert!(
475 relative_error < 0.01,
476 "Dim {i}: orig={orig}, rec={rec}, error={relative_error:.4}"
477 );
478 }
479 }
480
481 #[test]
486 fn test_distance_l2_quantized_identical_is_zero() {
487 let quantizer = ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]);
488 let v = quantizer.quantize(&[5.0, 5.0]);
489
490 let dist = quantizer.distance_l2_quantized(&v, &v);
491 assert_eq!(dist, 0, "Distance to self should be 0");
492 }
493
494 #[test]
495 fn test_distance_l2_quantized_symmetry() {
496 let quantizer = ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]);
497 let a = quantizer.quantize(&[2.0, 3.0]);
498 let b = quantizer.quantize(&[7.0, 8.0]);
499
500 let dist_ab = quantizer.distance_l2_quantized(&a, &b);
501 let dist_ba = quantizer.distance_l2_quantized(&b, &a);
502
503 assert_eq!(dist_ab, dist_ba, "Distance should be symmetric");
504 }
505
506 #[test]
507 fn test_distance_l2_asymmetric_close_to_exact() {
508 let v1 = vec![0.0; 128];
509 let v2 = vec![10.0; 128];
510 let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
511
512 let query = vec![3.0; 128];
513 let candidate = vec![7.0; 128];
514
515 let quantized_candidate = quantizer.quantize(&candidate);
516 let approx_dist = quantizer.distance_l2_asymmetric(&query, &quantized_candidate);
517
518 let exact_dist: f32 = query
520 .iter()
521 .zip(candidate.iter())
522 .map(|(a, b)| (a - b).powi(2))
523 .sum::<f32>()
524 .sqrt();
525
526 let relative_error = (approx_dist - exact_dist).abs() / exact_dist;
528 assert!(
529 relative_error < 0.05,
530 "approx={approx_dist}, exact={exact_dist}, error={relative_error:.4}"
531 );
532 }
533
534 #[test]
539 fn test_store_push_and_get() {
540 let quantizer = Arc::new(ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]));
541 let mut store = QuantizedVectorStore::new(quantizer.clone(), 100);
542
543 store.push(&[2.0, 3.0]);
544 store.push(&[7.0, 8.0]);
545
546 assert_eq!(store.len(), 2);
547
548 let v0 = store.get(0).expect("Should have index 0");
549 let v1 = store.get(1).expect("Should have index 1");
550
551 assert_ne!(v0.data, v1.data);
553 }
554
555 #[test]
556 fn test_store_get_out_of_bounds_returns_none() {
557 let quantizer = Arc::new(ScalarQuantizer::train(&[&[0.0], &[10.0]]));
558 let store = QuantizedVectorStore::new(quantizer, 100);
559
560 assert!(store.get(0).is_none());
561 assert!(store.get(100).is_none());
562 }
563
564 #[test]
565 fn test_store_get_slice_zero_copy() {
566 let quantizer = Arc::new(ScalarQuantizer::train(&[&[0.0, 0.0], &[10.0, 10.0]]));
567 let mut store = QuantizedVectorStore::new(quantizer.clone(), 100);
568
569 store.push(&[5.0, 5.0]);
570
571 let slice = store.get_slice(0).expect("Should have slice");
572 assert_eq!(slice.len(), 2);
573
574 assert!((i32::from(slice[0]) - 127).abs() <= 1);
576 assert!((i32::from(slice[1]) - 127).abs() <= 1);
577 }
578
579 #[test]
584 fn test_memory_efficiency_4x_reduction() {
585 let dim = 768;
586 let count = 10_000;
587
588 let float32_bytes = dim * 4 * count;
590
591 let int8_bytes = dim * count;
593
594 assert_eq!(float32_bytes / int8_bytes, 4, "Should be 4x reduction");
595 }
596
597 #[test]
602 fn test_quantize_768d_embedding() {
603 let v1: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
605 let v2: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).cos()).collect();
606
607 let quantizer = ScalarQuantizer::train(&[&v1, &v2]);
608 assert_eq!(quantizer.dimension, 768);
609
610 let qvec = quantizer.quantize(&v1);
611 assert_eq!(qvec.data.len(), 768);
612
613 let recovered = quantizer.dequantize(&qvec);
614 assert_eq!(recovered.len(), 768);
615
616 let mse: f32 = v1
618 .iter()
619 .zip(recovered.iter())
620 .map(|(a, b)| (a - b).powi(2))
621 .sum::<f32>()
622 / 768.0;
623
624 assert!(mse < 0.001, "MSE should be small: {mse}");
625 }
626}