Skip to main content

tinyquant_io/zero_copy/
view.rs

1//! Zero-copy view of a Level-1 serialized `CompressedVector`.
2
3use crate::compressed_vector::header::{decode_header, FORMAT_VERSION, HEADER_SIZE};
4use crate::compressed_vector::unpack::unpack_indices;
5use crate::errors::IoError;
6use std::sync::Arc;
7use tinyquant_core::codec::CompressedVector;
8
9const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
10
11/// A zero-copy view of a single Level-1 serialized [`CompressedVector`].
12///
13/// All fields borrow from the input byte slice. No heap allocation occurs
14/// during [`parse`](CompressedVectorView::parse) or
15/// [`unpack_into`](CompressedVectorView::unpack_into).
16///
17/// ```compile_fail
18/// fn assert_not_serialize<T: serde::Serialize>() {}
19/// assert_not_serialize::<tinyquant_io::CompressedVectorView<'_>>();
20/// ```
21pub struct CompressedVectorView<'a> {
22    /// Format version byte from the Level-1 header.
23    pub format_version: u8,
24    /// Config hash string borrowed from the Level-1 header.
25    pub config_hash: &'a str,
26    /// Number of dimensions.
27    pub dimension: u32,
28    /// Bit width (2, 4, or 8).
29    pub bit_width: u8,
30    /// Packed index bytes (LSB-first).
31    pub packed_indices: &'a [u8],
32    /// Residual bytes (2 bytes per dimension), or `None`.
33    pub residual: Option<&'a [u8]>,
34}
35
36impl<'a> CompressedVectorView<'a> {
37    /// Parse a single Level-1 record from the start of `data`.
38    ///
39    /// Returns `(view, unconsumed_tail)`.
40    ///
41    /// # Errors
42    ///
43    /// Returns [`IoError`] if the data is truncated, has an unknown version,
44    /// an invalid bit-width, or a bad residual flag.
45    pub fn parse(data: &'a [u8]) -> Result<(Self, &'a [u8]), IoError> {
46        let (version, config_hash, dimension, bit_width) = decode_header(data)?;
47        if version != FORMAT_VERSION {
48            return Err(IoError::UnknownVersion { got: version });
49        }
50        if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
51            return Err(IoError::InvalidBitWidth { got: bit_width });
52        }
53        let dim = dimension as usize;
54        let packed_len = (dim * bit_width as usize + 7) / 8;
55        let flag_offset = HEADER_SIZE + packed_len;
56        if data.len() < flag_offset + 1 {
57            return Err(IoError::Truncated {
58                needed: flag_offset + 1,
59                got: data.len(),
60            });
61        }
62        let packed_indices = data
63            .get(HEADER_SIZE..flag_offset)
64            .ok_or(IoError::Truncated {
65                needed: flag_offset,
66                got: data.len(),
67            })?;
68        let residual_flag = data.get(flag_offset).copied().ok_or(IoError::Truncated {
69            needed: flag_offset + 1,
70            got: data.len(),
71        })?;
72        let (residual, record_end) = parse_residual(data, flag_offset, residual_flag)?;
73        let tail = data.get(record_end..).ok_or(IoError::Truncated {
74            needed: record_end,
75            got: data.len(),
76        })?;
77        Ok((
78            Self {
79                format_version: version,
80                config_hash,
81                dimension,
82                bit_width,
83                packed_indices,
84                residual,
85            },
86            tail,
87        ))
88    }
89
90    /// Unpack indices into caller-provided buffer (no allocation).
91    ///
92    /// # Errors
93    ///
94    /// Returns [`IoError::LengthMismatch`] if `out.len() != self.dimension as usize`.
95    pub fn unpack_into(&self, out: &mut [u8]) -> Result<(), IoError> {
96        if out.len() != self.dimension as usize {
97            return Err(IoError::LengthMismatch);
98        }
99        unpack_indices(
100            self.packed_indices,
101            self.dimension as usize,
102            self.bit_width,
103            out,
104        );
105        Ok(())
106    }
107
108    /// Allocating escape hatch — copies borrowed bytes into an owned [`CompressedVector`].
109    ///
110    /// # Errors
111    ///
112    /// Returns [`IoError`] if unpacking or construction fails.
113    pub fn to_owned_cv(&self) -> Result<CompressedVector, IoError> {
114        let mut indices = vec![0u8; self.dimension as usize];
115        self.unpack_into(&mut indices)?;
116        let residual = self.residual.map(|r| r.to_vec().into_boxed_slice());
117        let cv = CompressedVector::new(
118            indices.into_boxed_slice(),
119            residual,
120            Arc::from(self.config_hash),
121            self.dimension,
122            self.bit_width,
123        )?;
124        Ok(cv)
125    }
126}
127
128/// Parse the residual section starting at `flag_offset` in `data`.
129/// Returns `(residual_slice_or_none, record_end)`.
130fn parse_residual(
131    data: &[u8],
132    flag_offset: usize,
133    residual_flag: u8,
134) -> Result<(Option<&[u8]>, usize), IoError> {
135    match residual_flag {
136        0x00 => Ok((None, flag_offset + 1)),
137        0x01 => {
138            let rlen_start = flag_offset + 1;
139            let rlen_end = rlen_start + 4;
140            if data.len() < rlen_end {
141                return Err(IoError::Truncated {
142                    needed: rlen_end,
143                    got: data.len(),
144                });
145            }
146            let rlen_bytes: [u8; 4] = data
147                .get(rlen_start..rlen_end)
148                .ok_or(IoError::Truncated {
149                    needed: rlen_end,
150                    got: data.len(),
151                })?
152                .try_into()
153                .map_err(|_| IoError::Truncated {
154                    needed: rlen_end,
155                    got: data.len(),
156                })?;
157            let rlen = u32::from_le_bytes(rlen_bytes) as usize;
158            let rdata_end = rlen_end + rlen;
159            if data.len() < rdata_end {
160                return Err(IoError::Truncated {
161                    needed: rdata_end,
162                    got: data.len(),
163                });
164            }
165            let rdata = data.get(rlen_end..rdata_end).ok_or(IoError::Truncated {
166                needed: rdata_end,
167                got: data.len(),
168            })?;
169            Ok((Some(rdata), rdata_end))
170        }
171        got => Err(IoError::Decode(
172            tinyquant_core::errors::CodecError::InvalidResidualFlag { got },
173        )),
174    }
175}