Skip to main content

tinyquant_io/compressed_vector/
from_bytes.rs

1//! Decode a `CompressedVector` from its binary wire format.
2
3use crate::compressed_vector::header::{decode_header, FORMAT_VERSION, HEADER_SIZE};
4use crate::compressed_vector::unpack::unpack_indices;
5use crate::errors::IoError;
6use tinyquant_core::codec::CompressedVector;
7use tinyquant_core::errors::CodecError;
8
9const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
10
11/// Deserialize a `CompressedVector` from `data`.
12///
13/// # Errors
14///
15/// Returns [`IoError`] if the data is malformed, truncated, has an unknown
16/// version, or contains an invalid bit-width or residual flag.
17pub fn from_bytes(data: &[u8]) -> Result<CompressedVector, IoError> {
18    let (version, config_hash, dimension, bit_width) = decode_header(data)?;
19
20    if version != FORMAT_VERSION {
21        return Err(IoError::UnknownVersion { got: version });
22    }
23    if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
24        return Err(IoError::InvalidBitWidth { got: bit_width });
25    }
26
27    let dim = dimension as usize;
28    let packed_len = (dim * bit_width as usize + 7) / 8;
29    let payload_start = HEADER_SIZE;
30    let flag_offset = payload_start + packed_len;
31
32    if data.len() < flag_offset + 1 {
33        return Err(IoError::Truncated {
34            needed: flag_offset + 1,
35            got: data.len(),
36        });
37    }
38
39    // Unpack indices
40    let packed = data
41        .get(payload_start..flag_offset)
42        .ok_or(IoError::Truncated {
43            needed: flag_offset,
44            got: data.len(),
45        })?;
46    let mut indices = vec![0u8; dim];
47    unpack_indices(packed, dim, bit_width, &mut indices);
48
49    // Residual flag
50    let residual_flag = data.get(flag_offset).copied().ok_or(IoError::Truncated {
51        needed: flag_offset + 1,
52        got: data.len(),
53    })?;
54
55    let residual: Option<Box<[u8]>> = match residual_flag {
56        0x00 => None,
57        0x01 => {
58            let rlen_start = flag_offset + 1;
59            let rlen_end = rlen_start + 4;
60            if data.len() < rlen_end {
61                return Err(IoError::Truncated {
62                    needed: rlen_end,
63                    got: data.len(),
64                });
65            }
66            let rlen_bytes: [u8; 4] = data
67                .get(rlen_start..rlen_end)
68                .ok_or(IoError::Truncated {
69                    needed: rlen_end,
70                    got: data.len(),
71                })?
72                .try_into()
73                .map_err(|_| IoError::Truncated {
74                    needed: rlen_end,
75                    got: data.len(),
76                })?;
77            let rlen = u32::from_le_bytes(rlen_bytes) as usize;
78            let rdata_start = rlen_end;
79            let rdata_end = rdata_start + rlen;
80            if data.len() < rdata_end {
81                return Err(IoError::Truncated {
82                    needed: rdata_end,
83                    got: data.len(),
84                });
85            }
86            let rdata = data.get(rdata_start..rdata_end).ok_or(IoError::Truncated {
87                needed: rdata_end,
88                got: data.len(),
89            })?;
90            Some(rdata.to_vec().into_boxed_slice())
91        }
92        got => {
93            return Err(IoError::Decode(CodecError::InvalidResidualFlag { got }));
94        }
95    };
96
97    let cv = CompressedVector::new(
98        indices.into_boxed_slice(),
99        residual,
100        std::sync::Arc::from(config_hash),
101        dimension,
102        bit_width,
103    )?;
104    Ok(cv)
105}