ruvector_router_core/
quantization.rs1use crate::error::{Result, VectorDbError};
4use crate::types::QuantizationType;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum QuantizedVector {
10 None(Vec<f32>),
12 Scalar {
14 data: Vec<u8>,
16 min: f32,
18 scale: f32,
20 },
21 Product {
23 codes: Vec<u8>,
25 subspaces: usize,
27 },
28 Binary {
30 data: Vec<u8>,
32 threshold: f32,
34 },
35}
36
37pub fn quantize(vector: &[f32], qtype: QuantizationType) -> Result<QuantizedVector> {
39 match qtype {
40 QuantizationType::None => Ok(QuantizedVector::None(vector.to_vec())),
41 QuantizationType::Scalar => Ok(scalar_quantize(vector)),
42 QuantizationType::Product { subspaces, k } => product_quantize(vector, subspaces, k),
43 QuantizationType::Binary => Ok(binary_quantize(vector)),
44 }
45}
46
47pub fn dequantize(quantized: &QuantizedVector) -> Vec<f32> {
49 match quantized {
50 QuantizedVector::None(v) => v.clone(),
51 QuantizedVector::Scalar { data, min, scale } => scalar_dequantize(data, *min, *scale),
52 QuantizedVector::Product { codes, subspaces } => {
53 vec![0.0; codes.len() * (codes.len() / subspaces)]
55 }
56 QuantizedVector::Binary { data, threshold } => binary_dequantize(data, *threshold),
57 }
58}
59
60fn scalar_quantize(vector: &[f32]) -> QuantizedVector {
62 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
63 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
64
65 let scale = if max > min { 255.0 / (max - min) } else { 1.0 };
66
67 let data: Vec<u8> = vector
68 .iter()
69 .map(|&v| ((v - min) * scale).clamp(0.0, 255.0) as u8)
70 .collect();
71
72 QuantizedVector::Scalar { data, min, scale }
73}
74
75fn scalar_dequantize(data: &[u8], min: f32, scale: f32) -> Vec<f32> {
77 data.iter().map(|&v| (v as f32) / scale + min).collect()
78}
79
80fn product_quantize(vector: &[f32], subspaces: usize, _k: usize) -> Result<QuantizedVector> {
82 if !vector.len().is_multiple_of(subspaces) {
83 return Err(VectorDbError::Quantization(
84 "Vector length must be divisible by number of subspaces".to_string(),
85 ));
86 }
87
88 let subspace_dim = vector.len() / subspaces;
91 let codes: Vec<u8> = (0..subspaces)
92 .map(|i| {
93 let start = i * subspace_dim;
94 let subvec = &vector[start..start + subspace_dim];
95 (subvec.iter().sum::<f32>() as u32 % 256) as u8
97 })
98 .collect();
99
100 Ok(QuantizedVector::Product { codes, subspaces })
101}
102
103fn binary_quantize(vector: &[f32]) -> QuantizedVector {
105 let threshold = vector.iter().sum::<f32>() / vector.len() as f32;
106
107 let num_bytes = vector.len().div_ceil(8);
108 let mut data = vec![0u8; num_bytes];
109
110 for (i, &val) in vector.iter().enumerate() {
111 if val > threshold {
112 let byte_idx = i / 8;
113 let bit_idx = i % 8;
114 data[byte_idx] |= 1 << bit_idx;
115 }
116 }
117
118 QuantizedVector::Binary { data, threshold }
119}
120
121fn binary_dequantize(data: &[u8], threshold: f32) -> Vec<f32> {
123 let mut result = Vec::with_capacity(data.len() * 8);
124
125 for &byte in data {
126 for bit_idx in 0..8 {
127 let bit = (byte >> bit_idx) & 1;
128 result.push(if bit == 1 {
129 threshold + 1.0
130 } else {
131 threshold - 1.0
132 });
133 }
134 }
135
136 result
137}
138
139pub fn calculate_compression_ratio(original_dims: usize, qtype: QuantizationType) -> f32 {
141 let original_bytes = original_dims * 4; let quantized_bytes = match qtype {
143 QuantizationType::None => original_bytes,
144 QuantizationType::Scalar => original_dims + 8, QuantizationType::Product { subspaces, .. } => subspaces + 4, QuantizationType::Binary => original_dims.div_ceil(8) + 4, };
148
149 original_bytes as f32 / quantized_bytes as f32
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_scalar_quantization() {
158 let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
159 let quantized = scalar_quantize(&vector);
160 let dequantized = dequantize(&quantized);
161
162 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
164 assert!((orig - deq).abs() < 0.1);
165 }
166 }
167
168 #[test]
169 fn test_binary_quantization() {
170 let vector = vec![1.0, 5.0, 2.0, 8.0, 3.0];
171 let quantized = binary_quantize(&vector);
172
173 match quantized {
174 QuantizedVector::Binary { data, .. } => {
175 assert!(!data.is_empty());
176 }
177 _ => panic!("Expected binary quantization"),
178 }
179 }
180
181 #[test]
182 fn test_compression_ratio() {
183 let ratio = calculate_compression_ratio(384, QuantizationType::Scalar);
184 assert!(ratio > 3.0); let ratio = calculate_compression_ratio(384, QuantizationType::Binary);
187 assert!(ratio > 20.0); }
189}