1use crate::error::Result;
4use serde::{Deserialize, Serialize};
5
6pub trait QuantizedVector: Send + Sync {
8 fn quantize(vector: &[f32]) -> Self;
10
11 fn distance(&self, other: &Self) -> f32;
13
14 fn reconstruct(&self) -> Vec<f32>;
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScalarQuantized {
21 pub data: Vec<u8>,
23 pub min: f32,
25 pub scale: f32,
27}
28
29impl QuantizedVector for ScalarQuantized {
30 fn quantize(vector: &[f32]) -> Self {
31 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
32 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
33
34 let scale = if (max - min).abs() < f32::EPSILON {
36 1.0 } else {
38 (max - min) / 255.0
39 };
40
41 let data = vector
42 .iter()
43 .map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
44 .collect();
45
46 Self { data, min, scale }
47 }
48
49 fn distance(&self, other: &Self) -> f32 {
50 let avg_scale = (self.scale + other.scale) / 2.0;
58
59 self.data
60 .iter()
61 .zip(&other.data)
62 .map(|(&a, &b)| {
63 let diff = a as i32 - b as i32;
64 (diff * diff) as f32
65 })
66 .sum::<f32>()
67 .sqrt()
68 * avg_scale
69 }
70
71 fn reconstruct(&self) -> Vec<f32> {
72 self.data
73 .iter()
74 .map(|&v| self.min + (v as f32) * self.scale)
75 .collect()
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ProductQuantized {
82 pub codes: Vec<u8>,
84 pub codebooks: Vec<Vec<Vec<f32>>>,
86}
87
88impl ProductQuantized {
89 pub fn train(
91 vectors: &[Vec<f32>],
92 num_subspaces: usize,
93 codebook_size: usize,
94 iterations: usize,
95 ) -> Result<Self> {
96 if vectors.is_empty() {
97 return Err(crate::error::RuvectorError::InvalidInput(
98 "Cannot train on empty vector set".into(),
99 ));
100 }
101 if vectors[0].is_empty() {
102 return Err(crate::error::RuvectorError::InvalidInput(
103 "Cannot train on vectors with zero dimensions".into(),
104 ));
105 }
106 if codebook_size > 256 {
107 return Err(crate::error::RuvectorError::InvalidParameter(format!(
108 "Codebook size {} exceeds u8 maximum of 256",
109 codebook_size
110 )));
111 }
112 let dimensions = vectors[0].len();
113 let subspace_dim = dimensions / num_subspaces;
114
115 let mut codebooks = Vec::with_capacity(num_subspaces);
116
117 for subspace_idx in 0..num_subspaces {
119 let start = subspace_idx * subspace_dim;
120 let end = start + subspace_dim;
121
122 let subspace_vectors: Vec<Vec<f32>> =
124 vectors.iter().map(|v| v[start..end].to_vec()).collect();
125
126 let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
128 codebooks.push(codebook);
129 }
130
131 Ok(Self {
132 codes: vec![],
133 codebooks,
134 })
135 }
136
137 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
139 let num_subspaces = self.codebooks.len();
140 let subspace_dim = vector.len() / num_subspaces;
141
142 let mut codes = Vec::with_capacity(num_subspaces);
143
144 for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
145 let start = subspace_idx * subspace_dim;
146 let end = start + subspace_dim;
147 let subvector = &vector[start..end];
148
149 let code = codebook
151 .iter()
152 .enumerate()
153 .min_by(|(_, a), (_, b)| {
154 let dist_a = euclidean_squared(subvector, a);
155 let dist_b = euclidean_squared(subvector, b);
156 dist_a.partial_cmp(&dist_b).unwrap()
157 })
158 .map(|(idx, _)| idx as u8)
159 .unwrap_or(0);
160
161 codes.push(code);
162 }
163
164 codes
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct BinaryQuantized {
171 pub bits: Vec<u8>,
173 pub dimensions: usize,
175}
176
177impl QuantizedVector for BinaryQuantized {
178 fn quantize(vector: &[f32]) -> Self {
179 let dimensions = vector.len();
180 let num_bytes = (dimensions + 7) / 8;
181 let mut bits = vec![0u8; num_bytes];
182
183 for (i, &v) in vector.iter().enumerate() {
184 if v > 0.0 {
185 let byte_idx = i / 8;
186 let bit_idx = i % 8;
187 bits[byte_idx] |= 1 << bit_idx;
188 }
189 }
190
191 Self { bits, dimensions }
192 }
193
194 fn distance(&self, other: &Self) -> f32 {
195 let mut distance = 0u32;
197
198 for (&a, &b) in self.bits.iter().zip(&other.bits) {
199 distance += (a ^ b).count_ones();
200 }
201
202 distance as f32
203 }
204
205 fn reconstruct(&self) -> Vec<f32> {
206 let mut result = Vec::with_capacity(self.dimensions);
207
208 for i in 0..self.dimensions {
209 let byte_idx = i / 8;
210 let bit_idx = i % 8;
211 let bit = (self.bits[byte_idx] >> bit_idx) & 1;
212 result.push(if bit == 1 { 1.0 } else { -1.0 });
213 }
214
215 result
216 }
217}
218
219fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
222 a.iter()
223 .zip(b)
224 .map(|(&x, &y)| {
225 let diff = x - y;
226 diff * diff
227 })
228 .sum()
229}
230
231fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
232 use rand::seq::SliceRandom;
233 use rand::thread_rng;
234
235 let mut rng = thread_rng();
236
237 let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
239
240 for _ in 0..iterations {
241 let mut assignments = vec![Vec::new(); k];
243
244 for vector in vectors {
245 let nearest = centroids
246 .iter()
247 .enumerate()
248 .min_by(|(_, a), (_, b)| {
249 let dist_a = euclidean_squared(vector, a);
250 let dist_b = euclidean_squared(vector, b);
251 dist_a.partial_cmp(&dist_b).unwrap()
252 })
253 .map(|(idx, _)| idx)
254 .unwrap_or(0);
255
256 assignments[nearest].push(vector.clone());
257 }
258
259 for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
261 if !assigned.is_empty() {
262 let dim = centroid.len();
263 *centroid = vec![0.0; dim];
264
265 for vector in assigned {
266 for (i, &v) in vector.iter().enumerate() {
267 centroid[i] += v;
268 }
269 }
270
271 let count = assigned.len() as f32;
272 for v in centroid.iter_mut() {
273 *v /= count;
274 }
275 }
276 }
277 }
278
279 centroids
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_scalar_quantization() {
288 let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
289 let quantized = ScalarQuantized::quantize(&vector);
290 let reconstructed = quantized.reconstruct();
291
292 for (orig, recon) in vector.iter().zip(&reconstructed) {
294 assert!((orig - recon).abs() < 0.1);
295 }
296 }
297
298 #[test]
299 fn test_binary_quantization() {
300 let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
301 let quantized = BinaryQuantized::quantize(&vector);
302
303 assert_eq!(quantized.dimensions, 5);
304 assert_eq!(quantized.bits.len(), 1); }
306
307 #[test]
308 fn test_binary_distance() {
309 let v1 = vec![1.0, 1.0, 1.0, 1.0];
310 let v2 = vec![1.0, 1.0, -1.0, -1.0];
311
312 let q1 = BinaryQuantized::quantize(&v1);
313 let q2 = BinaryQuantized::quantize(&v2);
314
315 let dist = q1.distance(&q2);
316 assert_eq!(dist, 2.0); }
318
319 #[test]
320 fn test_scalar_quantization_roundtrip() {
321 let test_vectors = vec![
323 vec![1.0, 2.0, 3.0, 4.0, 5.0],
324 vec![-10.0, -5.0, 0.0, 5.0, 10.0],
325 vec![0.1, 0.2, 0.3, 0.4, 0.5],
326 vec![100.0, 200.0, 300.0, 400.0, 500.0],
327 ];
328
329 for vector in test_vectors {
330 let quantized = ScalarQuantized::quantize(&vector);
331 let reconstructed = quantized.reconstruct();
332
333 assert_eq!(vector.len(), reconstructed.len());
334
335 for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
336 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
338 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
339 let max_error = (max - min) / 255.0 * 2.0; assert!(
342 (orig - recon).abs() < max_error,
343 "Roundtrip error too large: orig={}, recon={}, error={}",
344 orig,
345 recon,
346 (orig - recon).abs()
347 );
348 }
349 }
350 }
351
352 #[test]
353 fn test_scalar_distance_symmetry() {
354 let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
356 let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
357
358 let q1 = ScalarQuantized::quantize(&v1);
359 let q2 = ScalarQuantized::quantize(&v2);
360
361 let dist_ab = q1.distance(&q2);
362 let dist_ba = q2.distance(&q1);
363
364 assert!(
366 (dist_ab - dist_ba).abs() < 0.01,
367 "Distance is not symmetric: d(a,b)={}, d(b,a)={}",
368 dist_ab,
369 dist_ba
370 );
371 }
372
373 #[test]
374 fn test_scalar_distance_different_scales() {
375 let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let v2 = vec![10.0, 20.0, 30.0, 40.0, 50.0]; let q1 = ScalarQuantized::quantize(&v1);
380 let q2 = ScalarQuantized::quantize(&v2);
381
382 let dist_ab = q1.distance(&q2);
383 let dist_ba = q2.distance(&q1);
384
385 assert!(
387 (dist_ab - dist_ba).abs() < 0.01,
388 "Distance with different scales not symmetric: d(a,b)={}, d(b,a)={}",
389 dist_ab,
390 dist_ba
391 );
392 }
393
394 #[test]
395 fn test_scalar_quantization_edge_cases() {
396 let same_values = vec![5.0, 5.0, 5.0, 5.0];
398 let quantized = ScalarQuantized::quantize(&same_values);
399 let reconstructed = quantized.reconstruct();
400
401 for (orig, recon) in same_values.iter().zip(reconstructed.iter()) {
402 assert!((orig - recon).abs() < 0.01);
403 }
404
405 let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
407 let quantized = ScalarQuantized::quantize(&extreme);
408 let reconstructed = quantized.reconstruct();
409
410 assert_eq!(extreme.len(), reconstructed.len());
411 }
412
413 #[test]
414 fn test_binary_distance_symmetry() {
415 let v1 = vec![1.0, -1.0, 1.0, -1.0];
417 let v2 = vec![1.0, 1.0, -1.0, -1.0];
418
419 let q1 = BinaryQuantized::quantize(&v1);
420 let q2 = BinaryQuantized::quantize(&v2);
421
422 let dist_ab = q1.distance(&q2);
423 let dist_ba = q2.distance(&q1);
424
425 assert_eq!(
426 dist_ab, dist_ba,
427 "Binary distance not symmetric: d(a,b)={}, d(b,a)={}",
428 dist_ab, dist_ba
429 );
430 }
431}