tinyquant_io/compressed_vector/
from_bytes.rs1use crate::compressed_vector::header::{decode_header, FORMAT_VERSION, HEADER_SIZE};
4use crate::compressed_vector::unpack::unpack_indices;
5use crate::errors::IoError;
6use tinyquant_core::codec::CompressedVector;
7use tinyquant_core::errors::CodecError;
8
9const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
10
11pub fn from_bytes(data: &[u8]) -> Result<CompressedVector, IoError> {
18 let (version, config_hash, dimension, bit_width) = decode_header(data)?;
19
20 if version != FORMAT_VERSION {
21 return Err(IoError::UnknownVersion { got: version });
22 }
23 if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
24 return Err(IoError::InvalidBitWidth { got: bit_width });
25 }
26
27 let dim = dimension as usize;
28 let bw = bit_width as usize;
29 let packed_len = dim
30 .checked_mul(bw)
31 .and_then(|n| n.checked_add(7))
32 .map(|n| n / 8)
33 .ok_or(IoError::InvalidHeader)?;
34 let payload_start = HEADER_SIZE;
35 let flag_offset = payload_start + packed_len;
36
37 if data.len() < flag_offset + 1 {
38 return Err(IoError::Truncated {
39 needed: flag_offset + 1,
40 got: data.len(),
41 });
42 }
43
44 let packed = data
46 .get(payload_start..flag_offset)
47 .ok_or(IoError::Truncated {
48 needed: flag_offset,
49 got: data.len(),
50 })?;
51 let mut indices = vec![0u8; dim];
52 unpack_indices(packed, dim, bit_width, &mut indices);
53
54 let residual_flag = data.get(flag_offset).copied().ok_or(IoError::Truncated {
56 needed: flag_offset + 1,
57 got: data.len(),
58 })?;
59
60 let residual: Option<Box<[u8]>> = match residual_flag {
61 0x00 => None,
62 0x01 => {
63 let rlen_start = flag_offset + 1;
64 let rlen_end = rlen_start + 4;
65 if data.len() < rlen_end {
66 return Err(IoError::Truncated {
67 needed: rlen_end,
68 got: data.len(),
69 });
70 }
71 let rlen_bytes: [u8; 4] = data
72 .get(rlen_start..rlen_end)
73 .ok_or(IoError::Truncated {
74 needed: rlen_end,
75 got: data.len(),
76 })?
77 .try_into()
78 .map_err(|_| IoError::Truncated {
79 needed: rlen_end,
80 got: data.len(),
81 })?;
82 let rlen = u32::from_le_bytes(rlen_bytes) as usize;
83 let rdata_start = rlen_end;
84 let rdata_end = rdata_start + rlen;
85 if data.len() < rdata_end {
86 return Err(IoError::Truncated {
87 needed: rdata_end,
88 got: data.len(),
89 });
90 }
91 let rdata = data.get(rdata_start..rdata_end).ok_or(IoError::Truncated {
92 needed: rdata_end,
93 got: data.len(),
94 })?;
95 Some(rdata.to_vec().into_boxed_slice())
96 }
97 got => {
98 return Err(IoError::Decode(CodecError::InvalidResidualFlag { got }));
99 }
100 };
101
102 let cv = CompressedVector::new(
103 indices.into_boxed_slice(),
104 residual,
105 std::sync::Arc::from(config_hash),
106 dimension,
107 bit_width,
108 )?;
109 Ok(cv)
110}