Skip to main content

sql_rs/vector/
quantization.rs

1use super::{Distance, DistanceMetric};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct QuantizedVector {
6    /// Original dimension of the vector
7    pub dimension: usize,
8    /// Number of bits per scalar (8, 16)
9    pub bits_per_scalar: u8,
10    /// Quantized data
11    pub data: Vec<u8>,
12    /// Minimum value for each dimension (for dequantization)
13    pub min_values: Vec<f32>,
14    /// Maximum value for each dimension (for dequantization)
15    pub max_values: Vec<f32>,
16}
17
18impl QuantizedVector {
19    /// Quantize a f32 vector to 8-bit integers
20    pub fn quantize_8bit(vector: &[f32]) -> Self {
21        let dimension = vector.len();
22        let mut min_values = Vec::with_capacity(dimension);
23        let mut max_values = Vec::with_capacity(dimension);
24
25        // Find min/max for each dimension
26        for i in 0..dimension {
27            let min = vector[i];
28            let max = vector[i];
29
30            // In a real implementation, we'd sample from multiple vectors
31            // For now, just use the single vector's values
32            min_values.push(min);
33            max_values.push(max);
34        }
35
36        // Quantize to 8-bit
37        let mut data = Vec::with_capacity(dimension);
38        for i in 0..dimension {
39            let min = min_values[i];
40            let max = max_values[i];
41
42            let normalized = if max - min > 0.0 {
43                (vector[i] - min) / (max - min)
44            } else {
45                0.5
46            };
47
48            let quantized = (normalized * 255.0).round() as u8;
49            data.push(quantized);
50        }
51
52        Self {
53            dimension,
54            bits_per_scalar: 8,
55            data,
56            min_values,
57            max_values,
58        }
59    }
60
61    /// Quantize a f32 vector to 16-bit integers
62    pub fn quantize_16bit(vector: &[f32]) -> Self {
63        let dimension = vector.len();
64        let mut min_values = Vec::with_capacity(dimension);
65        let mut max_values = Vec::with_capacity(dimension);
66
67        // Find min/max for each dimension
68        for i in 0..dimension {
69            let min = vector[i];
70            let max = vector[i];
71
72            min_values.push(min);
73            max_values.push(max);
74        }
75
76        // Quantize to 16-bit (stored as 2 bytes per value)
77        let mut data = Vec::with_capacity(dimension * 2);
78        for i in 0..dimension {
79            let min = min_values[i];
80            let max = max_values[i];
81
82            let normalized = if max - min > 0.0 {
83                (vector[i] - min) / (max - min)
84            } else {
85                0.5
86            };
87
88            let quantized = (normalized * 65535.0).round() as u16;
89            data.push((quantized >> 8) as u8);
90            data.push((quantized & 0xFF) as u8);
91        }
92
93        Self {
94            dimension,
95            bits_per_scalar: 16,
96            data,
97            min_values,
98            max_values,
99        }
100    }
101
102    /// Dequantize back to f32 vector
103    pub fn dequantize(&self) -> Vec<f32> {
104        let mut result = Vec::with_capacity(self.dimension);
105
106        if self.bits_per_scalar == 8 {
107            for i in 0..self.dimension {
108                let quantized = self.data[i] as f32;
109                let min = self.min_values[i];
110                let max = self.max_values[i];
111
112                let normalized = quantized / 255.0;
113                let value = min + normalized * (max - min);
114                result.push(value);
115            }
116        } else if self.bits_per_scalar == 16 {
117            for i in 0..self.dimension {
118                let quantized =
119                    ((self.data[i * 2] as u16) << 8 | self.data[i * 2 + 1] as u16) as f32;
120                let min = self.min_values[i];
121                let max = self.max_values[i];
122
123                let normalized = quantized / 65535.0;
124                let value = min + normalized * (max - min);
125                result.push(value);
126            }
127        }
128
129        result
130    }
131
132    /// Calculate distance between two quantized vectors
133    pub fn distance(&self, other: &QuantizedVector, metric: DistanceMetric) -> f32 {
134        if self.dimension != other.dimension {
135            return f32::INFINITY;
136        }
137
138        // For simplicity, dequantize both and calculate distance
139        // In a production implementation, we'd calculate distance directly in quantized space
140        let v1 = self.dequantize();
141        let v2 = other.dequantize();
142
143        v1.distance(&v2, metric)
144    }
145
146    /// Calculate memory usage in bytes
147    pub fn memory_usage(&self) -> usize {
148        std::mem::size_of::<Self>()
149            + self.data.len()
150            + self.min_values.len() * std::mem::size_of::<f32>()
151            + self.max_values.len() * std::mem::size_of::<f32>()
152    }
153
154    /// Calculate compression ratio compared to original f32 vector
155    pub fn compression_ratio(&self) -> f32 {
156        let original_size = self.dimension * std::mem::size_of::<f32>();
157        let compressed_size = self.data.len()
158            + self.min_values.len() * std::mem::size_of::<f32>()
159            + self.max_values.len() * std::mem::size_of::<f32>();
160        original_size as f32 / compressed_size as f32
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_quantize_8bit() {
170        let vector = vec![1.0, -1.0, 0.5, 0.0];
171        let quantized = QuantizedVector::quantize_8bit(&vector);
172
173        assert_eq!(quantized.dimension, 4);
174        assert_eq!(quantized.bits_per_scalar, 8);
175
176        let dequantized = quantized.dequantize();
177        assert_eq!(dequantized.len(), 4);
178
179        // Check that dequantization is close to original
180        for (orig, deq) in vector.iter().zip(dequantized.iter()) {
181            assert!((orig - deq).abs() < 0.01);
182        }
183    }
184
185    #[test]
186    fn test_quantize_16bit() {
187        let vector = vec![1.0, -1.0, 0.5, 0.0];
188        let quantized = QuantizedVector::quantize_16bit(&vector);
189
190        assert_eq!(quantized.dimension, 4);
191        assert_eq!(quantized.bits_per_scalar, 16);
192        assert_eq!(quantized.data.len(), 8); // 2 bytes per value
193
194        let dequantized = quantized.dequantize();
195        assert_eq!(dequantized.len(), 4);
196
197        // Check that dequantization is close to original
198        for (orig, deq) in vector.iter().zip(dequantized.iter()) {
199            assert!((orig - deq).abs() < 0.001);
200        }
201    }
202
203    #[test]
204    fn test_compression_ratio() {
205        let vector = vec![1.0; 100];
206        let quantized_8bit = QuantizedVector::quantize_8bit(&vector);
207        let quantized_16bit = QuantizedVector::quantize_16bit(&vector);
208
209        // 8-bit should have higher compression ratio than 16-bit
210        assert!(quantized_8bit.compression_ratio() > quantized_16bit.compression_ratio());
211
212        // Note: With our simple quantization scheme, compression might not always be beneficial
213        // for all vectors due to the overhead of storing min/max values per dimension.
214        // In a production system, we'd use more sophisticated quantization.
215        // But the quantization should still work correctly.
216        assert!(quantized_8bit.compression_ratio() > 0.0);
217        assert!(quantized_16bit.compression_ratio() > 0.0);
218    }
219
220    #[test]
221    fn test_quantized_distance() {
222        let v1 = vec![1.0, 0.0, 0.0];
223        let v2 = vec![0.0, 1.0, 0.0];
224
225        let q1 = QuantizedVector::quantize_8bit(&v1);
226        let q2 = QuantizedVector::quantize_8bit(&v2);
227
228        let distance = q1.distance(&q2, DistanceMetric::Euclidean);
229        assert!((distance - 1.414).abs() < 0.1);
230    }
231}