Skip to main content

sochdb_vector/segment/
reader.rs

1//! Segment reader with mmap support.
2
3use memmap2::Mmap;
4use std::fs::File;
5use std::path::Path;
6use std::sync::Arc;
7
8use super::format::*;
9use crate::error::{Error, Result};
10use crate::types::*;
11
12/// An immutable segment backed by mmap
13pub struct Segment {
14    /// Memory-mapped file
15    mmap: Arc<Mmap>,
16    /// Parsed header
17    header: SegmentHeader,
18    /// File path
19    path: String,
20}
21
22impl Segment {
23    /// Open a segment file
24    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
25        let path_str = path.as_ref().to_string_lossy().to_string();
26        let file = File::open(&path)?;
27        let mmap = unsafe { Mmap::map(&file)? };
28
29        if mmap.len() < SegmentHeader::SIZE {
30            return Err(Error::Segment("File too small for header".into()));
31        }
32
33        // Parse header
34        let header: SegmentHeader =
35            unsafe { std::ptr::read_unaligned(mmap.as_ptr() as *const SegmentHeader) };
36        header.validate()?;
37
38        // Validate file length
39        if mmap.len() < header.file_len as usize {
40            return Err(Error::Segment(format!(
41                "File size {} < declared length {}",
42                mmap.len(),
43                header.file_len
44            )));
45        }
46
47        Ok(Self {
48            mmap: Arc::new(mmap),
49            header,
50            path: path_str,
51        })
52    }
53
54    /// Get segment header
55    #[inline]
56    pub fn header(&self) -> &SegmentHeader {
57        &self.header
58    }
59
60    /// Number of vectors
61    #[inline]
62    pub fn num_vectors(&self) -> u32 {
63        self.header.n_vec
64    }
65
66    /// Vector dimension
67    #[inline]
68    pub fn dim(&self) -> u32 {
69        self.header.dim
70    }
71
72    /// Get raw pointer to BPS data
73    #[inline]
74    pub fn bps_ptr(&self) -> *const u8 {
75        unsafe { self.mmap.as_ptr().add(self.header.off_bps as usize) }
76    }
77
78    /// Get BPS data slice
79    pub fn bps_data(&self) -> &[u8] {
80        let size = self.header.bps_size();
81        unsafe { std::slice::from_raw_parts(self.bps_ptr(), size) }
82    }
83
84    /// Get raw pointer to int8 embedding data
85    #[inline]
86    pub fn i8_ptr(&self) -> *const i8 {
87        unsafe { self.mmap.as_ptr().add(self.header.off_i8 as usize) as *const i8 }
88    }
89
90    /// Get int8 embedding data slice
91    pub fn i8_data(&self) -> &[i8] {
92        let size = self.header.i8_size();
93        unsafe { std::slice::from_raw_parts(self.i8_ptr(), size) }
94    }
95
96    /// Get int8 vector for a specific ID
97    pub fn get_i8_vector(&self, vid: VectorId) -> Option<&[i8]> {
98        if vid >= self.header.n_vec {
99            return None;
100        }
101        let dim = self.header.dim as usize;
102        let offset = vid as usize * dim;
103        Some(&self.i8_data()[offset..offset + dim])
104    }
105
106    /// Get raw pointer to quantization scales
107    #[inline]
108    pub fn scales_ptr(&self) -> *const f32 {
109        unsafe { self.mmap.as_ptr().add(self.header.off_scales as usize) as *const f32 }
110    }
111
112    /// Get quantization scales
113    pub fn scales_data(&self) -> &[f32] {
114        let num_blocks = self.header.num_bps_blocks() as usize;
115        // One scale per block per vector
116        let size = num_blocks * self.header.n_vec as usize;
117        unsafe { std::slice::from_raw_parts(self.scales_ptr(), size) }
118    }
119
120    /// Get raw pointer to outlier data
121    #[inline]
122    pub fn outliers_ptr(&self) -> *const OutlierEntry {
123        unsafe { self.mmap.as_ptr().add(self.header.off_outliers as usize) as *const OutlierEntry }
124    }
125
126    /// Get outliers for a specific vector
127    pub fn get_outliers(&self, vid: VectorId) -> Option<&[OutlierEntry]> {
128        if vid >= self.header.n_vec || !self.header.flags.has(SegmentFlags::HAS_OUTLIERS) {
129            return None;
130        }
131        let num_outliers = self.header.num_outliers as usize;
132        let offset = vid as usize * num_outliers;
133        unsafe {
134            Some(std::slice::from_raw_parts(
135                self.outliers_ptr().add(offset),
136                num_outliers,
137            ))
138        }
139    }
140
141    /// Get raw pointer to tombstone bitset
142    #[inline]
143    pub fn tombstone_ptr(&self) -> *const u64 {
144        unsafe { self.mmap.as_ptr().add(self.header.off_tombstone as usize) as *const u64 }
145    }
146
147    /// Get tombstone bitset
148    pub fn tombstone_data(&self) -> &[u64] {
149        let num_words = (self.header.n_vec as usize + 63) / 64;
150        unsafe { std::slice::from_raw_parts(self.tombstone_ptr(), num_words) }
151    }
152
153    /// Check if a vector is tombstoned
154    pub fn is_tombstoned(&self, vid: VectorId) -> bool {
155        if vid >= self.header.n_vec {
156            return true;
157        }
158        let word_idx = vid as usize / 64;
159        let bit_idx = vid as usize % 64;
160        let tombstones = self.tombstone_data();
161        if word_idx >= tombstones.len() {
162            return false;
163        }
164        (tombstones[word_idx] & (1u64 << bit_idx)) != 0
165    }
166
167    /// Get RDF posting list directory
168    pub fn rdf_directory(&self) -> &[PostingListEntry] {
169        if !self.header.flags.has(SegmentFlags::HAS_RDF) {
170            return &[];
171        }
172        let dim = self.header.dim as usize;
173        unsafe {
174            std::slice::from_raw_parts(
175                self.mmap.as_ptr().add(self.header.off_rdf_dir as usize) as *const PostingListEntry,
176                dim,
177            )
178        }
179    }
180
181    /// Get raw pointer to RDF posting list data
182    #[inline]
183    pub fn rdf_data_ptr(&self) -> *const u8 {
184        unsafe { self.mmap.as_ptr().add(self.header.off_rdf_data as usize) }
185    }
186
187    /// Get dimension weights for RDF
188    pub fn dim_weights(&self) -> &[f32] {
189        if !self.header.flags.has(SegmentFlags::HAS_RDF) {
190            return &[];
191        }
192        let dim = self.header.dim as usize;
193        unsafe {
194            std::slice::from_raw_parts(
195                self.mmap.as_ptr().add(self.header.off_dim_weights as usize) as *const f32,
196                dim,
197            )
198        }
199    }
200
201    /// Get optional fp32 vectors for verification
202    pub fn fp32_data(&self) -> Option<&[f32]> {
203        if !self.header.flags.has(SegmentFlags::HAS_FP32) {
204            return None;
205        }
206        let size = self.header.n_vec as usize * self.header.dim as usize;
207        unsafe {
208            Some(std::slice::from_raw_parts(
209                self.mmap.as_ptr().add(self.header.off_fp32 as usize) as *const f32,
210                size,
211            ))
212        }
213    }
214
215    /// Get BPS quantization parameters (min, inv_range per slot).
216    ///
217    /// Returns `None` if qparams were not stored (legacy segments).
218    /// The number of slots = num_bps_blocks × bps_proj.
219    pub fn bps_qparams(&self) -> Option<&[super::bps::BpsQParam]> {
220        if self.header.off_bps_qparams == 0 {
221            return None;
222        }
223        let num_slots = self.header.num_bps_blocks() as usize * self.header.bps_proj as usize;
224        if num_slots == 0 {
225            return None;
226        }
227        unsafe {
228            Some(std::slice::from_raw_parts(
229                self.mmap.as_ptr().add(self.header.off_bps_qparams as usize)
230                    as *const super::bps::BpsQParam,
231                num_slots,
232            ))
233        }
234    }
235
236    /// Get fp32 vector for a specific ID
237    pub fn get_fp32_vector(&self, vid: VectorId) -> Option<&[f32]> {
238        let fp32 = self.fp32_data()?;
239        let dim = self.header.dim as usize;
240        let offset = vid as usize * dim;
241        Some(&fp32[offset..offset + dim])
242    }
243
244    /// Get file path
245    pub fn path(&self) -> &str {
246        &self.path
247    }
248
249    /// Clone the mmap handle (cheap, Arc-backed)
250    pub fn clone_mmap(&self) -> Arc<Mmap> {
251        Arc::clone(&self.mmap)
252    }
253}
254
255impl std::fmt::Debug for Segment {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        f.debug_struct("Segment")
258            .field("path", &self.path)
259            .field("n_vec", &self.header.n_vec)
260            .field("dim", &self.header.dim)
261            .field("flags", &self.header.flags)
262            .finish()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use std::io::Write;
270    use tempfile::NamedTempFile;
271
272    fn create_test_segment() -> NamedTempFile {
273        let mut file = NamedTempFile::new().unwrap();
274
275        let n_vec = 100u32;
276        let dim = 64u32;
277        let num_blocks = (dim + 15) / 16;
278
279        let mut header = SegmentHeader::new(n_vec, dim);
280        header.flags.set(SegmentFlags::HAS_BPS);
281
282        // Calculate offsets
283        let mut offset = SegmentHeader::SIZE as u64;
284
285        // BPS data
286        header.off_bps = offset;
287        let bps_size = (num_blocks as usize * n_vec as usize) as u64;
288        offset += bps_size;
289
290        // i8 data
291        header.off_i8 = offset;
292        let i8_size = (n_vec as usize * dim as usize) as u64;
293        offset += i8_size;
294
295        // Scales
296        header.off_scales = offset;
297        let scales_size = (num_blocks as usize * n_vec as usize * 4) as u64;
298        offset += scales_size;
299
300        // Tombstone
301        header.off_tombstone = offset;
302        let tombstone_size = ((n_vec as usize + 63) / 64 * 8) as u64;
303        offset += tombstone_size;
304
305        header.file_len = offset;
306
307        // Write header
308        file.write_all(bytemuck::bytes_of(&header)).unwrap();
309
310        // Write BPS data (zeros)
311        file.write_all(&vec![0u8; bps_size as usize]).unwrap();
312
313        // Write i8 data (zeros)
314        file.write_all(&vec![0u8; i8_size as usize]).unwrap();
315
316        // Write scales (ones)
317        for _ in 0..(num_blocks * n_vec) {
318            file.write_all(&1.0f32.to_le_bytes()).unwrap();
319        }
320
321        // Write tombstone (zeros = no tombstones)
322        file.write_all(&vec![0u8; tombstone_size as usize]).unwrap();
323
324        file.flush().unwrap();
325        file
326    }
327
328    #[test]
329    fn test_segment_open() {
330        let file = create_test_segment();
331        let segment = Segment::open(file.path()).unwrap();
332
333        assert_eq!(segment.num_vectors(), 100);
334        assert_eq!(segment.dim(), 64);
335    }
336
337    #[test]
338    fn test_tombstone_check() {
339        let file = create_test_segment();
340        let segment = Segment::open(file.path()).unwrap();
341
342        // No tombstones set
343        assert!(!segment.is_tombstoned(0));
344        assert!(!segment.is_tombstoned(50));
345        assert!(!segment.is_tombstoned(99));
346
347        // Out of range should return true
348        assert!(segment.is_tombstoned(100));
349    }
350}