Skip to main content

ruvector_core/advanced_features/
compaction.rs

1//! LSM-Tree Style Streaming Index Compaction
2//!
3//! Implements a Log-Structured Merge-tree (LSM-tree) index for write-heavy
4//! vector workloads. Writes are absorbed by an in-memory [`MemTable`] and
5//! flushed into immutable, sorted [`Segment`]s across tiered levels.
6//! Compaction merges segments to bound read amplification.
7//!
8//! LSM-trees turn random writes into sequential appends, ideal for
9//! high-throughput ingestion, streaming embedding updates, and frequent
10//! deletes (tombstone-based).
11
12use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
13use serde::{Deserialize, Serialize};
14use crate::types::{SearchResult, VectorId};
15
16/// Configuration for the LSM-tree index.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CompactionConfig {
19    /// Max entries in memtable before flush.
20    pub memtable_capacity: usize,
21    /// Size ratio between adjacent levels.
22    pub level_size_ratio: usize,
23    /// Maximum number of levels.
24    pub max_levels: usize,
25    /// Segments per level that triggers compaction.
26    pub merge_threshold: usize,
27    /// False-positive rate for bloom filters.
28    pub bloom_fp_rate: f64,
29}
30
31impl Default for CompactionConfig {
32    fn default() -> Self {
33        Self { memtable_capacity: 1000, level_size_ratio: 10, max_levels: 4,
34               merge_threshold: 4, bloom_fp_rate: 0.01 }
35    }
36}
37
38/// Probabilistic set using double-hashing: `h_i(x) = h1(x) + i * h2(x)`.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct BloomFilter { bits: Vec<bool>, num_hashes: usize }
41
42impl BloomFilter {
43    /// Create a bloom filter for `n` items at `fp_rate`.
44    pub fn new(n: usize, fp_rate: f64) -> Self {
45        let n = n.max(1);
46        let fp = fp_rate.clamp(1e-10, 0.5);
47        let m = (-(n as f64) * fp.ln() / 2.0_f64.ln().powi(2)).ceil() as usize;
48        let m = m.max(8);
49        let k = ((m as f64 / n as f64) * 2.0_f64.ln()).ceil().max(1.0) as usize;
50        Self { bits: vec![false; m], num_hashes: k }
51    }
52
53    /// Insert an element.
54    pub fn insert(&mut self, key: &str) {
55        let (h1, h2) = Self::hashes(key);
56        let m = self.bits.len();
57        for i in 0..self.num_hashes { self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m] = true; }
58    }
59
60    /// Test membership (may return false positives).
61    pub fn may_contain(&self, key: &str) -> bool {
62        let (h1, h2) = Self::hashes(key);
63        let m = self.bits.len();
64        (0..self.num_hashes).all(|i| self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m])
65    }
66
67    fn hashes(key: &str) -> (usize, usize) {
68        let (mut h1, mut h2): (u64, u64) = (0xcbf29ce484222325, 0x517cc1b727220a95);
69        for &b in key.as_bytes() {
70            h1 ^= b as u64; h1 = h1.wrapping_mul(0x100000001b3);
71            h2 = h2.wrapping_mul(31).wrapping_add(b as u64);
72        }
73        (h1 as usize, (h2 | 1) as usize)
74    }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78struct LSMEntry {
79    id: VectorId,
80    vector: Option<Vec<f32>>, // None = tombstone
81    metadata: Option<HashMap<String, serde_json::Value>>,
82    seq: u64, // higher wins on conflict
83}
84
85/// In-memory sorted write buffer backed by `BTreeMap`.
86#[derive(Debug, Clone)]
87pub struct MemTable { entries: BTreeMap<VectorId, LSMEntry>, capacity: usize }
88
89impl MemTable {
90    pub fn new(capacity: usize) -> Self { Self { entries: BTreeMap::new(), capacity } }
91
92    /// Insert/update. Returns `true` when full.
93    pub fn insert(&mut self, id: VectorId, vector: Option<Vec<f32>>,
94                  metadata: Option<HashMap<String, serde_json::Value>>, seq: u64) -> bool {
95        self.entries.insert(id.clone(), LSMEntry { id, vector, metadata, seq });
96        self.is_full()
97    }
98
99    /// Brute-force nearest-neighbour scan (Euclidean).
100    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
101        let mut heap: BinaryHeap<(OrdF32, VectorId)> = BinaryHeap::new();
102        for e in self.entries.values() {
103            let v = match &e.vector { Some(v) => v, None => continue };
104            let d = OrdF32(euclid(query, v));
105            if heap.len() < top_k { heap.push((d, e.id.clone())); }
106            else if d < heap.peek().unwrap().0 { heap.pop(); heap.push((d, e.id.clone())); }
107        }
108        let mut r: Vec<SearchResult> = heap.into_sorted_vec().into_iter().filter_map(|(OrdF32(s), id)| {
109            self.entries.get(&id).map(|e| SearchResult { id: e.id.clone(), score: s,
110                vector: e.vector.clone(), metadata: e.metadata.clone() })
111        }).collect();
112        r.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)); r
113    }
114
115    /// Flush to an immutable segment, clearing the memtable.
116    pub fn flush(&mut self, level: usize, fp_rate: f64) -> Segment {
117        let entries: Vec<LSMEntry> = self.entries.values().cloned().collect();
118        self.entries.clear();
119        Segment::from_entries(entries, level, fp_rate)
120    }
121
122    pub fn len(&self) -> usize { self.entries.len() }
123    pub fn is_empty(&self) -> bool { self.entries.is_empty() }
124    pub fn is_full(&self) -> bool { self.entries.len() >= self.capacity }
125}
126
127/// Immutable sorted run with bloom filter for point lookups.
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct Segment { entries: Vec<LSMEntry>, bloom: BloomFilter, pub level: usize }
130
131impl Segment {
132    fn from_entries(entries: Vec<LSMEntry>, level: usize, fp_rate: f64) -> Self {
133        let mut bloom = BloomFilter::new(entries.len(), fp_rate);
134        for e in &entries { bloom.insert(&e.id); }
135        Self { entries, bloom, level }
136    }
137
138    pub fn size(&self) -> usize { self.entries.len() }
139    pub fn contains(&self, id: &str) -> bool { self.bloom.may_contain(id) }
140
141    /// Brute-force search within this segment.
142    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
143        let mut heap: BinaryHeap<(OrdF32, usize)> = BinaryHeap::new();
144        for (i, e) in self.entries.iter().enumerate() {
145            let v = match &e.vector { Some(v) => v, None => continue };
146            let d = OrdF32(euclid(query, v));
147            if heap.len() < top_k { heap.push((d, i)); }
148            else if d < heap.peek().unwrap().0 { heap.pop(); heap.push((d, i)); }
149        }
150        let mut r: Vec<SearchResult> = heap.into_sorted_vec().into_iter().map(|(OrdF32(s), i)| {
151            let e = &self.entries[i];
152            SearchResult { id: e.id.clone(), score: s, vector: e.vector.clone(), metadata: e.metadata.clone() }
153        }).collect();
154        r.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)); r
155    }
156
157    /// K-way merge deduplicating by id (highest seq wins). Drops tombstones.
158    pub fn merge(segments: &[Segment], target_level: usize, fp_rate: f64) -> Segment {
159        let mut merged: BTreeMap<VectorId, LSMEntry> = BTreeMap::new();
160        for seg in segments {
161            for e in &seg.entries {
162                if merged.get(&e.id).map_or(true, |x| e.seq > x.seq) {
163                    merged.insert(e.id.clone(), e.clone());
164                }
165            }
166        }
167        let entries: Vec<LSMEntry> = merged.into_values().filter(|e| e.vector.is_some()).collect();
168        Segment::from_entries(entries, target_level, fp_rate)
169    }
170}
171
172/// Runtime statistics.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct LSMStats {
175    pub num_levels: usize,
176    pub segments_per_level: Vec<usize>,
177    pub total_entries: usize,
178    pub write_amplification: f64,
179}
180
181/// Write-optimised vector index using LSM-tree tiered compaction.
182///
183/// Writes go to the [`MemTable`]; when full it flushes to level 0. Levels
184/// exceeding `merge_threshold` segments are compacted into the next level.
185#[derive(Debug, Clone)]
186pub struct LSMIndex {
187    config: CompactionConfig,
188    memtable: MemTable,
189    levels: Vec<Vec<Segment>>,
190    next_seq: u64,
191    bytes_written_user: u64,
192    bytes_written_total: u64,
193    deleted_ids: HashSet<VectorId>,
194}
195
196impl LSMIndex {
197    pub fn new(config: CompactionConfig) -> Self {
198        let cap = config.memtable_capacity;
199        let nl = config.max_levels;
200        Self { config, memtable: MemTable::new(cap), levels: vec![Vec::new(); nl],
201               next_seq: 0, bytes_written_user: 0, bytes_written_total: 0,
202               deleted_ids: HashSet::new() }
203    }
204
205    /// Insert a vector. Auto-flushes and compacts as needed.
206    pub fn insert(&mut self, id: VectorId, vector: Vec<f32>,
207                  metadata: Option<HashMap<String, serde_json::Value>>) {
208        let bytes = (vector.len() * 4 + id.len()) as u64;
209        self.bytes_written_user += bytes;
210        self.bytes_written_total += bytes;
211        self.deleted_ids.remove(&id);
212        let seq = self.next_seq; self.next_seq += 1;
213        if self.memtable.insert(id, Some(vector), metadata, seq) {
214            self.flush_memtable(); self.auto_compact();
215        }
216    }
217
218    /// Mark a vector as deleted (tombstone).
219    pub fn delete(&mut self, id: VectorId) {
220        let bytes = id.len() as u64;
221        self.bytes_written_user += bytes;
222        self.bytes_written_total += bytes;
223        self.deleted_ids.insert(id.clone());
224        let seq = self.next_seq; self.next_seq += 1;
225        if self.memtable.insert(id, None, None, seq) {
226            self.flush_memtable(); self.auto_compact();
227        }
228    }
229
230    /// Search across memtable and all levels, merging results.
231    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
232        let mut seen = HashSet::new();
233        let mut all = Vec::new();
234        for r in self.memtable.search(query, top_k) {
235            if !self.deleted_ids.contains(&r.id) { seen.insert(r.id.clone()); all.push(r); }
236        }
237        for level in &self.levels {
238            for seg in level.iter().rev() {
239                for r in seg.search(query, top_k) {
240                    if !seen.contains(&r.id) && !self.deleted_ids.contains(&r.id) {
241                        seen.insert(r.id.clone()); all.push(r);
242                    }
243                }
244            }
245        }
246        all.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal));
247        all.truncate(top_k); all
248    }
249
250    /// Manual compaction across all levels.
251    pub fn compact(&mut self) {
252        if !self.memtable.is_empty() { self.flush_memtable(); }
253        for l in 0..self.config.max_levels.saturating_sub(1) {
254            if self.levels[l].len() >= 2 { self.compact_level(l); }
255        }
256    }
257
258    /// Auto-compact levels exceeding `merge_threshold`.
259    pub fn auto_compact(&mut self) {
260        for l in 0..self.config.max_levels.saturating_sub(1) {
261            if self.levels[l].len() >= self.config.merge_threshold { self.compact_level(l); }
262        }
263    }
264
265    pub fn stats(&self) -> LSMStats {
266        let spl: Vec<usize> = self.levels.iter().map(|l| l.len()).collect();
267        let total = self.memtable.len()
268            + self.levels.iter().flat_map(|l| l.iter()).map(|s| s.size()).sum::<usize>();
269        LSMStats { num_levels: self.levels.len(), segments_per_level: spl,
270                   total_entries: total, write_amplification: self.write_amplification() }
271    }
272
273    pub fn write_amplification(&self) -> f64 {
274        if self.bytes_written_user == 0 { 1.0 }
275        else { self.bytes_written_total as f64 / self.bytes_written_user as f64 }
276    }
277
278    fn flush_memtable(&mut self) {
279        let seg = self.memtable.flush(0, self.config.bloom_fp_rate);
280        self.bytes_written_total += entry_bytes(&seg.entries);
281        self.levels[0].push(seg);
282    }
283
284    fn compact_level(&mut self, level: usize) {
285        let target = level + 1;
286        if target >= self.config.max_levels { return; }
287        let segments = std::mem::take(&mut self.levels[level]);
288        let merged = Segment::merge(&segments, target, self.config.bloom_fp_rate);
289        self.bytes_written_total += entry_bytes(&merged.entries);
290        self.levels[target].push(merged);
291    }
292}
293
294fn entry_bytes(entries: &[LSMEntry]) -> u64 {
295    entries.iter().map(|e| {
296        (e.vector.as_ref().map_or(0, |v| v.len() * 4) + e.id.len()) as u64
297    }).sum()
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
301struct OrdF32(f32);
302impl Eq for OrdF32 {}
303impl PartialOrd for OrdF32 {
304    fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(o)) }
305}
306impl Ord for OrdF32 {
307    fn cmp(&self, o: &Self) -> std::cmp::Ordering {
308        self.0.partial_cmp(&o.0).unwrap_or(std::cmp::Ordering::Equal)
309    }
310}
311
312fn euclid(a: &[f32], b: &[f32]) -> f32 {
313    a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    fn v(dim: usize, val: f32) -> Vec<f32> { vec![val; dim] }
320    fn entry(id: &str, vec: Option<Vec<f32>>, seq: u64) -> LSMEntry {
321        LSMEntry { id: id.into(), vector: vec, metadata: None, seq }
322    }
323
324    #[test]
325    fn memtable_insert_and_len() {
326        let mut mt = MemTable::new(5);
327        assert!(mt.is_empty());
328        mt.insert("a".into(), Some(vec![1.0]), None, 0);
329        mt.insert("b".into(), Some(vec![2.0]), None, 1);
330        assert_eq!(mt.len(), 2);
331        assert!(!mt.is_full());
332    }
333
334    #[test]
335    fn memtable_is_full() {
336        let mut mt = MemTable::new(2);
337        mt.insert("a".into(), Some(vec![1.0]), None, 0);
338        assert!(mt.insert("b".into(), Some(vec![2.0]), None, 1));
339    }
340
341    #[test]
342    fn memtable_search_returns_closest() {
343        let mut mt = MemTable::new(100);
344        mt.insert("far".into(), Some(vec![10.0, 10.0]), None, 0);
345        mt.insert("close".into(), Some(vec![1.0, 0.0]), None, 1);
346        mt.insert("mid".into(), Some(vec![5.0, 5.0]), None, 2);
347        let r = mt.search(&[0.0, 0.0], 2);
348        assert_eq!(r.len(), 2);
349        assert_eq!(r[0].id, "close");
350    }
351
352    #[test]
353    fn memtable_flush_produces_segment() {
354        let mut mt = MemTable::new(10);
355        mt.insert("x".into(), Some(vec![1.0]), None, 0);
356        mt.insert("y".into(), Some(vec![2.0]), None, 1);
357        let seg = mt.flush(0, 0.01);
358        assert_eq!(seg.size(), 2);
359        assert_eq!(seg.level, 0);
360        assert!(mt.is_empty());
361    }
362
363    #[test]
364    fn segment_merge_dedup_keeps_latest() {
365        let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
366        let s2 = Segment::from_entries(vec![entry("a", Some(vec![9.0]), 5)], 0, 0.01);
367        let m = Segment::merge(&[s1, s2], 1, 0.01);
368        assert_eq!(m.size(), 1);
369        assert_eq!(m.entries[0].vector.as_ref().unwrap(), &vec![9.0]);
370    }
371
372    #[test]
373    fn segment_merge_drops_tombstones() {
374        let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
375        let s2 = Segment::from_entries(vec![entry("a", None, 5)], 0, 0.01);
376        assert_eq!(Segment::merge(&[s1, s2], 1, 0.01).size(), 0);
377    }
378
379    #[test]
380    fn bloom_filter_no_false_negatives() {
381        let mut bf = BloomFilter::new(100, 0.01);
382        for i in 0..100 { bf.insert(&format!("key-{i}")); }
383        for i in 0..100 { assert!(bf.may_contain(&format!("key-{i}"))); }
384    }
385
386    #[test]
387    fn bloom_filter_low_false_positive_rate() {
388        let mut bf = BloomFilter::new(1000, 0.01);
389        for i in 0..1000 { bf.insert(&format!("present-{i}")); }
390        let fp: usize = (0..10_000).filter(|i| bf.may_contain(&format!("absent-{i}"))).count();
391        assert!((fp as f64 / 10_000.0) < 0.05, "FP rate too high: {fp}/10000");
392    }
393
394    #[test]
395    fn lsm_insert_and_search() {
396        let mut idx = LSMIndex::new(CompactionConfig { memtable_capacity: 10, ..Default::default() });
397        idx.insert("v1".into(), vec![1.0, 0.0], None);
398        idx.insert("v2".into(), vec![0.0, 1.0], None);
399        let r = idx.search(&[1.0, 0.0], 1);
400        assert_eq!(r.len(), 1);
401        assert_eq!(r[0].id, "v1");
402    }
403
404    #[test]
405    fn lsm_delete_with_tombstone() {
406        let mut idx = LSMIndex::new(CompactionConfig { memtable_capacity: 100, ..Default::default() });
407        idx.insert("v1".into(), vec![1.0, 0.0], None);
408        idx.insert("v2".into(), vec![0.0, 1.0], None);
409        idx.delete("v1".into());
410        let r = idx.search(&[1.0, 0.0], 2);
411        assert_eq!(r.len(), 1);
412        assert_eq!(r[0].id, "v2");
413    }
414
415    #[test]
416    fn lsm_auto_compaction_trigger() {
417        let cfg = CompactionConfig { memtable_capacity: 2, merge_threshold: 2, max_levels: 3, ..Default::default() };
418        let mut idx = LSMIndex::new(cfg);
419        for i in 0..10 { idx.insert(format!("v{i}"), vec![i as f32], None); }
420        assert!(idx.stats().segments_per_level[0] < 4, "L0 should compact");
421    }
422
423    #[test]
424    fn lsm_multi_level_compaction() {
425        let cfg = CompactionConfig { memtable_capacity: 2, merge_threshold: 2, max_levels: 4, ..Default::default() };
426        let mut idx = LSMIndex::new(cfg);
427        for i in 0..30 { idx.insert(format!("v{i}"), v(4, i as f32), None); }
428        let total_seg: usize = idx.stats().segments_per_level.iter().sum();
429        assert!(total_seg >= 1);
430    }
431
432    #[test]
433    fn lsm_write_amplification_increases() {
434        let cfg = CompactionConfig { memtable_capacity: 5, merge_threshold: 2, max_levels: 3, ..Default::default() };
435        let mut idx = LSMIndex::new(cfg);
436        for i in 0..20 { idx.insert(format!("v{i}"), v(4, i as f32), None); }
437        assert!(idx.write_amplification() >= 1.0);
438    }
439
440    #[test]
441    fn lsm_empty_index() {
442        let idx = LSMIndex::new(CompactionConfig::default());
443        assert!(idx.search(&[0.0, 0.0], 10).is_empty());
444        let s = idx.stats();
445        assert_eq!(s.total_entries, 0);
446        assert!((s.write_amplification - 1.0).abs() < f64::EPSILON);
447    }
448
449    #[test]
450    fn lsm_large_batch_insert() {
451        let cfg = CompactionConfig { memtable_capacity: 50, merge_threshold: 4, max_levels: 4, ..Default::default() };
452        let mut idx = LSMIndex::new(cfg);
453        for i in 0..500 { idx.insert(format!("v{i}"), v(8, i as f32 * 0.01), None); }
454        assert!(idx.stats().total_entries > 0);
455        let r = idx.search(&v(8, 0.0), 5);
456        assert_eq!(r.len(), 5);
457        assert_eq!(r[0].id, "v0");
458    }
459
460    #[test]
461    fn lsm_search_across_levels() {
462        let cfg = CompactionConfig { memtable_capacity: 3, merge_threshold: 3, max_levels: 3, ..Default::default() };
463        let mut idx = LSMIndex::new(cfg);
464        for i in 0..9 { idx.insert(format!("v{i}"), vec![i as f32, 0.0], None); }
465        idx.insert("latest".into(), vec![0.0, 0.0], None);
466        let r = idx.search(&[0.0, 0.0], 3);
467        assert_eq!(r.len(), 3);
468        let ids: Vec<&str> = r.iter().map(|r| r.id.as_str()).collect();
469        assert!(ids.contains(&"latest"));
470        assert!(ids.contains(&"v0"));
471    }
472}