ruvector_router_core/
quantization.rs1use crate::error::{Result, VectorDbError};
4use crate::types::QuantizationType;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum QuantizedVector {
10 None(Vec<f32>),
12 Scalar {
14 data: Vec<u8>,
16 min: f32,
18 scale: f32,
20 },
21 Product {
23 codes: Vec<u8>,
25 subspaces: usize,
27 },
28 Binary {
30 data: Vec<u8>,
32 threshold: f32,
34 },
35}
36
37pub fn quantize(vector: &[f32], qtype: QuantizationType) -> Result<QuantizedVector> {
39 match qtype {
40 QuantizationType::None => Ok(QuantizedVector::None(vector.to_vec())),
41 QuantizationType::Scalar => Ok(scalar_quantize(vector)),
42 QuantizationType::Product { subspaces, k } => {
43 product_quantize(vector, subspaces, k)
44 }
45 QuantizationType::Binary => Ok(binary_quantize(vector)),
46 }
47}
48
49pub fn dequantize(quantized: &QuantizedVector) -> Vec<f32> {
51 match quantized {
52 QuantizedVector::None(v) => v.clone(),
53 QuantizedVector::Scalar { data, min, scale } => {
54 scalar_dequantize(data, *min, *scale)
55 }
56 QuantizedVector::Product { codes, subspaces } => {
57 vec![0.0; codes.len() * (codes.len() / subspaces)]
59 }
60 QuantizedVector::Binary { data, threshold } => {
61 binary_dequantize(data, *threshold)
62 }
63 }
64}
65
66fn scalar_quantize(vector: &[f32]) -> QuantizedVector {
68 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
69 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
70
71 let scale = if max > min {
72 255.0 / (max - min)
73 } else {
74 1.0
75 };
76
77 let data: Vec<u8> = vector
78 .iter()
79 .map(|&v| ((v - min) * scale).clamp(0.0, 255.0) as u8)
80 .collect();
81
82 QuantizedVector::Scalar { data, min, scale }
83}
84
85fn scalar_dequantize(data: &[u8], min: f32, scale: f32) -> Vec<f32> {
87 data.iter()
88 .map(|&v| (v as f32) / scale + min)
89 .collect()
90}
91
92fn product_quantize(
94 vector: &[f32],
95 subspaces: usize,
96 _k: usize,
97) -> Result<QuantizedVector> {
98 if !vector.len().is_multiple_of(subspaces) {
99 return Err(VectorDbError::Quantization(
100 "Vector length must be divisible by number of subspaces".to_string(),
101 ));
102 }
103
104 let subspace_dim = vector.len() / subspaces;
107 let codes: Vec<u8> = (0..subspaces)
108 .map(|i| {
109 let start = i * subspace_dim;
110 let subvec = &vector[start..start + subspace_dim];
111 (subvec.iter().sum::<f32>() as u32 % 256) as u8
113 })
114 .collect();
115
116 Ok(QuantizedVector::Product { codes, subspaces })
117}
118
119fn binary_quantize(vector: &[f32]) -> QuantizedVector {
121 let threshold = vector.iter().sum::<f32>() / vector.len() as f32;
122
123 let num_bytes = vector.len().div_ceil(8);
124 let mut data = vec![0u8; num_bytes];
125
126 for (i, &val) in vector.iter().enumerate() {
127 if val > threshold {
128 let byte_idx = i / 8;
129 let bit_idx = i % 8;
130 data[byte_idx] |= 1 << bit_idx;
131 }
132 }
133
134 QuantizedVector::Binary { data, threshold }
135}
136
137fn binary_dequantize(data: &[u8], threshold: f32) -> Vec<f32> {
139 let mut result = Vec::with_capacity(data.len() * 8);
140
141 for &byte in data {
142 for bit_idx in 0..8 {
143 let bit = (byte >> bit_idx) & 1;
144 result.push(if bit == 1 { threshold + 1.0 } else { threshold - 1.0 });
145 }
146 }
147
148 result
149}
150
151pub fn calculate_compression_ratio(
153 original_dims: usize,
154 qtype: QuantizationType,
155) -> f32 {
156 let original_bytes = original_dims * 4; let quantized_bytes = match qtype {
158 QuantizationType::None => original_bytes,
159 QuantizationType::Scalar => original_dims + 8, QuantizationType::Product { subspaces, .. } => subspaces + 4, QuantizationType::Binary => original_dims.div_ceil(8) + 4, };
163
164 original_bytes as f32 / quantized_bytes as f32
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_scalar_quantization() {
173 let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
174 let quantized = scalar_quantize(&vector);
175 let dequantized = dequantize(&quantized);
176
177 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
179 assert!((orig - deq).abs() < 0.1);
180 }
181 }
182
183 #[test]
184 fn test_binary_quantization() {
185 let vector = vec![1.0, 5.0, 2.0, 8.0, 3.0];
186 let quantized = binary_quantize(&vector);
187
188 match quantized {
189 QuantizedVector::Binary { data, .. } => {
190 assert!(!data.is_empty());
191 }
192 _ => panic!("Expected binary quantization"),
193 }
194 }
195
196 #[test]
197 fn test_compression_ratio() {
198 let ratio = calculate_compression_ratio(384, QuantizationType::Scalar);
199 assert!(ratio > 3.0); let ratio = calculate_compression_ratio(384, QuantizationType::Binary);
202 assert!(ratio > 20.0); }
204}