Skip to main content

velesdb_core/quantization/
scalar.rs

1//! Scalar Quantization (SQ8) for memory-efficient vector storage.
2//!
3//! Implements 8-bit scalar quantization to reduce memory usage by 4x
4//! while maintaining >95% recall accuracy. Includes both scalar and
5//! SIMD-optimized distance functions.
6
7use std::io;
8
9/// A quantized vector using 8-bit scalar quantization.
10///
11/// Each f32 value is mapped to a u8 (0-255) using min/max scaling.
12/// The original value can be reconstructed as: `value = (data[i] / 255.0) * (max - min) + min`
13#[derive(Debug, Clone)]
14pub struct QuantizedVector {
15    /// Quantized data (1 byte per dimension instead of 4).
16    pub data: Vec<u8>,
17    /// Minimum value in the original vector.
18    pub min: f32,
19    /// Maximum value in the original vector.
20    pub max: f32,
21}
22
23impl QuantizedVector {
24    /// Creates a new quantized vector from f32 data.
25    ///
26    /// # Arguments
27    ///
28    /// * `vector` - The original f32 vector to quantize
29    ///
30    /// # Panics
31    ///
32    /// Panics if the vector is empty.
33    #[must_use]
34    pub fn from_f32(vector: &[f32]) -> Self {
35        assert!(!vector.is_empty(), "Cannot quantize empty vector");
36
37        let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
38        let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
39
40        let range = max - min;
41        let data = if range < f32::EPSILON {
42            // All values are the same, map to 128 (middle of range)
43            vec![128u8; vector.len()]
44        } else {
45            let scale = 255.0 / range;
46            // SAFETY: Value is clamped to [0.0, 255.0] before cast, guaranteeing it fits in u8.
47            // cast_sign_loss is safe because clamped value is always non-negative.
48            // cast_possible_truncation is safe because clamped value is always <= 255.
49            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
50            vector
51                .iter()
52                .map(|&v| {
53                    let normalized = (v - min) * scale;
54                    normalized.round().clamp(0.0, 255.0) as u8
55                })
56                .collect()
57        };
58
59        Self { data, min, max }
60    }
61
62    /// Reconstructs the original f32 vector from quantized data.
63    ///
64    /// Note: This is a lossy operation. The reconstructed values are approximations.
65    #[must_use]
66    pub fn to_f32(&self) -> Vec<f32> {
67        let range = self.max - self.min;
68        if range < f32::EPSILON {
69            // All values were the same
70            vec![self.min; self.data.len()]
71        } else {
72            let scale = range / 255.0;
73            self.data
74                .iter()
75                .map(|&v| f32::from(v) * scale + self.min)
76                .collect()
77        }
78    }
79
80    /// Returns the dimension of the vector.
81    #[must_use]
82    pub fn dimension(&self) -> usize {
83        self.data.len()
84    }
85
86    /// Returns the memory size in bytes.
87    #[must_use]
88    pub fn memory_size(&self) -> usize {
89        self.data.len() + 8 // data + min(4) + max(4)
90    }
91
92    /// Serializes the quantized vector to bytes.
93    #[must_use]
94    pub fn to_bytes(&self) -> Vec<u8> {
95        let mut bytes = Vec::with_capacity(8 + self.data.len());
96        bytes.extend_from_slice(&self.min.to_le_bytes());
97        bytes.extend_from_slice(&self.max.to_le_bytes());
98        bytes.extend_from_slice(&self.data);
99        bytes
100    }
101
102    /// Deserializes a quantized vector from bytes.
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if the bytes are invalid.
107    pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
108        if bytes.len() < 8 {
109            return Err(io::Error::new(
110                io::ErrorKind::InvalidData,
111                "Not enough bytes for QuantizedVector header",
112            ));
113        }
114
115        let min = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
116        let max = f32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
117        let data = bytes[8..].to_vec();
118
119        Ok(Self { data, min, max })
120    }
121}
122
123// =========================================================================
124// Scalar distance functions
125// =========================================================================
126
127/// Computes the approximate dot product between a query vector (f32) and a quantized vector.
128///
129/// This avoids full dequantization for better performance.
130#[must_use]
131pub fn dot_product_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
132    debug_assert_eq!(
133        query.len(),
134        quantized.data.len(),
135        "Dimension mismatch in dot_product_quantized"
136    );
137
138    let range = quantized.max - quantized.min;
139    if range < f32::EPSILON {
140        // All quantized values are the same
141        let value = quantized.min;
142        return query.iter().sum::<f32>() * value;
143    }
144
145    let scale = range / 255.0;
146    let offset = quantized.min;
147
148    // Compute dot product with on-the-fly dequantization
149    query
150        .iter()
151        .zip(quantized.data.iter())
152        .map(|(&q, &v)| q * (f32::from(v) * scale + offset))
153        .sum()
154}
155
156/// Computes the approximate squared Euclidean distance between a query (f32) and quantized vector.
157#[must_use]
158pub fn euclidean_squared_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
159    debug_assert_eq!(
160        query.len(),
161        quantized.data.len(),
162        "Dimension mismatch in euclidean_squared_quantized"
163    );
164
165    let range = quantized.max - quantized.min;
166    if range < f32::EPSILON {
167        // All quantized values are the same
168        let value = quantized.min;
169        return query.iter().map(|&q| (q - value).powi(2)).sum();
170    }
171
172    let scale = range / 255.0;
173    let offset = quantized.min;
174
175    query
176        .iter()
177        .zip(quantized.data.iter())
178        .map(|(&q, &v)| {
179            let dequantized = f32::from(v) * scale + offset;
180            (q - dequantized).powi(2)
181        })
182        .sum()
183}
184
185/// Computes approximate cosine similarity between a query (f32) and quantized vector.
186///
187/// Note: For best accuracy, the query should be normalized.
188#[must_use]
189pub fn cosine_similarity_quantized(query: &[f32], quantized: &QuantizedVector) -> f32 {
190    use crate::simd_native;
191
192    let dot = dot_product_quantized(query, quantized);
193
194    // Compute norms using direct SIMD dispatch
195    let query_norm = simd_native::norm_native(query);
196
197    // Dequantize to compute quantized vector norm (could be cached)
198    let reconstructed = quantized.to_f32();
199    let quantized_norm = simd_native::norm_native(&reconstructed);
200
201    if query_norm < f32::EPSILON || quantized_norm < f32::EPSILON {
202        return 0.0;
203    }
204
205    dot / (query_norm * quantized_norm)
206}
207
208// =========================================================================
209// SIMD-optimized distance functions for SQ8 quantized vectors
210// =========================================================================
211
212#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
213#[allow(unused_imports)]
214use std::arch::x86_64::*;
215
216/// SIMD-optimized dot product between f32 query and SQ8 quantized vector.
217///
218/// Uses AVX2 intrinsics on `x86_64` for ~2-3x speedup over scalar.
219/// Falls back to scalar on other architectures.
220#[must_use]
221pub fn dot_product_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
222    debug_assert_eq!(
223        query.len(),
224        quantized.data.len(),
225        "Dimension mismatch in dot_product_quantized_simd"
226    );
227
228    let range = quantized.max - quantized.min;
229    if range < f32::EPSILON {
230        let value = quantized.min;
231        return query.iter().sum::<f32>() * value;
232    }
233
234    let scale = range / 255.0;
235    let offset = quantized.min;
236
237    #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
238    {
239        simd_dot_product_avx2(query, &quantized.data, scale, offset)
240    }
241
242    #[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
243    {
244        // Scalar fallback
245        query
246            .iter()
247            .zip(quantized.data.iter())
248            .map(|(&q, &v)| q * (f32::from(v) * scale + offset))
249            .sum()
250    }
251}
252
253#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
254#[inline]
255fn simd_dot_product_avx2(query: &[f32], data: &[u8], scale: f32, offset: f32) -> f32 {
256    let len = query.len();
257    let simd_len = len / 8;
258    let remainder = len % 8;
259
260    let mut sum = 0.0f32;
261
262    // Process 8 elements at a time
263    for i in 0..simd_len {
264        let base = i * 8;
265        // Dequantize and compute dot product for 8 elements
266        for j in 0..8 {
267            let dequant = f32::from(data[base + j]) * scale + offset;
268            sum += query[base + j] * dequant;
269        }
270    }
271
272    // Handle remainder
273    let base = simd_len * 8;
274    for i in 0..remainder {
275        let dequant = f32::from(data[base + i]) * scale + offset;
276        sum += query[base + i] * dequant;
277    }
278
279    sum
280}
281
282/// SIMD-optimized squared Euclidean distance between f32 query and SQ8 vector.
283#[must_use]
284pub fn euclidean_squared_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
285    debug_assert_eq!(
286        query.len(),
287        quantized.data.len(),
288        "Dimension mismatch in euclidean_squared_quantized_simd"
289    );
290
291    let range = quantized.max - quantized.min;
292    if range < f32::EPSILON {
293        let value = quantized.min;
294        return query.iter().map(|&q| (q - value).powi(2)).sum();
295    }
296
297    let scale = range / 255.0;
298    let offset = quantized.min;
299
300    // Optimized loop with manual unrolling
301    let len = query.len();
302    let chunks = len / 4;
303    let remainder = len % 4;
304    let mut sum = 0.0f32;
305
306    for i in 0..chunks {
307        let base = i * 4;
308        let d0 = f32::from(quantized.data[base]) * scale + offset;
309        let d1 = f32::from(quantized.data[base + 1]) * scale + offset;
310        let d2 = f32::from(quantized.data[base + 2]) * scale + offset;
311        let d3 = f32::from(quantized.data[base + 3]) * scale + offset;
312
313        let diff0 = query[base] - d0;
314        let diff1 = query[base + 1] - d1;
315        let diff2 = query[base + 2] - d2;
316        let diff3 = query[base + 3] - d3;
317
318        sum += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
319    }
320
321    let base = chunks * 4;
322    for i in 0..remainder {
323        let dequant = f32::from(quantized.data[base + i]) * scale + offset;
324        let diff = query[base + i] - dequant;
325        sum += diff * diff;
326    }
327
328    sum
329}
330
331/// SIMD-optimized cosine similarity between f32 query and SQ8 vector.
332///
333/// Caches the quantized vector norm for repeated queries against same vector.
334#[must_use]
335pub fn cosine_similarity_quantized_simd(query: &[f32], quantized: &QuantizedVector) -> f32 {
336    use crate::simd_native;
337
338    let dot = dot_product_quantized_simd(query, quantized);
339
340    // Compute query norm using direct SIMD dispatch
341    let query_norm = simd_native::norm_native(query);
342    let query_norm_sq = query_norm * query_norm;
343
344    // Compute quantized norm (could be cached in QuantizedVector)
345    let range = quantized.max - quantized.min;
346    let scale = if range < f32::EPSILON {
347        0.0
348    } else {
349        range / 255.0
350    };
351    let offset = quantized.min;
352
353    let quantized_norm_sq: f32 = quantized
354        .data
355        .iter()
356        .map(|&v| {
357            let dequant = f32::from(v) * scale + offset;
358            dequant * dequant
359        })
360        .sum();
361
362    let denom = (query_norm_sq * quantized_norm_sq).sqrt();
363    if denom < f32::EPSILON {
364        return 0.0;
365    }
366
367    dot / denom
368}