1use crate::error::{Result, VectorDbError};
4use crate::types::QuantizationType;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum QuantizedVector {
10 None(Vec<f32>),
12 Scalar {
14 data: Vec<u8>,
16 min: f32,
18 scale: f32,
20 },
21 Product {
23 codes: Vec<u8>,
25 subspaces: usize,
27 },
28 Binary {
30 data: Vec<u8>,
32 threshold: f32,
34 dimensions: usize,
36 },
37}
38
39pub fn quantize(vector: &[f32], qtype: QuantizationType) -> Result<QuantizedVector> {
41 match qtype {
42 QuantizationType::None => Ok(QuantizedVector::None(vector.to_vec())),
43 QuantizationType::Scalar => Ok(scalar_quantize(vector)),
44 QuantizationType::Product { subspaces, k } => product_quantize(vector, subspaces, k),
45 QuantizationType::Binary => Ok(binary_quantize(vector)),
46 }
47}
48
49pub fn dequantize(quantized: &QuantizedVector) -> Vec<f32> {
51 match quantized {
52 QuantizedVector::None(v) => v.clone(),
53 QuantizedVector::Scalar { data, min, scale } => scalar_dequantize(data, *min, *scale),
54 QuantizedVector::Product { codes, subspaces } => {
55 vec![0.0; codes.len() * (codes.len() / subspaces)]
57 }
58 QuantizedVector::Binary {
59 data,
60 threshold,
61 dimensions,
62 } => binary_dequantize(data, *threshold, *dimensions),
63 }
64}
65
66fn scalar_quantize(vector: &[f32]) -> QuantizedVector {
68 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
69 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
70
71 let scale = if max > min { 255.0 / (max - min) } else { 1.0 };
72
73 let data: Vec<u8> = vector
74 .iter()
75 .map(|&v| ((v - min) * scale).clamp(0.0, 255.0) as u8)
76 .collect();
77
78 QuantizedVector::Scalar { data, min, scale }
79}
80
81fn scalar_dequantize(data: &[u8], min: f32, scale: f32) -> Vec<f32> {
83 data.iter().map(|&v| min + (v as f32) / scale).collect()
90}
91
92fn product_quantize(vector: &[f32], subspaces: usize, _k: usize) -> Result<QuantizedVector> {
94 if !vector.len().is_multiple_of(subspaces) {
95 return Err(VectorDbError::Quantization(
96 "Vector length must be divisible by number of subspaces".to_string(),
97 ));
98 }
99
100 let subspace_dim = vector.len() / subspaces;
103 let codes: Vec<u8> = (0..subspaces)
104 .map(|i| {
105 let start = i * subspace_dim;
106 let subvec = &vector[start..start + subspace_dim];
107 (subvec.iter().sum::<f32>() as u32 % 256) as u8
109 })
110 .collect();
111
112 Ok(QuantizedVector::Product { codes, subspaces })
113}
114
115fn binary_quantize(vector: &[f32]) -> QuantizedVector {
117 let threshold = vector.iter().sum::<f32>() / vector.len() as f32;
118 let dimensions = vector.len();
119
120 let num_bytes = dimensions.div_ceil(8);
121 let mut data = vec![0u8; num_bytes];
122
123 for (i, &val) in vector.iter().enumerate() {
124 if val > threshold {
125 let byte_idx = i / 8;
126 let bit_idx = i % 8;
127 data[byte_idx] |= 1 << bit_idx;
128 }
129 }
130
131 QuantizedVector::Binary {
132 data,
133 threshold,
134 dimensions,
135 }
136}
137
138fn binary_dequantize(data: &[u8], threshold: f32, dimensions: usize) -> Vec<f32> {
140 let mut result = Vec::with_capacity(dimensions);
141
142 for (i, &byte) in data.iter().enumerate() {
143 for bit_idx in 0..8 {
144 if result.len() >= dimensions {
145 break;
146 }
147 let bit = (byte >> bit_idx) & 1;
148 result.push(if bit == 1 {
149 threshold + 1.0
150 } else {
151 threshold - 1.0
152 });
153 }
154 if result.len() >= dimensions {
155 break;
156 }
157 }
158
159 result
160}
161
162pub fn calculate_compression_ratio(original_dims: usize, qtype: QuantizationType) -> f32 {
164 let original_bytes = original_dims * 4; let quantized_bytes = match qtype {
166 QuantizationType::None => original_bytes,
167 QuantizationType::Scalar => original_dims + 8, QuantizationType::Product { subspaces, .. } => subspaces + 4, QuantizationType::Binary => original_dims.div_ceil(8) + 4, };
171
172 original_bytes as f32 / quantized_bytes as f32
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_scalar_quantization() {
181 let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
182 let quantized = scalar_quantize(&vector);
183 let dequantized = dequantize(&quantized);
184
185 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
187 assert!((orig - deq).abs() < 0.1);
188 }
189 }
190
191 #[test]
192 fn test_binary_quantization() {
193 let vector = vec![1.0, 5.0, 2.0, 8.0, 3.0];
194 let quantized = binary_quantize(&vector);
195
196 match quantized {
197 QuantizedVector::Binary {
198 data, dimensions, ..
199 } => {
200 assert!(!data.is_empty());
201 assert_eq!(dimensions, 5);
202 }
203 _ => panic!("Expected binary quantization"),
204 }
205 }
206
207 #[test]
208 fn test_compression_ratio() {
209 let ratio = calculate_compression_ratio(384, QuantizationType::Scalar);
210 assert!(ratio > 3.0); let ratio = calculate_compression_ratio(384, QuantizationType::Binary);
213 assert!(ratio > 20.0); }
215
216 #[test]
217 fn test_scalar_quantization_roundtrip() {
218 let test_vectors = vec![
220 vec![1.0, 2.0, 3.0, 4.0, 5.0],
221 vec![-10.0, -5.0, 0.0, 5.0, 10.0],
222 vec![0.1, 0.2, 0.3, 0.4, 0.5],
223 vec![100.0, 200.0, 300.0, 400.0, 500.0],
224 ];
225
226 for vector in test_vectors {
227 let quantized = scalar_quantize(&vector);
228 let dequantized = dequantize(&quantized);
229
230 assert_eq!(vector.len(), dequantized.len());
231
232 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
233 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
235 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
236 let max_error = (max - min) / 255.0 * 2.0; assert!(
239 (orig - deq).abs() < max_error,
240 "Roundtrip error too large: orig={}, deq={}, error={}",
241 orig,
242 deq,
243 (orig - deq).abs()
244 );
245 }
246 }
247 }
248
249 #[test]
250 fn test_scalar_quantization_edge_cases() {
251 let same_values = vec![5.0, 5.0, 5.0, 5.0];
253 let quantized = scalar_quantize(&same_values);
254 let dequantized = dequantize(&quantized);
255
256 for (orig, deq) in same_values.iter().zip(dequantized.iter()) {
257 assert!((orig - deq).abs() < 0.01);
258 }
259
260 let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
262 let quantized = scalar_quantize(&extreme);
263 let dequantized = dequantize(&quantized);
264
265 assert_eq!(extreme.len(), dequantized.len());
266 }
267
268 #[test]
269 fn test_binary_quantization_roundtrip() {
270 let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
271 let quantized = binary_quantize(&vector);
272 let dequantized = dequantize(&quantized);
273
274 assert_eq!(
277 vector.len(),
278 dequantized.len(),
279 "Dequantized vector should have same length as original"
280 );
281
282 match quantized {
283 QuantizedVector::Binary {
284 threshold,
285 dimensions,
286 ..
287 } => {
288 assert_eq!(dimensions, vector.len());
289 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
290 let orig_above = orig > &threshold;
292 let deq_above = deq > &threshold;
293 assert_eq!(orig_above, deq_above);
294 }
295 }
296 _ => panic!("Expected binary quantization"),
297 }
298 }
299}