Skip to main content

tinyquant_io/compressed_vector/
from_bytes.rs

1//! Decode a `CompressedVector` from its binary wire format.
2
3use crate::compressed_vector::header::{decode_header, FORMAT_VERSION, HEADER_SIZE};
4use crate::compressed_vector::unpack::unpack_indices;
5use crate::errors::IoError;
6use tinyquant_core::codec::CompressedVector;
7use tinyquant_core::errors::CodecError;
8
9const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
10
11/// Deserialize a `CompressedVector` from `data`.
12///
13/// # Errors
14///
15/// Returns [`IoError`] if the data is malformed, truncated, has an unknown
16/// version, or contains an invalid bit-width or residual flag.
17pub fn from_bytes(data: &[u8]) -> Result<CompressedVector, IoError> {
18    let (version, config_hash, dimension, bit_width) = decode_header(data)?;
19
20    if version != FORMAT_VERSION {
21        return Err(IoError::UnknownVersion { got: version });
22    }
23    if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
24        return Err(IoError::InvalidBitWidth { got: bit_width });
25    }
26
27    let dim = dimension as usize;
28    let bw = bit_width as usize;
29    let packed_len = dim
30        .checked_mul(bw)
31        .and_then(|n| n.checked_add(7))
32        .map(|n| n / 8)
33        .ok_or(IoError::InvalidHeader)?;
34    let payload_start = HEADER_SIZE;
35    let flag_offset = payload_start + packed_len;
36
37    if data.len() < flag_offset + 1 {
38        return Err(IoError::Truncated {
39            needed: flag_offset + 1,
40            got: data.len(),
41        });
42    }
43
44    // Unpack indices
45    let packed = data
46        .get(payload_start..flag_offset)
47        .ok_or(IoError::Truncated {
48            needed: flag_offset,
49            got: data.len(),
50        })?;
51    let mut indices = vec![0u8; dim];
52    unpack_indices(packed, dim, bit_width, &mut indices);
53
54    // Residual flag
55    let residual_flag = data.get(flag_offset).copied().ok_or(IoError::Truncated {
56        needed: flag_offset + 1,
57        got: data.len(),
58    })?;
59
60    let residual: Option<Box<[u8]>> = match residual_flag {
61        0x00 => None,
62        0x01 => {
63            let rlen_start = flag_offset + 1;
64            let rlen_end = rlen_start + 4;
65            if data.len() < rlen_end {
66                return Err(IoError::Truncated {
67                    needed: rlen_end,
68                    got: data.len(),
69                });
70            }
71            let rlen_bytes: [u8; 4] = data
72                .get(rlen_start..rlen_end)
73                .ok_or(IoError::Truncated {
74                    needed: rlen_end,
75                    got: data.len(),
76                })?
77                .try_into()
78                .map_err(|_| IoError::Truncated {
79                    needed: rlen_end,
80                    got: data.len(),
81                })?;
82            let rlen = u32::from_le_bytes(rlen_bytes) as usize;
83            let rdata_start = rlen_end;
84            let rdata_end = rdata_start + rlen;
85            if data.len() < rdata_end {
86                return Err(IoError::Truncated {
87                    needed: rdata_end,
88                    got: data.len(),
89                });
90            }
91            let rdata = data.get(rdata_start..rdata_end).ok_or(IoError::Truncated {
92                needed: rdata_end,
93                got: data.len(),
94            })?;
95            Some(rdata.to_vec().into_boxed_slice())
96        }
97        got => {
98            return Err(IoError::Decode(CodecError::InvalidResidualFlag { got }));
99        }
100    };
101
102    let cv = CompressedVector::new(
103        indices.into_boxed_slice(),
104        residual,
105        std::sync::Arc::from(config_hash),
106        dimension,
107        bit_width,
108    )?;
109    Ok(cv)
110}