tinyquant_io/compressed_vector/
from_bytes.rs1use crate::compressed_vector::header::{decode_header, FORMAT_VERSION, HEADER_SIZE};
4use crate::compressed_vector::unpack::unpack_indices;
5use crate::errors::IoError;
6use tinyquant_core::codec::CompressedVector;
7use tinyquant_core::errors::CodecError;
8
9const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
10
11pub fn from_bytes(data: &[u8]) -> Result<CompressedVector, IoError> {
18 let (version, config_hash, dimension, bit_width) = decode_header(data)?;
19
20 if version != FORMAT_VERSION {
21 return Err(IoError::UnknownVersion { got: version });
22 }
23 if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
24 return Err(IoError::InvalidBitWidth { got: bit_width });
25 }
26
27 let dim = dimension as usize;
28 let packed_len = (dim * bit_width as usize + 7) / 8;
29 let payload_start = HEADER_SIZE;
30 let flag_offset = payload_start + packed_len;
31
32 if data.len() < flag_offset + 1 {
33 return Err(IoError::Truncated {
34 needed: flag_offset + 1,
35 got: data.len(),
36 });
37 }
38
39 let packed = data
41 .get(payload_start..flag_offset)
42 .ok_or(IoError::Truncated {
43 needed: flag_offset,
44 got: data.len(),
45 })?;
46 let mut indices = vec![0u8; dim];
47 unpack_indices(packed, dim, bit_width, &mut indices);
48
49 let residual_flag = data.get(flag_offset).copied().ok_or(IoError::Truncated {
51 needed: flag_offset + 1,
52 got: data.len(),
53 })?;
54
55 let residual: Option<Box<[u8]>> = match residual_flag {
56 0x00 => None,
57 0x01 => {
58 let rlen_start = flag_offset + 1;
59 let rlen_end = rlen_start + 4;
60 if data.len() < rlen_end {
61 return Err(IoError::Truncated {
62 needed: rlen_end,
63 got: data.len(),
64 });
65 }
66 let rlen_bytes: [u8; 4] = data
67 .get(rlen_start..rlen_end)
68 .ok_or(IoError::Truncated {
69 needed: rlen_end,
70 got: data.len(),
71 })?
72 .try_into()
73 .map_err(|_| IoError::Truncated {
74 needed: rlen_end,
75 got: data.len(),
76 })?;
77 let rlen = u32::from_le_bytes(rlen_bytes) as usize;
78 let rdata_start = rlen_end;
79 let rdata_end = rdata_start + rlen;
80 if data.len() < rdata_end {
81 return Err(IoError::Truncated {
82 needed: rdata_end,
83 got: data.len(),
84 });
85 }
86 let rdata = data.get(rdata_start..rdata_end).ok_or(IoError::Truncated {
87 needed: rdata_end,
88 got: data.len(),
89 })?;
90 Some(rdata.to_vec().into_boxed_slice())
91 }
92 got => {
93 return Err(IoError::Decode(CodecError::InvalidResidualFlag { got }));
94 }
95 };
96
97 let cv = CompressedVector::new(
98 indices.into_boxed_slice(),
99 residual,
100 std::sync::Arc::from(config_hash),
101 dimension,
102 bit_width,
103 )?;
104 Ok(cv)
105}