1use std::io;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum StorageMode {
19 #[default]
21 Full,
22 SQ8,
24 Binary,
27}
28
29#[derive(Debug, Clone)]
46pub struct BinaryQuantizedVector {
47 pub data: Vec<u8>,
49 dimension: usize,
51}
52
53impl BinaryQuantizedVector {
54 #[must_use]
66 pub fn from_f32(vector: &[f32]) -> Self {
67 assert!(!vector.is_empty(), "Cannot quantize empty vector");
68
69 let dimension = vector.len();
70 let num_bytes = dimension.div_ceil(8);
72 let mut data = vec![0u8; num_bytes];
73
74 for (i, &value) in vector.iter().enumerate() {
75 if value >= 0.0 {
76 let byte_idx = i / 8;
78 let bit_idx = i % 8;
79 data[byte_idx] |= 1 << bit_idx;
80 }
81 }
82
83 Self { data, dimension }
84 }
85
86 #[must_use]
88 pub fn dimension(&self) -> usize {
89 self.dimension
90 }
91
92 #[must_use]
94 pub fn memory_size(&self) -> usize {
95 self.data.len()
96 }
97
98 #[must_use]
102 pub fn get_bits(&self) -> Vec<bool> {
103 (0..self.dimension)
104 .map(|i| {
105 let byte_idx = i / 8;
106 let bit_idx = i % 8;
107 (self.data[byte_idx] >> bit_idx) & 1 == 1
108 })
109 .collect()
110 }
111
112 #[must_use]
121 pub fn hamming_distance(&self, other: &Self) -> u32 {
122 debug_assert_eq!(
123 self.dimension, other.dimension,
124 "Dimension mismatch in hamming_distance"
125 );
126
127 self.data
129 .iter()
130 .zip(other.data.iter())
131 .map(|(&a, &b)| (a ^ b).count_ones())
132 .sum()
133 }
134
135 #[must_use]
139 #[allow(clippy::cast_precision_loss)]
140 pub fn hamming_similarity(&self, other: &Self) -> f32 {
141 let distance = self.hamming_distance(other);
142 1.0 - (distance as f32 / self.dimension as f32)
143 }
144
145 #[must_use]
147 pub fn to_bytes(&self) -> Vec<u8> {
148 let mut bytes = Vec::with_capacity(4 + self.data.len());
149 #[allow(clippy::cast_possible_truncation)]
151 bytes.extend_from_slice(&(self.dimension as u32).to_le_bytes());
152 bytes.extend_from_slice(&self.data);
153 bytes
154 }
155
156 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
162 if bytes.len() < 4 {
163 return Err(io::Error::new(
164 io::ErrorKind::InvalidData,
165 "Not enough bytes for BinaryQuantizedVector header",
166 ));
167 }
168
169 let dimension = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
170 let expected_data_len = dimension.div_ceil(8);
171
172 if bytes.len() < 4 + expected_data_len {
173 return Err(io::Error::new(
174 io::ErrorKind::InvalidData,
175 format!(
176 "Not enough bytes for BinaryQuantizedVector data: expected {}, got {}",
177 4 + expected_data_len,
178 bytes.len()
179 ),
180 ));
181 }
182
183 let data = bytes[4..4 + expected_data_len].to_vec();
184
185 Ok(Self { data, dimension })
186 }
187}
188
189#[derive(Debug, Clone)]
194pub struct QuantizedVector {
195 pub data: Vec<u8>,
197 pub min: f32,
199 pub max: f32,
201}
202
203impl QuantizedVector {
204 #[must_use]
214 pub fn from_f32(vector: &[f32]) -> Self {
215 assert!(!vector.is_empty(), "Cannot quantize empty vector");
216
217 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
218 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
219
220 let range = max - min;
221 let data = if range < f32::EPSILON {
222 vec![128u8; vector.len()]
224 } else {
225 let scale = 255.0 / range;
226 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
227 vector
228 .iter()
229 .map(|&v| {
230 let normalized = (v - min) * scale;
231 normalized.round().clamp(0.0, 255.0) as u8
234 })
235 .collect()
236 };
237
238 Self { data, min, max }
239 }
240
241 #[must_use]
245 pub fn to_f32(&self) -> Vec<f32> {
246 let range = self.max - self.min;
247 if range < f32::EPSILON {
248 vec![self.min; self.data.len()]
250 } else {
251 let scale = range / 255.0;
252 self.data
253 .iter()
254 .map(|&v| f32::from(v) * scale + self.min)
255 .collect()
256 }
257 }
258
259 #[must_use]
261 pub fn dimension(&self) -> usize {
262 self.data.len()
263 }
264
265 #[must_use]
267 pub fn memory_size(&self) -> usize {
268 self.data.len() + 8 }
270
271 #[must_use]
273 pub fn to_bytes(&self) -> Vec<u8> {
274 let mut bytes = Vec::with_capacity(8 + self.data.len());
275 bytes.extend_from_slice(&self.min.to_le_bytes());
276 bytes.extend_from_slice(&self.max.to_le_bytes());
277 bytes.extend_from_slice(&self.data);
278 bytes
279 }
280
281 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
287 if bytes.len() < 8 {
288 return Err(io::Error::new(
289 io::ErrorKind::InvalidData,
290 "Not enough bytes for QuantizedVector header",
291 ));
292 }
293
294 let min = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
295 let max = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
296 let data = bytes[8..].to_vec();
297
298 Ok(Self { data, min, max })
299 }
300}
301
302#[must_use]
306pub fn dot_product_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
307 debug_assert_eq!(
308 query.len(),
309 quantized.data.len(),
310 "Dimension mismatch in dot_product_quantized"
311 );
312
313 let range = quantized.max - quantized.min;
314 if range < f32::EPSILON {
315 let value = quantized.min;
317 return query.iter().sum::<f32>() * value;
318 }
319
320 let scale = range / 255.0;
321 let offset = quantized.min;
322
323 query
325 .iter()
326 .zip(quantized.data.iter())
327 .map(|(&q, &v)| q * (f32::from(v) * scale + offset))
328 .sum()
329}
330
331#[must_use]
333pub fn euclidean_squared_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
334 debug_assert_eq!(
335 query.len(),
336 quantized.data.len(),
337 "Dimension mismatch in euclidean_squared_quantized"
338 );
339
340 let range = quantized.max - quantized.min;
341 if range < f32::EPSILON {
342 let value = quantized.min;
344 return query.iter().map(|&q| (q - value).powi(2)).sum();
345 }
346
347 let scale = range / 255.0;
348 let offset = quantized.min;
349
350 query
351 .iter()
352 .zip(quantized.data.iter())
353 .map(|(&q, &v)| {
354 let dequantized = f32::from(v) * scale + offset;
355 (q - dequantized).powi(2)
356 })
357 .sum()
358}
359
360#[must_use]
364pub fn cosine_similarity_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
365 let dot = dot_product_quantized(query, quantized);
366
367 let query_norm: f32 = query.iter().map(|&x| x * x).sum::<f32>().sqrt();
369
370 let reconstructed = quantized.to_f32();
372 let quantized_norm: f32 = reconstructed.iter().map(|&x| x * x).sum::<f32>().sqrt();
373
374 if query_norm < f32::EPSILON || quantized_norm < f32::EPSILON {
375 return 0.0;
376 }
377
378 dot / (query_norm * quantized_norm)
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
390 fn test_quantize_simple_vector() {
391 let vector = vec![0.0, 0.5, 1.0];
393
394 let quantized = QuantizedVector::from_f32(&vector);
396
397 assert_eq!(quantized.dimension(), 3);
399 assert!((quantized.min - 0.0).abs() < f32::EPSILON);
400 assert!((quantized.max - 1.0).abs() < f32::EPSILON);
401 assert_eq!(quantized.data[0], 0); assert_eq!(quantized.data[1], 128); assert_eq!(quantized.data[2], 255); }
405
406 #[test]
407 fn test_quantize_negative_values() {
408 let vector = vec![-1.0, 0.0, 1.0];
410
411 let quantized = QuantizedVector::from_f32(&vector);
413
414 assert!((quantized.min - (-1.0)).abs() < f32::EPSILON);
416 assert!((quantized.max - 1.0).abs() < f32::EPSILON);
417 assert_eq!(quantized.data[0], 0); assert_eq!(quantized.data[1], 128); assert_eq!(quantized.data[2], 255); }
421
422 #[test]
423 fn test_quantize_constant_vector() {
424 let vector = vec![0.5, 0.5, 0.5];
426
427 let quantized = QuantizedVector::from_f32(&vector);
429
430 assert_eq!(quantized.dimension(), 3);
432 for &v in &quantized.data {
434 assert_eq!(v, 128);
435 }
436 }
437
438 #[test]
439 fn test_dequantize_roundtrip() {
440 let original = vec![0.1, 0.5, 0.9, -0.3, 0.0];
442
443 let quantized = QuantizedVector::from_f32(&original);
445 let reconstructed = quantized.to_f32();
446
447 assert_eq!(reconstructed.len(), original.len());
449 for (orig, recon) in original.iter().zip(reconstructed.iter()) {
450 let error = (orig - recon).abs();
451 let max_error = (quantized.max - quantized.min) / 255.0;
453 assert!(
454 error <= max_error + f32::EPSILON,
455 "Error {error} exceeds max {max_error}"
456 );
457 }
458 }
459
460 #[test]
461 #[allow(clippy::cast_precision_loss)]
462 fn test_memory_reduction() {
463 let dimension = 768;
465 let vector: Vec<f32> = (0..dimension)
466 .map(|i| i as f32 / dimension as f32)
467 .collect();
468
469 let quantized = QuantizedVector::from_f32(&vector);
471
472 let f32_size = dimension * 4; let sq8_size = quantized.memory_size(); assert_eq!(f32_size, 3072);
477 assert_eq!(sq8_size, 776);
478 #[allow(clippy::cast_precision_loss)]
480 let ratio = f32_size as f32 / sq8_size as f32;
481 assert!(ratio > 3.9);
482 }
483
484 #[test]
485 fn test_serialization_roundtrip() {
486 let vector = vec![0.1, 0.5, 0.9, -0.3];
488 let quantized = QuantizedVector::from_f32(&vector);
489
490 let bytes = quantized.to_bytes();
492 let deserialized = QuantizedVector::from_bytes(&bytes).unwrap();
493
494 assert!((deserialized.min - quantized.min).abs() < f32::EPSILON);
496 assert!((deserialized.max - quantized.max).abs() < f32::EPSILON);
497 assert_eq!(deserialized.data, quantized.data);
498 }
499
500 #[test]
501 fn test_from_bytes_invalid() {
502 let bytes = vec![0u8; 5];
504
505 let result = QuantizedVector::from_bytes(&bytes);
507
508 assert!(result.is_err());
510 }
511
512 #[test]
517 fn test_dot_product_quantized_simple() {
518 let query = vec![1.0, 0.0, 0.0];
520 let vector = vec![1.0, 0.0, 0.0];
521 let quantized = QuantizedVector::from_f32(&vector);
522
523 let dot = dot_product_quantized(&query, &quantized);
525
526 assert!(
528 (dot - 1.0).abs() < 0.1,
529 "Dot product {dot} not close to 1.0"
530 );
531 }
532
533 #[test]
534 fn test_dot_product_quantized_orthogonal() {
535 let query = vec![1.0, 0.0, 0.0];
537 let vector = vec![0.0, 1.0, 0.0];
538 let quantized = QuantizedVector::from_f32(&vector);
539
540 let dot = dot_product_quantized(&query, &quantized);
542
543 assert!(dot.abs() < 0.1, "Dot product {dot} not close to 0");
545 }
546
547 #[test]
548 fn test_euclidean_squared_quantized() {
549 let query = vec![0.0, 0.0, 0.0];
551 let vector = vec![1.0, 0.0, 0.0];
552 let quantized = QuantizedVector::from_f32(&vector);
553
554 let dist = euclidean_squared_quantized(&query, &quantized);
556
557 assert!(
559 (dist - 1.0).abs() < 0.1,
560 "Euclidean squared {dist} not close to 1.0"
561 );
562 }
563
564 #[test]
565 fn test_euclidean_squared_quantized_same_point() {
566 let vector = vec![0.5, 0.5, 0.5];
568 let quantized = QuantizedVector::from_f32(&vector);
569
570 let dist = euclidean_squared_quantized(&vector, &quantized);
572
573 assert!(dist < 0.01, "Distance to self {dist} should be ~0");
575 }
576
577 #[test]
578 fn test_cosine_similarity_quantized_identical() {
579 let vector = vec![1.0, 2.0, 3.0];
581 let quantized = QuantizedVector::from_f32(&vector);
582
583 let similarity = cosine_similarity_quantized(&vector, &quantized);
585
586 assert!(
588 (similarity - 1.0).abs() < 0.05,
589 "Cosine similarity to self {similarity} not close to 1.0"
590 );
591 }
592
593 #[test]
594 fn test_cosine_similarity_quantized_opposite() {
595 let query = vec![1.0, 0.0, 0.0];
597 let vector = vec![-1.0, 0.0, 0.0];
598 let quantized = QuantizedVector::from_f32(&vector);
599
600 let similarity = cosine_similarity_quantized(&query, &quantized);
602
603 assert!(
605 (similarity - (-1.0)).abs() < 0.1,
606 "Cosine similarity {similarity} not close to -1.0"
607 );
608 }
609
610 #[test]
615 #[allow(clippy::cast_precision_loss)]
616 fn test_recall_accuracy_high_dimension() {
617 let dimension = 768;
619 let num_vectors = 100;
620
621 let vectors: Vec<Vec<f32>> = (0..num_vectors)
623 .map(|i| {
624 (0..dimension)
625 .map(|j| {
626 let x = ((i * 7 + j * 13) % 1000) as f32 / 1000.0;
627 x * 2.0 - 1.0 })
629 .collect()
630 })
631 .collect();
632
633 let quantized: Vec<QuantizedVector> = vectors
635 .iter()
636 .map(|v| QuantizedVector::from_f32(v))
637 .collect();
638
639 let query = &vectors[0];
641
642 let mut f32_distances: Vec<(usize, f32)> = vectors
644 .iter()
645 .enumerate()
646 .map(|(i, v)| {
647 let dot: f32 = query.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
648 (i, dot)
649 })
650 .collect();
651
652 let mut sq8_distances: Vec<(usize, f32)> = quantized
653 .iter()
654 .enumerate()
655 .map(|(i, q)| (i, dot_product_quantized(query, q)))
656 .collect();
657
658 f32_distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
660 sq8_distances.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
661
662 let k = 10;
664 let f32_top_k: std::collections::HashSet<usize> =
665 f32_distances.iter().take(k).map(|(i, _)| *i).collect();
666 let sq8_top_k: std::collections::HashSet<usize> =
667 sq8_distances.iter().take(k).map(|(i, _)| *i).collect();
668
669 let recall = f32_top_k.intersection(&sq8_top_k).count() as f32 / k as f32;
670
671 assert!(
672 recall >= 0.8,
673 "Recall@{k} is {recall}, expected >= 0.8 (80%)"
674 );
675 }
676
677 #[test]
678 fn test_storage_mode_enum() {
679 let full = StorageMode::Full;
681 let sq8 = StorageMode::SQ8;
682 let binary = StorageMode::Binary;
683 let default = StorageMode::default();
684
685 assert_eq!(full, StorageMode::Full);
687 assert_eq!(sq8, StorageMode::SQ8);
688 assert_eq!(binary, StorageMode::Binary);
689 assert_eq!(default, StorageMode::Full);
690 assert_ne!(full, sq8);
691 assert_ne!(sq8, binary);
692 }
693
694 #[test]
699 fn test_binary_quantize_simple_vector() {
700 let vector = vec![-1.0, 0.5, -0.5, 1.0];
702
703 let binary = BinaryQuantizedVector::from_f32(&vector);
705
706 assert_eq!(binary.dimension(), 4);
708 assert_eq!(binary.data.len(), 1); }
714
715 #[test]
716 fn test_binary_quantize_768d_memory() {
717 let vector: Vec<f32> = (0..768)
719 .map(|i| if i % 2 == 0 { 0.5 } else { -0.5 })
720 .collect();
721
722 let binary = BinaryQuantizedVector::from_f32(&vector);
724
725 assert_eq!(binary.dimension(), 768);
727 assert_eq!(binary.data.len(), 96); let f32_size = 768 * 4;
734 let binary_size = binary.memory_size();
735 assert_eq!(binary_size, 96);
736 #[allow(clippy::cast_precision_loss)]
737 let ratio = f32_size as f32 / binary_size as f32;
738 assert!(ratio >= 32.0, "Expected 32x reduction, got {ratio}x");
739 }
740
741 #[test]
742 fn test_binary_quantize_threshold_at_zero() {
743 let vector = vec![0.0, 0.001, -0.001, f32::EPSILON];
745
746 let binary = BinaryQuantizedVector::from_f32(&vector);
748
749 let bits = binary.get_bits();
752 assert!(bits[0], "0.0 should be 1");
753 assert!(bits[1], "0.001 should be 1");
754 assert!(!bits[2], "-0.001 should be 0");
755 assert!(bits[3], "EPSILON should be 1");
756 }
757
758 #[test]
759 fn test_binary_hamming_distance_identical() {
760 let vector = vec![0.5, -0.5, 0.5, -0.5, 0.5, -0.5, 0.5, -0.5];
762 let binary = BinaryQuantizedVector::from_f32(&vector);
763
764 let distance = binary.hamming_distance(&binary);
766
767 assert_eq!(distance, 0);
769 }
770
771 #[test]
772 fn test_binary_hamming_distance_opposite() {
773 let v1 = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
775 let v2 = vec![-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0];
776 let b1 = BinaryQuantizedVector::from_f32(&v1);
777 let b2 = BinaryQuantizedVector::from_f32(&v2);
778
779 let distance = b1.hamming_distance(&b2);
781
782 assert_eq!(distance, 8);
784 }
785
786 #[test]
787 fn test_binary_hamming_distance_half_different() {
788 let v1 = vec![1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0];
790 let v2 = vec![1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0];
791 let b1 = BinaryQuantizedVector::from_f32(&v1);
792 let b2 = BinaryQuantizedVector::from_f32(&v2);
793
794 let distance = b1.hamming_distance(&b2);
796
797 assert_eq!(distance, 4);
799 }
800
801 #[test]
802 fn test_binary_serialization_roundtrip() {
803 let vector: Vec<f32> = (0..768)
805 .map(|i| if i % 3 == 0 { 0.5 } else { -0.5 })
806 .collect();
807 let binary = BinaryQuantizedVector::from_f32(&vector);
808
809 let bytes = binary.to_bytes();
811 let deserialized = BinaryQuantizedVector::from_bytes(&bytes).unwrap();
812
813 assert_eq!(deserialized.dimension(), binary.dimension());
815 assert_eq!(deserialized.data, binary.data);
816 assert_eq!(deserialized.hamming_distance(&binary), 0);
817 }
818
819 #[test]
820 fn test_binary_from_bytes_invalid() {
821 let bytes = vec![0u8; 3];
823
824 let result = BinaryQuantizedVector::from_bytes(&bytes);
826
827 assert!(result.is_err());
829 }
830}