1use crate::errors::IoError;
17
18pub const MAGIC_FINAL: &[u8; 4] = b"TQCV";
20pub const MAGIC_TENTATIVE: &[u8; 4] = b"TQCX";
22pub const FORMAT_VERSION: u8 = 0x01;
24pub const FIXED_HEADER_SIZE: usize = 24;
26const MAX_CONFIG_HASH_LEN: usize = 256;
27const MAX_METADATA_LEN: usize = 16 * 1024 * 1024; const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
30
31pub struct CorpusFileHeader {
33 pub vector_count: u64,
35 pub dimension: u32,
37 pub bit_width: u8,
39 pub residual: bool,
41 pub config_hash: String,
43 pub metadata: Vec<u8>,
45 pub body_offset: usize,
47}
48
49pub fn encode_header(
59 config_hash: &str,
60 dimension: u32,
61 bit_width: u8,
62 residual: bool,
63 metadata: &[u8],
64 vector_count: u64,
65) -> Result<Vec<u8>, IoError> {
66 let hash_bytes = config_hash.as_bytes();
67 if hash_bytes.len() > MAX_CONFIG_HASH_LEN {
68 return Err(IoError::InvalidHeader);
69 }
70 if metadata.len() > MAX_METADATA_LEN {
71 return Err(IoError::InvalidHeader);
72 }
73 if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
74 return Err(IoError::InvalidBitWidth { got: bit_width });
75 }
76 if dimension == 0 {
77 return Err(IoError::InvalidHeader);
78 }
79
80 #[allow(clippy::cast_possible_truncation)]
81 let config_hash_len = hash_bytes.len() as u16;
82 #[allow(clippy::cast_possible_truncation)]
83 let metadata_len = metadata.len() as u32;
84
85 let capacity = FIXED_HEADER_SIZE + hash_bytes.len() + 4 + metadata.len() + 7;
86 let mut out = Vec::with_capacity(capacity);
87
88 out.extend_from_slice(MAGIC_TENTATIVE); out.push(FORMAT_VERSION); out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&vector_count.to_le_bytes()); out.extend_from_slice(&dimension.to_le_bytes()); out.push(bit_width); out.push(u8::from(residual)); out.extend_from_slice(&config_hash_len.to_le_bytes()); out.extend_from_slice(hash_bytes);
100 out.extend_from_slice(&metadata_len.to_le_bytes());
101 out.extend_from_slice(metadata);
102
103 let pad = (8 - (out.len() % 8)) % 8;
105 out.resize(out.len() + pad, 0x00);
106 Ok(out)
107}
108
109fn decode_fixed_header(data: &[u8]) -> Result<(u64, u32, u8, bool, usize), IoError> {
112 let magic: [u8; 4] = data
113 .get(0..4)
114 .ok_or(IoError::Truncated {
115 needed: 4,
116 got: data.len(),
117 })?
118 .try_into()
119 .map_err(|_| IoError::Truncated {
120 needed: 4,
121 got: data.len(),
122 })?;
123
124 match &magic {
125 m if m == MAGIC_FINAL => {}
126 m if m == MAGIC_TENTATIVE => {
127 return Err(IoError::Truncated {
132 needed: data.len() + 1,
133 got: data.len(),
134 });
135 }
136 _ => return Err(IoError::BadMagic { got: magic }),
137 }
138
139 let version = data.get(4).copied().ok_or(IoError::Truncated {
140 needed: 5,
141 got: data.len(),
142 })?;
143 if version != FORMAT_VERSION {
144 return Err(IoError::UnknownVersion { got: version });
145 }
146
147 let reserved = data.get(5..8).ok_or(IoError::Truncated {
148 needed: 8,
149 got: data.len(),
150 })?;
151 if reserved != [0u8, 0, 0] {
152 return Err(IoError::InvalidHeader);
153 }
154
155 let vc_bytes: [u8; 8] = data
156 .get(8..16)
157 .ok_or(IoError::Truncated {
158 needed: 16,
159 got: data.len(),
160 })?
161 .try_into()
162 .map_err(|_| IoError::Truncated {
163 needed: 16,
164 got: data.len(),
165 })?;
166 let vector_count = u64::from_le_bytes(vc_bytes);
167
168 let dim_bytes: [u8; 4] = data
169 .get(16..20)
170 .ok_or(IoError::Truncated {
171 needed: 20,
172 got: data.len(),
173 })?
174 .try_into()
175 .map_err(|_| IoError::Truncated {
176 needed: 20,
177 got: data.len(),
178 })?;
179 let dimension = u32::from_le_bytes(dim_bytes);
180 if dimension == 0 {
181 return Err(IoError::InvalidHeader);
182 }
183
184 let bit_width = data.get(20).copied().ok_or(IoError::Truncated {
185 needed: 21,
186 got: data.len(),
187 })?;
188 if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
189 return Err(IoError::InvalidBitWidth { got: bit_width });
190 }
191
192 let residual_flag = data.get(21).copied().ok_or(IoError::Truncated {
193 needed: 22,
194 got: data.len(),
195 })?;
196 let residual = match residual_flag {
197 0x00 => false,
198 0x01 => true,
199 _ => return Err(IoError::InvalidHeader),
200 };
201
202 let chl_bytes: [u8; 2] = data
203 .get(22..24)
204 .ok_or(IoError::Truncated {
205 needed: 24,
206 got: data.len(),
207 })?
208 .try_into()
209 .map_err(|_| IoError::Truncated {
210 needed: 24,
211 got: data.len(),
212 })?;
213 let config_hash_len = u16::from_le_bytes(chl_bytes) as usize;
214 if config_hash_len > MAX_CONFIG_HASH_LEN {
215 return Err(IoError::InvalidHeader);
216 }
217
218 Ok((
219 vector_count,
220 dimension,
221 bit_width,
222 residual,
223 config_hash_len,
224 ))
225}
226
227pub fn decode_header(data: &[u8]) -> Result<CorpusFileHeader, IoError> {
234 if data.len() < FIXED_HEADER_SIZE {
235 return Err(IoError::Truncated {
236 needed: FIXED_HEADER_SIZE,
237 got: data.len(),
238 });
239 }
240
241 let (vector_count, dimension, bit_width, residual, config_hash_len) =
242 decode_fixed_header(data)?;
243
244 let (config_hash, metadata, body_offset) =
245 decode_variable_prefix(data, FIXED_HEADER_SIZE, config_hash_len)?;
246
247 Ok(CorpusFileHeader {
248 vector_count,
249 dimension,
250 bit_width,
251 residual,
252 config_hash,
253 metadata,
254 body_offset,
255 })
256}
257
258fn decode_variable_prefix(
260 data: &[u8],
261 start: usize,
262 config_hash_len: usize,
263) -> Result<(String, Vec<u8>, usize), IoError> {
264 let mut pos = start;
265
266 let hash_end = pos + config_hash_len;
267 if data.len() < hash_end {
268 return Err(IoError::Truncated {
269 needed: hash_end,
270 got: data.len(),
271 });
272 }
273 let hash_bytes = data.get(pos..hash_end).ok_or(IoError::Truncated {
274 needed: hash_end,
275 got: data.len(),
276 })?;
277 let config_hash = std::str::from_utf8(hash_bytes)
278 .map_err(|_| IoError::InvalidUtf8)?
279 .to_owned();
280 pos = hash_end;
281
282 let ml_end = pos + 4;
283 if data.len() < ml_end {
284 return Err(IoError::Truncated {
285 needed: ml_end,
286 got: data.len(),
287 });
288 }
289 let ml_bytes: [u8; 4] = data
290 .get(pos..ml_end)
291 .ok_or(IoError::Truncated {
292 needed: ml_end,
293 got: data.len(),
294 })?
295 .try_into()
296 .map_err(|_| IoError::Truncated {
297 needed: ml_end,
298 got: data.len(),
299 })?;
300 let metadata_len = u32::from_le_bytes(ml_bytes) as usize;
301 pos = ml_end;
302
303 let meta_end = pos + metadata_len;
304 if data.len() < meta_end {
305 return Err(IoError::Truncated {
306 needed: meta_end,
307 got: data.len(),
308 });
309 }
310 let metadata = data
311 .get(pos..meta_end)
312 .ok_or(IoError::Truncated {
313 needed: meta_end,
314 got: data.len(),
315 })?
316 .to_vec();
317 pos = meta_end;
318
319 let body_offset = ((pos + 7) / 8) * 8;
320
321 Ok((config_hash, metadata, body_offset))
322}