1use crate::multivector::types::WarpIndexConfig;
12use crate::Result;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ResidualCodec {
35 centroids: Vec<f32>,
37 num_centroids: usize,
39 dim: usize,
41 bucket_cutoffs: Vec<f32>,
43 bucket_weights: Vec<f32>,
45 nbits: u8,
47}
48
49impl ResidualCodec {
50 pub fn train(
64 embeddings: &[f32],
65 dim: usize,
66 num_centroids: usize,
67 nbits: u8,
68 iterations: usize,
69 ) -> Result<Self> {
70 if nbits != 2 && nbits != 4 {
71 return Err(crate::Error::InvalidInput("nbits must be 2 or 4".to_string()));
72 }
73
74 let n = embeddings.len() / dim;
75 if n < num_centroids {
76 return Err(crate::Error::InvalidInput(format!(
77 "Insufficient training data: {n} samples for {num_centroids} centroids"
78 )));
79 }
80
81 let centroids = Self::kmeans_clustering(embeddings, dim, num_centroids, iterations);
83
84 let residuals = Self::compute_all_residuals(embeddings, dim, ¢roids, num_centroids);
86
87 let (bucket_cutoffs, bucket_weights) =
89 Self::learn_quantization_params(&residuals, dim, nbits);
90
91 Ok(Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits })
92 }
93
94 #[must_use]
96 pub fn with_params(
97 centroids: Vec<f32>,
98 num_centroids: usize,
99 dim: usize,
100 bucket_cutoffs: Vec<f32>,
101 bucket_weights: Vec<f32>,
102 nbits: u8,
103 ) -> Self {
104 Self { centroids, num_centroids, dim, bucket_cutoffs, bucket_weights, nbits }
105 }
106
107 #[must_use]
109 pub fn num_centroids(&self) -> usize {
110 self.num_centroids
111 }
112
113 #[must_use]
115 pub fn dim(&self) -> usize {
116 self.dim
117 }
118
119 #[must_use]
121 pub fn nbits(&self) -> u8 {
122 self.nbits
123 }
124
125 #[must_use]
127 pub fn packed_size(&self) -> usize {
128 (self.dim * self.nbits as usize + 7) / 8
129 }
130
131 #[must_use]
133 pub fn centroid(&self, id: usize) -> &[f32] {
134 let start = id * self.dim;
135 &self.centroids[start..start + self.dim]
136 }
137
138 #[must_use]
140 pub fn centroids(&self) -> &[f32] {
141 &self.centroids
142 }
143
144 #[must_use]
146 pub fn find_nearest_centroid(&self, embedding: &[f32]) -> usize {
147 let mut best_id = 0;
148 let mut best_dist = f32::MAX;
149
150 for c in 0..self.num_centroids {
151 let centroid = self.centroid(c);
152 let dist = Self::squared_distance(embedding, centroid);
153 if dist < best_dist {
154 best_dist = dist;
155 best_id = c;
156 }
157 }
158
159 best_id
160 }
161
162 #[must_use]
164 pub fn compress(&self, embedding: &[f32]) -> (usize, Vec<u8>) {
165 let centroid_id = self.find_nearest_centroid(embedding);
167 let centroid = self.centroid(centroid_id);
168
169 let residual: Vec<f32> =
171 embedding.iter().zip(centroid.iter()).map(|(e, c)| e - c).collect();
172
173 let codes = self.quantize_residual(&residual);
175
176 let packed = self.pack_codes(&codes);
178
179 (centroid_id, packed)
180 }
181
182 #[must_use]
193 pub fn decompress_score(
194 &self,
195 query_token: &[f32],
196 centroid_id: usize,
197 centroid_score: f32,
198 packed_residual: &[u8],
199 ) -> f32 {
200 let _ = centroid_id; let codes = self.unpack_codes(packed_residual);
204
205 let num_buckets = 1usize << self.nbits;
207 let residual_score: f32 = codes
208 .iter()
209 .enumerate()
210 .map(|(d, &code)| {
211 let weight_idx = d * num_buckets + code as usize;
212 query_token[d] * self.bucket_weights[weight_idx]
213 })
214 .sum();
215
216 centroid_score + residual_score
217 }
218
219 #[must_use]
221 pub fn centroid_score(&self, query_token: &[f32], centroid_id: usize) -> f32 {
222 let centroid = self.centroid(centroid_id);
223 Self::dot_product(query_token, centroid)
224 }
225
226 fn quantize_residual(&self, residual: &[f32]) -> Vec<u8> {
228 let num_buckets = 1usize << self.nbits;
229
230 residual
231 .iter()
232 .enumerate()
233 .map(|(d, &value)| {
234 let cutoff_start = d * (num_buckets - 1);
236 let cutoffs = &self.bucket_cutoffs[cutoff_start..cutoff_start + num_buckets - 1];
237
238 cutoffs.iter().position(|&c| value < c).unwrap_or(num_buckets - 1) as u8
240 })
241 .collect()
242 }
243
244 fn pack_codes(&self, codes: &[u8]) -> Vec<u8> {
246 match self.nbits {
247 2 => {
248 codes
250 .chunks(4)
251 .map(|chunk| {
252 let mut byte = 0u8;
253 for (i, &code) in chunk.iter().enumerate() {
254 byte |= (code & 0x03) << (i * 2);
255 }
256 byte
257 })
258 .collect()
259 }
260 4 => {
261 codes
263 .chunks(2)
264 .map(|chunk| {
265 let low = chunk.first().copied().unwrap_or(0) & 0x0F;
266 let high = chunk.get(1).copied().unwrap_or(0) & 0x0F;
267 low | (high << 4)
268 })
269 .collect()
270 }
271 _ => panic!("Unsupported nbits: {}", self.nbits),
272 }
273 }
274
275 fn unpack_codes(&self, packed: &[u8]) -> Vec<u8> {
277 match self.nbits {
278 2 => packed
279 .iter()
280 .flat_map(|&byte| (0..4).map(move |i| (byte >> (i * 2)) & 0x03))
281 .take(self.dim)
282 .collect(),
283 4 => packed
284 .iter()
285 .flat_map(|&byte| vec![byte & 0x0F, (byte >> 4) & 0x0F])
286 .take(self.dim)
287 .collect(),
288 _ => panic!("Unsupported nbits: {}", self.nbits),
289 }
290 }
291
292 fn kmeans_clustering(embeddings: &[f32], dim: usize, k: usize, iterations: usize) -> Vec<f32> {
296 let n = embeddings.len() / dim;
297
298 let mut centroids = Self::kmeans_plus_plus_init(embeddings, dim, k);
300 let mut assignments = vec![0usize; n];
301
302 for _ in 0..iterations {
303 for i in 0..n {
305 let point = &embeddings[i * dim..(i + 1) * dim];
306 let mut best_dist = f32::MAX;
307 let mut best_c = 0;
308
309 for c in 0..k {
310 let centroid = ¢roids[c * dim..(c + 1) * dim];
311 let dist = Self::squared_distance(point, centroid);
312 if dist < best_dist {
313 best_dist = dist;
314 best_c = c;
315 }
316 }
317 assignments[i] = best_c;
318 }
319
320 let mut new_centroids = vec![0.0f32; k * dim];
322 let mut counts = vec![0usize; k];
323
324 for i in 0..n {
325 let c = assignments[i];
326 counts[c] += 1;
327 let point = &embeddings[i * dim..(i + 1) * dim];
328 for d in 0..dim {
329 new_centroids[c * dim + d] += point[d];
330 }
331 }
332
333 for c in 0..k {
334 if counts[c] > 0 {
335 for d in 0..dim {
336 new_centroids[c * dim + d] /= counts[c] as f32;
337 }
338 } else {
339 for d in 0..dim {
341 new_centroids[c * dim + d] = centroids[c * dim + d];
342 }
343 }
344 }
345
346 centroids = new_centroids;
347 }
348
349 centroids
350 }
351
352 fn kmeans_plus_plus_init(embeddings: &[f32], dim: usize, k: usize) -> Vec<f32> {
354 let n = embeddings.len() / dim;
355 let mut centroids = Vec::with_capacity(k * dim);
356 let mut rng_state = 42u64; let first_idx = Self::simple_random(&mut rng_state, n);
360 centroids.extend_from_slice(&embeddings[first_idx * dim..(first_idx + 1) * dim]);
361
362 let mut distances = vec![f32::MAX; n];
363
364 for _ in 1..k {
365 let num_centroids = centroids.len() / dim;
366
367 for i in 0..n {
369 let point = &embeddings[i * dim..(i + 1) * dim];
370 let centroid = ¢roids[(num_centroids - 1) * dim..num_centroids * dim];
371 let dist = Self::squared_distance(point, centroid);
372 distances[i] = distances[i].min(dist);
373 }
374
375 let total: f32 = distances.iter().sum();
377 if total <= 0.0 {
378 let idx = Self::simple_random(&mut rng_state, n);
380 centroids.extend_from_slice(&embeddings[idx * dim..(idx + 1) * dim]);
381 continue;
382 }
383
384 let threshold = Self::simple_random_f32(&mut rng_state) * total;
385 let mut cumsum = 0.0f32;
386 let mut chosen = 0;
387
388 for (i, &d) in distances.iter().enumerate() {
389 cumsum += d;
390 if cumsum >= threshold {
391 chosen = i;
392 break;
393 }
394 }
395
396 centroids.extend_from_slice(&embeddings[chosen * dim..(chosen + 1) * dim]);
397 }
398
399 centroids
400 }
401
402 fn compute_all_residuals(
404 embeddings: &[f32],
405 dim: usize,
406 centroids: &[f32],
407 num_centroids: usize,
408 ) -> Vec<f32> {
409 let n = embeddings.len() / dim;
410 let mut residuals = Vec::with_capacity(n * dim);
411
412 for i in 0..n {
413 let point = &embeddings[i * dim..(i + 1) * dim];
414
415 let mut best_c = 0;
417 let mut best_dist = f32::MAX;
418 for c in 0..num_centroids {
419 let centroid = ¢roids[c * dim..(c + 1) * dim];
420 let dist = Self::squared_distance(point, centroid);
421 if dist < best_dist {
422 best_dist = dist;
423 best_c = c;
424 }
425 }
426
427 let centroid = ¢roids[best_c * dim..(best_c + 1) * dim];
429 for d in 0..dim {
430 residuals.push(point[d] - centroid[d]);
431 }
432 }
433
434 residuals
435 }
436
437 fn learn_quantization_params(residuals: &[f32], dim: usize, nbits: u8) -> (Vec<f32>, Vec<f32>) {
439 let num_buckets = 1usize << nbits;
440 let n = residuals.len() / dim;
441
442 let mut cutoffs = Vec::with_capacity(dim * (num_buckets - 1));
443 let mut weights = Vec::with_capacity(dim * num_buckets);
444
445 for d in 0..dim {
446 let mut values: Vec<f32> = (0..n).map(|i| residuals[i * dim + d]).collect();
448 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
449
450 for b in 1..num_buckets {
452 let quantile_idx = (b * n) / num_buckets;
453 cutoffs.push(values[quantile_idx.min(n - 1)]);
454 }
455
456 for b in 0..num_buckets {
458 let start = (b * n) / num_buckets;
459 let end = ((b + 1) * n) / num_buckets;
460 let end = end.max(start + 1).min(n);
461
462 let sum: f32 = values[start..end].iter().sum();
463 let mean = sum / (end - start) as f32;
464 weights.push(mean);
465 }
466 }
467
468 (cutoffs, weights)
469 }
470
471 fn squared_distance(a: &[f32], b: &[f32]) -> f32 {
474 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
475 }
476
477 fn dot_product(a: &[f32], b: &[f32]) -> f32 {
478 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
479 }
480
481 fn simple_random(state: &mut u64, max: usize) -> usize {
482 *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
483 ((*state >> 33) as usize) % max
484 }
485
486 fn simple_random_f32(state: &mut u64) -> f32 {
487 *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
488 ((*state >> 33) as f32) / (u32::MAX as f32)
489 }
490}
491
492pub struct ResidualCodecBuilder {
494 config: WarpIndexConfig,
495}
496
497impl ResidualCodecBuilder {
498 #[must_use]
500 pub fn new(config: WarpIndexConfig) -> Self {
501 Self { config }
502 }
503
504 pub fn train(&self, embeddings: &[f32]) -> Result<ResidualCodec> {
506 ResidualCodec::train(
507 embeddings,
508 self.config.token_dim,
509 self.config.num_centroids,
510 self.config.nbits,
511 self.config.kmeans_iterations,
512 )
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 fn generate_test_embeddings(n: usize, dim: usize) -> Vec<f32> {
521 let mut embeddings = Vec::with_capacity(n * dim);
522 let mut rng_state = 12345u64;
523
524 for _ in 0..(n * dim) {
525 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
526 let val = ((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
527 embeddings.push(val);
528 }
529
530 embeddings
531 }
532
533 #[test]
536 fn test_codec_train_2bit() {
537 let embeddings = generate_test_embeddings(1000, 32);
538 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
539
540 assert_eq!(codec.num_centroids(), 16);
541 assert_eq!(codec.dim(), 32);
542 assert_eq!(codec.nbits(), 2);
543 }
544
545 #[test]
546 fn test_codec_train_4bit() {
547 let embeddings = generate_test_embeddings(1000, 32);
548 let codec = ResidualCodec::train(&embeddings, 32, 16, 4, 5).unwrap();
549
550 assert_eq!(codec.nbits(), 4);
551 }
552
553 #[test]
554 fn test_codec_train_insufficient_data() {
555 let embeddings = generate_test_embeddings(5, 32);
556 let result = ResidualCodec::train(&embeddings, 32, 16, 2, 5);
557
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_codec_train_invalid_nbits() {
563 let embeddings = generate_test_embeddings(100, 32);
564 let result = ResidualCodec::train(&embeddings, 32, 16, 3, 5);
565
566 assert!(result.is_err());
567 }
568
569 #[test]
572 fn test_codec_compress() {
573 let embeddings = generate_test_embeddings(500, 32);
574 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
575
576 let test_vec = &embeddings[0..32];
577 let (centroid_id, packed) = codec.compress(test_vec);
578
579 assert!(centroid_id < 16);
580 assert_eq!(packed.len(), codec.packed_size());
581 }
582
583 #[test]
584 fn test_codec_packed_size_2bit() {
585 let embeddings = generate_test_embeddings(500, 128);
586 let codec = ResidualCodec::train(&embeddings, 128, 16, 2, 5).unwrap();
587
588 assert_eq!(codec.packed_size(), 32);
590 }
591
592 #[test]
593 fn test_codec_packed_size_4bit() {
594 let embeddings = generate_test_embeddings(500, 128);
595 let codec = ResidualCodec::train(&embeddings, 128, 16, 4, 5).unwrap();
596
597 assert_eq!(codec.packed_size(), 64);
599 }
600
601 #[test]
604 fn test_pack_unpack_2bit() {
605 let embeddings = generate_test_embeddings(500, 8);
606 let codec = ResidualCodec::train(&embeddings, 8, 16, 2, 5).unwrap();
607
608 let codes: Vec<u8> = vec![0, 1, 2, 3, 0, 1, 2, 3];
609 let packed = codec.pack_codes(&codes);
610 let unpacked = codec.unpack_codes(&packed);
611
612 assert_eq!(codes, unpacked);
613 }
614
615 #[test]
616 fn test_pack_unpack_4bit() {
617 let embeddings = generate_test_embeddings(500, 8);
618 let codec = ResidualCodec::train(&embeddings, 8, 16, 4, 5).unwrap();
619
620 let codes: Vec<u8> = vec![0, 5, 10, 15, 1, 6, 11, 14];
621 let packed = codec.pack_codes(&codes);
622 let unpacked = codec.unpack_codes(&packed);
623
624 assert_eq!(codes, unpacked);
625 }
626
627 #[test]
630 fn test_decompress_score() {
631 let embeddings = generate_test_embeddings(500, 32);
632 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
633
634 let query = &embeddings[0..32];
635 let doc = &embeddings[32..64];
636
637 let (centroid_id, packed) = codec.compress(doc);
639
640 let centroid_score = codec.centroid_score(query, centroid_id);
642
643 let approx_score = codec.decompress_score(query, centroid_id, centroid_score, &packed);
645
646 let exact_score: f32 = query.iter().zip(doc.iter()).map(|(q, d)| q * d).sum();
648
649 let error = (approx_score - exact_score).abs();
651 assert!(
652 error < exact_score.abs() * 0.5 + 1.0,
653 "Error too large: approx={}, exact={}, error={}",
654 approx_score,
655 exact_score,
656 error
657 );
658 }
659
660 #[test]
661 fn test_centroid_score() {
662 let embeddings = generate_test_embeddings(500, 32);
663 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
664
665 let query = &embeddings[0..32];
666 let centroid = codec.centroid(0);
667
668 let expected: f32 = query.iter().zip(centroid.iter()).map(|(q, c)| q * c).sum();
669 let actual = codec.centroid_score(query, 0);
670
671 assert!((expected - actual).abs() < 1e-6);
672 }
673
674 #[test]
677 fn test_find_nearest_centroid() {
678 let embeddings = generate_test_embeddings(500, 32);
679 let codec = ResidualCodec::train(&embeddings, 32, 16, 2, 5).unwrap();
680
681 let centroid_0 = codec.centroid(0).to_vec();
683 let nearest = codec.find_nearest_centroid(¢roid_0);
684 assert_eq!(nearest, 0);
685 }
686
687 #[test]
690 fn test_codec_builder() {
691 let config = WarpIndexConfig::new(2, 16, 32).with_kmeans_iterations(5);
692 let builder = ResidualCodecBuilder::new(config);
693
694 let embeddings = generate_test_embeddings(500, 32);
695 let codec = builder.train(&embeddings).unwrap();
696
697 assert_eq!(codec.num_centroids(), 16);
698 assert_eq!(codec.dim(), 32);
699 }
700
701 #[test]
704 fn test_codec_serialization() {
705 let embeddings = generate_test_embeddings(500, 16);
706 let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 5).unwrap();
707
708 let json = serde_json::to_string(&codec).unwrap();
709 let deserialized: ResidualCodec = serde_json::from_str(&json).unwrap();
710
711 assert_eq!(codec.num_centroids(), deserialized.num_centroids());
712 assert_eq!(codec.dim(), deserialized.dim());
713 assert_eq!(codec.nbits(), deserialized.nbits());
714 }
715
716 use proptest::prelude::*;
719
720 proptest! {
721 #[test]
722 fn prop_compress_produces_valid_centroid_id(
723 seed in 0u64..1000
724 ) {
725 let mut embeddings = Vec::with_capacity(200 * 16);
726 let mut rng_state = seed;
727 for _ in 0..(200 * 16) {
728 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
729 embeddings.push(((rng_state >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0);
730 }
731
732 let codec = ResidualCodec::train(&embeddings, 16, 8, 2, 3).unwrap();
733 let test_vec = &embeddings[0..16];
734 let (centroid_id, _) = codec.compress(test_vec);
735
736 prop_assert!(centroid_id < 8);
737 }
738
739 #[test]
740 fn prop_packed_size_matches_config(
741 nbits in prop::sample::select(vec![2u8, 4]),
742 dim in 8usize..64
743 ) {
744 let num_samples = 100 * dim;
745 let embeddings = generate_test_embeddings(num_samples / dim, dim);
746
747 if let Ok(codec) = ResidualCodec::train(&embeddings, dim, 8, nbits, 3) {
748 let expected_size = (dim * nbits as usize + 7) / 8;
749 prop_assert_eq!(codec.packed_size(), expected_size);
750 }
751 }
752 }
753}