1use half::f16;
9
10use crate::error::{require, Result, TurboQuantError};
11
12pub mod indices;
13pub use indices::{
14 pack_indices_2bit, pack_indices_3bit, pack_indices_4bit, unpack_indices_2bit,
15 unpack_indices_3bit, unpack_indices_4bit,
16};
17#[allow(unused_imports)]
19use indices::{
20 chunk_to_2bit_array, chunk_to_3bit_array, chunk_to_4bit_array, chunk_to_packed_3bit_array,
21 has_2bit_remainder, has_3bit_remainder, has_4bit_remainder, num_2bit_groups, num_3bit_groups,
22 num_4bit_pairs, packed_2bit_capacity, packed_3bit_capacity, packed_4bit_capacity,
23 pad_remainder_2bit, pad_remainder_3bit, trailing_4bit_pair,
24};
25
26pub(crate) const BITS_TQ2: u8 = 2;
32
33pub(crate) const BITS_TQ3: u8 = 3;
35
36pub(crate) const BITS_TQ4: u8 = 4;
38
39const PACK_2BIT_GROUP_SIZE: usize = 4;
41
42const PACK_3BIT_GROUP_SIZE: usize = 8;
44
45const PACK_3BIT_BYTES: usize = 3;
47
48const PACK_4BIT_GROUP_SIZE: usize = 2;
50
51const MASK_3BIT: u8 = 0x7;
53
54const MASK_2BIT: u8 = 0x3;
56
57const MASK_1BIT: u8 = 0x1;
59
60const MASK_4BIT: u8 = 0xF;
62
63const SHIFT_3: u32 = 3;
65
66const SHIFT_4: u32 = 4;
68
69const SHIFT_5: u32 = 5;
71
72const SHIFT_6: u32 = 6;
74
75const SHIFT_7: u32 = 7;
77
78const SHIFT_1: u32 = 1;
80
81const SHIFT_2: u32 = 2;
83
84const SCALE_SIZE_BYTES: usize = 2;
86
87#[derive(Clone, Copy)]
93pub struct TurboQuantConfig {
94 pub(crate) bits: u8,
96 pub(crate) dim: usize,
98 pub(crate) rotation_seed: u64,
100}
101
102pub(crate) fn is_valid_bits(bits: u8) -> bool {
106 bits == BITS_TQ2 || bits == BITS_TQ3 || bits == BITS_TQ4
107}
108
109pub(crate) fn is_valid_dim(dim: usize) -> bool {
113 dim > 0 && dim.is_power_of_two()
114}
115
116impl TurboQuantConfig {
117 pub fn new(bits: u8, dim: usize) -> Result<Self> {
124 require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))?;
125 require(is_valid_dim(dim), TurboQuantError::InvalidDimension(dim))?;
126 Ok(Self {
127 bits,
128 dim,
129 rotation_seed: 0,
130 })
131 }
132
133 pub fn with_seed(mut self, seed: u64) -> Self {
136 self.rotation_seed = seed;
137 self
138 }
139}
140
141pub struct PackedBlock {
150 pub bits: u8,
152 pub scale: f16,
154 pub packed_indices: Vec<u8>,
156}
157
158impl PackedBlock {
159 pub fn new(bits: u8, scale: f16, indices: &[u8]) -> Self {
166 let pack = |indices: &[u8]| -> Vec<u8> {
167 match bits {
168 BITS_TQ2 => pack_indices_2bit(indices),
169 BITS_TQ3 => pack_indices_3bit(indices),
170 BITS_TQ4 => pack_indices_4bit(indices),
171 _ => unreachable!("bits validated to be 2, 3, or 4"),
172 }
173 };
174 Self {
175 bits,
176 scale,
177 packed_indices: pack(indices),
178 }
179 }
180
181 pub fn size_bytes(&self) -> usize {
183 SCALE_SIZE_BYTES + self.packed_indices.len()
184 }
185
186 pub fn from_raw(bits: u8, scale: f16, packed_indices: Vec<u8>) -> Self {
194 Self {
195 bits,
196 scale,
197 packed_indices,
198 }
199 }
200
201 pub fn unpack_into(&self, count: usize, buf: &mut Vec<u8>) {
209 buf.clear();
210 let do_unpack = |packed: &[u8], out: &mut Vec<u8>| match self.bits {
211 BITS_TQ2 => out.extend_from_slice(&unpack_indices_2bit(packed, count)),
212 BITS_TQ3 => out.extend_from_slice(&unpack_indices_3bit(packed, count)),
213 BITS_TQ4 => out.extend_from_slice(&unpack_indices_4bit(packed, count)),
214 _ => unreachable!("bits validated"),
215 };
216 do_unpack(&self.packed_indices, buf);
217 buf.truncate(count);
218 }
219
220 pub fn unpack(&self, count: usize) -> Vec<u8> {
225 let do_unpack = |packed: &[u8]| match self.bits {
226 BITS_TQ2 => unpack_indices_2bit(packed, count),
227 BITS_TQ3 => unpack_indices_3bit(packed, count),
228 BITS_TQ4 => unpack_indices_4bit(packed, count),
229 _ => unreachable!("bits validated"),
230 };
231 do_unpack(&self.packed_indices)
232 }
233}
234
235pub fn pack_2bit(values: &[u8; PACK_2BIT_GROUP_SIZE]) -> u8 {
243 (values[0] & MASK_2BIT)
244 | ((values[1] & MASK_2BIT) << SHIFT_2)
245 | ((values[2] & MASK_2BIT) << SHIFT_4)
246 | ((values[3] & MASK_2BIT) << SHIFT_6)
247}
248
249pub fn unpack_2bit(packed: u8) -> [u8; PACK_2BIT_GROUP_SIZE] {
251 [
252 packed & MASK_2BIT,
253 (packed >> SHIFT_2) & MASK_2BIT,
254 (packed >> SHIFT_4) & MASK_2BIT,
255 (packed >> SHIFT_6) & MASK_2BIT,
256 ]
257}
258
259pub fn pack_3bit(values: &[u8; PACK_3BIT_GROUP_SIZE]) -> [u8; PACK_3BIT_BYTES] {
267 let mut packed = [0u8; PACK_3BIT_BYTES];
268 packed[0] = (values[0] & MASK_3BIT)
269 | ((values[1] & MASK_3BIT) << SHIFT_3)
270 | ((values[2] & MASK_2BIT) << SHIFT_6);
271 packed[1] = ((values[2] >> SHIFT_2) & MASK_1BIT)
272 | ((values[3] & MASK_3BIT) << SHIFT_1)
273 | ((values[4] & MASK_3BIT) << SHIFT_4)
274 | ((values[5] & MASK_1BIT) << SHIFT_7);
275 packed[2] = ((values[5] >> SHIFT_1) & MASK_2BIT)
276 | ((values[6] & MASK_3BIT) << SHIFT_2)
277 | ((values[7] & MASK_3BIT) << SHIFT_5);
278 packed
279}
280
281pub fn unpack_3bit(packed: &[u8; PACK_3BIT_BYTES]) -> [u8; PACK_3BIT_GROUP_SIZE] {
283 let mut values = [0u8; PACK_3BIT_GROUP_SIZE];
284 values[0] = packed[0] & MASK_3BIT;
285 values[1] = (packed[0] >> SHIFT_3) & MASK_3BIT;
286 values[2] = ((packed[0] >> SHIFT_6) & MASK_2BIT) | ((packed[1] & MASK_1BIT) << SHIFT_2);
287 values[3] = (packed[1] >> SHIFT_1) & MASK_3BIT;
288 values[4] = (packed[1] >> SHIFT_4) & MASK_3BIT;
289 values[5] = ((packed[1] >> SHIFT_7) & MASK_1BIT) | ((packed[2] & MASK_2BIT) << SHIFT_1);
290 values[6] = (packed[2] >> SHIFT_2) & MASK_3BIT;
291 values[7] = (packed[2] >> SHIFT_5) & MASK_3BIT;
292 values
293}
294
295pub fn pack_4bit(values: &[u8; 2]) -> u8 {
303 (values[0] & MASK_4BIT) | ((values[1] & MASK_4BIT) << SHIFT_4)
304}
305
306pub fn unpack_4bit(packed: u8) -> [u8; 2] {
308 [packed & MASK_4BIT, (packed >> SHIFT_4) & MASK_4BIT]
309}
310
311#[cfg(test)]
316mod tests {
317 use super::*;
318
319 const TEST_BLOCK_SIZE: usize = 32;
321 const TEST_DIM_128: usize = 128;
323 const TEST_3BIT_GROUPS: usize = 4;
325 const TEST_4BIT_PAIRS: usize = 5;
327 const MAX_3BIT_VALUE: u8 = 7;
329 const MAX_4BIT_VALUE: u8 = 15;
331 const TEST_TRAILING_VALUE: u8 = 9;
333 const TEST_3BIT_EXACT_COUNT: usize = 16;
335 const TEST_3BIT_REMAINDER_COUNT: usize = 11;
337 const TEST_4BIT_EVEN_COUNT: usize = 10;
339 const TEST_4BIT_ODD_COUNT: usize = 7;
341 const TEST_4BIT_LEVELS: u8 = 16;
343
344 const TEST_3BIT_LEVELS: usize = 8;
346 const TEST_SCALE: f32 = 1.5;
348 const TEST_SCALE_HALF: f32 = 0.5;
350 const MAX_2BIT_VALUE: u8 = 3;
352 const TEST_2BIT_EXACT_COUNT: usize = 12;
354 const TEST_2BIT_REMAINDER_COUNT: usize = 7;
356
357 #[test]
360 fn is_valid_bits_accepts_2_3_and_4() {
361 assert!(is_valid_bits(BITS_TQ2));
362 assert!(is_valid_bits(BITS_TQ3));
363 assert!(is_valid_bits(BITS_TQ4));
364 }
365
366 #[test]
367 fn is_valid_bits_rejects_others() {
368 assert!(!is_valid_bits(0));
369 assert!(!is_valid_bits(1));
370 assert!(!is_valid_bits(5));
371 }
372
373 #[test]
376 fn is_valid_dim_accepts_powers_of_two() {
377 assert!(is_valid_dim(TEST_DIM_128 / 2));
378 assert!(is_valid_dim(TEST_DIM_128));
379 }
380
381 #[test]
382 fn is_valid_dim_rejects_invalid() {
383 assert!(!is_valid_dim(0));
384 assert!(!is_valid_dim(3));
385 assert!(!is_valid_dim(100));
386 }
387
388 #[test]
391 fn packed_3bit_capacity_no_remainder() {
392 assert_eq!(
394 packed_3bit_capacity(TEST_3BIT_GROUPS, false),
395 TEST_3BIT_GROUPS * PACK_3BIT_BYTES
396 );
397 }
398
399 #[test]
400 fn packed_3bit_capacity_with_remainder() {
401 assert_eq!(
403 packed_3bit_capacity(TEST_3BIT_GROUPS, true),
404 TEST_3BIT_GROUPS * PACK_3BIT_BYTES + PACK_3BIT_BYTES
405 );
406 }
407
408 #[test]
409 fn packed_3bit_capacity_zero_groups() {
410 assert_eq!(packed_3bit_capacity(0, false), 0);
411 assert_eq!(packed_3bit_capacity(0, true), 3);
412 }
413
414 #[test]
417 fn packed_4bit_capacity_no_remainder() {
418 assert_eq!(
419 packed_4bit_capacity(TEST_4BIT_PAIRS, false),
420 TEST_4BIT_PAIRS
421 );
422 }
423
424 #[test]
425 fn packed_4bit_capacity_with_remainder() {
426 assert_eq!(
427 packed_4bit_capacity(TEST_4BIT_PAIRS, true),
428 TEST_4BIT_PAIRS + 1
429 );
430 }
431
432 #[test]
435 fn chunk_to_3bit_array_preserves_values() {
436 let input: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
437 let arr = chunk_to_3bit_array(&input);
438 assert_eq!(arr, [0, 1, 2, 3, 4, 5, 6, 7]);
439 }
440
441 #[test]
442 fn chunk_to_4bit_array_preserves_values() {
443 let input: Vec<u8> = vec![10, 15];
444 let arr = chunk_to_4bit_array(&input);
445 assert_eq!(arr, [10, 15]);
446 }
447
448 #[test]
451 fn pad_remainder_3bit_pads_correctly() {
452 let tail: Vec<u8> = vec![1, 2, 3];
453 let padded = pad_remainder_3bit(&tail);
454 assert_eq!(padded, [1, 2, 3, 0, 0, 0, 0, 0]);
455 }
456
457 #[test]
458 fn pad_remainder_3bit_single_element() {
459 let tail: Vec<u8> = vec![5];
460 let padded = pad_remainder_3bit(&tail);
461 assert_eq!(padded, [5, 0, 0, 0, 0, 0, 0, 0]);
462 }
463
464 #[test]
467 fn trailing_4bit_pair_handles_single_element() {
468 let pair = trailing_4bit_pair(TEST_TRAILING_VALUE);
469 assert_eq!(pair, [TEST_TRAILING_VALUE, 0]);
470 }
471
472 #[test]
475 fn chunk_to_packed_3bit_array_preserves_values() {
476 let input: Vec<u8> = vec![0xAB, 0xCD, 0xEF];
477 let arr = chunk_to_packed_3bit_array(&input);
478 assert_eq!(arr, [0xAB, 0xCD, 0xEF]);
479 }
480
481 #[test]
484 fn pack_unpack_3bit_identity() {
485 let values: [u8; PACK_3BIT_GROUP_SIZE] = [0, 1, 2, 3, 4, 5, 6, MAX_3BIT_VALUE];
486 let packed = pack_3bit(&values);
487 let unpacked = unpack_3bit(&packed);
488 assert_eq!(values, unpacked);
489 }
490
491 #[test]
492 fn pack_unpack_3bit_zeros() {
493 let values = [0u8; PACK_3BIT_GROUP_SIZE];
494 assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
495 }
496
497 #[test]
498 fn pack_unpack_3bit_max() {
499 let values = [MAX_3BIT_VALUE; PACK_3BIT_GROUP_SIZE];
500 assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
501 }
502
503 #[test]
506 fn pack_unpack_4bit_identity() {
507 let values: [u8; PACK_4BIT_GROUP_SIZE] = [0, MAX_4BIT_VALUE];
508 let packed = pack_4bit(&values);
509 let unpacked = unpack_4bit(packed);
510 assert_eq!(values, unpacked);
511 }
512
513 #[test]
514 fn pack_unpack_4bit_zeros() {
515 let values = [0u8; PACK_4BIT_GROUP_SIZE];
516 assert_eq!(unpack_4bit(pack_4bit(&values)), values);
517 }
518
519 #[test]
520 fn pack_unpack_4bit_max() {
521 let values = [MAX_4BIT_VALUE; PACK_4BIT_GROUP_SIZE];
522 assert_eq!(unpack_4bit(pack_4bit(&values)), values);
523 }
524
525 #[test]
528 fn roundtrip_3bit_exact_multiple() {
529 let indices: Vec<u8> = (0..TEST_3BIT_EXACT_COUNT as u8)
530 .map(|i| i % (MAX_3BIT_VALUE + 1))
531 .collect();
532 let packed = pack_indices_3bit(&indices);
533 let unpacked = unpack_indices_3bit(&packed, indices.len());
534 assert_eq!(indices, unpacked);
535 }
536
537 #[test]
538 fn roundtrip_3bit_with_remainder() {
539 let indices: Vec<u8> = (0..TEST_3BIT_REMAINDER_COUNT as u8)
540 .map(|i| i % (MAX_3BIT_VALUE + 1))
541 .collect();
542 let packed = pack_indices_3bit(&indices);
543 let unpacked = unpack_indices_3bit(&packed, indices.len());
544 assert_eq!(indices, unpacked);
545 }
546
547 #[test]
550 fn roundtrip_4bit_even_count() {
551 let indices: Vec<u8> = (0..TEST_4BIT_EVEN_COUNT as u8)
552 .map(|i| i % TEST_4BIT_LEVELS)
553 .collect();
554 let packed = pack_indices_4bit(&indices);
555 let unpacked = unpack_indices_4bit(&packed, indices.len());
556 assert_eq!(indices, unpacked);
557 }
558
559 #[test]
560 fn roundtrip_4bit_odd_count() {
561 let indices: Vec<u8> = (0..TEST_4BIT_ODD_COUNT as u8)
562 .map(|i| i % TEST_4BIT_LEVELS)
563 .collect();
564 let packed = pack_indices_4bit(&indices);
565 let unpacked = unpack_indices_4bit(&packed, indices.len());
566 assert_eq!(indices, unpacked);
567 }
568
569 #[test]
572 fn config_rejects_invalid_bits() {
573 assert!(TurboQuantConfig::new(1, TEST_BLOCK_SIZE).is_err());
574 assert!(TurboQuantConfig::new(5, TEST_BLOCK_SIZE).is_err());
575 }
576
577 #[test]
578 fn config_rejects_non_power_of_two() {
579 assert!(TurboQuantConfig::new(BITS_TQ3, 33).is_err());
580 assert!(TurboQuantConfig::new(BITS_TQ4, 0).is_err());
581 }
582
583 #[test]
584 fn config_accepts_valid() {
585 assert!(TurboQuantConfig::new(BITS_TQ2, TEST_BLOCK_SIZE).is_ok());
586 assert!(TurboQuantConfig::new(BITS_TQ3, TEST_BLOCK_SIZE).is_ok());
587 assert!(TurboQuantConfig::new(BITS_TQ4, TEST_DIM_128).is_ok());
588 }
589
590 const TQ3_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 12;
595
596 const TQ4_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 16;
599
600 #[test]
601 fn packed_block_tq3_size_bytes() {
602 let indices = vec![0u8; TEST_BLOCK_SIZE];
603 let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
604 assert_eq!(block.size_bytes(), TQ3_D32_EXPECTED_SIZE);
606 }
607
608 #[test]
609 fn packed_block_tq4_size_bytes() {
610 let indices = vec![0u8; TEST_BLOCK_SIZE];
611 let block = PackedBlock::new(BITS_TQ4, f16::from_f32(1.0), &indices);
612 assert_eq!(block.size_bytes(), TQ4_D32_EXPECTED_SIZE);
614 }
615
616 #[test]
619 fn pack_unpack_2bit_identity() {
620 let values: [u8; PACK_2BIT_GROUP_SIZE] = [0, 1, 2, MAX_2BIT_VALUE];
621 let packed = pack_2bit(&values);
622 let unpacked = unpack_2bit(packed);
623 assert_eq!(values, unpacked);
624 }
625
626 #[test]
627 fn pack_unpack_2bit_zeros() {
628 let values = [0u8; PACK_2BIT_GROUP_SIZE];
629 assert_eq!(unpack_2bit(pack_2bit(&values)), values);
630 }
631
632 #[test]
633 fn pack_unpack_2bit_max() {
634 let values = [MAX_2BIT_VALUE; PACK_2BIT_GROUP_SIZE];
635 assert_eq!(unpack_2bit(pack_2bit(&values)), values);
636 }
637
638 #[test]
641 fn roundtrip_2bit_exact_multiple() {
642 let indices: Vec<u8> = (0..TEST_2BIT_EXACT_COUNT as u8)
643 .map(|i| i % (MAX_2BIT_VALUE + 1))
644 .collect();
645 let packed = pack_indices_2bit(&indices);
646 let unpacked = unpack_indices_2bit(&packed, indices.len());
647 assert_eq!(indices, unpacked);
648 }
649
650 #[test]
651 fn roundtrip_2bit_with_remainder() {
652 let indices: Vec<u8> = (0..TEST_2BIT_REMAINDER_COUNT as u8)
653 .map(|i| i % (MAX_2BIT_VALUE + 1))
654 .collect();
655 let packed = pack_indices_2bit(&indices);
656 let unpacked = unpack_indices_2bit(&packed, indices.len());
657 assert_eq!(indices, unpacked);
658 }
659
660 #[test]
663 fn packed_block_tq2_size_bytes() {
664 let indices = vec![0u8; TEST_BLOCK_SIZE];
665 let block = PackedBlock::new(BITS_TQ2, f16::from_f32(1.0), &indices);
666 assert_eq!(block.size_bytes(), 10);
668 }
669
670 #[test]
673 fn packed_indices_returns_raw_bytes() {
674 let indices = vec![1u8, 2, 3, 0, 1, 2, 3, 0];
675 let block = PackedBlock::new(BITS_TQ2, f16::from_f32(TEST_SCALE), &indices);
676 let raw = block.packed_indices;
677 assert_eq!(raw.len(), 2);
679 let block2 = PackedBlock::new(BITS_TQ2, f16::from_f32(TEST_SCALE), &indices);
681 assert_eq!(raw, block2.packed_indices);
682 }
683
684 #[test]
685 fn packed_indices_3bit_length() {
686 let indices = vec![0u8; TEST_DIM_128];
687 let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
688 assert_eq!(block.packed_indices.len(), 48);
690 }
691
692 #[test]
695 fn from_raw_roundtrip() {
696 let indices = vec![3u8, 1, 0, 2, 3, 1, 0, 2];
697 let original = PackedBlock::new(BITS_TQ2, f16::from_f32(2.0), &indices);
698 let reconstructed = PackedBlock::from_raw(
699 original.bits,
700 original.scale,
701 original.packed_indices.to_vec(),
702 );
703 assert_eq!(reconstructed.bits, original.bits);
704 assert_eq!(reconstructed.scale, original.scale);
705 assert_eq!(reconstructed.packed_indices, original.packed_indices);
706 assert_eq!(reconstructed.unpack(indices.len()), indices);
708 }
709
710 #[test]
711 fn from_raw_3bit_roundtrip() {
712 let indices: Vec<u8> = (0..TEST_DIM_128)
713 .map(|i| (i % TEST_3BIT_LEVELS) as u8)
714 .collect();
715 let original = PackedBlock::new(BITS_TQ3, f16::from_f32(TEST_SCALE_HALF), &indices);
716 let reconstructed =
717 PackedBlock::from_raw(BITS_TQ3, original.scale, original.packed_indices.to_vec());
718 assert_eq!(reconstructed.unpack(TEST_DIM_128), indices);
719 }
720}