sochdb_vector/segment/
writer.rs1use std::fs::File;
4use std::io::{BufWriter, Seek, SeekFrom, Write};
5use std::path::Path;
6
7use super::bps::BpsBuilder;
8use super::format::*;
9use super::rdf::RdfBuilder;
10use super::rerank::RerankBuilder;
11use crate::config::EngineConfig;
12use crate::error::{Error, Result};
13use crate::rotation::Rotator;
14use crate::types::*;
15
16pub struct SegmentWriter {
18 config: EngineConfig,
19 rotator: Rotator,
20 vectors: Vec<Vec<f32>>,
21 rotated: Vec<Vec<f32>>,
22}
23
24impl SegmentWriter {
25 pub fn new(config: EngineConfig) -> Result<Self> {
27 config.validate()?;
28 let rotator = Rotator::new(config.dim);
29 Ok(Self {
30 config,
31 rotator,
32 vectors: Vec::new(),
33 rotated: Vec::new(),
34 })
35 }
36
37 pub fn add(&mut self, vector: &[f32]) -> Result<VectorId> {
39 if vector.len() != self.config.dim as usize {
40 return Err(Error::DimensionMismatch {
41 expected: self.config.dim,
42 got: vector.len() as u32,
43 });
44 }
45
46 let vid = self.vectors.len() as VectorId;
47 let vec_owned = vector.to_vec();
48
49 let rotated = self.rotator.rotate(&vec_owned);
51
52 self.vectors.push(vec_owned);
53 self.rotated.push(rotated);
54
55 Ok(vid)
56 }
57
58 pub fn add_batch(&mut self, vectors: &[Vec<f32>]) -> Result<Vec<VectorId>> {
60 vectors.iter().map(|v| self.add(v)).collect()
61 }
62
63 pub fn len(&self) -> usize {
65 self.vectors.len()
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.vectors.is_empty()
71 }
72
73 pub fn build<P: AsRef<Path>>(self, path: P) -> Result<()> {
75 if self.vectors.is_empty() {
76 return Err(Error::EmptyIndex);
77 }
78
79 let n_vec = self.vectors.len() as u32;
80 let dim = self.config.dim;
81 let _num_blocks = self.config.bps.num_blocks as usize;
82
83 let bps_builder = BpsBuilder::new(&self.config.bps, &self.rotated);
85 let rdf_builder = RdfBuilder::new(&self.config.rdf, dim, &self.rotated);
86 let rerank_builder = RerankBuilder::new(&self.config.rerank, &self.rotated);
87
88 let file = File::create(&path)?;
90 let mut writer = BufWriter::new(file);
91
92 let mut header = SegmentHeader::new(n_vec, dim);
94 header.bps_block = self.config.bps.block_size;
95 header.bps_proj = self.config.bps.num_projections;
96 header.rdf_t = self.config.rdf.top_t;
97 header.rdf_stripe_shift = self.config.rdf.stripe_shift;
98 header.num_outliers = self.config.rerank.num_outliers;
99 header.flags.set(SegmentFlags::HAS_BPS);
100 header.flags.set(SegmentFlags::HAS_RDF);
101 header.flags.set(SegmentFlags::HAS_OUTLIERS);
102 header.flags.set(SegmentFlags::ROTATED);
103 header.flags.set(SegmentFlags::HAS_FP32); writer.write_all(&[0u8; SegmentHeader::SIZE])?;
107 let mut offset = SegmentHeader::SIZE as u64;
108
109 header.off_bps = offset;
111 let (bps_data, bps_qparams) = bps_builder.build();
112 writer.write_all(&bps_data)?;
113 offset += bps_data.len() as u64;
114
115 header.off_bps_qparams = offset;
117 writer.write_all(bytemuck::cast_slice(&bps_qparams))?;
118 offset += (bps_qparams.len() * std::mem::size_of::<super::bps::BpsQParam>()) as u64;
119
120 header.off_i8 = offset;
122 let (i8_data, scales) = rerank_builder.build_i8();
123 writer.write_all(bytemuck::cast_slice(&i8_data))?;
124 offset += i8_data.len() as u64;
125
126 header.off_scales = offset;
128 writer.write_all(bytemuck::cast_slice(&scales))?;
129 offset += (scales.len() * 4) as u64;
130
131 header.off_outliers = offset;
133 let outliers = rerank_builder.build_outliers();
134 writer.write_all(bytemuck::cast_slice(&outliers))?;
135 offset += (outliers.len() * std::mem::size_of::<OutlierEntry>()) as u64;
136
137 header.off_tombstone = offset;
139 let tombstone_words = (n_vec as usize + 63) / 64;
140 let tombstone_data = vec![0u64; tombstone_words];
141 writer.write_all(bytemuck::cast_slice(&tombstone_data))?;
142 offset += (tombstone_words * 8) as u64;
143
144 header.off_rdf_dir = offset;
146 let (rdf_dir, rdf_data) = rdf_builder.build();
147 writer.write_all(bytemuck::cast_slice(&rdf_dir))?;
148 offset += (rdf_dir.len() * std::mem::size_of::<PostingListEntry>()) as u64;
149
150 header.off_rdf_data = offset;
152 writer.write_all(&rdf_data)?;
153 offset += rdf_data.len() as u64;
154
155 header.off_dim_weights = offset;
157 let weights = rdf_builder.dim_weights();
158 writer.write_all(bytemuck::cast_slice(&weights))?;
159 offset += (weights.len() * 4) as u64;
160
161 header.off_fp32 = offset;
163 for vec in &self.vectors {
164 writer.write_all(bytemuck::cast_slice(vec))?;
165 }
166 offset += (n_vec as usize * dim as usize * 4) as u64;
167
168 header.file_len = offset;
170
171 writer.seek(SeekFrom::Start(0))?;
173 writer.write_all(bytemuck::bytes_of(&header))?;
174 writer.flush()?;
175
176 Ok(())
177 }
178
179 pub fn config(&self) -> &EngineConfig {
181 &self.config
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::segment::Segment;
189 use tempfile::NamedTempFile;
190
191 #[test]
192 fn test_segment_write_read() {
193 let config = EngineConfig::with_dim(64);
194 let mut writer = SegmentWriter::new(config).unwrap();
195
196 use rand::Rng;
198 let mut rng = rand::thread_rng();
199 for _ in 0..100 {
200 let vec: Vec<f32> = (0..64).map(|_| rng.gen_range(-1.0..1.0)).collect();
201 writer.add(&vec).unwrap();
202 }
203
204 let file = NamedTempFile::new().unwrap();
206 writer.build(file.path()).unwrap();
207
208 let segment = Segment::open(file.path()).unwrap();
210 assert_eq!(segment.num_vectors(), 100);
211 assert_eq!(segment.dim(), 64);
212 }
213}