tinyquant_io/codec_file/
reader.rs1use crate::codec_file::header::{
4 decode_header, CorpusFileHeader, MAX_CONFIG_HASH_LEN, MAX_METADATA_LEN,
5};
6use crate::compressed_vector::from_bytes;
7use crate::compressed_vector::header::HEADER_SIZE;
8use crate::errors::IoError;
9use std::io::{Read, Seek, SeekFrom};
10use tinyquant_core::codec::CompressedVector;
11
12const MAX_RECORD_LEN: usize = 4 * 1024 * 1024; pub struct CodecFileReader<R: Read + Seek> {
19 inner: R,
20 header: CorpusFileHeader,
21 records_read: u64,
22}
23
24impl<R: Read + Seek> CodecFileReader<R> {
25 pub fn new(mut inner: R) -> Result<Self, IoError> {
35 let header = read_and_decode_header(&mut inner)?;
36 let body_offset = header.body_offset;
37 inner.seek(SeekFrom::Start(
38 u64::try_from(body_offset).map_err(|_| IoError::InvalidHeader)?,
39 ))?;
40 Ok(Self {
41 inner,
42 header,
43 records_read: 0,
44 })
45 }
46
47 pub const fn header(&self) -> &CorpusFileHeader {
49 &self.header
50 }
51
52 pub fn next_vector(&mut self) -> Result<Option<CompressedVector>, IoError> {
60 if self.records_read >= self.header.vector_count {
61 return Ok(None);
62 }
63 let max_record_len = max_record_len_for_header(&self.header);
64 let cv = read_record(&mut self.inner, max_record_len)?;
65 self.records_read += 1;
66 Ok(Some(cv))
67 }
68
69 pub const fn records_read(&self) -> u64 {
71 self.records_read
72 }
73}
74
75fn read_and_decode_header<R: Read + Seek>(r: &mut R) -> Result<CorpusFileHeader, IoError> {
77 let mut fixed = [0u8; 24];
79 r.read_exact(&mut fixed)?;
80
81 let chl_bytes: [u8; 2] = fixed
83 .get(22..24)
84 .ok_or(IoError::Truncated {
85 needed: 24,
86 got: fixed.len(),
87 })?
88 .try_into()
89 .map_err(|_| IoError::InvalidHeader)?;
90 let config_hash_len = u16::from_le_bytes(chl_bytes) as usize;
91 if config_hash_len > MAX_CONFIG_HASH_LEN {
92 return Err(IoError::InvalidHeader);
93 }
94
95 let mut var_prefix = vec![0u8; config_hash_len + 4];
97 r.read_exact(&mut var_prefix)?;
98
99 let ml_bytes: [u8; 4] = var_prefix
101 .get(config_hash_len..config_hash_len + 4)
102 .ok_or(IoError::Truncated {
103 needed: config_hash_len + 4,
104 got: var_prefix.len(),
105 })?
106 .try_into()
107 .map_err(|_| IoError::InvalidHeader)?;
108 let metadata_len = u32::from_le_bytes(ml_bytes) as usize;
109 if metadata_len > MAX_METADATA_LEN {
110 return Err(IoError::InvalidHeader);
111 }
112
113 let header_end = 24_usize
115 .checked_add(config_hash_len)
116 .and_then(|n| n.checked_add(4))
117 .and_then(|n| n.checked_add(metadata_len))
118 .ok_or(IoError::InvalidHeader)?;
119 let body_offset = header_end.next_multiple_of(8);
120 let remaining_header = body_offset
121 .checked_sub(24 + config_hash_len + 4)
122 .ok_or(IoError::InvalidHeader)?;
123 let mut rest = vec![0u8; remaining_header];
124 r.read_exact(&mut rest)?;
125
126 let mut full = Vec::with_capacity(body_offset);
128 full.extend_from_slice(&fixed);
129 full.extend_from_slice(&var_prefix);
130 full.extend_from_slice(&rest);
131
132 decode_header(&full)
133}
134
135fn max_record_len_for_header(header: &CorpusFileHeader) -> usize {
141 let dim = header.dimension as usize;
142 let bw = header.bit_width as usize;
143 let packed = dim.saturating_mul(bw).saturating_add(7) / 8;
145 let residual = if header.residual {
147 4_usize.saturating_add(dim.saturating_mul(2))
148 } else {
149 0
150 };
151 HEADER_SIZE
152 .saturating_add(packed)
153 .saturating_add(1) .saturating_add(residual)
155 .min(MAX_RECORD_LEN)
156}
157
158fn read_record<R: Read>(r: &mut R, max_record_len: usize) -> Result<CompressedVector, IoError> {
160 let mut len_buf = [0u8; 4];
161 r.read_exact(&mut len_buf)?;
162 let record_len = u32::from_le_bytes(len_buf) as usize;
163 if record_len > max_record_len {
164 return Err(IoError::InvalidHeader);
165 }
166
167 let mut payload = vec![0u8; record_len];
168 r.read_exact(&mut payload)?;
169
170 from_bytes(&payload)
171}