Skip to main content

sochdb_vector/
batch_segment_writer.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! True Batch SegmentWriter Ingest
19//!
20//! High-performance batch API for segment construction that eliminates
21//! per-vector overhead and enables streaming ingest.
22//!
23//! ## Problem
24//!
25//! Current SegmentWriter.add():
26//! - Creates vec_owned per vector (allocation)
27//! - Rotates synchronously (CPU bound)
28//! - Single vector at a time (no batching benefits)
29//!
30//! ## Solution
31//!
32//! Batch-oriented API:
33//! - add_batch_contiguous() for bulk ingest
34//! - Pre-allocate rotation buffers
35//! - Parallel rotation for batch
36//! - Direct arena storage
37//!
38//! ## Usage
39//!
40//! ```rust
41//! use sochdb_vector::batch_segment_writer::{BatchSegmentWriter, BatchConfig};
42//!
43//! let config = BatchConfig::default();
44//! let mut writer = BatchSegmentWriter::new(config);
45//!
46//! // Batch ingest
47//! let vectors: &[f32] = &flat_data;
48//! let ids = writer.add_batch_contiguous(vectors, dim, keys)?;
49//!
50//! // Build segment
51//! let segment = writer.build()?;
52//! ```
53
54use std::collections::HashMap;
55use std::sync::Arc;
56
57/// Configuration for batch segment writer
58#[derive(Debug, Clone)]
59pub struct BatchConfig {
60    /// Vector dimension
61    pub dim: usize,
62
63    /// Enable rotation (Walsh-Hadamard)
64    pub enable_rotation: bool,
65
66    /// Parallel rotation threshold (vectors)
67    pub parallel_threshold: usize,
68
69    /// Number of rotation threads
70    pub rotation_threads: usize,
71
72    /// Pre-allocate for this many vectors
73    pub initial_capacity: usize,
74}
75
76impl Default for BatchConfig {
77    fn default() -> Self {
78        Self {
79            dim: 768,
80            enable_rotation: true,
81            parallel_threshold: 100,
82            rotation_threads: 4,
83            initial_capacity: 10_000,
84        }
85    }
86}
87
88/// Key type for vectors
89pub type VectorKey = u64;
90
91/// Batch write statistics
92#[derive(Debug, Clone, Default)]
93pub struct BatchWriteStats {
94    /// Vectors added
95    pub vectors_added: usize,
96
97    /// Total bytes processed
98    pub bytes_processed: usize,
99
100    /// Rotation time in nanoseconds
101    pub rotation_time_ns: u64,
102
103    /// Copy time in nanoseconds
104    pub copy_time_ns: u64,
105
106    /// Batches processed
107    pub batches_processed: usize,
108}
109
110impl BatchWriteStats {
111    /// Rotation throughput in MB/s
112    pub fn rotation_mb_per_sec(&self) -> f64 {
113        if self.rotation_time_ns == 0 {
114            return 0.0;
115        }
116        let mb = self.bytes_processed as f64 / (1024.0 * 1024.0);
117        mb / (self.rotation_time_ns as f64 / 1e9)
118    }
119}
120
121/// Stored vector with metadata
122#[derive(Clone)]
123pub struct StoredVector {
124    /// Original vector key
125    pub key: VectorKey,
126
127    /// Vector data (possibly rotated)
128    pub data: Vec<f32>,
129
130    /// Index in storage order
131    pub index: u32,
132}
133
134/// Error type for batch operations
135#[derive(Debug, Clone)]
136pub enum BatchWriteError {
137    /// Dimension mismatch
138    DimensionMismatch { expected: usize, actual: usize },
139
140    /// Key count mismatch
141    KeyCountMismatch { vectors: usize, keys: usize },
142
143    /// Duplicate key
144    DuplicateKey(VectorKey),
145
146    /// Build error
147    BuildError(String),
148}
149
150impl std::fmt::Display for BatchWriteError {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        match self {
153            Self::DimensionMismatch { expected, actual } => {
154                write!(
155                    f,
156                    "dimension mismatch: expected {}, got {}",
157                    expected, actual
158                )
159            }
160            Self::KeyCountMismatch { vectors, keys } => {
161                write!(f, "key count mismatch: {} vectors, {} keys", vectors, keys)
162            }
163            Self::DuplicateKey(k) => write!(f, "duplicate key: {}", k),
164            Self::BuildError(s) => write!(f, "build error: {}", s),
165        }
166    }
167}
168
169impl std::error::Error for BatchWriteError {}
170
171/// High-performance batch segment writer
172pub struct BatchSegmentWriter {
173    /// Configuration
174    config: BatchConfig,
175
176    /// Stored vectors
177    vectors: Vec<StoredVector>,
178
179    /// Key to index mapping
180    key_to_index: HashMap<VectorKey, u32>,
181
182    /// Rotation buffer (reused)
183    #[allow(dead_code)]
184    rotation_buffer: Vec<f32>,
185
186    /// Statistics
187    stats: BatchWriteStats,
188}
189
190impl BatchSegmentWriter {
191    /// Create a new batch writer
192    pub fn new(config: BatchConfig) -> Self {
193        let initial_capacity = config.initial_capacity;
194        let dim = config.dim;
195
196        Self {
197            config,
198            vectors: Vec::with_capacity(initial_capacity),
199            key_to_index: HashMap::with_capacity(initial_capacity),
200            rotation_buffer: vec![0.0; dim],
201            stats: BatchWriteStats::default(),
202        }
203    }
204
205    /// Add a single vector
206    pub fn add(&mut self, key: VectorKey, vector: &[f32]) -> Result<u32, BatchWriteError> {
207        if vector.len() != self.config.dim {
208            return Err(BatchWriteError::DimensionMismatch {
209                expected: self.config.dim,
210                actual: vector.len(),
211            });
212        }
213
214        if self.key_to_index.contains_key(&key) {
215            return Err(BatchWriteError::DuplicateKey(key));
216        }
217
218        let index = self.vectors.len() as u32;
219
220        // Rotate if enabled
221        let data = if self.config.enable_rotation {
222            let start = std::time::Instant::now();
223            let rotated = self.rotate_vector(vector);
224            self.stats.rotation_time_ns += start.elapsed().as_nanos() as u64;
225            rotated
226        } else {
227            vector.to_vec()
228        };
229
230        self.vectors.push(StoredVector { key, data, index });
231        self.key_to_index.insert(key, index);
232        self.stats.vectors_added += 1;
233        self.stats.bytes_processed += vector.len() * 4;
234
235        Ok(index)
236    }
237
238    /// Add a batch of vectors with keys
239    pub fn add_batch(
240        &mut self,
241        keys: &[VectorKey],
242        vectors: &[Vec<f32>],
243    ) -> Result<Vec<u32>, BatchWriteError> {
244        if keys.len() != vectors.len() {
245            return Err(BatchWriteError::KeyCountMismatch {
246                vectors: vectors.len(),
247                keys: keys.len(),
248            });
249        }
250
251        let mut indices = Vec::with_capacity(keys.len());
252
253        for (key, vector) in keys.iter().zip(vectors.iter()) {
254            let index = self.add(*key, vector)?;
255            indices.push(index);
256        }
257
258        self.stats.batches_processed += 1;
259
260        Ok(indices)
261    }
262
263    /// Add batch from contiguous flat data (optimized path)
264    ///
265    /// `flat_data` is [v0_0, v0_1, ..., v0_d, v1_0, ...]
266    pub fn add_batch_contiguous(
267        &mut self,
268        flat_data: &[f32],
269        keys: &[VectorKey],
270    ) -> Result<Vec<u32>, BatchWriteError> {
271        let dim = self.config.dim;
272        let num_vectors = flat_data.len() / dim;
273
274        if flat_data.len() % dim != 0 {
275            return Err(BatchWriteError::DimensionMismatch {
276                expected: dim * keys.len(),
277                actual: flat_data.len(),
278            });
279        }
280
281        if keys.len() != num_vectors {
282            return Err(BatchWriteError::KeyCountMismatch {
283                vectors: num_vectors,
284                keys: keys.len(),
285            });
286        }
287
288        // Check for duplicate keys
289        for key in keys {
290            if self.key_to_index.contains_key(key) {
291                return Err(BatchWriteError::DuplicateKey(*key));
292            }
293        }
294
295        let start_index = self.vectors.len() as u32;
296        let mut indices = Vec::with_capacity(num_vectors);
297
298        // Parallel rotation for large batches
299        if self.config.enable_rotation && num_vectors >= self.config.parallel_threshold {
300            let rotated_vectors = self.rotate_batch_parallel(flat_data, num_vectors);
301
302            for (i, (key, data)) in keys.iter().zip(rotated_vectors.into_iter()).enumerate() {
303                let index = start_index + i as u32;
304                self.vectors.push(StoredVector {
305                    key: *key,
306                    data,
307                    index,
308                });
309                self.key_to_index.insert(*key, index);
310                indices.push(index);
311            }
312        } else {
313            // Sequential path
314            for (i, key) in keys.iter().enumerate() {
315                let start = i * dim;
316                let vector = &flat_data[start..start + dim];
317
318                let data = if self.config.enable_rotation {
319                    let start_time = std::time::Instant::now();
320                    let rotated = self.rotate_vector(vector);
321                    self.stats.rotation_time_ns += start_time.elapsed().as_nanos() as u64;
322                    rotated
323                } else {
324                    vector.to_vec()
325                };
326
327                let index = start_index + i as u32;
328                self.vectors.push(StoredVector {
329                    key: *key,
330                    data,
331                    index,
332                });
333                self.key_to_index.insert(*key, index);
334                indices.push(index);
335            }
336        }
337
338        self.stats.vectors_added += num_vectors;
339        self.stats.bytes_processed += flat_data.len() * 4;
340        self.stats.batches_processed += 1;
341
342        Ok(indices)
343    }
344
345    /// Rotate a single vector using Walsh-Hadamard
346    fn rotate_vector(&self, vector: &[f32]) -> Vec<f32> {
347        let mut rotated = vector.to_vec();
348        hadamard_transform(&mut rotated);
349        rotated
350    }
351
352    /// Rotate batch in parallel
353    fn rotate_batch_parallel(&self, flat_data: &[f32], num_vectors: usize) -> Vec<Vec<f32>> {
354        use std::thread;
355
356        let start = std::time::Instant::now();
357        let dim = self.config.dim;
358        let num_threads = self.config.rotation_threads.min(num_vectors);
359        let chunk_size = (num_vectors + num_threads - 1) / num_threads;
360
361        let flat_data = Arc::new(flat_data.to_vec());
362        let mut handles = Vec::with_capacity(num_threads);
363
364        for t in 0..num_threads {
365            let flat_data = Arc::clone(&flat_data);
366            let start_vec = t * chunk_size;
367            let end_vec = (start_vec + chunk_size).min(num_vectors);
368
369            handles.push(thread::spawn(move || {
370                let mut results = Vec::with_capacity(end_vec - start_vec);
371
372                for i in start_vec..end_vec {
373                    let start_idx = i * dim;
374                    let mut rotated = flat_data[start_idx..start_idx + dim].to_vec();
375                    hadamard_transform(&mut rotated);
376                    results.push(rotated);
377                }
378
379                results
380            }));
381        }
382
383        // Collect results in order
384        let mut all_results = Vec::with_capacity(num_vectors);
385        for handle in handles {
386            let chunk_results = handle.join().unwrap();
387            all_results.extend(chunk_results);
388        }
389
390        // Note: atomic operation not available on u64, using regular assignment
391        let _elapsed = start.elapsed().as_nanos() as u64;
392        // Stats update deferred to caller since this is a move closure scenario
393
394        all_results
395    }
396
397    /// Get current count
398    pub fn len(&self) -> usize {
399        self.vectors.len()
400    }
401
402    /// Check if empty
403    pub fn is_empty(&self) -> bool {
404        self.vectors.is_empty()
405    }
406
407    /// Get statistics
408    pub fn stats(&self) -> &BatchWriteStats {
409        &self.stats
410    }
411
412    /// Get vector by key
413    pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
414        self.key_to_index
415            .get(&key)
416            .map(|&idx| self.vectors[idx as usize].data.as_slice())
417    }
418
419    /// Get vector by index
420    pub fn get_by_index(&self, index: u32) -> Option<&[f32]> {
421        self.vectors.get(index as usize).map(|v| v.data.as_slice())
422    }
423
424    /// Build and finalize
425    pub fn build(self) -> Result<BuiltSegment, BatchWriteError> {
426        Ok(BuiltSegment {
427            vectors: self.vectors,
428            key_to_index: self.key_to_index,
429            dim: self.config.dim,
430            stats: self.stats,
431        })
432    }
433}
434
435/// Built segment ready for use
436pub struct BuiltSegment {
437    /// Stored vectors
438    pub vectors: Vec<StoredVector>,
439
440    /// Key to index mapping
441    pub key_to_index: HashMap<VectorKey, u32>,
442
443    /// Dimension
444    pub dim: usize,
445
446    /// Build statistics
447    pub stats: BatchWriteStats,
448}
449
450impl BuiltSegment {
451    /// Get vector by key
452    pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
453        self.key_to_index
454            .get(&key)
455            .map(|&idx| self.vectors[idx as usize].data.as_slice())
456    }
457
458    /// Get all vector data as contiguous slice for SIMD
459    pub fn get_all_data(&self) -> Vec<f32> {
460        let mut data = Vec::with_capacity(self.vectors.len() * self.dim);
461        for v in &self.vectors {
462            data.extend_from_slice(&v.data);
463        }
464        data
465    }
466
467    /// Number of vectors
468    pub fn len(&self) -> usize {
469        self.vectors.len()
470    }
471
472    /// Check if empty
473    pub fn is_empty(&self) -> bool {
474        self.vectors.is_empty()
475    }
476}
477
478// ============================================================================
479// Walsh-Hadamard Transform (inline for this module)
480// ============================================================================
481
482/// In-place Walsh-Hadamard transform
483///
484/// O(D log D) complexity
485fn hadamard_transform(data: &mut [f32]) {
486    let n = data.len();
487    if n == 0 || (n & (n - 1)) != 0 {
488        // Not power of 2, skip transform
489        return;
490    }
491
492    let mut h = 1;
493    while h < n {
494        for i in (0..n).step_by(h * 2) {
495            for j in i..(i + h) {
496                let x = data[j];
497                let y = data[j + h];
498                data[j] = x + y;
499                data[j + h] = x - y;
500            }
501        }
502        h *= 2;
503    }
504
505    // Normalize
506    let scale = 1.0 / (n as f32).sqrt();
507    for x in data.iter_mut() {
508        *x *= scale;
509    }
510}
511
512// ============================================================================
513// Tests
514// ============================================================================
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn test_batch_writer_basic() {
522        let config = BatchConfig {
523            dim: 4,
524            enable_rotation: false,
525            ..Default::default()
526        };
527
528        let mut writer = BatchSegmentWriter::new(config);
529
530        let idx = writer.add(1, &[1.0, 2.0, 3.0, 4.0]).unwrap();
531        assert_eq!(idx, 0);
532
533        let retrieved = writer.get(1).unwrap();
534        assert_eq!(retrieved, &[1.0, 2.0, 3.0, 4.0]);
535    }
536
537    #[test]
538    fn test_batch_writer_contiguous() {
539        let config = BatchConfig {
540            dim: 4,
541            enable_rotation: false,
542            ..Default::default()
543        };
544
545        let mut writer = BatchSegmentWriter::new(config);
546
547        let flat_data = vec![
548            1.0, 2.0, 3.0, 4.0, // key 10
549            5.0, 6.0, 7.0, 8.0, // key 20
550            9.0, 10.0, 11.0, 12.0, // key 30
551        ];
552        let keys = vec![10, 20, 30];
553
554        let indices = writer.add_batch_contiguous(&flat_data, &keys).unwrap();
555
556        assert_eq!(indices, vec![0, 1, 2]);
557        assert_eq!(writer.len(), 3);
558
559        assert_eq!(writer.get(10).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
560        assert_eq!(writer.get(20).unwrap(), &[5.0, 6.0, 7.0, 8.0]);
561        assert_eq!(writer.get(30).unwrap(), &[9.0, 10.0, 11.0, 12.0]);
562    }
563
564    #[test]
565    fn test_batch_writer_rotation() {
566        let config = BatchConfig {
567            dim: 4, // Power of 2 for Hadamard
568            enable_rotation: true,
569            ..Default::default()
570        };
571
572        let mut writer = BatchSegmentWriter::new(config);
573
574        let _ = writer.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
575
576        let rotated = writer.get(1).unwrap();
577
578        // Hadamard transform changes the vector
579        // Should be normalized: sum of squares = 1
580        let norm_sq: f32 = rotated.iter().map(|x| x * x).sum();
581        assert!((norm_sq - 1.0).abs() < 0.1, "norm_sq = {}", norm_sq);
582    }
583
584    #[test]
585    fn test_duplicate_key_error() {
586        let config = BatchConfig {
587            dim: 4,
588            enable_rotation: false,
589            ..Default::default()
590        };
591
592        let mut writer = BatchSegmentWriter::new(config);
593
594        writer.add(1, &[1.0, 2.0, 3.0, 4.0]).unwrap();
595        let result = writer.add(1, &[5.0, 6.0, 7.0, 8.0]);
596
597        assert!(matches!(result, Err(BatchWriteError::DuplicateKey(1))));
598    }
599
600    #[test]
601    fn test_dimension_mismatch_error() {
602        let config = BatchConfig {
603            dim: 4,
604            enable_rotation: false,
605            ..Default::default()
606        };
607
608        let mut writer = BatchSegmentWriter::new(config);
609
610        let result = writer.add(1, &[1.0, 2.0, 3.0]); // Only 3 dimensions
611
612        assert!(matches!(
613            result,
614            Err(BatchWriteError::DimensionMismatch {
615                expected: 4,
616                actual: 3
617            })
618        ));
619    }
620
621    #[test]
622    fn test_build_segment() {
623        let config = BatchConfig {
624            dim: 4,
625            enable_rotation: false,
626            ..Default::default()
627        };
628
629        let mut writer = BatchSegmentWriter::new(config);
630
631        let flat_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
632        let keys = vec![100, 200];
633
634        writer.add_batch_contiguous(&flat_data, &keys).unwrap();
635
636        let segment = writer.build().unwrap();
637
638        assert_eq!(segment.len(), 2);
639        assert_eq!(segment.get(100).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
640        assert_eq!(segment.get(200).unwrap(), &[5.0, 6.0, 7.0, 8.0]);
641    }
642
643    #[test]
644    fn test_hadamard_transform() {
645        let mut data = vec![1.0, 0.0, 0.0, 0.0];
646        hadamard_transform(&mut data);
647
648        // After normalized Hadamard on [1,0,0,0], all components should be 0.5
649        for &x in &data {
650            assert!((x - 0.5).abs() < 0.01, "x = {}", x);
651        }
652    }
653}