ruvector_core/
quantization.rs1use crate::error::Result;
4use serde::{Deserialize, Serialize};
5
6pub trait QuantizedVector: Send + Sync {
8 fn quantize(vector: &[f32]) -> Self;
10
11 fn distance(&self, other: &Self) -> f32;
13
14 fn reconstruct(&self) -> Vec<f32>;
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ScalarQuantized {
21 pub data: Vec<u8>,
23 pub min: f32,
25 pub scale: f32,
27}
28
29impl QuantizedVector for ScalarQuantized {
30 fn quantize(vector: &[f32]) -> Self {
31 let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
32 let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
33 let scale = (max - min) / 255.0;
34
35 let data = vector
36 .iter()
37 .map(|&v| ((v - min) / scale).round() as u8)
38 .collect();
39
40 Self { data, min, scale }
41 }
42
43 fn distance(&self, other: &Self) -> f32 {
44 self.data
46 .iter()
47 .zip(&other.data)
48 .map(|(&a, &b)| {
49 let diff = a as i16 - b as i16;
50 (diff * diff) as f32
51 })
52 .sum::<f32>()
53 .sqrt()
54 * self.scale.max(other.scale)
55 }
56
57 fn reconstruct(&self) -> Vec<f32> {
58 self.data
59 .iter()
60 .map(|&v| self.min + (v as f32) * self.scale)
61 .collect()
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ProductQuantized {
68 pub codes: Vec<u8>,
70 pub codebooks: Vec<Vec<Vec<f32>>>,
72}
73
74impl ProductQuantized {
75 pub fn train(
77 vectors: &[Vec<f32>],
78 num_subspaces: usize,
79 codebook_size: usize,
80 iterations: usize,
81 ) -> Result<Self> {
82 let dimensions = vectors[0].len();
83 let subspace_dim = dimensions / num_subspaces;
84
85 let mut codebooks = Vec::with_capacity(num_subspaces);
86
87 for subspace_idx in 0..num_subspaces {
89 let start = subspace_idx * subspace_dim;
90 let end = start + subspace_dim;
91
92 let subspace_vectors: Vec<Vec<f32>> =
94 vectors.iter().map(|v| v[start..end].to_vec()).collect();
95
96 let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
98 codebooks.push(codebook);
99 }
100
101 Ok(Self {
102 codes: vec![],
103 codebooks,
104 })
105 }
106
107 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
109 let num_subspaces = self.codebooks.len();
110 let subspace_dim = vector.len() / num_subspaces;
111
112 let mut codes = Vec::with_capacity(num_subspaces);
113
114 for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
115 let start = subspace_idx * subspace_dim;
116 let end = start + subspace_dim;
117 let subvector = &vector[start..end];
118
119 let code = codebook
121 .iter()
122 .enumerate()
123 .min_by(|(_, a), (_, b)| {
124 let dist_a = euclidean_squared(subvector, a);
125 let dist_b = euclidean_squared(subvector, b);
126 dist_a.partial_cmp(&dist_b).unwrap()
127 })
128 .map(|(idx, _)| idx as u8)
129 .unwrap_or(0);
130
131 codes.push(code);
132 }
133
134 codes
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct BinaryQuantized {
141 pub bits: Vec<u8>,
143 pub dimensions: usize,
145}
146
147impl QuantizedVector for BinaryQuantized {
148 fn quantize(vector: &[f32]) -> Self {
149 let dimensions = vector.len();
150 let num_bytes = (dimensions + 7) / 8;
151 let mut bits = vec![0u8; num_bytes];
152
153 for (i, &v) in vector.iter().enumerate() {
154 if v > 0.0 {
155 let byte_idx = i / 8;
156 let bit_idx = i % 8;
157 bits[byte_idx] |= 1 << bit_idx;
158 }
159 }
160
161 Self { bits, dimensions }
162 }
163
164 fn distance(&self, other: &Self) -> f32 {
165 let mut distance = 0u32;
167
168 for (&a, &b) in self.bits.iter().zip(&other.bits) {
169 distance += (a ^ b).count_ones();
170 }
171
172 distance as f32
173 }
174
175 fn reconstruct(&self) -> Vec<f32> {
176 let mut result = Vec::with_capacity(self.dimensions);
177
178 for i in 0..self.dimensions {
179 let byte_idx = i / 8;
180 let bit_idx = i % 8;
181 let bit = (self.bits[byte_idx] >> bit_idx) & 1;
182 result.push(if bit == 1 { 1.0 } else { -1.0 });
183 }
184
185 result
186 }
187}
188
189fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
192 a.iter()
193 .zip(b)
194 .map(|(&x, &y)| {
195 let diff = x - y;
196 diff * diff
197 })
198 .sum()
199}
200
201fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
202 use rand::seq::SliceRandom;
203 use rand::thread_rng;
204
205 let mut rng = thread_rng();
206
207 let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
209
210 for _ in 0..iterations {
211 let mut assignments = vec![Vec::new(); k];
213
214 for vector in vectors {
215 let nearest = centroids
216 .iter()
217 .enumerate()
218 .min_by(|(_, a), (_, b)| {
219 let dist_a = euclidean_squared(vector, a);
220 let dist_b = euclidean_squared(vector, b);
221 dist_a.partial_cmp(&dist_b).unwrap()
222 })
223 .map(|(idx, _)| idx)
224 .unwrap_or(0);
225
226 assignments[nearest].push(vector.clone());
227 }
228
229 for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
231 if !assigned.is_empty() {
232 let dim = centroid.len();
233 *centroid = vec![0.0; dim];
234
235 for vector in assigned {
236 for (i, &v) in vector.iter().enumerate() {
237 centroid[i] += v;
238 }
239 }
240
241 let count = assigned.len() as f32;
242 for v in centroid.iter_mut() {
243 *v /= count;
244 }
245 }
246 }
247 }
248
249 centroids
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_scalar_quantization() {
258 let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
259 let quantized = ScalarQuantized::quantize(&vector);
260 let reconstructed = quantized.reconstruct();
261
262 for (orig, recon) in vector.iter().zip(&reconstructed) {
264 assert!((orig - recon).abs() < 0.1);
265 }
266 }
267
268 #[test]
269 fn test_binary_quantization() {
270 let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
271 let quantized = BinaryQuantized::quantize(&vector);
272
273 assert_eq!(quantized.dimensions, 5);
274 assert_eq!(quantized.bits.len(), 1); }
276
277 #[test]
278 fn test_binary_distance() {
279 let v1 = vec![1.0, 1.0, 1.0, 1.0];
280 let v2 = vec![1.0, 1.0, -1.0, -1.0];
281
282 let q1 = BinaryQuantized::quantize(&v1);
283 let q2 = BinaryQuantized::quantize(&v2);
284
285 let dist = q1.distance(&q2);
286 assert_eq!(dist, 2.0); }
288}