tinyquant_io/zero_copy/
view.rs1use crate::compressed_vector::header::{decode_header, FORMAT_VERSION, HEADER_SIZE};
4use crate::compressed_vector::unpack::unpack_indices;
5use crate::errors::IoError;
6use std::sync::Arc;
7use tinyquant_core::codec::CompressedVector;
8
9const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
10
11pub struct CompressedVectorView<'a> {
22 pub format_version: u8,
24 pub config_hash: &'a str,
26 pub dimension: u32,
28 pub bit_width: u8,
30 pub packed_indices: &'a [u8],
32 pub residual: Option<&'a [u8]>,
34}
35
36impl<'a> CompressedVectorView<'a> {
37 pub fn parse(data: &'a [u8]) -> Result<(Self, &'a [u8]), IoError> {
46 let (version, config_hash, dimension, bit_width) = decode_header(data)?;
47 if version != FORMAT_VERSION {
48 return Err(IoError::UnknownVersion { got: version });
49 }
50 if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
51 return Err(IoError::InvalidBitWidth { got: bit_width });
52 }
53 let dim = dimension as usize;
54 let packed_len = (dim * bit_width as usize + 7) / 8;
55 let flag_offset = HEADER_SIZE + packed_len;
56 if data.len() < flag_offset + 1 {
57 return Err(IoError::Truncated {
58 needed: flag_offset + 1,
59 got: data.len(),
60 });
61 }
62 let packed_indices = data
63 .get(HEADER_SIZE..flag_offset)
64 .ok_or(IoError::Truncated {
65 needed: flag_offset,
66 got: data.len(),
67 })?;
68 let residual_flag = data.get(flag_offset).copied().ok_or(IoError::Truncated {
69 needed: flag_offset + 1,
70 got: data.len(),
71 })?;
72 let (residual, record_end) = parse_residual(data, flag_offset, residual_flag)?;
73 let tail = data.get(record_end..).ok_or(IoError::Truncated {
74 needed: record_end,
75 got: data.len(),
76 })?;
77 Ok((
78 Self {
79 format_version: version,
80 config_hash,
81 dimension,
82 bit_width,
83 packed_indices,
84 residual,
85 },
86 tail,
87 ))
88 }
89
90 pub fn unpack_into(&self, out: &mut [u8]) -> Result<(), IoError> {
96 if out.len() != self.dimension as usize {
97 return Err(IoError::LengthMismatch);
98 }
99 unpack_indices(
100 self.packed_indices,
101 self.dimension as usize,
102 self.bit_width,
103 out,
104 );
105 Ok(())
106 }
107
108 pub fn to_owned_cv(&self) -> Result<CompressedVector, IoError> {
114 let mut indices = vec![0u8; self.dimension as usize];
115 self.unpack_into(&mut indices)?;
116 let residual = self.residual.map(|r| r.to_vec().into_boxed_slice());
117 let cv = CompressedVector::new(
118 indices.into_boxed_slice(),
119 residual,
120 Arc::from(self.config_hash),
121 self.dimension,
122 self.bit_width,
123 )?;
124 Ok(cv)
125 }
126}
127
128fn parse_residual(
131 data: &[u8],
132 flag_offset: usize,
133 residual_flag: u8,
134) -> Result<(Option<&[u8]>, usize), IoError> {
135 match residual_flag {
136 0x00 => Ok((None, flag_offset + 1)),
137 0x01 => {
138 let rlen_start = flag_offset + 1;
139 let rlen_end = rlen_start + 4;
140 if data.len() < rlen_end {
141 return Err(IoError::Truncated {
142 needed: rlen_end,
143 got: data.len(),
144 });
145 }
146 let rlen_bytes: [u8; 4] = data
147 .get(rlen_start..rlen_end)
148 .ok_or(IoError::Truncated {
149 needed: rlen_end,
150 got: data.len(),
151 })?
152 .try_into()
153 .map_err(|_| IoError::Truncated {
154 needed: rlen_end,
155 got: data.len(),
156 })?;
157 let rlen = u32::from_le_bytes(rlen_bytes) as usize;
158 let rdata_end = rlen_end + rlen;
159 if data.len() < rdata_end {
160 return Err(IoError::Truncated {
161 needed: rdata_end,
162 got: data.len(),
163 });
164 }
165 let rdata = data.get(rlen_end..rdata_end).ok_or(IoError::Truncated {
166 needed: rdata_end,
167 got: data.len(),
168 })?;
169 Ok((Some(rdata), rdata_end))
170 }
171 got => Err(IoError::Decode(
172 tinyquant_core::errors::CodecError::InvalidResidualFlag { got },
173 )),
174 }
175}