1use std::collections::HashMap;
55use std::sync::Arc;
56
57#[derive(Debug, Clone)]
59pub struct BatchConfig {
60 pub dim: usize,
62
63 pub enable_rotation: bool,
65
66 pub parallel_threshold: usize,
68
69 pub rotation_threads: usize,
71
72 pub initial_capacity: usize,
74}
75
76impl Default for BatchConfig {
77 fn default() -> Self {
78 Self {
79 dim: 768,
80 enable_rotation: true,
81 parallel_threshold: 100,
82 rotation_threads: 4,
83 initial_capacity: 10_000,
84 }
85 }
86}
87
88pub type VectorKey = u64;
90
91#[derive(Debug, Clone, Default)]
93pub struct BatchWriteStats {
94 pub vectors_added: usize,
96
97 pub bytes_processed: usize,
99
100 pub rotation_time_ns: u64,
102
103 pub copy_time_ns: u64,
105
106 pub batches_processed: usize,
108}
109
110impl BatchWriteStats {
111 pub fn rotation_mb_per_sec(&self) -> f64 {
113 if self.rotation_time_ns == 0 {
114 return 0.0;
115 }
116 let mb = self.bytes_processed as f64 / (1024.0 * 1024.0);
117 mb / (self.rotation_time_ns as f64 / 1e9)
118 }
119}
120
121#[derive(Clone)]
123pub struct StoredVector {
124 pub key: VectorKey,
126
127 pub data: Vec<f32>,
129
130 pub index: u32,
132}
133
134#[derive(Debug, Clone)]
136pub enum BatchWriteError {
137 DimensionMismatch { expected: usize, actual: usize },
139
140 KeyCountMismatch { vectors: usize, keys: usize },
142
143 DuplicateKey(VectorKey),
145
146 BuildError(String),
148}
149
150impl std::fmt::Display for BatchWriteError {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 match self {
153 Self::DimensionMismatch { expected, actual } => {
154 write!(
155 f,
156 "dimension mismatch: expected {}, got {}",
157 expected, actual
158 )
159 }
160 Self::KeyCountMismatch { vectors, keys } => {
161 write!(f, "key count mismatch: {} vectors, {} keys", vectors, keys)
162 }
163 Self::DuplicateKey(k) => write!(f, "duplicate key: {}", k),
164 Self::BuildError(s) => write!(f, "build error: {}", s),
165 }
166 }
167}
168
169impl std::error::Error for BatchWriteError {}
170
171pub struct BatchSegmentWriter {
173 config: BatchConfig,
175
176 vectors: Vec<StoredVector>,
178
179 key_to_index: HashMap<VectorKey, u32>,
181
182 #[allow(dead_code)]
184 rotation_buffer: Vec<f32>,
185
186 stats: BatchWriteStats,
188}
189
190impl BatchSegmentWriter {
191 pub fn new(config: BatchConfig) -> Self {
193 let initial_capacity = config.initial_capacity;
194 let dim = config.dim;
195
196 Self {
197 config,
198 vectors: Vec::with_capacity(initial_capacity),
199 key_to_index: HashMap::with_capacity(initial_capacity),
200 rotation_buffer: vec![0.0; dim],
201 stats: BatchWriteStats::default(),
202 }
203 }
204
205 pub fn add(&mut self, key: VectorKey, vector: &[f32]) -> Result<u32, BatchWriteError> {
207 if vector.len() != self.config.dim {
208 return Err(BatchWriteError::DimensionMismatch {
209 expected: self.config.dim,
210 actual: vector.len(),
211 });
212 }
213
214 if self.key_to_index.contains_key(&key) {
215 return Err(BatchWriteError::DuplicateKey(key));
216 }
217
218 let index = self.vectors.len() as u32;
219
220 let data = if self.config.enable_rotation {
222 let start = std::time::Instant::now();
223 let rotated = self.rotate_vector(vector);
224 self.stats.rotation_time_ns += start.elapsed().as_nanos() as u64;
225 rotated
226 } else {
227 vector.to_vec()
228 };
229
230 self.vectors.push(StoredVector { key, data, index });
231 self.key_to_index.insert(key, index);
232 self.stats.vectors_added += 1;
233 self.stats.bytes_processed += vector.len() * 4;
234
235 Ok(index)
236 }
237
238 pub fn add_batch(
240 &mut self,
241 keys: &[VectorKey],
242 vectors: &[Vec<f32>],
243 ) -> Result<Vec<u32>, BatchWriteError> {
244 if keys.len() != vectors.len() {
245 return Err(BatchWriteError::KeyCountMismatch {
246 vectors: vectors.len(),
247 keys: keys.len(),
248 });
249 }
250
251 let mut indices = Vec::with_capacity(keys.len());
252
253 for (key, vector) in keys.iter().zip(vectors.iter()) {
254 let index = self.add(*key, vector)?;
255 indices.push(index);
256 }
257
258 self.stats.batches_processed += 1;
259
260 Ok(indices)
261 }
262
263 pub fn add_batch_contiguous(
267 &mut self,
268 flat_data: &[f32],
269 keys: &[VectorKey],
270 ) -> Result<Vec<u32>, BatchWriteError> {
271 let dim = self.config.dim;
272 let num_vectors = flat_data.len() / dim;
273
274 if flat_data.len() % dim != 0 {
275 return Err(BatchWriteError::DimensionMismatch {
276 expected: dim * keys.len(),
277 actual: flat_data.len(),
278 });
279 }
280
281 if keys.len() != num_vectors {
282 return Err(BatchWriteError::KeyCountMismatch {
283 vectors: num_vectors,
284 keys: keys.len(),
285 });
286 }
287
288 for key in keys {
290 if self.key_to_index.contains_key(key) {
291 return Err(BatchWriteError::DuplicateKey(*key));
292 }
293 }
294
295 let start_index = self.vectors.len() as u32;
296 let mut indices = Vec::with_capacity(num_vectors);
297
298 if self.config.enable_rotation && num_vectors >= self.config.parallel_threshold {
300 let rotated_vectors = self.rotate_batch_parallel(flat_data, num_vectors);
301
302 for (i, (key, data)) in keys.iter().zip(rotated_vectors.into_iter()).enumerate() {
303 let index = start_index + i as u32;
304 self.vectors.push(StoredVector {
305 key: *key,
306 data,
307 index,
308 });
309 self.key_to_index.insert(*key, index);
310 indices.push(index);
311 }
312 } else {
313 for (i, key) in keys.iter().enumerate() {
315 let start = i * dim;
316 let vector = &flat_data[start..start + dim];
317
318 let data = if self.config.enable_rotation {
319 let start_time = std::time::Instant::now();
320 let rotated = self.rotate_vector(vector);
321 self.stats.rotation_time_ns += start_time.elapsed().as_nanos() as u64;
322 rotated
323 } else {
324 vector.to_vec()
325 };
326
327 let index = start_index + i as u32;
328 self.vectors.push(StoredVector {
329 key: *key,
330 data,
331 index,
332 });
333 self.key_to_index.insert(*key, index);
334 indices.push(index);
335 }
336 }
337
338 self.stats.vectors_added += num_vectors;
339 self.stats.bytes_processed += flat_data.len() * 4;
340 self.stats.batches_processed += 1;
341
342 Ok(indices)
343 }
344
345 fn rotate_vector(&self, vector: &[f32]) -> Vec<f32> {
347 let mut rotated = vector.to_vec();
348 hadamard_transform(&mut rotated);
349 rotated
350 }
351
352 fn rotate_batch_parallel(&self, flat_data: &[f32], num_vectors: usize) -> Vec<Vec<f32>> {
354 use std::thread;
355
356 let start = std::time::Instant::now();
357 let dim = self.config.dim;
358 let num_threads = self.config.rotation_threads.min(num_vectors);
359 let chunk_size = (num_vectors + num_threads - 1) / num_threads;
360
361 let flat_data = Arc::new(flat_data.to_vec());
362 let mut handles = Vec::with_capacity(num_threads);
363
364 for t in 0..num_threads {
365 let flat_data = Arc::clone(&flat_data);
366 let start_vec = t * chunk_size;
367 let end_vec = (start_vec + chunk_size).min(num_vectors);
368
369 handles.push(thread::spawn(move || {
370 let mut results = Vec::with_capacity(end_vec - start_vec);
371
372 for i in start_vec..end_vec {
373 let start_idx = i * dim;
374 let mut rotated = flat_data[start_idx..start_idx + dim].to_vec();
375 hadamard_transform(&mut rotated);
376 results.push(rotated);
377 }
378
379 results
380 }));
381 }
382
383 let mut all_results = Vec::with_capacity(num_vectors);
385 for handle in handles {
386 let chunk_results = handle.join().unwrap();
387 all_results.extend(chunk_results);
388 }
389
390 let _elapsed = start.elapsed().as_nanos() as u64;
392 all_results
395 }
396
397 pub fn len(&self) -> usize {
399 self.vectors.len()
400 }
401
402 pub fn is_empty(&self) -> bool {
404 self.vectors.is_empty()
405 }
406
407 pub fn stats(&self) -> &BatchWriteStats {
409 &self.stats
410 }
411
412 pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
414 self.key_to_index
415 .get(&key)
416 .map(|&idx| self.vectors[idx as usize].data.as_slice())
417 }
418
419 pub fn get_by_index(&self, index: u32) -> Option<&[f32]> {
421 self.vectors.get(index as usize).map(|v| v.data.as_slice())
422 }
423
424 pub fn build(self) -> Result<BuiltSegment, BatchWriteError> {
426 Ok(BuiltSegment {
427 vectors: self.vectors,
428 key_to_index: self.key_to_index,
429 dim: self.config.dim,
430 stats: self.stats,
431 })
432 }
433}
434
435pub struct BuiltSegment {
437 pub vectors: Vec<StoredVector>,
439
440 pub key_to_index: HashMap<VectorKey, u32>,
442
443 pub dim: usize,
445
446 pub stats: BatchWriteStats,
448}
449
450impl BuiltSegment {
451 pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
453 self.key_to_index
454 .get(&key)
455 .map(|&idx| self.vectors[idx as usize].data.as_slice())
456 }
457
458 pub fn get_all_data(&self) -> Vec<f32> {
460 let mut data = Vec::with_capacity(self.vectors.len() * self.dim);
461 for v in &self.vectors {
462 data.extend_from_slice(&v.data);
463 }
464 data
465 }
466
467 pub fn len(&self) -> usize {
469 self.vectors.len()
470 }
471
472 pub fn is_empty(&self) -> bool {
474 self.vectors.is_empty()
475 }
476}
477
478fn hadamard_transform(data: &mut [f32]) {
486 let n = data.len();
487 if n == 0 || (n & (n - 1)) != 0 {
488 return;
490 }
491
492 let mut h = 1;
493 while h < n {
494 for i in (0..n).step_by(h * 2) {
495 for j in i..(i + h) {
496 let x = data[j];
497 let y = data[j + h];
498 data[j] = x + y;
499 data[j + h] = x - y;
500 }
501 }
502 h *= 2;
503 }
504
505 let scale = 1.0 / (n as f32).sqrt();
507 for x in data.iter_mut() {
508 *x *= scale;
509 }
510}
511
512#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_batch_writer_basic() {
522 let config = BatchConfig {
523 dim: 4,
524 enable_rotation: false,
525 ..Default::default()
526 };
527
528 let mut writer = BatchSegmentWriter::new(config);
529
530 let idx = writer.add(1, &[1.0, 2.0, 3.0, 4.0]).unwrap();
531 assert_eq!(idx, 0);
532
533 let retrieved = writer.get(1).unwrap();
534 assert_eq!(retrieved, &[1.0, 2.0, 3.0, 4.0]);
535 }
536
537 #[test]
538 fn test_batch_writer_contiguous() {
539 let config = BatchConfig {
540 dim: 4,
541 enable_rotation: false,
542 ..Default::default()
543 };
544
545 let mut writer = BatchSegmentWriter::new(config);
546
547 let flat_data = vec![
548 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
552 let keys = vec![10, 20, 30];
553
554 let indices = writer.add_batch_contiguous(&flat_data, &keys).unwrap();
555
556 assert_eq!(indices, vec![0, 1, 2]);
557 assert_eq!(writer.len(), 3);
558
559 assert_eq!(writer.get(10).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
560 assert_eq!(writer.get(20).unwrap(), &[5.0, 6.0, 7.0, 8.0]);
561 assert_eq!(writer.get(30).unwrap(), &[9.0, 10.0, 11.0, 12.0]);
562 }
563
564 #[test]
565 fn test_batch_writer_rotation() {
566 let config = BatchConfig {
567 dim: 4, enable_rotation: true,
569 ..Default::default()
570 };
571
572 let mut writer = BatchSegmentWriter::new(config);
573
574 let _ = writer.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
575
576 let rotated = writer.get(1).unwrap();
577
578 let norm_sq: f32 = rotated.iter().map(|x| x * x).sum();
581 assert!((norm_sq - 1.0).abs() < 0.1, "norm_sq = {}", norm_sq);
582 }
583
584 #[test]
585 fn test_duplicate_key_error() {
586 let config = BatchConfig {
587 dim: 4,
588 enable_rotation: false,
589 ..Default::default()
590 };
591
592 let mut writer = BatchSegmentWriter::new(config);
593
594 writer.add(1, &[1.0, 2.0, 3.0, 4.0]).unwrap();
595 let result = writer.add(1, &[5.0, 6.0, 7.0, 8.0]);
596
597 assert!(matches!(result, Err(BatchWriteError::DuplicateKey(1))));
598 }
599
600 #[test]
601 fn test_dimension_mismatch_error() {
602 let config = BatchConfig {
603 dim: 4,
604 enable_rotation: false,
605 ..Default::default()
606 };
607
608 let mut writer = BatchSegmentWriter::new(config);
609
610 let result = writer.add(1, &[1.0, 2.0, 3.0]); assert!(matches!(
613 result,
614 Err(BatchWriteError::DimensionMismatch {
615 expected: 4,
616 actual: 3
617 })
618 ));
619 }
620
621 #[test]
622 fn test_build_segment() {
623 let config = BatchConfig {
624 dim: 4,
625 enable_rotation: false,
626 ..Default::default()
627 };
628
629 let mut writer = BatchSegmentWriter::new(config);
630
631 let flat_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
632 let keys = vec![100, 200];
633
634 writer.add_batch_contiguous(&flat_data, &keys).unwrap();
635
636 let segment = writer.build().unwrap();
637
638 assert_eq!(segment.len(), 2);
639 assert_eq!(segment.get(100).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
640 assert_eq!(segment.get(200).unwrap(), &[5.0, 6.0, 7.0, 8.0]);
641 }
642
643 #[test]
644 fn test_hadamard_transform() {
645 let mut data = vec![1.0, 0.0, 0.0, 0.0];
646 hadamard_transform(&mut data);
647
648 for &x in &data {
650 assert!((x - 0.5).abs() < 0.01, "x = {}", x);
651 }
652 }
653}