1use alloc::boxed::Box;
7use alloc::vec;
8use alloc::vec::Vec;
9
10use crate::binary;
11use crate::product::ProductQuantizer;
12use crate::rabitq::RabitqQuantizer;
13use crate::scalar::ScalarQuantizer;
14use crate::sketch::CountMinSketch;
15use crate::traits::Quantizer;
16
17const QUANT_TYPE_SCALAR: u8 = 0;
24const QUANT_TYPE_PRODUCT: u8 = 1;
25const QUANT_TYPE_BINARY: u8 = 2;
26const QUANT_TYPE_RABITQ: u8 = 4;
27
28const RABITQ_VERSION: u8 = 1;
31
32#[derive(Clone, Debug, PartialEq, Eq)]
34pub enum CodecError {
35 TooShort,
37 UnknownQuantType(u8),
39 UnsupportedVersion(u8),
41 InvalidField,
43}
44
45impl core::fmt::Display for CodecError {
46 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
47 match self {
48 Self::TooShort => write!(f, "input data too short"),
49 Self::UnknownQuantType(t) => write!(f, "unknown quant_type: {}", t),
50 Self::UnsupportedVersion(v) => write!(f, "unsupported quant_seg version: {}", v),
51 Self::InvalidField => write!(f, "invalid quant_seg header field"),
52 }
53 }
54}
55
56pub fn encode_quant_seg(quantizer: &dyn Quantizer) -> Vec<u8> {
64 let any: &dyn core::any::Any = quantizer;
67 if let Some(sq) = any.downcast_ref::<ScalarQuantizer>() {
68 encode_scalar_quantizer(sq)
69 } else if let Some(pq) = any.downcast_ref::<ProductQuantizer>() {
70 encode_product_quantizer(pq)
71 } else if let Some(rq) = any.downcast_ref::<RabitqQuantizer>() {
72 encode_rabitq_quantizer(rq)
73 } else if quantizer.tier() as u8 == 2 {
74 encode_binary_quant_seg(quantizer.dim() as u16)
76 } else {
77 panic!("unknown quantizer type")
78 }
79}
80
81pub fn decode_quant_seg(data: &[u8]) -> Result<Box<dyn Quantizer>, CodecError> {
83 if data.len() < 64 {
84 return Err(CodecError::TooShort);
85 }
86
87 let quant_type = data[0];
88 let _tier = data[1];
89 let dim = u16::from_le_bytes([data[2], data[3]]) as usize;
90 let body = &data[64..];
91
92 match quant_type {
93 QUANT_TYPE_SCALAR => Ok(Box::new(decode_scalar(body, dim)?)),
94 QUANT_TYPE_PRODUCT => Ok(Box::new(decode_product(body, dim)?)),
95 QUANT_TYPE_BINARY => Ok(Box::new(BinaryQuantizerWrapper { dim })),
96 QUANT_TYPE_RABITQ => Ok(Box::new(decode_rabitq(data, body, dim)?)),
97 _ => Err(CodecError::UnknownQuantType(quant_type)),
98 }
99}
100
101pub fn encode_scalar_quantizer(sq: &ScalarQuantizer) -> Vec<u8> {
107 let dim = sq.dim as u16;
108 let mut buf = vec![0u8; 64];
109 buf[0] = QUANT_TYPE_SCALAR;
110 buf[1] = 0; buf[2..4].copy_from_slice(&dim.to_le_bytes());
112
113 for &v in &sq.min_vals {
115 buf.extend_from_slice(&v.to_le_bytes());
116 }
117 for &v in &sq.max_vals {
118 buf.extend_from_slice(&v.to_le_bytes());
119 }
120 buf
121}
122
123fn decode_scalar(body: &[u8], dim: usize) -> Result<ScalarQuantizer, CodecError> {
124 let float_bytes = dim * 4;
125 if body.len() < float_bytes * 2 {
126 return Err(CodecError::TooShort);
127 }
128
129 let mut min_vals = Vec::with_capacity(dim);
130 let mut max_vals = Vec::with_capacity(dim);
131
132 for d in 0..dim {
133 let offset = d * 4;
134 let v = f32::from_le_bytes([
135 body[offset],
136 body[offset + 1],
137 body[offset + 2],
138 body[offset + 3],
139 ]);
140 min_vals.push(v);
141 }
142 for d in 0..dim {
143 let offset = (dim + d) * 4;
144 let v = f32::from_le_bytes([
145 body[offset],
146 body[offset + 1],
147 body[offset + 2],
148 body[offset + 3],
149 ]);
150 max_vals.push(v);
151 }
152
153 Ok(ScalarQuantizer {
154 min_vals,
155 max_vals,
156 dim,
157 })
158}
159
160pub fn encode_product_quantizer(pq: &ProductQuantizer) -> Vec<u8> {
166 let dim = (pq.m * pq.sub_dim) as u16;
167 let mut buf = vec![0u8; 64];
168 buf[0] = QUANT_TYPE_PRODUCT;
169 buf[1] = 1; buf[2..4].copy_from_slice(&dim.to_le_bytes());
171
172 buf.extend_from_slice(&(pq.m as u16).to_le_bytes());
175 buf.extend_from_slice(&(pq.k as u16).to_le_bytes());
176 buf.extend_from_slice(&(pq.sub_dim as u16).to_le_bytes());
177
178 for sub_book in &pq.codebooks {
180 for centroid in sub_book {
181 for &val in centroid {
182 buf.extend_from_slice(&val.to_le_bytes());
183 }
184 }
185 }
186
187 buf
188}
189
190fn decode_product(body: &[u8], _dim: usize) -> Result<ProductQuantizer, CodecError> {
191 if body.len() < 6 {
192 return Err(CodecError::TooShort);
193 }
194
195 let m = u16::from_le_bytes([body[0], body[1]]) as usize;
196 let k = u16::from_le_bytes([body[2], body[3]]) as usize;
197 let sub_dim = u16::from_le_bytes([body[4], body[5]]) as usize;
198
199 let codebook_bytes = (m as u64)
203 .checked_mul(k as u64)
204 .and_then(|v| v.checked_mul(sub_dim as u64))
205 .and_then(|v| v.checked_mul(4))
206 .ok_or(CodecError::InvalidField)?;
207 let expected = codebook_bytes
208 .checked_add(6)
209 .ok_or(CodecError::InvalidField)?;
210 if (body.len() as u64) < expected {
211 return Err(CodecError::TooShort);
212 }
213
214 let mut codebooks = Vec::with_capacity(m);
215 let mut offset = 6;
216 for _ in 0..m {
217 let mut sub_book = Vec::with_capacity(k);
218 for _ in 0..k {
219 let mut centroid = Vec::with_capacity(sub_dim);
220 for _ in 0..sub_dim {
221 let v = f32::from_le_bytes([
222 body[offset],
223 body[offset + 1],
224 body[offset + 2],
225 body[offset + 3],
226 ]);
227 centroid.push(v);
228 offset += 4;
229 }
230 sub_book.push(centroid);
231 }
232 codebooks.push(sub_book);
233 }
234
235 Ok(ProductQuantizer {
236 m,
237 k,
238 sub_dim,
239 codebooks,
240 })
241}
242
243fn encode_binary_quant_seg(dim: u16) -> Vec<u8> {
248 let mut buf = vec![0u8; 64];
249 buf[0] = QUANT_TYPE_BINARY;
250 buf[1] = 2; buf[2..4].copy_from_slice(&dim.to_le_bytes());
252 buf
254}
255
256struct BinaryQuantizerWrapper {
258 dim: usize,
259}
260
261impl Quantizer for BinaryQuantizerWrapper {
262 fn encode(&self, vector: &[f32]) -> Vec<u8> {
263 binary::encode_binary(vector)
264 }
265
266 fn decode(&self, codes: &[u8]) -> Vec<f32> {
267 binary::decode_binary(codes, self.dim)
268 }
269
270 fn tier(&self) -> crate::tier::TemperatureTier {
271 crate::tier::TemperatureTier::Cold
272 }
273
274 fn dim(&self) -> usize {
275 self.dim
276 }
277}
278
279pub fn encode_rabitq_quantizer(rq: &RabitqQuantizer) -> Vec<u8> {
294 let mut buf = vec![0u8; 64];
295 buf[0] = QUANT_TYPE_RABITQ;
296 buf[1] = 2; buf[2..4].copy_from_slice(&(rq.dim as u16).to_le_bytes());
298 buf[4] = RABITQ_VERSION;
299 buf[5] = rq.rounds;
300 buf[8..16].copy_from_slice(&rq.seed.to_le_bytes());
302 buf[16..20].copy_from_slice(&(rq.padded_dim as u32).to_le_bytes());
303
304 for &v in &rq.centroid {
305 buf.extend_from_slice(&v.to_le_bytes());
306 }
307 buf
308}
309
310fn decode_rabitq(data: &[u8], body: &[u8], dim: usize) -> Result<RabitqQuantizer, CodecError> {
315 let version = data[4];
317 if version != RABITQ_VERSION {
318 return Err(CodecError::UnsupportedVersion(version));
319 }
320 let rounds = data[5];
321 let seed = u64::from_le_bytes(data[8..16].try_into().expect("len checked"));
322 let padded_dim = u32::from_le_bytes(data[16..20].try_into().expect("len checked")) as usize;
323
324 if dim == 0 || rounds == 0 {
325 return Err(CodecError::InvalidField);
326 }
327 if padded_dim != dim.max(1).next_power_of_two() {
330 return Err(CodecError::InvalidField);
331 }
332
333 let centroid_bytes = dim.checked_mul(4).ok_or(CodecError::InvalidField)?;
334 if body.len() < centroid_bytes {
335 return Err(CodecError::TooShort);
336 }
337 let mut centroid = Vec::with_capacity(dim);
338 for d in 0..dim {
339 let offset = d * 4;
340 centroid.push(f32::from_le_bytes(
341 body[offset..offset + 4].try_into().expect("len checked"),
342 ));
343 }
344
345 Ok(RabitqQuantizer::with_centroid(dim, centroid, seed, rounds))
346}
347
348pub fn encode_sketch_seg(sketch: &CountMinSketch) -> Vec<u8> {
360 let mut buf = vec![0u8; 64]; buf[0..4].copy_from_slice(&(sketch.width as u32).to_le_bytes());
363 buf[4..8].copy_from_slice(&(sketch.depth as u32).to_le_bytes());
364 buf[8..16].copy_from_slice(&sketch.total_accesses.to_le_bytes());
365
366 for row in &sketch.counters {
368 buf.extend_from_slice(row);
369 }
370
371 buf
372}
373
374pub fn decode_sketch_seg(data: &[u8]) -> Result<CountMinSketch, CodecError> {
381 if data.len() < 64 {
382 return Err(CodecError::TooShort);
383 }
384
385 let width = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
386 let depth = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
387 let total_accesses = u64::from_le_bytes([
388 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
389 ]);
390
391 let body = &data[64..];
392
393 if width == 0 && depth != 0 {
397 return Err(CodecError::InvalidField);
398 }
399 let expected = (width as u64)
402 .checked_mul(depth as u64)
403 .ok_or(CodecError::InvalidField)?;
404 if (body.len() as u64) < expected {
405 return Err(CodecError::TooShort);
406 }
407
408 let mut counters = Vec::with_capacity(depth);
410 for row in 0..depth {
411 let start = row * width;
412 counters.push(body[start..start + width].to_vec());
413 }
414
415 Ok(CountMinSketch {
416 counters,
417 width,
418 depth,
419 total_accesses,
420 })
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn scalar_quant_seg_round_trip() {
429 let sq = ScalarQuantizer {
430 min_vals: vec![-1.0, -2.0, -0.5, 0.0],
431 max_vals: vec![1.0, 2.0, 0.5, 1.0],
432 dim: 4,
433 };
434
435 let encoded = encode_scalar_quantizer(&sq);
436 let decoded = decode_quant_seg(&encoded).unwrap();
437
438 assert_eq!(decoded.dim(), 4);
439 assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Hot);
440
441 let test_vec = vec![0.5, 1.0, 0.0, 0.5];
443 let codes_orig = sq.encode_vec(&test_vec);
444 let codes_decoded = decoded.encode(&test_vec);
445 assert_eq!(codes_orig, codes_decoded);
446 }
447
448 #[test]
449 fn product_quant_seg_round_trip() {
450 let pq = ProductQuantizer {
452 m: 2,
453 k: 4,
454 sub_dim: 2,
455 codebooks: vec![
456 vec![
457 vec![0.0, 0.1],
458 vec![0.2, 0.3],
459 vec![0.4, 0.5],
460 vec![0.6, 0.7],
461 ],
462 vec![
463 vec![0.8, 0.9],
464 vec![1.0, 1.1],
465 vec![1.2, 1.3],
466 vec![1.4, 1.5],
467 ],
468 ],
469 };
470
471 let encoded = encode_product_quantizer(&pq);
472 let decoded = decode_quant_seg(&encoded).unwrap();
473
474 assert_eq!(decoded.dim(), 4);
475 assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Warm);
476
477 let test_vec = vec![0.1, 0.2, 0.9, 1.0];
478 let codes_orig = pq.encode_vec(&test_vec);
479 let codes_decoded = decoded.encode(&test_vec);
480 assert_eq!(codes_orig, codes_decoded);
481 }
482
483 #[test]
484 fn binary_quant_seg_round_trip() {
485 let dim: u16 = 16;
486 let encoded = encode_binary_quant_seg(dim);
487 let decoded = decode_quant_seg(&encoded).unwrap();
488
489 assert_eq!(decoded.dim(), 16);
490 assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
491
492 let test_vec: Vec<f32> = (0..16)
493 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
494 .collect();
495 let codes = decoded.encode(&test_vec);
496 let recon = decoded.decode(&codes);
497 assert_eq!(recon.len(), 16);
498 }
499
500 #[test]
501 fn encode_quant_seg_scalar_round_trip() {
502 let sq = ScalarQuantizer {
503 min_vals: vec![-1.0, -2.0, -0.5, 0.0],
504 max_vals: vec![1.0, 2.0, 0.5, 1.0],
505 dim: 4,
506 };
507
508 let encoded = encode_quant_seg(&sq);
509 let decoded = decode_quant_seg(&encoded).unwrap();
510
511 let any: &dyn core::any::Any = decoded.as_ref();
512 let dec_sq = any
513 .downcast_ref::<ScalarQuantizer>()
514 .expect("expected ScalarQuantizer");
515 assert_eq!(dec_sq.min_vals, sq.min_vals);
516 assert_eq!(dec_sq.max_vals, sq.max_vals);
517 assert_eq!(dec_sq.dim, sq.dim);
518 }
519
520 #[test]
521 fn encode_quant_seg_product_round_trip() {
522 let pq = ProductQuantizer {
523 m: 2,
524 k: 2,
525 sub_dim: 2,
526 codebooks: vec![
527 vec![vec![0.0, 0.1], vec![0.2, 0.3]],
528 vec![vec![0.8, 0.9], vec![1.0, 1.1]],
529 ],
530 };
531
532 let encoded = encode_quant_seg(&pq);
533 let decoded = decode_quant_seg(&encoded).unwrap();
534
535 let any: &dyn core::any::Any = decoded.as_ref();
536 let dec_pq = any
537 .downcast_ref::<ProductQuantizer>()
538 .expect("expected ProductQuantizer");
539 assert_eq!(dec_pq.m, pq.m);
540 assert_eq!(dec_pq.k, pq.k);
541 assert_eq!(dec_pq.sub_dim, pq.sub_dim);
542 assert_eq!(dec_pq.codebooks, pq.codebooks);
543 }
544
545 #[test]
546 fn encode_quant_seg_binary_round_trip() {
547 let bq = BinaryQuantizerWrapper { dim: 16 };
548 let encoded = encode_quant_seg(&bq);
549 let decoded = decode_quant_seg(&encoded).unwrap();
550
551 assert_eq!(decoded.dim(), 16);
552 assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
553 }
554
555 #[test]
556 fn decode_quant_seg_malformed_inputs() {
557 assert!(matches!(
559 decode_quant_seg(&[0u8; 8]),
560 Err(CodecError::TooShort)
561 ));
562
563 let mut bad_type = vec![0u8; 64];
565 bad_type[0] = 9;
566 assert!(matches!(
567 decode_quant_seg(&bad_type),
568 Err(CodecError::UnknownQuantType(9))
569 ));
570
571 let mut truncated = vec![0u8; 64];
573 truncated[0] = 0; truncated[2..4].copy_from_slice(&4u16.to_le_bytes());
575 assert!(matches!(
576 decode_quant_seg(&truncated),
577 Err(CodecError::TooShort)
578 ));
579
580 let mut pq_truncated = vec![0u8; 64];
582 pq_truncated[0] = 1; pq_truncated[2..4].copy_from_slice(&4u16.to_le_bytes());
584 pq_truncated.extend_from_slice(&2u16.to_le_bytes()); pq_truncated.extend_from_slice(&4u16.to_le_bytes()); pq_truncated.extend_from_slice(&2u16.to_le_bytes()); assert!(matches!(
588 decode_quant_seg(&pq_truncated),
589 Err(CodecError::TooShort)
590 ));
591 }
592
593 #[test]
594 fn rabitq_quant_seg_round_trip() {
595 let centroid: Vec<f32> = (0..20).map(|i| i as f32 * 0.1 - 1.0).collect();
596 let rq = RabitqQuantizer::with_centroid(20, centroid.clone(), 0x1234_5678_9ABC_DEF0, 3);
597
598 let encoded = encode_rabitq_quantizer(&rq);
599 let decoded = decode_quant_seg(&encoded).unwrap();
600 assert_eq!(decoded.dim(), 20);
601 assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
602
603 let any: &dyn core::any::Any = decoded.as_ref();
604 let dec = any
605 .downcast_ref::<RabitqQuantizer>()
606 .expect("expected RabitqQuantizer");
607 assert_eq!(dec.dim, rq.dim);
608 assert_eq!(dec.padded_dim, 32);
609 assert_eq!(dec.seed, rq.seed);
610 assert_eq!(dec.rounds, rq.rounds);
611 assert_eq!(dec.centroid, centroid);
612
613 let v: Vec<f32> = (0..20).map(|i| (i as f32 * 0.7).sin()).collect();
615 assert_eq!(dec.encode(&v), rq.encode(&v));
616
617 assert_eq!(encode_quant_seg(&rq), encoded);
619 }
620
621 #[test]
622 fn rabitq_quant_seg_rejects_bad_versions_and_fields() {
623 let rq = RabitqQuantizer::with_centroid(8, vec![0.0; 8], 7, 3);
624 let good = encode_rabitq_quantizer(&rq);
625
626 let mut future = good.clone();
628 future[4] = RABITQ_VERSION + 1;
629 assert!(matches!(
630 decode_quant_seg(&future),
631 Err(CodecError::UnsupportedVersion(v)) if v == RABITQ_VERSION + 1
632 ));
633
634 let mut bad_pad = good.clone();
636 bad_pad[16..20].copy_from_slice(&7u32.to_le_bytes());
637 assert!(matches!(
638 decode_quant_seg(&bad_pad),
639 Err(CodecError::InvalidField)
640 ));
641
642 assert!(matches!(
644 decode_quant_seg(&good[..good.len() - 4]),
645 Err(CodecError::TooShort)
646 ));
647
648 let mut zero_rounds = good.clone();
650 zero_rounds[5] = 0;
651 assert!(matches!(
652 decode_quant_seg(&zero_rounds),
653 Err(CodecError::InvalidField)
654 ));
655 }
656
657 #[test]
658 fn pre_rabitq_payloads_still_decode() {
659 let mut legacy = vec![0u8; 64];
663 legacy[0] = 2; legacy[1] = 2; legacy[2..4].copy_from_slice(&24u16.to_le_bytes());
666 let decoded = decode_quant_seg(&legacy).unwrap();
667 assert_eq!(decoded.dim(), 24);
668 assert_eq!(decoded.tier(), crate::tier::TemperatureTier::Cold);
669
670 let sq = ScalarQuantizer {
672 min_vals: vec![-1.0, 0.0],
673 max_vals: vec![1.0, 2.0],
674 dim: 2,
675 };
676 let legacy_scalar = encode_scalar_quantizer(&sq);
677 assert!(decode_quant_seg(&legacy_scalar).is_ok());
678 }
679
680 #[test]
681 fn decode_product_rejects_huge_codebook_dimensions() {
682 let mut pq = vec![0u8; 64];
687 pq[0] = QUANT_TYPE_PRODUCT;
688 pq[2..4].copy_from_slice(&4u16.to_le_bytes());
689 pq.extend_from_slice(&u16::MAX.to_le_bytes()); pq.extend_from_slice(&u16::MAX.to_le_bytes()); pq.extend_from_slice(&u16::MAX.to_le_bytes()); assert!(matches!(decode_quant_seg(&pq), Err(CodecError::TooShort)));
693 }
694
695 #[test]
696 fn decode_sketch_seg_rejects_malformed_inputs() {
697 assert!(matches!(decode_sketch_seg(&[]), Err(CodecError::TooShort)));
699 assert!(matches!(
700 decode_sketch_seg(&[0u8; 16]),
701 Err(CodecError::TooShort)
702 ));
703
704 let mut zero_width = vec![0u8; 64];
708 zero_width[4..8].copy_from_slice(&u32::MAX.to_le_bytes());
709 assert!(matches!(
710 decode_sketch_seg(&zero_width),
711 Err(CodecError::InvalidField)
712 ));
713
714 let mut huge = vec![0u8; 64];
718 huge[0..4].copy_from_slice(&u32::MAX.to_le_bytes());
719 huge[4..8].copy_from_slice(&u32::MAX.to_le_bytes());
720 assert!(matches!(
721 decode_sketch_seg(&huge),
722 Err(CodecError::TooShort)
723 ));
724
725 let mut truncated = vec![0u8; 64 + 10];
727 truncated[0..4].copy_from_slice(&8u32.to_le_bytes()); truncated[4..8].copy_from_slice(&4u32.to_le_bytes()); assert!(matches!(
730 decode_sketch_seg(&truncated),
731 Err(CodecError::TooShort)
732 ));
733
734 let empty = decode_sketch_seg(&[0u8; 64]).expect("empty sketch decodes");
737 assert_eq!(empty.width, 0);
738 assert_eq!(empty.depth, 0);
739 assert!(empty.counters.is_empty());
740 }
741
742 #[test]
743 fn sketch_seg_round_trip() {
744 let mut sketch = CountMinSketch::new(64, 4);
745 for block_id in 0..20u64 {
746 for _ in 0..(block_id + 1) {
747 sketch.increment(block_id);
748 }
749 }
750
751 let encoded = encode_sketch_seg(&sketch);
752 let decoded = decode_sketch_seg(&encoded).expect("well-formed sketch should decode");
753
754 assert_eq!(decoded.width, sketch.width);
755 assert_eq!(decoded.depth, sketch.depth);
756 assert_eq!(decoded.total_accesses, sketch.total_accesses);
757
758 for block_id in 0..20u64 {
760 assert_eq!(decoded.estimate(block_id), sketch.estimate(block_id));
761 }
762 }
763}