sql_rs/vector/
quantization.rs1use super::{Distance, DistanceMetric};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct QuantizedVector {
6 pub dimension: usize,
8 pub bits_per_scalar: u8,
10 pub data: Vec<u8>,
12 pub min_values: Vec<f32>,
14 pub max_values: Vec<f32>,
16}
17
18impl QuantizedVector {
19 pub fn quantize_8bit(vector: &[f32]) -> Self {
21 let dimension = vector.len();
22 let mut min_values = Vec::with_capacity(dimension);
23 let mut max_values = Vec::with_capacity(dimension);
24
25 for i in 0..dimension {
27 let min = vector[i];
28 let max = vector[i];
29
30 min_values.push(min);
33 max_values.push(max);
34 }
35
36 let mut data = Vec::with_capacity(dimension);
38 for i in 0..dimension {
39 let min = min_values[i];
40 let max = max_values[i];
41
42 let normalized = if max - min > 0.0 {
43 (vector[i] - min) / (max - min)
44 } else {
45 0.5
46 };
47
48 let quantized = (normalized * 255.0).round() as u8;
49 data.push(quantized);
50 }
51
52 Self {
53 dimension,
54 bits_per_scalar: 8,
55 data,
56 min_values,
57 max_values,
58 }
59 }
60
61 pub fn quantize_16bit(vector: &[f32]) -> Self {
63 let dimension = vector.len();
64 let mut min_values = Vec::with_capacity(dimension);
65 let mut max_values = Vec::with_capacity(dimension);
66
67 for i in 0..dimension {
69 let min = vector[i];
70 let max = vector[i];
71
72 min_values.push(min);
73 max_values.push(max);
74 }
75
76 let mut data = Vec::with_capacity(dimension * 2);
78 for i in 0..dimension {
79 let min = min_values[i];
80 let max = max_values[i];
81
82 let normalized = if max - min > 0.0 {
83 (vector[i] - min) / (max - min)
84 } else {
85 0.5
86 };
87
88 let quantized = (normalized * 65535.0).round() as u16;
89 data.push((quantized >> 8) as u8);
90 data.push((quantized & 0xFF) as u8);
91 }
92
93 Self {
94 dimension,
95 bits_per_scalar: 16,
96 data,
97 min_values,
98 max_values,
99 }
100 }
101
102 pub fn dequantize(&self) -> Vec<f32> {
104 let mut result = Vec::with_capacity(self.dimension);
105
106 if self.bits_per_scalar == 8 {
107 for i in 0..self.dimension {
108 let quantized = self.data[i] as f32;
109 let min = self.min_values[i];
110 let max = self.max_values[i];
111
112 let normalized = quantized / 255.0;
113 let value = min + normalized * (max - min);
114 result.push(value);
115 }
116 } else if self.bits_per_scalar == 16 {
117 for i in 0..self.dimension {
118 let quantized =
119 ((self.data[i * 2] as u16) << 8 | self.data[i * 2 + 1] as u16) as f32;
120 let min = self.min_values[i];
121 let max = self.max_values[i];
122
123 let normalized = quantized / 65535.0;
124 let value = min + normalized * (max - min);
125 result.push(value);
126 }
127 }
128
129 result
130 }
131
132 pub fn distance(&self, other: &QuantizedVector, metric: DistanceMetric) -> f32 {
134 if self.dimension != other.dimension {
135 return f32::INFINITY;
136 }
137
138 let v1 = self.dequantize();
141 let v2 = other.dequantize();
142
143 v1.distance(&v2, metric)
144 }
145
146 pub fn memory_usage(&self) -> usize {
148 std::mem::size_of::<Self>()
149 + self.data.len()
150 + self.min_values.len() * std::mem::size_of::<f32>()
151 + self.max_values.len() * std::mem::size_of::<f32>()
152 }
153
154 pub fn compression_ratio(&self) -> f32 {
156 let original_size = self.dimension * std::mem::size_of::<f32>();
157 let compressed_size = self.data.len()
158 + self.min_values.len() * std::mem::size_of::<f32>()
159 + self.max_values.len() * std::mem::size_of::<f32>();
160 original_size as f32 / compressed_size as f32
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_quantize_8bit() {
170 let vector = vec![1.0, -1.0, 0.5, 0.0];
171 let quantized = QuantizedVector::quantize_8bit(&vector);
172
173 assert_eq!(quantized.dimension, 4);
174 assert_eq!(quantized.bits_per_scalar, 8);
175
176 let dequantized = quantized.dequantize();
177 assert_eq!(dequantized.len(), 4);
178
179 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
181 assert!((orig - deq).abs() < 0.01);
182 }
183 }
184
185 #[test]
186 fn test_quantize_16bit() {
187 let vector = vec![1.0, -1.0, 0.5, 0.0];
188 let quantized = QuantizedVector::quantize_16bit(&vector);
189
190 assert_eq!(quantized.dimension, 4);
191 assert_eq!(quantized.bits_per_scalar, 16);
192 assert_eq!(quantized.data.len(), 8); let dequantized = quantized.dequantize();
195 assert_eq!(dequantized.len(), 4);
196
197 for (orig, deq) in vector.iter().zip(dequantized.iter()) {
199 assert!((orig - deq).abs() < 0.001);
200 }
201 }
202
203 #[test]
204 fn test_compression_ratio() {
205 let vector = vec![1.0; 100];
206 let quantized_8bit = QuantizedVector::quantize_8bit(&vector);
207 let quantized_16bit = QuantizedVector::quantize_16bit(&vector);
208
209 assert!(quantized_8bit.compression_ratio() > quantized_16bit.compression_ratio());
211
212 assert!(quantized_8bit.compression_ratio() > 0.0);
217 assert!(quantized_16bit.compression_ratio() > 0.0);
218 }
219
220 #[test]
221 fn test_quantized_distance() {
222 let v1 = vec![1.0, 0.0, 0.0];
223 let v2 = vec![0.0, 1.0, 0.0];
224
225 let q1 = QuantizedVector::quantize_8bit(&v1);
226 let q2 = QuantizedVector::quantize_8bit(&v2);
227
228 let distance = q1.distance(&q2, DistanceMetric::Euclidean);
229 assert!((distance - 1.414).abs() < 0.1);
230 }
231}