1use memmap2::Mmap;
4use std::fs::File;
5use std::path::Path;
6use std::sync::Arc;
7
8use super::format::*;
9use crate::error::{Error, Result};
10use crate::types::*;
11
12fn validate_segment_layout(h: &SegmentHeader, mmap_len: usize) -> Result<()> {
16 let fits = |name: &str, off: u64, bytes: usize| -> Result<()> {
18 let off = off as usize;
19 let end = off
20 .checked_add(bytes)
21 .ok_or_else(|| Error::Segment(format!("segment section '{name}' size overflow")))?;
22 if end > mmap_len {
23 return Err(Error::Segment(format!(
24 "segment section '{name}' [{off}..{end}) exceeds file length {mmap_len}"
25 )));
26 }
27 Ok(())
28 };
29 let bytes = |count: usize, elem: usize, name: &str| -> Result<usize> {
31 count
32 .checked_mul(elem)
33 .ok_or_else(|| Error::Segment(format!("segment section '{name}' byte-size overflow")))
34 };
35
36 let n_vec = h.n_vec as usize;
37 let dim = h.dim as usize;
38 let blocks = h.num_bps_blocks() as usize;
39
40 fits("bps", h.off_bps, h.bps_size())?;
42 fits("i8", h.off_i8, bytes(n_vec, dim, "i8")?)?; fits(
44 "scales",
45 h.off_scales,
46 bytes(bytes(blocks, n_vec, "scales")?, 4, "scales")?, )?;
48 fits(
49 "tombstone",
50 h.off_tombstone,
51 bytes(n_vec.div_ceil(64), 8, "tombstone")?, )?;
53
54 if h.flags.has(SegmentFlags::HAS_OUTLIERS) {
56 let cnt = bytes(n_vec, h.num_outliers as usize, "outliers")?;
57 fits(
58 "outliers",
59 h.off_outliers,
60 bytes(cnt, std::mem::size_of::<OutlierEntry>(), "outliers")?,
61 )?;
62 }
63 if h.flags.has(SegmentFlags::HAS_RDF) {
64 fits(
65 "rdf_dir",
66 h.off_rdf_dir,
67 bytes(dim, std::mem::size_of::<PostingListEntry>(), "rdf_dir")?,
68 )?;
69 fits(
70 "dim_weights",
71 h.off_dim_weights,
72 bytes(dim, 4, "dim_weights")?,
73 )?;
74 if h.off_rdf_data as usize > mmap_len {
78 return Err(Error::Segment(
79 "segment section 'rdf_data' offset exceeds file".into(),
80 ));
81 }
82 }
83 if h.flags.has(SegmentFlags::HAS_FP32) {
84 fits(
85 "fp32",
86 h.off_fp32,
87 bytes(bytes(n_vec, dim, "fp32")?, 4, "fp32")?,
88 )?;
89 }
90 if h.off_bps_qparams != 0 {
91 let cnt = bytes(blocks, h.bps_proj as usize, "bps_qparams")?;
92 fits(
93 "bps_qparams",
94 h.off_bps_qparams,
95 bytes(
96 cnt,
97 std::mem::size_of::<super::bps::BpsQParam>(),
98 "bps_qparams",
99 )?,
100 )?;
101 }
102 Ok(())
103}
104
105pub struct Segment {
107 mmap: Arc<Mmap>,
109 header: SegmentHeader,
111 path: String,
113}
114
115impl Segment {
116 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
118 let path_str = path.as_ref().to_string_lossy().to_string();
119 let file = File::open(&path)?;
120 let mmap = unsafe { Mmap::map(&file)? };
121
122 if mmap.len() < SegmentHeader::SIZE {
123 return Err(Error::Segment("File too small for header".into()));
124 }
125
126 let header: SegmentHeader =
128 unsafe { std::ptr::read_unaligned(mmap.as_ptr() as *const SegmentHeader) };
129 header.validate()?;
130
131 if mmap.len() < header.file_len as usize {
133 return Err(Error::Segment(format!(
134 "File size {} < declared length {}",
135 mmap.len(),
136 header.file_len
137 )));
138 }
139
140 validate_segment_layout(&header, mmap.len())?;
148
149 Ok(Self {
150 mmap: Arc::new(mmap),
151 header,
152 path: path_str,
153 })
154 }
155
156 #[inline]
158 pub fn header(&self) -> &SegmentHeader {
159 &self.header
160 }
161
162 #[inline]
164 pub fn num_vectors(&self) -> u32 {
165 self.header.n_vec
166 }
167
168 #[inline]
170 pub fn dim(&self) -> u32 {
171 self.header.dim
172 }
173
174 #[inline]
176 pub fn bps_ptr(&self) -> *const u8 {
177 unsafe { self.mmap.as_ptr().add(self.header.off_bps as usize) }
178 }
179
180 pub fn bps_data(&self) -> &[u8] {
182 let size = self.header.bps_size();
183 unsafe { std::slice::from_raw_parts(self.bps_ptr(), size) }
184 }
185
186 #[inline]
188 pub fn i8_ptr(&self) -> *const i8 {
189 unsafe { self.mmap.as_ptr().add(self.header.off_i8 as usize) as *const i8 }
190 }
191
192 pub fn i8_data(&self) -> &[i8] {
194 let size = self.header.i8_size();
195 unsafe { std::slice::from_raw_parts(self.i8_ptr(), size) }
196 }
197
198 pub fn get_i8_vector(&self, vid: VectorId) -> Option<&[i8]> {
200 if vid >= self.header.n_vec {
201 return None;
202 }
203 let dim = self.header.dim as usize;
204 let offset = vid as usize * dim;
205 Some(&self.i8_data()[offset..offset + dim])
206 }
207
208 #[inline]
210 pub fn scales_ptr(&self) -> *const f32 {
211 unsafe { self.mmap.as_ptr().add(self.header.off_scales as usize) as *const f32 }
212 }
213
214 pub fn scales_data(&self) -> &[f32] {
216 let num_blocks = self.header.num_bps_blocks() as usize;
217 let size = num_blocks * self.header.n_vec as usize;
219 unsafe { std::slice::from_raw_parts(self.scales_ptr(), size) }
220 }
221
222 #[inline]
224 pub fn outliers_ptr(&self) -> *const OutlierEntry {
225 unsafe { self.mmap.as_ptr().add(self.header.off_outliers as usize) as *const OutlierEntry }
226 }
227
228 pub fn get_outliers(&self, vid: VectorId) -> Option<&[OutlierEntry]> {
230 if vid >= self.header.n_vec || !self.header.flags.has(SegmentFlags::HAS_OUTLIERS) {
231 return None;
232 }
233 let num_outliers = self.header.num_outliers as usize;
234 let offset = vid as usize * num_outliers;
235 unsafe {
236 Some(std::slice::from_raw_parts(
237 self.outliers_ptr().add(offset),
238 num_outliers,
239 ))
240 }
241 }
242
243 #[inline]
245 pub fn tombstone_ptr(&self) -> *const u64 {
246 unsafe { self.mmap.as_ptr().add(self.header.off_tombstone as usize) as *const u64 }
247 }
248
249 pub fn tombstone_data(&self) -> &[u64] {
251 let num_words = (self.header.n_vec as usize + 63) / 64;
252 unsafe { std::slice::from_raw_parts(self.tombstone_ptr(), num_words) }
253 }
254
255 pub fn is_tombstoned(&self, vid: VectorId) -> bool {
257 if vid >= self.header.n_vec {
258 return true;
259 }
260 let word_idx = vid as usize / 64;
261 let bit_idx = vid as usize % 64;
262 let tombstones = self.tombstone_data();
263 if word_idx >= tombstones.len() {
264 return false;
265 }
266 (tombstones[word_idx] & (1u64 << bit_idx)) != 0
267 }
268
269 pub fn rdf_directory(&self) -> &[PostingListEntry] {
271 if !self.header.flags.has(SegmentFlags::HAS_RDF) {
272 return &[];
273 }
274 let dim = self.header.dim as usize;
275 unsafe {
276 std::slice::from_raw_parts(
277 self.mmap.as_ptr().add(self.header.off_rdf_dir as usize) as *const PostingListEntry,
278 dim,
279 )
280 }
281 }
282
283 #[inline]
285 pub fn rdf_data_ptr(&self) -> *const u8 {
286 unsafe { self.mmap.as_ptr().add(self.header.off_rdf_data as usize) }
287 }
288
289 pub fn dim_weights(&self) -> &[f32] {
291 if !self.header.flags.has(SegmentFlags::HAS_RDF) {
292 return &[];
293 }
294 let dim = self.header.dim as usize;
295 unsafe {
296 std::slice::from_raw_parts(
297 self.mmap.as_ptr().add(self.header.off_dim_weights as usize) as *const f32,
298 dim,
299 )
300 }
301 }
302
303 pub fn fp32_data(&self) -> Option<&[f32]> {
305 if !self.header.flags.has(SegmentFlags::HAS_FP32) {
306 return None;
307 }
308 let size = self.header.n_vec as usize * self.header.dim as usize;
309 unsafe {
310 Some(std::slice::from_raw_parts(
311 self.mmap.as_ptr().add(self.header.off_fp32 as usize) as *const f32,
312 size,
313 ))
314 }
315 }
316
317 pub fn bps_qparams(&self) -> Option<&[super::bps::BpsQParam]> {
322 if self.header.off_bps_qparams == 0 {
323 return None;
324 }
325 let num_slots = self.header.num_bps_blocks() as usize * self.header.bps_proj as usize;
326 if num_slots == 0 {
327 return None;
328 }
329 unsafe {
330 Some(std::slice::from_raw_parts(
331 self.mmap.as_ptr().add(self.header.off_bps_qparams as usize)
332 as *const super::bps::BpsQParam,
333 num_slots,
334 ))
335 }
336 }
337
338 pub fn get_fp32_vector(&self, vid: VectorId) -> Option<&[f32]> {
340 let fp32 = self.fp32_data()?;
341 let dim = self.header.dim as usize;
342 let offset = vid as usize * dim;
343 Some(&fp32[offset..offset + dim])
344 }
345
346 pub fn path(&self) -> &str {
348 &self.path
349 }
350
351 pub fn clone_mmap(&self) -> Arc<Mmap> {
353 Arc::clone(&self.mmap)
354 }
355}
356
357impl std::fmt::Debug for Segment {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 f.debug_struct("Segment")
360 .field("path", &self.path)
361 .field("n_vec", &self.header.n_vec)
362 .field("dim", &self.header.dim)
363 .field("flags", &self.header.flags)
364 .finish()
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use std::io::Write;
372 use tempfile::NamedTempFile;
373
374 fn create_test_segment() -> NamedTempFile {
375 let mut file = NamedTempFile::new().unwrap();
376
377 let n_vec = 100u32;
378 let dim = 64u32;
379 let num_blocks = (dim + 15) / 16;
380
381 let mut header = SegmentHeader::new(n_vec, dim);
382 header.flags.set(SegmentFlags::HAS_BPS);
383
384 let mut offset = SegmentHeader::SIZE as u64;
386
387 header.off_bps = offset;
389 let bps_size = (num_blocks as usize * n_vec as usize) as u64;
390 offset += bps_size;
391
392 header.off_i8 = offset;
394 let i8_size = (n_vec as usize * dim as usize) as u64;
395 offset += i8_size;
396
397 header.off_scales = offset;
399 let scales_size = (num_blocks as usize * n_vec as usize * 4) as u64;
400 offset += scales_size;
401
402 header.off_tombstone = offset;
404 let tombstone_size = ((n_vec as usize + 63) / 64 * 8) as u64;
405 offset += tombstone_size;
406
407 header.file_len = offset;
408
409 file.write_all(bytemuck::bytes_of(&header)).unwrap();
411
412 file.write_all(&vec![0u8; bps_size as usize]).unwrap();
414
415 file.write_all(&vec![0u8; i8_size as usize]).unwrap();
417
418 for _ in 0..(num_blocks * n_vec) {
420 file.write_all(&1.0f32.to_le_bytes()).unwrap();
421 }
422
423 file.write_all(&vec![0u8; tombstone_size as usize]).unwrap();
425
426 file.flush().unwrap();
427 file
428 }
429
430 #[test]
431 fn test_segment_open() {
432 let file = create_test_segment();
433 let segment = Segment::open(file.path()).unwrap();
434
435 assert_eq!(segment.num_vectors(), 100);
436 assert_eq!(segment.dim(), 64);
437 }
438
439 #[test]
440 fn rejects_out_of_bounds_offsets() {
441 let mut file = NamedTempFile::new().unwrap();
446 let mut header = SegmentHeader::new(1_000_000_000, 512);
447 header.flags.set(SegmentFlags::HAS_BPS);
448 header.off_bps = SegmentHeader::SIZE as u64;
449 header.off_i8 = SegmentHeader::SIZE as u64;
450 header.file_len = SegmentHeader::SIZE as u64;
452 file.write_all(bytemuck::bytes_of(&header)).unwrap();
453 file.flush().unwrap();
454
455 assert!(
456 Segment::open(file.path()).is_err(),
457 "segment with out-of-bounds section offsets must be rejected"
458 );
459 }
460
461 #[test]
462 fn rejects_offset_past_eof() {
463 let valid = create_test_segment();
466 let bytes = std::fs::read(valid.path()).unwrap();
467 let mut header: SegmentHeader = *bytemuck::from_bytes(&bytes[..SegmentHeader::SIZE]);
468 header.off_i8 = header.file_len + 4096; let mut tampered = bytes.clone();
470 tampered[..SegmentHeader::SIZE].copy_from_slice(bytemuck::bytes_of(&header));
471 let mut file = NamedTempFile::new().unwrap();
472 file.write_all(&tampered).unwrap();
473 file.flush().unwrap();
474
475 assert!(
476 Segment::open(file.path()).is_err(),
477 "segment with an offset past EOF must be rejected"
478 );
479 }
480
481 #[test]
482 fn test_tombstone_check() {
483 let file = create_test_segment();
484 let segment = Segment::open(file.path()).unwrap();
485
486 assert!(!segment.is_tombstoned(0));
488 assert!(!segment.is_tombstoned(50));
489 assert!(!segment.is_tombstoned(99));
490
491 assert!(segment.is_tombstoned(100));
493 }
494}