Skip to main content

tinyquant_io/codec_file/
reader.rs

1//! Level-2 TQCV corpus file streaming reader (non-mmap).
2
3use crate::codec_file::header::{
4    decode_header, CorpusFileHeader, MAX_CONFIG_HASH_LEN, MAX_METADATA_LEN,
5};
6use crate::compressed_vector::from_bytes;
7use crate::compressed_vector::header::HEADER_SIZE;
8use crate::errors::IoError;
9use std::io::{Read, Seek, SeekFrom};
10use tinyquant_core::codec::CompressedVector;
11
12/// Hard cap on a single Level-1 record size (defense in depth).
13const MAX_RECORD_LEN: usize = 4 * 1024 * 1024; // 4 MiB
14
15/// Streaming reader for a Level-2 TQCV corpus file.
16///
17/// Reads records sequentially without memory-mapping the file.
18pub struct CodecFileReader<R: Read + Seek> {
19    inner: R,
20    header: CorpusFileHeader,
21    records_read: u64,
22}
23
24impl<R: Read + Seek> CodecFileReader<R> {
25    /// Open a TQCV corpus file from a seekable reader.
26    ///
27    /// Reads and validates the Level-2 header, then positions the reader
28    /// at the start of the body.
29    ///
30    /// # Errors
31    ///
32    /// Returns [`IoError`] if the header is malformed, the magic is wrong,
33    /// or the underlying I/O fails.
34    pub fn new(mut inner: R) -> Result<Self, IoError> {
35        let header = read_and_decode_header(&mut inner)?;
36        let body_offset = header.body_offset;
37        inner.seek(SeekFrom::Start(
38            u64::try_from(body_offset).map_err(|_| IoError::InvalidHeader)?,
39        ))?;
40        Ok(Self {
41            inner,
42            header,
43            records_read: 0,
44        })
45    }
46
47    /// Return a reference to the decoded Level-2 header.
48    pub const fn header(&self) -> &CorpusFileHeader {
49        &self.header
50    }
51
52    /// Read the next [`CompressedVector`] from the file.
53    ///
54    /// Returns `Ok(None)` when all `vector_count` records have been read.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`IoError`] if the record is malformed or I/O fails.
59    pub fn next_vector(&mut self) -> Result<Option<CompressedVector>, IoError> {
60        if self.records_read >= self.header.vector_count {
61            return Ok(None);
62        }
63        let max_record_len = max_record_len_for_header(&self.header);
64        let cv = read_record(&mut self.inner, max_record_len)?;
65        self.records_read += 1;
66        Ok(Some(cv))
67    }
68
69    /// Number of records read so far.
70    pub const fn records_read(&self) -> u64 {
71        self.records_read
72    }
73}
74
75/// Read and validate the Level-2 header from `r`.
76fn read_and_decode_header<R: Read + Seek>(r: &mut R) -> Result<CorpusFileHeader, IoError> {
77    // Read the fixed 24-byte header first to find config_hash_len.
78    let mut fixed = [0u8; 24];
79    r.read_exact(&mut fixed)?;
80
81    // Peek at config_hash_len (bytes 22..24)
82    let chl_bytes: [u8; 2] = fixed
83        .get(22..24)
84        .ok_or(IoError::Truncated {
85            needed: 24,
86            got: fixed.len(),
87        })?
88        .try_into()
89        .map_err(|_| IoError::InvalidHeader)?;
90    let config_hash_len = u16::from_le_bytes(chl_bytes) as usize;
91    if config_hash_len > MAX_CONFIG_HASH_LEN {
92        return Err(IoError::InvalidHeader);
93    }
94
95    // Read config_hash + 4-byte metadata_len
96    let mut var_prefix = vec![0u8; config_hash_len + 4];
97    r.read_exact(&mut var_prefix)?;
98
99    // Peek at metadata_len
100    let ml_bytes: [u8; 4] = var_prefix
101        .get(config_hash_len..config_hash_len + 4)
102        .ok_or(IoError::Truncated {
103            needed: config_hash_len + 4,
104            got: var_prefix.len(),
105        })?
106        .try_into()
107        .map_err(|_| IoError::InvalidHeader)?;
108    let metadata_len = u32::from_le_bytes(ml_bytes) as usize;
109    if metadata_len > MAX_METADATA_LEN {
110        return Err(IoError::InvalidHeader);
111    }
112
113    // Read metadata + alignment padding
114    let header_end = 24_usize
115        .checked_add(config_hash_len)
116        .and_then(|n| n.checked_add(4))
117        .and_then(|n| n.checked_add(metadata_len))
118        .ok_or(IoError::InvalidHeader)?;
119    let body_offset = header_end.next_multiple_of(8);
120    let remaining_header = body_offset
121        .checked_sub(24 + config_hash_len + 4)
122        .ok_or(IoError::InvalidHeader)?;
123    let mut rest = vec![0u8; remaining_header];
124    r.read_exact(&mut rest)?;
125
126    // Reassemble for decode_header
127    let mut full = Vec::with_capacity(body_offset);
128    full.extend_from_slice(&fixed);
129    full.extend_from_slice(&var_prefix);
130    full.extend_from_slice(&rest);
131
132    decode_header(&full)
133}
134
135/// Compute the maximum legal Level-1 record size from validated header fields.
136///
137/// Derived from dimension, `bit_width`, and residual flag so a crafted corpus
138/// cannot trigger a record-sized allocation before `from_bytes` validates the
139/// payload.  `MAX_RECORD_LEN` is applied as a hard cap regardless.
140fn max_record_len_for_header(header: &CorpusFileHeader) -> usize {
141    let dim = header.dimension as usize;
142    let bw = header.bit_width as usize;
143    // packed indices: ceil(dim * bw / 8)
144    let packed = dim.saturating_mul(bw).saturating_add(7) / 8;
145    // optional residual: 4-byte length prefix + 2 bytes/element (fp16)
146    let residual = if header.residual {
147        4_usize.saturating_add(dim.saturating_mul(2))
148    } else {
149        0
150    };
151    HEADER_SIZE
152        .saturating_add(packed)
153        .saturating_add(1) // residual flag byte
154        .saturating_add(residual)
155        .min(MAX_RECORD_LEN)
156}
157
158/// Read one length-prefixed Level-1 record from `r`.
159fn read_record<R: Read>(r: &mut R, max_record_len: usize) -> Result<CompressedVector, IoError> {
160    let mut len_buf = [0u8; 4];
161    r.read_exact(&mut len_buf)?;
162    let record_len = u32::from_le_bytes(len_buf) as usize;
163    if record_len > max_record_len {
164        return Err(IoError::InvalidHeader);
165    }
166
167    let mut payload = vec![0u8; record_len];
168    r.read_exact(&mut payload)?;
169
170    from_bytes(&payload)
171}