Skip to main content

sochdb_vector/segment/
reader.rs

1//! Segment reader with mmap support.
2
3use memmap2::Mmap;
4use std::fs::File;
5use std::path::Path;
6use std::sync::Arc;
7
8use super::format::*;
9use crate::error::{Error, Result};
10use crate::types::*;
11
12/// Bounds-check every segment section's `[offset, offset+size)` against the real
13/// mmap length, with overflow-checked size arithmetic. Rejects crafted segments
14/// before any `from_raw_parts` accessor can read out of bounds.
15fn validate_segment_layout(h: &SegmentHeader, mmap_len: usize) -> Result<()> {
16    // A section of `bytes` at `off` must lie fully within the mapping.
17    let fits = |name: &str, off: u64, bytes: usize| -> Result<()> {
18        let off = off as usize;
19        let end = off
20            .checked_add(bytes)
21            .ok_or_else(|| Error::Segment(format!("segment section '{name}' size overflow")))?;
22        if end > mmap_len {
23            return Err(Error::Segment(format!(
24                "segment section '{name}' [{off}..{end}) exceeds file length {mmap_len}"
25            )));
26        }
27        Ok(())
28    };
29    // Overflow-checked element-count -> byte-size.
30    let bytes = |count: usize, elem: usize, name: &str| -> Result<usize> {
31        count
32            .checked_mul(elem)
33            .ok_or_else(|| Error::Segment(format!("segment section '{name}' byte-size overflow")))
34    };
35
36    let n_vec = h.n_vec as usize;
37    let dim = h.dim as usize;
38    let blocks = h.num_bps_blocks() as usize;
39
40    // Always-read sections (the accessors read these unconditionally).
41    fits("bps", h.off_bps, h.bps_size())?;
42    fits("i8", h.off_i8, bytes(n_vec, dim, "i8")?)?; // i8 = 1 byte
43    fits(
44        "scales",
45        h.off_scales,
46        bytes(bytes(blocks, n_vec, "scales")?, 4, "scales")?, // f32
47    )?;
48    fits(
49        "tombstone",
50        h.off_tombstone,
51        bytes(n_vec.div_ceil(64), 8, "tombstone")?, // u64 words
52    )?;
53
54    // Flag-gated optional sections.
55    if h.flags.has(SegmentFlags::HAS_OUTLIERS) {
56        let cnt = bytes(n_vec, h.num_outliers as usize, "outliers")?;
57        fits(
58            "outliers",
59            h.off_outliers,
60            bytes(cnt, std::mem::size_of::<OutlierEntry>(), "outliers")?,
61        )?;
62    }
63    if h.flags.has(SegmentFlags::HAS_RDF) {
64        fits(
65            "rdf_dir",
66            h.off_rdf_dir,
67            bytes(dim, std::mem::size_of::<PostingListEntry>(), "rdf_dir")?,
68        )?;
69        fits(
70            "dim_weights",
71            h.off_dim_weights,
72            bytes(dim, 4, "dim_weights")?,
73        )?;
74        // rdf_data is variable-length (posting lists indexed via the directory);
75        // bound the base offset here. Per-posting offsets are still consumed via
76        // the directory and should be validated at access time.
77        if h.off_rdf_data as usize > mmap_len {
78            return Err(Error::Segment(
79                "segment section 'rdf_data' offset exceeds file".into(),
80            ));
81        }
82    }
83    if h.flags.has(SegmentFlags::HAS_FP32) {
84        fits(
85            "fp32",
86            h.off_fp32,
87            bytes(bytes(n_vec, dim, "fp32")?, 4, "fp32")?,
88        )?;
89    }
90    if h.off_bps_qparams != 0 {
91        let cnt = bytes(blocks, h.bps_proj as usize, "bps_qparams")?;
92        fits(
93            "bps_qparams",
94            h.off_bps_qparams,
95            bytes(
96                cnt,
97                std::mem::size_of::<super::bps::BpsQParam>(),
98                "bps_qparams",
99            )?,
100        )?;
101    }
102    Ok(())
103}
104
105/// An immutable segment backed by mmap
106pub struct Segment {
107    /// Memory-mapped file
108    mmap: Arc<Mmap>,
109    /// Parsed header
110    header: SegmentHeader,
111    /// File path
112    path: String,
113}
114
115impl Segment {
116    /// Open a segment file
117    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
118        let path_str = path.as_ref().to_string_lossy().to_string();
119        let file = File::open(&path)?;
120        let mmap = unsafe { Mmap::map(&file)? };
121
122        if mmap.len() < SegmentHeader::SIZE {
123            return Err(Error::Segment("File too small for header".into()));
124        }
125
126        // Parse header
127        let header: SegmentHeader =
128            unsafe { std::ptr::read_unaligned(mmap.as_ptr() as *const SegmentHeader) };
129        header.validate()?;
130
131        // Validate file length
132        if mmap.len() < header.file_len as usize {
133            return Err(Error::Segment(format!(
134                "File size {} < declared length {}",
135                mmap.len(),
136                header.file_len
137            )));
138        }
139
140        // SECURITY: header.validate() only checks magic+version. The offset table
141        // and element counts (n_vec/dim) are attacker-controlled for any on-disk
142        // segment, and the accessors below build slices via from_raw_parts from
143        // them — so a crafted segment with an out-of-range offset or huge count
144        // yields an out-of-bounds read (crash / adjacent-memory disclosure into
145        // query results). Bounds-check every section against the real mmap length
146        // ONCE here, so the unsafe accessors are sound thereafter.
147        validate_segment_layout(&header, mmap.len())?;
148
149        Ok(Self {
150            mmap: Arc::new(mmap),
151            header,
152            path: path_str,
153        })
154    }
155
156    /// Get segment header
157    #[inline]
158    pub fn header(&self) -> &SegmentHeader {
159        &self.header
160    }
161
162    /// Number of vectors
163    #[inline]
164    pub fn num_vectors(&self) -> u32 {
165        self.header.n_vec
166    }
167
168    /// Vector dimension
169    #[inline]
170    pub fn dim(&self) -> u32 {
171        self.header.dim
172    }
173
174    /// Get raw pointer to BPS data
175    #[inline]
176    pub fn bps_ptr(&self) -> *const u8 {
177        unsafe { self.mmap.as_ptr().add(self.header.off_bps as usize) }
178    }
179
180    /// Get BPS data slice
181    pub fn bps_data(&self) -> &[u8] {
182        let size = self.header.bps_size();
183        unsafe { std::slice::from_raw_parts(self.bps_ptr(), size) }
184    }
185
186    /// Get raw pointer to int8 embedding data
187    #[inline]
188    pub fn i8_ptr(&self) -> *const i8 {
189        unsafe { self.mmap.as_ptr().add(self.header.off_i8 as usize) as *const i8 }
190    }
191
192    /// Get int8 embedding data slice
193    pub fn i8_data(&self) -> &[i8] {
194        let size = self.header.i8_size();
195        unsafe { std::slice::from_raw_parts(self.i8_ptr(), size) }
196    }
197
198    /// Get int8 vector for a specific ID
199    pub fn get_i8_vector(&self, vid: VectorId) -> Option<&[i8]> {
200        if vid >= self.header.n_vec {
201            return None;
202        }
203        let dim = self.header.dim as usize;
204        let offset = vid as usize * dim;
205        Some(&self.i8_data()[offset..offset + dim])
206    }
207
208    /// Get raw pointer to quantization scales
209    #[inline]
210    pub fn scales_ptr(&self) -> *const f32 {
211        unsafe { self.mmap.as_ptr().add(self.header.off_scales as usize) as *const f32 }
212    }
213
214    /// Get quantization scales
215    pub fn scales_data(&self) -> &[f32] {
216        let num_blocks = self.header.num_bps_blocks() as usize;
217        // One scale per block per vector
218        let size = num_blocks * self.header.n_vec as usize;
219        unsafe { std::slice::from_raw_parts(self.scales_ptr(), size) }
220    }
221
222    /// Get raw pointer to outlier data
223    #[inline]
224    pub fn outliers_ptr(&self) -> *const OutlierEntry {
225        unsafe { self.mmap.as_ptr().add(self.header.off_outliers as usize) as *const OutlierEntry }
226    }
227
228    /// Get outliers for a specific vector
229    pub fn get_outliers(&self, vid: VectorId) -> Option<&[OutlierEntry]> {
230        if vid >= self.header.n_vec || !self.header.flags.has(SegmentFlags::HAS_OUTLIERS) {
231            return None;
232        }
233        let num_outliers = self.header.num_outliers as usize;
234        let offset = vid as usize * num_outliers;
235        unsafe {
236            Some(std::slice::from_raw_parts(
237                self.outliers_ptr().add(offset),
238                num_outliers,
239            ))
240        }
241    }
242
243    /// Get raw pointer to tombstone bitset
244    #[inline]
245    pub fn tombstone_ptr(&self) -> *const u64 {
246        unsafe { self.mmap.as_ptr().add(self.header.off_tombstone as usize) as *const u64 }
247    }
248
249    /// Get tombstone bitset
250    pub fn tombstone_data(&self) -> &[u64] {
251        let num_words = (self.header.n_vec as usize + 63) / 64;
252        unsafe { std::slice::from_raw_parts(self.tombstone_ptr(), num_words) }
253    }
254
255    /// Check if a vector is tombstoned
256    pub fn is_tombstoned(&self, vid: VectorId) -> bool {
257        if vid >= self.header.n_vec {
258            return true;
259        }
260        let word_idx = vid as usize / 64;
261        let bit_idx = vid as usize % 64;
262        let tombstones = self.tombstone_data();
263        if word_idx >= tombstones.len() {
264            return false;
265        }
266        (tombstones[word_idx] & (1u64 << bit_idx)) != 0
267    }
268
269    /// Get RDF posting list directory
270    pub fn rdf_directory(&self) -> &[PostingListEntry] {
271        if !self.header.flags.has(SegmentFlags::HAS_RDF) {
272            return &[];
273        }
274        let dim = self.header.dim as usize;
275        unsafe {
276            std::slice::from_raw_parts(
277                self.mmap.as_ptr().add(self.header.off_rdf_dir as usize) as *const PostingListEntry,
278                dim,
279            )
280        }
281    }
282
283    /// Get raw pointer to RDF posting list data
284    #[inline]
285    pub fn rdf_data_ptr(&self) -> *const u8 {
286        unsafe { self.mmap.as_ptr().add(self.header.off_rdf_data as usize) }
287    }
288
289    /// Get dimension weights for RDF
290    pub fn dim_weights(&self) -> &[f32] {
291        if !self.header.flags.has(SegmentFlags::HAS_RDF) {
292            return &[];
293        }
294        let dim = self.header.dim as usize;
295        unsafe {
296            std::slice::from_raw_parts(
297                self.mmap.as_ptr().add(self.header.off_dim_weights as usize) as *const f32,
298                dim,
299            )
300        }
301    }
302
303    /// Get optional fp32 vectors for verification
304    pub fn fp32_data(&self) -> Option<&[f32]> {
305        if !self.header.flags.has(SegmentFlags::HAS_FP32) {
306            return None;
307        }
308        let size = self.header.n_vec as usize * self.header.dim as usize;
309        unsafe {
310            Some(std::slice::from_raw_parts(
311                self.mmap.as_ptr().add(self.header.off_fp32 as usize) as *const f32,
312                size,
313            ))
314        }
315    }
316
317    /// Get BPS quantization parameters (min, inv_range per slot).
318    ///
319    /// Returns `None` if qparams were not stored (legacy segments).
320    /// The number of slots = num_bps_blocks × bps_proj.
321    pub fn bps_qparams(&self) -> Option<&[super::bps::BpsQParam]> {
322        if self.header.off_bps_qparams == 0 {
323            return None;
324        }
325        let num_slots = self.header.num_bps_blocks() as usize * self.header.bps_proj as usize;
326        if num_slots == 0 {
327            return None;
328        }
329        unsafe {
330            Some(std::slice::from_raw_parts(
331                self.mmap.as_ptr().add(self.header.off_bps_qparams as usize)
332                    as *const super::bps::BpsQParam,
333                num_slots,
334            ))
335        }
336    }
337
338    /// Get fp32 vector for a specific ID
339    pub fn get_fp32_vector(&self, vid: VectorId) -> Option<&[f32]> {
340        let fp32 = self.fp32_data()?;
341        let dim = self.header.dim as usize;
342        let offset = vid as usize * dim;
343        Some(&fp32[offset..offset + dim])
344    }
345
346    /// Get file path
347    pub fn path(&self) -> &str {
348        &self.path
349    }
350
351    /// Clone the mmap handle (cheap, Arc-backed)
352    pub fn clone_mmap(&self) -> Arc<Mmap> {
353        Arc::clone(&self.mmap)
354    }
355}
356
357impl std::fmt::Debug for Segment {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        f.debug_struct("Segment")
360            .field("path", &self.path)
361            .field("n_vec", &self.header.n_vec)
362            .field("dim", &self.header.dim)
363            .field("flags", &self.header.flags)
364            .finish()
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use std::io::Write;
372    use tempfile::NamedTempFile;
373
374    fn create_test_segment() -> NamedTempFile {
375        let mut file = NamedTempFile::new().unwrap();
376
377        let n_vec = 100u32;
378        let dim = 64u32;
379        let num_blocks = (dim + 15) / 16;
380
381        let mut header = SegmentHeader::new(n_vec, dim);
382        header.flags.set(SegmentFlags::HAS_BPS);
383
384        // Calculate offsets
385        let mut offset = SegmentHeader::SIZE as u64;
386
387        // BPS data
388        header.off_bps = offset;
389        let bps_size = (num_blocks as usize * n_vec as usize) as u64;
390        offset += bps_size;
391
392        // i8 data
393        header.off_i8 = offset;
394        let i8_size = (n_vec as usize * dim as usize) as u64;
395        offset += i8_size;
396
397        // Scales
398        header.off_scales = offset;
399        let scales_size = (num_blocks as usize * n_vec as usize * 4) as u64;
400        offset += scales_size;
401
402        // Tombstone
403        header.off_tombstone = offset;
404        let tombstone_size = ((n_vec as usize + 63) / 64 * 8) as u64;
405        offset += tombstone_size;
406
407        header.file_len = offset;
408
409        // Write header
410        file.write_all(bytemuck::bytes_of(&header)).unwrap();
411
412        // Write BPS data (zeros)
413        file.write_all(&vec![0u8; bps_size as usize]).unwrap();
414
415        // Write i8 data (zeros)
416        file.write_all(&vec![0u8; i8_size as usize]).unwrap();
417
418        // Write scales (ones)
419        for _ in 0..(num_blocks * n_vec) {
420            file.write_all(&1.0f32.to_le_bytes()).unwrap();
421        }
422
423        // Write tombstone (zeros = no tombstones)
424        file.write_all(&vec![0u8; tombstone_size as usize]).unwrap();
425
426        file.flush().unwrap();
427        file
428    }
429
430    #[test]
431    fn test_segment_open() {
432        let file = create_test_segment();
433        let segment = Segment::open(file.path()).unwrap();
434
435        assert_eq!(segment.num_vectors(), 100);
436        assert_eq!(segment.dim(), 64);
437    }
438
439    #[test]
440    fn rejects_out_of_bounds_offsets() {
441        // SECURITY (CWE-125): a crafted segment with a valid magic/version/file_len
442        // but a huge n_vec/dim makes the section sizes point far past the mapping.
443        // Without layout validation the from_raw_parts accessors would read out of
444        // bounds; Segment::open must reject it instead.
445        let mut file = NamedTempFile::new().unwrap();
446        let mut header = SegmentHeader::new(1_000_000_000, 512);
447        header.flags.set(SegmentFlags::HAS_BPS);
448        header.off_bps = SegmentHeader::SIZE as u64;
449        header.off_i8 = SegmentHeader::SIZE as u64;
450        // The file contains ONLY the header — every data section is out of range.
451        header.file_len = SegmentHeader::SIZE as u64;
452        file.write_all(bytemuck::bytes_of(&header)).unwrap();
453        file.flush().unwrap();
454
455        assert!(
456            Segment::open(file.path()).is_err(),
457            "segment with out-of-bounds section offsets must be rejected"
458        );
459    }
460
461    #[test]
462    fn rejects_offset_past_eof() {
463        // A single out-of-range offset (i8 section starts beyond the file) must be
464        // caught even when n_vec/dim are small.
465        let valid = create_test_segment();
466        let bytes = std::fs::read(valid.path()).unwrap();
467        let mut header: SegmentHeader = *bytemuck::from_bytes(&bytes[..SegmentHeader::SIZE]);
468        header.off_i8 = header.file_len + 4096; // point i8 past EOF
469        let mut tampered = bytes.clone();
470        tampered[..SegmentHeader::SIZE].copy_from_slice(bytemuck::bytes_of(&header));
471        let mut file = NamedTempFile::new().unwrap();
472        file.write_all(&tampered).unwrap();
473        file.flush().unwrap();
474
475        assert!(
476            Segment::open(file.path()).is_err(),
477            "segment with an offset past EOF must be rejected"
478        );
479    }
480
481    #[test]
482    fn test_tombstone_check() {
483        let file = create_test_segment();
484        let segment = Segment::open(file.path()).unwrap();
485
486        // No tombstones set
487        assert!(!segment.is_tombstoned(0));
488        assert!(!segment.is_tombstoned(50));
489        assert!(!segment.is_tombstoned(99));
490
491        // Out of range should return true
492        assert!(segment.is_tombstoned(100));
493    }
494}