1use half::f16;
30
31use crate::codebook::{get_codebook, nearest_centroid, Codebook};
32use crate::error::{check_values_match, Result};
33use crate::packed::{PackedBlock, TurboQuantConfig};
34use crate::rotation::{generate_sign_pattern, rotate, RotationOrder};
35
36const MIN_NORM: f32 = 1e-10;
43
44pub fn l2_norm(data: &[f32]) -> f32 {
52 data.iter().map(|&x| x * x).sum::<f32>().sqrt()
53}
54
55pub fn normalize_inplace(data: &mut [f32], norm: f32) {
61 if norm < MIN_NORM {
62 for v in data.iter_mut() {
63 *v = 0.0;
64 }
65 } else {
66 let inv = 1.0 / norm;
67 for v in data.iter_mut() {
68 *v *= inv;
69 }
70 }
71}
72
73pub fn scale_inplace(data: &mut [f32], factor: f32) {
77 for v in data.iter_mut() {
78 *v *= factor;
79 }
80}
81
82pub fn quantize_coordinates(rotated: &[f32], codebook: &Codebook) -> Vec<u8> {
88 rotated
89 .iter()
90 .map(|&v| nearest_centroid(v as f64, codebook))
91 .collect()
92}
93
94pub fn lookup_centroids(indices: &[u8], codebook: &Codebook) -> Vec<f32> {
98 indices
99 .iter()
100 .map(|&idx| codebook.centroids[idx as usize] as f32)
101 .collect()
102}
103
104pub fn lookup_centroids_into(indices: &[u8], codebook: &Codebook, out: &mut Vec<f32>) {
111 out.clear();
112 out.extend(
113 indices
114 .iter()
115 .map(|&idx| codebook.centroids[idx as usize] as f32),
116 );
117}
118
119fn select_scale(norm: f32) -> f16 {
124 let effective = if norm < MIN_NORM { 0.0 } else { norm };
125 f16::from_f32(effective)
126}
127
128pub fn quantize_vec(config: &TurboQuantConfig, data: &[f32]) -> Result<PackedBlock> {
143 check_values_match(data.len(), config.dim)?;
144
145 let codebook = get_codebook(config.bits, config.dim)?;
146 let sign_pattern = generate_sign_pattern(config.dim, config.rotation_seed);
147 let norm = l2_norm(data);
148
149 let mut working = data.to_vec();
150 normalize_inplace(&mut working, norm);
151 rotate(&mut working, &sign_pattern, RotationOrder::Forward)?;
152
153 let indices = quantize_coordinates(&working, &codebook);
154 let scale = select_scale(norm);
155
156 Ok(PackedBlock::new(config.bits, scale, &indices))
157}
158
159pub fn quantize_vec_with_codebook(
175 config: &TurboQuantConfig,
176 data: &[f32],
177 codebook: &Codebook,
178 sign_pattern: &[f32],
179) -> Result<PackedBlock> {
180 check_values_match(data.len(), config.dim)?;
181
182 let norm = l2_norm(data);
183
184 let mut working = data.to_vec();
185 normalize_inplace(&mut working, norm);
186 rotate(&mut working, sign_pattern, RotationOrder::Forward)?;
187
188 let indices = quantize_coordinates(&working, codebook);
189 let scale = select_scale(norm);
190
191 Ok(PackedBlock::new(config.bits, scale, &indices))
192}
193
194pub fn dequantize_vec(config: &TurboQuantConfig, block: &PackedBlock) -> Result<Vec<f32>> {
208 let codebook = get_codebook(config.bits, config.dim)?;
209 let sign_pattern = generate_sign_pattern(config.dim, config.rotation_seed);
210 dequantize_vec_with_codebook(config, block, &codebook, &sign_pattern)
211}
212
213pub fn dequantize_vec_with_codebook(
226 config: &TurboQuantConfig,
227 block: &PackedBlock,
228 codebook: &Codebook,
229 sign_pattern: &[f32],
230) -> Result<Vec<f32>> {
231 let indices = block.unpack(config.dim);
232 let mut reconstructed = lookup_centroids(&indices, codebook);
233
234 rotate(&mut reconstructed, sign_pattern, RotationOrder::Inverse)?;
235
236 let scale = block.scale.to_f32();
237 scale_inplace(&mut reconstructed, scale);
238
239 Ok(reconstructed)
240}
241
242pub fn dequantize_into_with_codebook(
256 config: &TurboQuantConfig,
257 block: &PackedBlock,
258 codebook: &Codebook,
259 sign_pattern: &[f32],
260 scratch: &mut DequantScratch,
261) -> Result<()> {
262 block.unpack_into(config.dim, &mut scratch.indices);
263 lookup_centroids_into(&scratch.indices, codebook, &mut scratch.values);
264 rotate(&mut scratch.values, sign_pattern, RotationOrder::Inverse)?;
265 scale_inplace(&mut scratch.values, block.scale.to_f32());
266 Ok(())
267}
268
269pub struct DequantScratch {
273 pub(crate) indices: Vec<u8>,
275 pub(crate) values: Vec<f32>,
277}
278
279impl DequantScratch {
280 pub fn new(dim: usize) -> Self {
282 Self {
283 indices: Vec::with_capacity(dim),
284 values: Vec::with_capacity(dim),
285 }
286 }
287}
288
289pub fn dequantize_rotated(config: &TurboQuantConfig, block: &PackedBlock) -> Result<Vec<f32>> {
307 let codebook = get_codebook(config.bits, config.dim)?;
308
309 let indices = block.unpack(config.dim);
310 let mut reconstructed = lookup_centroids(&indices, &codebook);
311
312 let scale = block.scale.to_f32();
313 scale_inplace(&mut reconstructed, scale);
314
315 Ok(reconstructed)
316}
317
318#[cfg(test)]
323mod tests {
324 use super::*;
325 use crate::packed::{BITS_TQ2, BITS_TQ3, BITS_TQ4};
326 use crate::test_utils::pseudo_random_vec;
327
328 const TEST_DIM: usize = 64;
330 const TEST_SMALL_DIM: usize = 8;
332 const TEST_SEED: u64 = 42;
334 const TEST_SEED_3BIT: u64 = 12345;
336 const TEST_SEED_4BIT: u64 = 54321;
338 const TEST_SEED_DETERM: u64 = 99999;
340 const TEST_SEED_ROTATED: u64 = 77777;
342 const FLOAT_EPSILON: f32 = 1e-6;
344 const NORM_EPSILON: f32 = 0.02;
346 const MAX_SINGLE_VEC_RELATIVE_ERROR: f32 = 1.0;
350 const TEST_SCALE_VALUE: f32 = 5.0;
352 const TEST_NORM_VALUE: f32 = 2.0;
354 const TEST_CONST_VAL_A: f32 = 2.5;
356 const TEST_CONST_VAL_B: f32 = 3.0;
358
359 #[test]
362 fn l2_norm_of_unit_vector() {
363 let mut v = vec![0.0_f32; TEST_DIM];
364 v[0] = 1.0;
365 let norm = l2_norm(&v);
366 assert!((norm - 1.0).abs() < FLOAT_EPSILON);
367 }
368
369 #[test]
370 fn l2_norm_of_zero_vector() {
371 let v = vec![0.0_f32; TEST_DIM];
372 let norm = l2_norm(&v);
373 assert!(norm < FLOAT_EPSILON);
374 }
375
376 #[test]
377 fn l2_norm_of_known_vector() {
378 let v = vec![3.0_f32, 4.0];
380 let norm = l2_norm(&v);
381 assert!((norm - TEST_SCALE_VALUE).abs() < FLOAT_EPSILON);
382 }
383
384 #[test]
387 fn normalize_inplace_unit_result() {
388 let mut v = vec![3.0_f32, 4.0];
389 normalize_inplace(&mut v, TEST_SCALE_VALUE);
390 assert!((v[0] - 0.6).abs() < FLOAT_EPSILON);
391 assert!((v[1] - 0.8).abs() < FLOAT_EPSILON);
392 }
393
394 #[test]
395 fn normalize_inplace_zero_norm_gives_zeros() {
396 let mut v = vec![1.0_f32, 2.0, 3.0];
397 normalize_inplace(&mut v, 0.0);
398 for &val in &v {
399 assert!(val.abs() < FLOAT_EPSILON);
400 }
401 }
402
403 #[test]
406 fn scale_inplace_doubles() {
407 let mut v = vec![1.0_f32, 2.0, 3.0];
408 scale_inplace(&mut v, TEST_NORM_VALUE);
409 assert!((v[0] - 2.0).abs() < FLOAT_EPSILON);
410 assert!((v[1] - 4.0).abs() < FLOAT_EPSILON);
411 assert!((v[2] - 6.0).abs() < FLOAT_EPSILON);
412 }
413
414 #[test]
417 fn values_match_true() {
418 assert!(crate::error::values_match(128, 128));
419 }
420
421 #[test]
422 fn values_match_false() {
423 assert!(!crate::error::values_match(64, 128));
424 }
425
426 #[test]
427 fn select_scale_zero_for_tiny_norm() {
428 assert_eq!(select_scale(1e-11).to_f32(), 0.0);
429 }
430
431 #[test]
432 fn select_scale_preserves_normal_norm() {
433 assert!((select_scale(1.0).to_f32() - 1.0).abs() < FLOAT_EPSILON);
434 }
435
436 #[test]
439 fn quantize_lookup_roundtrip_preserves_structure() {
440 let codebook = get_codebook(BITS_TQ3, TEST_DIM).unwrap();
441 let coords: Vec<f32> = codebook.centroids.iter().map(|&c| c as f32).collect();
443 let indices = quantize_coordinates(&coords, &codebook);
444 let recovered = lookup_centroids(&indices, &codebook);
445 for (i, (&orig, &rec)) in coords.iter().zip(recovered.iter()).enumerate() {
447 assert!(
448 (orig - rec).abs() < 0.01,
449 "mismatch at index {i}: orig={orig}, rec={rec}"
450 );
451 }
452 }
453
454 #[test]
457 fn packed_block_tq3() {
458 let indices = vec![0u8; TEST_DIM];
459 let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
460 assert_eq!(block.bits, BITS_TQ3);
461 }
462
463 #[test]
464 fn packed_block_tq4() {
465 let indices = vec![0u8; TEST_DIM];
466 let block = PackedBlock::new(BITS_TQ4, f16::from_f32(1.0), &indices);
467 assert_eq!(block.bits, BITS_TQ4);
468 }
469
470 #[test]
473 fn quantize_vec_rejects_wrong_dimension() {
474 let config = TurboQuantConfig::new(BITS_TQ3, TEST_DIM).unwrap();
475 let data = vec![1.0_f32; TEST_DIM + 1];
476 assert!(quantize_vec(&config, &data).is_err());
477 }
478
479 #[test]
482 fn quantize_dequantize_roundtrip_3bit() {
483 let config = TurboQuantConfig::new(BITS_TQ3, TEST_DIM)
484 .unwrap()
485 .with_seed(TEST_SEED);
486 let data = pseudo_random_vec(TEST_DIM, TEST_SEED_3BIT);
487 let block = quantize_vec(&config, &data).unwrap();
488 let recovered = dequantize_vec(&config, &block).unwrap();
489
490 let orig_norm = l2_norm(&data);
491 let err_norm = l2_norm(
492 &data
493 .iter()
494 .zip(recovered.iter())
495 .map(|(&a, &b)| a - b)
496 .collect::<Vec<_>>(),
497 );
498 let relative_error = err_norm / orig_norm;
499 assert!(
503 relative_error < MAX_SINGLE_VEC_RELATIVE_ERROR,
504 "relative error too large: {relative_error}"
505 );
506 }
507
508 #[test]
509 fn quantize_dequantize_roundtrip_4bit() {
510 let config = TurboQuantConfig::new(BITS_TQ4, TEST_DIM)
511 .unwrap()
512 .with_seed(TEST_SEED);
513 let data = pseudo_random_vec(TEST_DIM, TEST_SEED_4BIT);
514 let block = quantize_vec(&config, &data).unwrap();
515 let recovered = dequantize_vec(&config, &block).unwrap();
516
517 let orig_norm = l2_norm(&data);
518 let err_norm = l2_norm(
519 &data
520 .iter()
521 .zip(recovered.iter())
522 .map(|(&a, &b)| a - b)
523 .collect::<Vec<_>>(),
524 );
525 let relative_error = err_norm / orig_norm;
526 assert!(
527 relative_error < MAX_SINGLE_VEC_RELATIVE_ERROR,
528 "relative error too large: {relative_error}"
529 );
530 }
531
532 #[test]
535 fn quantize_zero_vector_does_not_panic() {
536 let config = TurboQuantConfig::new(BITS_TQ3, TEST_DIM)
537 .unwrap()
538 .with_seed(TEST_SEED);
539 let data = vec![0.0_f32; TEST_DIM];
540 let block = quantize_vec(&config, &data).unwrap();
541 let recovered = dequantize_vec(&config, &block).unwrap();
542 let recovered_norm = l2_norm(&recovered);
543 assert!(
544 recovered_norm < NORM_EPSILON,
545 "recovered norm should be near zero, got {recovered_norm}"
546 );
547 }
548
549 #[test]
552 fn quantize_is_deterministic() {
553 let config = TurboQuantConfig::new(BITS_TQ3, TEST_DIM)
554 .unwrap()
555 .with_seed(TEST_SEED);
556 let data = pseudo_random_vec(TEST_DIM, TEST_SEED_DETERM);
557
558 let block_a = quantize_vec(&config, &data).unwrap();
559 let block_b = quantize_vec(&config, &data).unwrap();
560
561 let rec_a = dequantize_vec(&config, &block_a).unwrap();
562 let rec_b = dequantize_vec(&config, &block_b).unwrap();
563
564 assert_eq!(rec_a, rec_b);
565 }
566
567 #[test]
570 fn dequantize_rotated_differs_from_full() {
571 let config = TurboQuantConfig::new(BITS_TQ3, TEST_DIM)
572 .unwrap()
573 .with_seed(TEST_SEED);
574 let data = pseudo_random_vec(TEST_DIM, TEST_SEED_ROTATED);
575 let block = quantize_vec(&config, &data).unwrap();
576
577 let full = dequantize_vec(&config, &block).unwrap();
578 let rotated = dequantize_rotated(&config, &block).unwrap();
579
580 assert_ne!(full, rotated);
582 let full_norm = l2_norm(&full);
584 let rotated_norm = l2_norm(&rotated);
585 assert!(
586 (full_norm - rotated_norm).abs() < NORM_EPSILON,
587 "norms should be approximately equal: full={full_norm}, rotated={rotated_norm}"
588 );
589 }
590
591 #[test]
594 fn packed_block_scale_tq3() {
595 let block = PackedBlock::new(
596 BITS_TQ3,
597 f16::from_f32(TEST_CONST_VAL_A),
598 &[0u8; TEST_SMALL_DIM],
599 );
600 assert!((block.scale.to_f32() - TEST_CONST_VAL_A).abs() < 0.01);
601 }
602
603 #[test]
604 fn packed_block_scale_tq4() {
605 let block = PackedBlock::new(
606 BITS_TQ4,
607 f16::from_f32(TEST_CONST_VAL_B),
608 &[0u8; TEST_SMALL_DIM],
609 );
610 assert!((block.scale.to_f32() - TEST_CONST_VAL_B).abs() < 0.01);
611 }
612
613 #[test]
616 fn packed_block_size_bytes_tq3() {
617 let config = TurboQuantConfig::new(BITS_TQ3, TEST_DIM)
618 .unwrap()
619 .with_seed(TEST_SEED);
620 let data = pseudo_random_vec(TEST_DIM, TEST_SEED_3BIT);
621 let block = quantize_vec(&config, &data).unwrap();
622
623 assert!(block.size_bytes() > 2);
625 }
626
627 #[test]
628 fn packed_block_size_bytes_tq4() {
629 let config = TurboQuantConfig::new(BITS_TQ4, TEST_DIM)
630 .unwrap()
631 .with_seed(TEST_SEED);
632 let data = pseudo_random_vec(TEST_DIM, TEST_SEED_4BIT);
633 let block = quantize_vec(&config, &data).unwrap();
634
635 assert!(block.size_bytes() > 2);
636 }
637
638 const BLOCK_SIZE_DIM: usize = 128;
644
645 const SCALE_BYTES: usize = 2;
647
648 const TQ2_D128_EXPECTED_BYTES: usize = 34;
651
652 const TQ3_D128_EXPECTED_BYTES: usize = 50;
655
656 const TQ4_D128_EXPECTED_BYTES: usize = 66;
659
660 const BLOCK_SIZE_SEED: u64 = 42;
662
663 const BLOCK_SIZE_DATA_SEED_2: u64 = 20001;
665
666 const BLOCK_SIZE_DATA_SEED_3: u64 = 30001;
668
669 const BLOCK_SIZE_DATA_SEED_4: u64 = 40001;
671
672 #[test]
673 fn polar_block_size_2bit_d128() {
674 let config = TurboQuantConfig::new(BITS_TQ2, BLOCK_SIZE_DIM)
675 .unwrap()
676 .with_seed(BLOCK_SIZE_SEED);
677 let data = pseudo_random_vec(BLOCK_SIZE_DIM, BLOCK_SIZE_DATA_SEED_2);
678 let block = quantize_vec(&config, &data).unwrap();
679
680 assert_eq!(
681 block.size_bytes(),
682 TQ2_D128_EXPECTED_BYTES,
683 "2-bit polar block for d={BLOCK_SIZE_DIM}: expected {TQ2_D128_EXPECTED_BYTES} bytes, \
684 got {} (scale={SCALE_BYTES}, packed={})",
685 block.size_bytes(),
686 block.size_bytes() - SCALE_BYTES
687 );
688 }
689
690 #[test]
691 fn polar_block_size_3bit_d128() {
692 let config = TurboQuantConfig::new(BITS_TQ3, BLOCK_SIZE_DIM)
693 .unwrap()
694 .with_seed(BLOCK_SIZE_SEED);
695 let data = pseudo_random_vec(BLOCK_SIZE_DIM, BLOCK_SIZE_DATA_SEED_3);
696 let block = quantize_vec(&config, &data).unwrap();
697
698 assert_eq!(
699 block.size_bytes(),
700 TQ3_D128_EXPECTED_BYTES,
701 "3-bit polar block for d={BLOCK_SIZE_DIM}: expected {TQ3_D128_EXPECTED_BYTES} bytes, \
702 got {} (scale={SCALE_BYTES}, packed={})",
703 block.size_bytes(),
704 block.size_bytes() - SCALE_BYTES
705 );
706 }
707
708 #[test]
709 fn polar_block_size_4bit_d128() {
710 let config = TurboQuantConfig::new(BITS_TQ4, BLOCK_SIZE_DIM)
711 .unwrap()
712 .with_seed(BLOCK_SIZE_SEED);
713 let data = pseudo_random_vec(BLOCK_SIZE_DIM, BLOCK_SIZE_DATA_SEED_4);
714 let block = quantize_vec(&config, &data).unwrap();
715
716 assert_eq!(
717 block.size_bytes(),
718 TQ4_D128_EXPECTED_BYTES,
719 "4-bit polar block for d={BLOCK_SIZE_DIM}: expected {TQ4_D128_EXPECTED_BYTES} bytes, \
720 got {} (scale={SCALE_BYTES}, packed={})",
721 block.size_bytes(),
722 block.size_bytes() - SCALE_BYTES
723 );
724 }
725}