ruvector_core/
quantization.rs1use crate::error::Result;
4use serde::{Deserialize, Serialize};
5
6pub trait QuantizedVector: Send + Sync {
8 fn quantize(vector: &[f32]) -> Self;
10
11 fn distance(&self, other: &Self) -> f32;
13
14 fn reconstruct(&self) -> Vec<f32>;
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScalarQuantized {
21 pub data: Vec<u8>,
23 pub min: f32,
25 pub scale: f32,
27}
28
29impl QuantizedVector for ScalarQuantized {
30 fn quantize(vector: &[f32]) -> Self {
31 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
32 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
33 let scale = (max - min) / 255.0;
34
35 let data = vector
36 .iter()
37 .map(|&v| ((v - min) / scale).round() as u8)
38 .collect();
39
40 Self { data, min, scale }
41 }
42
43 fn distance(&self, other: &Self) -> f32 {
44 self.data
46 .iter()
47 .zip(&other.data)
48 .map(|(&a, &b)| {
49 let diff = a as i16 - b as i16;
50 (diff * diff) as f32
51 })
52 .sum::<f32>()
53 .sqrt()
54 * self.scale.max(other.scale)
55 }
56
57 fn reconstruct(&self) -> Vec<f32> {
58 self.data
59 .iter()
60 .map(|&v| self.min + (v as f32) * self.scale)
61 .collect()
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ProductQuantized {
68 pub codes: Vec<u8>,
70 pub codebooks: Vec<Vec<Vec<f32>>>,
72}
73
74impl ProductQuantized {
75 pub fn train(
77 vectors: &[Vec<f32>],
78 num_subspaces: usize,
79 codebook_size: usize,
80 iterations: usize,
81 ) -> Result<Self> {
82 if vectors.is_empty() {
83 return Err(crate::error::RuvectorError::InvalidInput(
84 "Cannot train on empty vector set".into(),
85 ));
86 }
87 if vectors[0].is_empty() {
88 return Err(crate::error::RuvectorError::InvalidInput(
89 "Cannot train on vectors with zero dimensions".into(),
90 ));
91 }
92 if codebook_size > 256 {
93 return Err(crate::error::RuvectorError::InvalidParameter(
94 format!("Codebook size {} exceeds u8 maximum of 256", codebook_size),
95 ));
96 }
97 let dimensions = vectors[0].len();
98 let subspace_dim = dimensions / num_subspaces;
99
100 let mut codebooks = Vec::with_capacity(num_subspaces);
101
102 for subspace_idx in 0..num_subspaces {
104 let start = subspace_idx * subspace_dim;
105 let end = start + subspace_dim;
106
107 let subspace_vectors: Vec<Vec<f32>> =
109 vectors.iter().map(|v| v[start..end].to_vec()).collect();
110
111 let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
113 codebooks.push(codebook);
114 }
115
116 Ok(Self {
117 codes: vec![],
118 codebooks,
119 })
120 }
121
122 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
124 let num_subspaces = self.codebooks.len();
125 let subspace_dim = vector.len() / num_subspaces;
126
127 let mut codes = Vec::with_capacity(num_subspaces);
128
129 for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
130 let start = subspace_idx * subspace_dim;
131 let end = start + subspace_dim;
132 let subvector = &vector[start..end];
133
134 let code = codebook
136 .iter()
137 .enumerate()
138 .min_by(|(_, a), (_, b)| {
139 let dist_a = euclidean_squared(subvector, a);
140 let dist_b = euclidean_squared(subvector, b);
141 dist_a.partial_cmp(&dist_b).unwrap()
142 })
143 .map(|(idx, _)| idx as u8)
144 .unwrap_or(0);
145
146 codes.push(code);
147 }
148
149 codes
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct BinaryQuantized {
156 pub bits: Vec<u8>,
158 pub dimensions: usize,
160}
161
162impl QuantizedVector for BinaryQuantized {
163 fn quantize(vector: &[f32]) -> Self {
164 let dimensions = vector.len();
165 let num_bytes = (dimensions + 7) / 8;
166 let mut bits = vec![0u8; num_bytes];
167
168 for (i, &v) in vector.iter().enumerate() {
169 if v > 0.0 {
170 let byte_idx = i / 8;
171 let bit_idx = i % 8;
172 bits[byte_idx] |= 1 << bit_idx;
173 }
174 }
175
176 Self { bits, dimensions }
177 }
178
179 fn distance(&self, other: &Self) -> f32 {
180 let mut distance = 0u32;
182
183 for (&a, &b) in self.bits.iter().zip(&other.bits) {
184 distance += (a ^ b).count_ones();
185 }
186
187 distance as f32
188 }
189
190 fn reconstruct(&self) -> Vec<f32> {
191 let mut result = Vec::with_capacity(self.dimensions);
192
193 for i in 0..self.dimensions {
194 let byte_idx = i / 8;
195 let bit_idx = i % 8;
196 let bit = (self.bits[byte_idx] >> bit_idx) & 1;
197 result.push(if bit == 1 { 1.0 } else { -1.0 });
198 }
199
200 result
201 }
202}
203
204fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
207 a.iter()
208 .zip(b)
209 .map(|(&x, &y)| {
210 let diff = x - y;
211 diff * diff
212 })
213 .sum()
214}
215
216fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
217 use rand::seq::SliceRandom;
218 use rand::thread_rng;
219
220 let mut rng = thread_rng();
221
222 let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
224
225 for _ in 0..iterations {
226 let mut assignments = vec![Vec::new(); k];
228
229 for vector in vectors {
230 let nearest = centroids
231 .iter()
232 .enumerate()
233 .min_by(|(_, a), (_, b)| {
234 let dist_a = euclidean_squared(vector, a);
235 let dist_b = euclidean_squared(vector, b);
236 dist_a.partial_cmp(&dist_b).unwrap()
237 })
238 .map(|(idx, _)| idx)
239 .unwrap_or(0);
240
241 assignments[nearest].push(vector.clone());
242 }
243
244 for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
246 if !assigned.is_empty() {
247 let dim = centroid.len();
248 *centroid = vec![0.0; dim];
249
250 for vector in assigned {
251 for (i, &v) in vector.iter().enumerate() {
252 centroid[i] += v;
253 }
254 }
255
256 let count = assigned.len() as f32;
257 for v in centroid.iter_mut() {
258 *v /= count;
259 }
260 }
261 }
262 }
263
264 centroids
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_scalar_quantization() {
273 let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
274 let quantized = ScalarQuantized::quantize(&vector);
275 let reconstructed = quantized.reconstruct();
276
277 for (orig, recon) in vector.iter().zip(&reconstructed) {
279 assert!((orig - recon).abs() < 0.1);
280 }
281 }
282
283 #[test]
284 fn test_binary_quantization() {
285 let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
286 let quantized = BinaryQuantized::quantize(&vector);
287
288 assert_eq!(quantized.dimensions, 5);
289 assert_eq!(quantized.bits.len(), 1); }
291
292 #[test]
293 fn test_binary_distance() {
294 let v1 = vec![1.0, 1.0, 1.0, 1.0];
295 let v2 = vec![1.0, 1.0, -1.0, -1.0];
296
297 let q1 = BinaryQuantized::quantize(&v1);
298 let q2 = BinaryQuantized::quantize(&v2);
299
300 let dist = q1.distance(&q2);
301 assert_eq!(dist, 2.0); }
303}