velesdb_core/quantization/
scalar.rs1use std::io;
8
9use super::QuantizationCodec;
10
11#[derive(Debug, Clone)]
16pub struct QuantizedVector {
17 pub data: Vec<u8>,
19 pub min: f32,
21 pub max: f32,
23}
24
25impl QuantizedVector {
26 #[must_use]
36 pub fn from_f32(vector: &[f32]) -> Self {
37 assert!(!vector.is_empty(), "Cannot quantize empty vector");
38
39 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
40 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
41
42 let range = max - min;
43 let data = if range < f32::EPSILON {
44 vec![128u8; vector.len()]
46 } else {
47 let scale = 255.0 / range;
48 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
52 vector
53 .iter()
54 .map(|&v| {
55 let normalized = (v - min) * scale;
56 normalized.round().clamp(0.0, 255.0) as u8
57 })
58 .collect()
59 };
60
61 Self { data, min, max }
62 }
63
64 #[must_use]
68 pub fn to_f32(&self) -> Vec<f32> {
69 let range = self.max - self.min;
70 if range < f32::EPSILON {
71 vec![self.min; self.data.len()]
73 } else {
74 let scale = range / 255.0;
75 self.data
76 .iter()
77 .map(|&v| f32::from(v) * scale + self.min)
78 .collect()
79 }
80 }
81
82 #[must_use]
84 pub fn dimension(&self) -> usize {
85 self.data.len()
86 }
87
88 #[must_use]
90 pub fn memory_size(&self) -> usize {
91 self.data.len() + 8 }
93}
94
95impl QuantizationCodec for QuantizedVector {
96 fn to_bytes(&self) -> Vec<u8> {
97 let mut bytes = Vec::with_capacity(8 + self.data.len());
98 bytes.extend_from_slice(&self.min.to_le_bytes());
99 bytes.extend_from_slice(&self.max.to_le_bytes());
100 bytes.extend_from_slice(&self.data);
101 bytes
102 }
103
104 fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
105 if bytes.len() < 8 {
106 return Err(io::Error::new(
107 io::ErrorKind::InvalidData,
108 "Not enough bytes for QuantizedVector header",
109 ));
110 }
111
112 let min = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
113 let max = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
114 let data = bytes[8..].to_vec();
115
116 Ok(Self { data, min, max })
117 }
118}
119
120struct DequantParams {
129 scale: f32,
130 offset: f32,
131}
132
133fn dequant_params(quantized: &QuantizedVector) -> Option<DequantParams> {
138 let range = quantized.max - quantized.min;
139 if range < f32::EPSILON {
140 return None;
141 }
142 Some(DequantParams {
143 scale: range / 255.0,
144 offset: quantized.min,
145 })
146}
147
148#[must_use]
156pub fn dot_product_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
157 debug_assert_eq!(
158 query.len(),
159 quantized.data.len(),
160 "Dimension mismatch in dot_product_quantized"
161 );
162
163 let Some(params) = dequant_params(quantized) else {
164 return query.iter().sum::<f32>() * quantized.min;
166 };
167
168 query
170 .iter()
171 .zip(quantized.data.iter())
172 .map(|(&q, &v)| q * (f32::from(v) * params.scale + params.offset))
173 .sum()
174}
175
176#[must_use]
178pub fn euclidean_squared_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
179 debug_assert_eq!(
180 query.len(),
181 quantized.data.len(),
182 "Dimension mismatch in euclidean_squared_quantized"
183 );
184
185 let Some(params) = dequant_params(quantized) else {
186 let value = quantized.min;
187 return query.iter().map(|&q| (q - value).powi(2)).sum();
188 };
189
190 query
191 .iter()
192 .zip(quantized.data.iter())
193 .map(|(&q, &v)| {
194 let dequantized = f32::from(v) * params.scale + params.offset;
195 (q - dequantized).powi(2)
196 })
197 .sum()
198}
199
200#[must_use]
208pub fn cosine_similarity_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
209 cosine_from_dot(dot_product_quantized(query, quantized), query, quantized)
210}
211
212fn cosine_from_dot(dot: f32, query: &[f32], quantized: &QuantizedVector) -> f32 {
218 use crate::simd_native;
219
220 let query_norm = simd_native::norm_native(query);
221
222 let quantized_norm = quantized_vector_norm(quantized);
224
225 if query_norm < f32::EPSILON || quantized_norm < f32::EPSILON {
226 return 0.0;
227 }
228
229 dot / (query_norm * quantized_norm)
230}
231
232#[inline]
237fn quantized_vector_norm(quantized: &QuantizedVector) -> f32 {
238 let Some(params) = dequant_params(quantized) else {
239 let value = quantized.min;
240 #[allow(clippy::cast_precision_loss)]
241 return value.abs() * (quantized.data.len() as f32).sqrt();
242 };
243
244 let len = quantized.data.len();
245 let chunks = len / 4;
246 let remainder = len % 4;
247
248 let mut sum0: f32 = 0.0;
249 let mut sum1: f32 = 0.0;
250 let mut sum2: f32 = 0.0;
251 let mut sum3: f32 = 0.0;
252
253 for i in 0..chunks {
254 let base = i * 4;
255 let d0 = f32::from(quantized.data[base]) * params.scale + params.offset;
256 let d1 = f32::from(quantized.data[base + 1]) * params.scale + params.offset;
257 let d2 = f32::from(quantized.data[base + 2]) * params.scale + params.offset;
258 let d3 = f32::from(quantized.data[base + 3]) * params.scale + params.offset;
259 sum0 += d0 * d0;
260 sum1 += d1 * d1;
261 sum2 += d2 * d2;
262 sum3 += d3 * d3;
263 }
264
265 let base = chunks * 4;
266 for i in 0..remainder {
267 let d = f32::from(quantized.data[base + i]) * params.scale + params.offset;
268 sum0 += d * d;
269 }
270
271 (sum0 + sum1 + sum2 + sum3).sqrt()
272}
273
274#[must_use]
286pub fn dot_product_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
287 debug_assert_eq!(
288 query.len(),
289 quantized.data.len(),
290 "Dimension mismatch in dot_product_quantized_simd"
291 );
292
293 let Some(params) = dequant_params(quantized) else {
294 return query.iter().sum::<f32>() * quantized.min;
295 };
296
297 dot_product_dequant_unrolled_8(query, &quantized.data, params.scale, params.offset)
298}
299
300#[inline]
302fn dot_product_dequant_unrolled_8(query: &[f32], data: &[u8], scale: f32, offset: f32) -> f32 {
303 let len = query.len();
304 let chunks = len / 8;
305 let remainder = len % 8;
306
307 let mut sum = 0.0f32;
308
309 for i in 0..chunks {
310 let base = i * 8;
311 for j in 0..8 {
312 let dequant = f32::from(data[base + j]) * scale + offset;
313 sum += query[base + j] * dequant;
314 }
315 }
316
317 let base = chunks * 8;
318 for i in 0..remainder {
319 let dequant = f32::from(data[base + i]) * scale + offset;
320 sum += query[base + i] * dequant;
321 }
322
323 sum
324}
325
326#[must_use]
328pub fn euclidean_squared_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
329 debug_assert_eq!(
330 query.len(),
331 quantized.data.len(),
332 "Dimension mismatch in euclidean_squared_quantized_simd"
333 );
334
335 let Some(params) = dequant_params(quantized) else {
336 let value = quantized.min;
337 return query.iter().map(|&q| (q - value).powi(2)).sum();
338 };
339
340 let len = query.len();
342 let chunks = len / 4;
343 let remainder = len % 4;
344 let mut sum = 0.0f32;
345
346 for i in 0..chunks {
347 let base = i * 4;
348 let d0 = f32::from(quantized.data[base]) * params.scale + params.offset;
349 let d1 = f32::from(quantized.data[base + 1]) * params.scale + params.offset;
350 let d2 = f32::from(quantized.data[base + 2]) * params.scale + params.offset;
351 let d3 = f32::from(quantized.data[base + 3]) * params.scale + params.offset;
352
353 let diff0 = query[base] - d0;
354 let diff1 = query[base + 1] - d1;
355 let diff2 = query[base + 2] - d2;
356 let diff3 = query[base + 3] - d3;
357
358 sum += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
359 }
360
361 let base = chunks * 4;
362 for i in 0..remainder {
363 let dequant = f32::from(quantized.data[base + i]) * params.scale + params.offset;
364 let diff = query[base + i] - dequant;
365 sum += diff * diff;
366 }
367
368 sum
369}
370
371#[must_use]
376pub fn cosine_similarity_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
377 cosine_from_dot(
378 dot_product_quantized_simd(query, quantized),
379 query,
380 quantized,
381 )
382}