Skip to main content

sochdb_vector/segment/
writer.rs

1//! Segment writer for building immutable segments.
2
3use std::fs::File;
4use std::io::{BufWriter, Seek, SeekFrom, Write};
5use std::path::Path;
6
7use super::bps::BpsBuilder;
8use super::format::*;
9use super::rdf::RdfBuilder;
10use super::rerank::RerankBuilder;
11use crate::config::EngineConfig;
12use crate::error::{Error, Result};
13use crate::rotation::Rotator;
14use crate::types::*;
15
16/// Builder for creating segment files
17pub struct SegmentWriter {
18    config: EngineConfig,
19    rotator: Rotator,
20    vectors: Vec<Vec<f32>>,
21    rotated: Vec<Vec<f32>>,
22}
23
24impl SegmentWriter {
25    /// Create a new segment writer
26    pub fn new(config: EngineConfig) -> Result<Self> {
27        config.validate()?;
28        let rotator = Rotator::new(config.dim);
29        Ok(Self {
30            config,
31            rotator,
32            vectors: Vec::new(),
33            rotated: Vec::new(),
34        })
35    }
36
37    /// Add a vector to the segment
38    pub fn add(&mut self, vector: &[f32]) -> Result<VectorId> {
39        if vector.len() != self.config.dim as usize {
40            return Err(Error::DimensionMismatch {
41                expected: self.config.dim,
42                got: vector.len() as u32,
43            });
44        }
45
46        let vid = self.vectors.len() as VectorId;
47        let vec_owned = vector.to_vec();
48
49        // Apply rotation
50        let rotated = self.rotator.rotate(&vec_owned);
51
52        self.vectors.push(vec_owned);
53        self.rotated.push(rotated);
54
55        Ok(vid)
56    }
57
58    /// Add multiple vectors
59    pub fn add_batch(&mut self, vectors: &[Vec<f32>]) -> Result<Vec<VectorId>> {
60        vectors.iter().map(|v| self.add(v)).collect()
61    }
62
63    /// Number of vectors added
64    pub fn len(&self) -> usize {
65        self.vectors.len()
66    }
67
68    /// Check if empty
69    pub fn is_empty(&self) -> bool {
70        self.vectors.is_empty()
71    }
72
73    /// Build and write segment to file
74    pub fn build<P: AsRef<Path>>(self, path: P) -> Result<()> {
75        if self.vectors.is_empty() {
76            return Err(Error::EmptyIndex);
77        }
78
79        let n_vec = self.vectors.len() as u32;
80        let dim = self.config.dim;
81        let _num_blocks = self.config.bps.num_blocks as usize;
82
83        // Create builders
84        let bps_builder = BpsBuilder::new(&self.config.bps, &self.rotated);
85        let rdf_builder = RdfBuilder::new(&self.config.rdf, dim, &self.rotated);
86        let rerank_builder = RerankBuilder::new(&self.config.rerank, &self.rotated);
87
88        // Open file for writing
89        let file = File::create(&path)?;
90        let mut writer = BufWriter::new(file);
91
92        // Reserve space for header
93        let mut header = SegmentHeader::new(n_vec, dim);
94        header.bps_block = self.config.bps.block_size;
95        header.bps_proj = self.config.bps.num_projections;
96        header.rdf_t = self.config.rdf.top_t;
97        header.rdf_stripe_shift = self.config.rdf.stripe_shift;
98        header.num_outliers = self.config.rerank.num_outliers;
99        header.flags.set(SegmentFlags::HAS_BPS);
100        header.flags.set(SegmentFlags::HAS_RDF);
101        header.flags.set(SegmentFlags::HAS_OUTLIERS);
102        header.flags.set(SegmentFlags::ROTATED);
103        header.flags.set(SegmentFlags::HAS_FP32); // Store originals for verification
104
105        // Write placeholder header (will rewrite at end)
106        writer.write_all(&[0u8; SegmentHeader::SIZE])?;
107        let mut offset = SegmentHeader::SIZE as u64;
108
109        // Write BPS data (SoA layout)
110        header.off_bps = offset;
111        let (bps_data, bps_qparams) = bps_builder.build();
112        writer.write_all(&bps_data)?;
113        offset += bps_data.len() as u64;
114
115        // Write BPS quantization parameters
116        header.off_bps_qparams = offset;
117        writer.write_all(bytemuck::cast_slice(&bps_qparams))?;
118        offset += (bps_qparams.len() * std::mem::size_of::<super::bps::BpsQParam>()) as u64;
119
120        // Write int8 embeddings
121        header.off_i8 = offset;
122        let (i8_data, scales) = rerank_builder.build_i8();
123        writer.write_all(bytemuck::cast_slice(&i8_data))?;
124        offset += i8_data.len() as u64;
125
126        // Write scales
127        header.off_scales = offset;
128        writer.write_all(bytemuck::cast_slice(&scales))?;
129        offset += (scales.len() * 4) as u64;
130
131        // Write outliers
132        header.off_outliers = offset;
133        let outliers = rerank_builder.build_outliers();
134        writer.write_all(bytemuck::cast_slice(&outliers))?;
135        offset += (outliers.len() * std::mem::size_of::<OutlierEntry>()) as u64;
136
137        // Write tombstone bitset (all zeros = no tombstones)
138        header.off_tombstone = offset;
139        let tombstone_words = (n_vec as usize + 63) / 64;
140        let tombstone_data = vec![0u64; tombstone_words];
141        writer.write_all(bytemuck::cast_slice(&tombstone_data))?;
142        offset += (tombstone_words * 8) as u64;
143
144        // Write RDF directory
145        header.off_rdf_dir = offset;
146        let (rdf_dir, rdf_data) = rdf_builder.build();
147        writer.write_all(bytemuck::cast_slice(&rdf_dir))?;
148        offset += (rdf_dir.len() * std::mem::size_of::<PostingListEntry>()) as u64;
149
150        // Write RDF posting list data
151        header.off_rdf_data = offset;
152        writer.write_all(&rdf_data)?;
153        offset += rdf_data.len() as u64;
154
155        // Write dimension weights
156        header.off_dim_weights = offset;
157        let weights = rdf_builder.dim_weights();
158        writer.write_all(bytemuck::cast_slice(&weights))?;
159        offset += (weights.len() * 4) as u64;
160
161        // Write original fp32 vectors
162        header.off_fp32 = offset;
163        for vec in &self.vectors {
164            writer.write_all(bytemuck::cast_slice(vec))?;
165        }
166        offset += (n_vec as usize * dim as usize * 4) as u64;
167
168        // Update header with final file length
169        header.file_len = offset;
170
171        // Seek back and write final header
172        writer.seek(SeekFrom::Start(0))?;
173        writer.write_all(bytemuck::bytes_of(&header))?;
174        writer.flush()?;
175
176        Ok(())
177    }
178
179    /// Get config
180    pub fn config(&self) -> &EngineConfig {
181        &self.config
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::segment::Segment;
189    use tempfile::NamedTempFile;
190
191    #[test]
192    fn test_segment_write_read() {
193        let config = EngineConfig::with_dim(64);
194        let mut writer = SegmentWriter::new(config).unwrap();
195
196        // Add some random vectors
197        use rand::Rng;
198        let mut rng = rand::thread_rng();
199        for _ in 0..100 {
200            let vec: Vec<f32> = (0..64).map(|_| rng.gen_range(-1.0..1.0)).collect();
201            writer.add(&vec).unwrap();
202        }
203
204        // Write segment
205        let file = NamedTempFile::new().unwrap();
206        writer.build(file.path()).unwrap();
207
208        // Read back
209        let segment = Segment::open(file.path()).unwrap();
210        assert_eq!(segment.num_vectors(), 100);
211        assert_eq!(segment.dim(), 64);
212    }
213}