Skip to main content

ruvector_core/advanced_features/
compaction.rs

1//! LSM-Tree Style Streaming Index Compaction
2//!
3//! Implements a Log-Structured Merge-tree (LSM-tree) index for write-heavy
4//! vector workloads. Writes are absorbed by an in-memory [`MemTable`] and
5//! flushed into immutable, sorted [`Segment`]s across tiered levels.
6//! Compaction merges segments to bound read amplification.
7//!
8//! LSM-trees turn random writes into sequential appends, ideal for
9//! high-throughput ingestion, streaming embedding updates, and frequent
10//! deletes (tombstone-based).
11
12use crate::types::{SearchResult, VectorId};
13use serde::{Deserialize, Serialize};
14use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
15
16/// Configuration for the LSM-tree index.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CompactionConfig {
19    /// Max entries in memtable before flush.
20    pub memtable_capacity: usize,
21    /// Size ratio between adjacent levels.
22    pub level_size_ratio: usize,
23    /// Maximum number of levels.
24    pub max_levels: usize,
25    /// Segments per level that triggers compaction.
26    pub merge_threshold: usize,
27    /// False-positive rate for bloom filters.
28    pub bloom_fp_rate: f64,
29}
30
31impl Default for CompactionConfig {
32    fn default() -> Self {
33        Self {
34            memtable_capacity: 1000,
35            level_size_ratio: 10,
36            max_levels: 4,
37            merge_threshold: 4,
38            bloom_fp_rate: 0.01,
39        }
40    }
41}
42
43/// Probabilistic set using double-hashing: `h_i(x) = h1(x) + i * h2(x)`.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BloomFilter {
46    bits: Vec<bool>,
47    num_hashes: usize,
48}
49
50impl BloomFilter {
51    /// Create a bloom filter for `n` items at `fp_rate`.
52    pub fn new(n: usize, fp_rate: f64) -> Self {
53        let n = n.max(1);
54        let fp = fp_rate.clamp(1e-10, 0.5);
55        let m = (-(n as f64) * fp.ln() / 2.0_f64.ln().powi(2)).ceil() as usize;
56        let m = m.max(8);
57        let k = ((m as f64 / n as f64) * 2.0_f64.ln()).ceil().max(1.0) as usize;
58        Self {
59            bits: vec![false; m],
60            num_hashes: k,
61        }
62    }
63
64    /// Insert an element.
65    pub fn insert(&mut self, key: &str) {
66        let (h1, h2) = Self::hashes(key);
67        let m = self.bits.len();
68        for i in 0..self.num_hashes {
69            self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m] = true;
70        }
71    }
72
73    /// Test membership (may return false positives).
74    pub fn may_contain(&self, key: &str) -> bool {
75        let (h1, h2) = Self::hashes(key);
76        let m = self.bits.len();
77        (0..self.num_hashes).all(|i| self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m])
78    }
79
80    fn hashes(key: &str) -> (usize, usize) {
81        let (mut h1, mut h2): (u64, u64) = (0xcbf29ce484222325, 0x517cc1b727220a95);
82        for &b in key.as_bytes() {
83            h1 ^= b as u64;
84            h1 = h1.wrapping_mul(0x100000001b3);
85            h2 = h2.wrapping_mul(31).wrapping_add(b as u64);
86        }
87        (h1 as usize, (h2 | 1) as usize)
88    }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct LSMEntry {
93    id: VectorId,
94    vector: Option<Vec<f32>>, // None = tombstone
95    metadata: Option<HashMap<String, serde_json::Value>>,
96    seq: u64, // higher wins on conflict
97}
98
99/// In-memory sorted write buffer backed by `BTreeMap`.
100#[derive(Debug, Clone)]
101pub struct MemTable {
102    entries: BTreeMap<VectorId, LSMEntry>,
103    capacity: usize,
104}
105
106impl MemTable {
107    pub fn new(capacity: usize) -> Self {
108        Self {
109            entries: BTreeMap::new(),
110            capacity,
111        }
112    }
113
114    /// Insert/update. Returns `true` when full.
115    pub fn insert(
116        &mut self,
117        id: VectorId,
118        vector: Option<Vec<f32>>,
119        metadata: Option<HashMap<String, serde_json::Value>>,
120        seq: u64,
121    ) -> bool {
122        self.entries.insert(
123            id.clone(),
124            LSMEntry {
125                id,
126                vector,
127                metadata,
128                seq,
129            },
130        );
131        self.is_full()
132    }
133
134    /// Brute-force nearest-neighbour scan (Euclidean).
135    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
136        let mut heap: BinaryHeap<(OrdF32, VectorId)> = BinaryHeap::new();
137        for e in self.entries.values() {
138            let v = match &e.vector {
139                Some(v) => v,
140                None => continue,
141            };
142            let d = OrdF32(euclid(query, v));
143            if heap.len() < top_k {
144                heap.push((d, e.id.clone()));
145            } else if d < heap.peek().unwrap().0 {
146                heap.pop();
147                heap.push((d, e.id.clone()));
148            }
149        }
150        let mut r: Vec<SearchResult> = heap
151            .into_sorted_vec()
152            .into_iter()
153            .filter_map(|(OrdF32(s), id)| {
154                self.entries.get(&id).map(|e| SearchResult {
155                    id: e.id.clone(),
156                    score: s,
157                    vector: e.vector.clone(),
158                    metadata: e.metadata.clone(),
159                })
160            })
161            .collect();
162        r.sort_by(|a, b| {
163            a.score
164                .partial_cmp(&b.score)
165                .unwrap_or(std::cmp::Ordering::Equal)
166        });
167        r
168    }
169
170    /// Flush to an immutable segment, clearing the memtable.
171    pub fn flush(&mut self, level: usize, fp_rate: f64) -> Segment {
172        let entries: Vec<LSMEntry> = self.entries.values().cloned().collect();
173        self.entries.clear();
174        Segment::from_entries(entries, level, fp_rate)
175    }
176
177    pub fn len(&self) -> usize {
178        self.entries.len()
179    }
180    pub fn is_empty(&self) -> bool {
181        self.entries.is_empty()
182    }
183    pub fn is_full(&self) -> bool {
184        self.entries.len() >= self.capacity
185    }
186}
187
188/// Immutable sorted run with bloom filter for point lookups.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct Segment {
191    entries: Vec<LSMEntry>,
192    bloom: BloomFilter,
193    pub level: usize,
194}
195
196impl Segment {
197    fn from_entries(entries: Vec<LSMEntry>, level: usize, fp_rate: f64) -> Self {
198        let mut bloom = BloomFilter::new(entries.len(), fp_rate);
199        for e in &entries {
200            bloom.insert(&e.id);
201        }
202        Self {
203            entries,
204            bloom,
205            level,
206        }
207    }
208
209    pub fn size(&self) -> usize {
210        self.entries.len()
211    }
212    pub fn contains(&self, id: &str) -> bool {
213        self.bloom.may_contain(id)
214    }
215
216    /// Brute-force search within this segment.
217    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
218        let mut heap: BinaryHeap<(OrdF32, usize)> = BinaryHeap::new();
219        for (i, e) in self.entries.iter().enumerate() {
220            let v = match &e.vector {
221                Some(v) => v,
222                None => continue,
223            };
224            let d = OrdF32(euclid(query, v));
225            if heap.len() < top_k {
226                heap.push((d, i));
227            } else if d < heap.peek().unwrap().0 {
228                heap.pop();
229                heap.push((d, i));
230            }
231        }
232        let mut r: Vec<SearchResult> = heap
233            .into_sorted_vec()
234            .into_iter()
235            .map(|(OrdF32(s), i)| {
236                let e = &self.entries[i];
237                SearchResult {
238                    id: e.id.clone(),
239                    score: s,
240                    vector: e.vector.clone(),
241                    metadata: e.metadata.clone(),
242                }
243            })
244            .collect();
245        r.sort_by(|a, b| {
246            a.score
247                .partial_cmp(&b.score)
248                .unwrap_or(std::cmp::Ordering::Equal)
249        });
250        r
251    }
252
253    /// K-way merge deduplicating by id (highest seq wins). Drops tombstones.
254    pub fn merge(segments: &[Segment], target_level: usize, fp_rate: f64) -> Segment {
255        let mut merged: BTreeMap<VectorId, LSMEntry> = BTreeMap::new();
256        for seg in segments {
257            for e in &seg.entries {
258                if merged.get(&e.id).map_or(true, |x| e.seq > x.seq) {
259                    merged.insert(e.id.clone(), e.clone());
260                }
261            }
262        }
263        let entries: Vec<LSMEntry> = merged
264            .into_values()
265            .filter(|e| e.vector.is_some())
266            .collect();
267        Segment::from_entries(entries, target_level, fp_rate)
268    }
269}
270
271/// Runtime statistics.
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct LSMStats {
274    pub num_levels: usize,
275    pub segments_per_level: Vec<usize>,
276    pub total_entries: usize,
277    pub write_amplification: f64,
278}
279
280/// Write-optimised vector index using LSM-tree tiered compaction.
281///
282/// Writes go to the [`MemTable`]; when full it flushes to level 0. Levels
283/// exceeding `merge_threshold` segments are compacted into the next level.
284#[derive(Debug, Clone)]
285pub struct LSMIndex {
286    config: CompactionConfig,
287    memtable: MemTable,
288    levels: Vec<Vec<Segment>>,
289    next_seq: u64,
290    bytes_written_user: u64,
291    bytes_written_total: u64,
292    deleted_ids: HashSet<VectorId>,
293}
294
295impl LSMIndex {
296    pub fn new(config: CompactionConfig) -> Self {
297        let cap = config.memtable_capacity;
298        let nl = config.max_levels;
299        Self {
300            config,
301            memtable: MemTable::new(cap),
302            levels: vec![Vec::new(); nl],
303            next_seq: 0,
304            bytes_written_user: 0,
305            bytes_written_total: 0,
306            deleted_ids: HashSet::new(),
307        }
308    }
309
310    /// Insert a vector. Auto-flushes and compacts as needed.
311    pub fn insert(
312        &mut self,
313        id: VectorId,
314        vector: Vec<f32>,
315        metadata: Option<HashMap<String, serde_json::Value>>,
316    ) {
317        let bytes = (vector.len() * 4 + id.len()) as u64;
318        self.bytes_written_user += bytes;
319        self.bytes_written_total += bytes;
320        self.deleted_ids.remove(&id);
321        let seq = self.next_seq;
322        self.next_seq += 1;
323        if self.memtable.insert(id, Some(vector), metadata, seq) {
324            self.flush_memtable();
325            self.auto_compact();
326        }
327    }
328
329    /// Mark a vector as deleted (tombstone).
330    pub fn delete(&mut self, id: VectorId) {
331        let bytes = id.len() as u64;
332        self.bytes_written_user += bytes;
333        self.bytes_written_total += bytes;
334        self.deleted_ids.insert(id.clone());
335        let seq = self.next_seq;
336        self.next_seq += 1;
337        if self.memtable.insert(id, None, None, seq) {
338            self.flush_memtable();
339            self.auto_compact();
340        }
341    }
342
343    /// Search across memtable and all levels, merging results.
344    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
345        let mut seen = HashSet::new();
346        let mut all = Vec::new();
347        for r in self.memtable.search(query, top_k) {
348            if !self.deleted_ids.contains(&r.id) {
349                seen.insert(r.id.clone());
350                all.push(r);
351            }
352        }
353        for level in &self.levels {
354            for seg in level.iter().rev() {
355                for r in seg.search(query, top_k) {
356                    if !seen.contains(&r.id) && !self.deleted_ids.contains(&r.id) {
357                        seen.insert(r.id.clone());
358                        all.push(r);
359                    }
360                }
361            }
362        }
363        all.sort_by(|a, b| {
364            a.score
365                .partial_cmp(&b.score)
366                .unwrap_or(std::cmp::Ordering::Equal)
367        });
368        all.truncate(top_k);
369        all
370    }
371
372    /// Manual compaction across all levels.
373    pub fn compact(&mut self) {
374        if !self.memtable.is_empty() {
375            self.flush_memtable();
376        }
377        for l in 0..self.config.max_levels.saturating_sub(1) {
378            if self.levels[l].len() >= 2 {
379                self.compact_level(l);
380            }
381        }
382    }
383
384    /// Auto-compact levels exceeding `merge_threshold`.
385    pub fn auto_compact(&mut self) {
386        for l in 0..self.config.max_levels.saturating_sub(1) {
387            if self.levels[l].len() >= self.config.merge_threshold {
388                self.compact_level(l);
389            }
390        }
391    }
392
393    pub fn stats(&self) -> LSMStats {
394        let spl: Vec<usize> = self.levels.iter().map(|l| l.len()).collect();
395        let total = self.memtable.len()
396            + self
397                .levels
398                .iter()
399                .flat_map(|l| l.iter())
400                .map(|s| s.size())
401                .sum::<usize>();
402        LSMStats {
403            num_levels: self.levels.len(),
404            segments_per_level: spl,
405            total_entries: total,
406            write_amplification: self.write_amplification(),
407        }
408    }
409
410    pub fn write_amplification(&self) -> f64 {
411        if self.bytes_written_user == 0 {
412            1.0
413        } else {
414            self.bytes_written_total as f64 / self.bytes_written_user as f64
415        }
416    }
417
418    fn flush_memtable(&mut self) {
419        let seg = self.memtable.flush(0, self.config.bloom_fp_rate);
420        self.bytes_written_total += entry_bytes(&seg.entries);
421        self.levels[0].push(seg);
422    }
423
424    fn compact_level(&mut self, level: usize) {
425        let target = level + 1;
426        if target >= self.config.max_levels {
427            return;
428        }
429        let segments = std::mem::take(&mut self.levels[level]);
430        let merged = Segment::merge(&segments, target, self.config.bloom_fp_rate);
431        self.bytes_written_total += entry_bytes(&merged.entries);
432        self.levels[target].push(merged);
433    }
434}
435
436fn entry_bytes(entries: &[LSMEntry]) -> u64 {
437    entries
438        .iter()
439        .map(|e| (e.vector.as_ref().map_or(0, |v| v.len() * 4) + e.id.len()) as u64)
440        .sum()
441}
442
443#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
444struct OrdF32(f32);
445impl Eq for OrdF32 {}
446impl PartialOrd for OrdF32 {
447    fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
448        Some(self.cmp(o))
449    }
450}
451impl Ord for OrdF32 {
452    fn cmp(&self, o: &Self) -> std::cmp::Ordering {
453        self.0
454            .partial_cmp(&o.0)
455            .unwrap_or(std::cmp::Ordering::Equal)
456    }
457}
458
459fn euclid(a: &[f32], b: &[f32]) -> f32 {
460    a.iter()
461        .zip(b)
462        .map(|(x, y)| (x - y).powi(2))
463        .sum::<f32>()
464        .sqrt()
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    fn v(dim: usize, val: f32) -> Vec<f32> {
471        vec![val; dim]
472    }
473    fn entry(id: &str, vec: Option<Vec<f32>>, seq: u64) -> LSMEntry {
474        LSMEntry {
475            id: id.into(),
476            vector: vec,
477            metadata: None,
478            seq,
479        }
480    }
481
482    #[test]
483    fn memtable_insert_and_len() {
484        let mut mt = MemTable::new(5);
485        assert!(mt.is_empty());
486        mt.insert("a".into(), Some(vec![1.0]), None, 0);
487        mt.insert("b".into(), Some(vec![2.0]), None, 1);
488        assert_eq!(mt.len(), 2);
489        assert!(!mt.is_full());
490    }
491
492    #[test]
493    fn memtable_is_full() {
494        let mut mt = MemTable::new(2);
495        mt.insert("a".into(), Some(vec![1.0]), None, 0);
496        assert!(mt.insert("b".into(), Some(vec![2.0]), None, 1));
497    }
498
499    #[test]
500    fn memtable_search_returns_closest() {
501        let mut mt = MemTable::new(100);
502        mt.insert("far".into(), Some(vec![10.0, 10.0]), None, 0);
503        mt.insert("close".into(), Some(vec![1.0, 0.0]), None, 1);
504        mt.insert("mid".into(), Some(vec![5.0, 5.0]), None, 2);
505        let r = mt.search(&[0.0, 0.0], 2);
506        assert_eq!(r.len(), 2);
507        assert_eq!(r[0].id, "close");
508    }
509
510    #[test]
511    fn memtable_flush_produces_segment() {
512        let mut mt = MemTable::new(10);
513        mt.insert("x".into(), Some(vec![1.0]), None, 0);
514        mt.insert("y".into(), Some(vec![2.0]), None, 1);
515        let seg = mt.flush(0, 0.01);
516        assert_eq!(seg.size(), 2);
517        assert_eq!(seg.level, 0);
518        assert!(mt.is_empty());
519    }
520
521    #[test]
522    fn segment_merge_dedup_keeps_latest() {
523        let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
524        let s2 = Segment::from_entries(vec![entry("a", Some(vec![9.0]), 5)], 0, 0.01);
525        let m = Segment::merge(&[s1, s2], 1, 0.01);
526        assert_eq!(m.size(), 1);
527        assert_eq!(m.entries[0].vector.as_ref().unwrap(), &vec![9.0]);
528    }
529
530    #[test]
531    fn segment_merge_drops_tombstones() {
532        let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
533        let s2 = Segment::from_entries(vec![entry("a", None, 5)], 0, 0.01);
534        assert_eq!(Segment::merge(&[s1, s2], 1, 0.01).size(), 0);
535    }
536
537    #[test]
538    fn bloom_filter_no_false_negatives() {
539        let mut bf = BloomFilter::new(100, 0.01);
540        for i in 0..100 {
541            bf.insert(&format!("key-{i}"));
542        }
543        for i in 0..100 {
544            assert!(bf.may_contain(&format!("key-{i}")));
545        }
546    }
547
548    #[test]
549    fn bloom_filter_low_false_positive_rate() {
550        let mut bf = BloomFilter::new(1000, 0.01);
551        for i in 0..1000 {
552            bf.insert(&format!("present-{i}"));
553        }
554        let fp: usize = (0..10_000)
555            .filter(|i| bf.may_contain(&format!("absent-{i}")))
556            .count();
557        assert!(
558            (fp as f64 / 10_000.0) < 0.05,
559            "FP rate too high: {fp}/10000"
560        );
561    }
562
563    #[test]
564    fn lsm_insert_and_search() {
565        let mut idx = LSMIndex::new(CompactionConfig {
566            memtable_capacity: 10,
567            ..Default::default()
568        });
569        idx.insert("v1".into(), vec![1.0, 0.0], None);
570        idx.insert("v2".into(), vec![0.0, 1.0], None);
571        let r = idx.search(&[1.0, 0.0], 1);
572        assert_eq!(r.len(), 1);
573        assert_eq!(r[0].id, "v1");
574    }
575
576    #[test]
577    fn lsm_delete_with_tombstone() {
578        let mut idx = LSMIndex::new(CompactionConfig {
579            memtable_capacity: 100,
580            ..Default::default()
581        });
582        idx.insert("v1".into(), vec![1.0, 0.0], None);
583        idx.insert("v2".into(), vec![0.0, 1.0], None);
584        idx.delete("v1".into());
585        let r = idx.search(&[1.0, 0.0], 2);
586        assert_eq!(r.len(), 1);
587        assert_eq!(r[0].id, "v2");
588    }
589
590    #[test]
591    fn lsm_auto_compaction_trigger() {
592        let cfg = CompactionConfig {
593            memtable_capacity: 2,
594            merge_threshold: 2,
595            max_levels: 3,
596            ..Default::default()
597        };
598        let mut idx = LSMIndex::new(cfg);
599        for i in 0..10 {
600            idx.insert(format!("v{i}"), vec![i as f32], None);
601        }
602        assert!(idx.stats().segments_per_level[0] < 4, "L0 should compact");
603    }
604
605    #[test]
606    fn lsm_multi_level_compaction() {
607        let cfg = CompactionConfig {
608            memtable_capacity: 2,
609            merge_threshold: 2,
610            max_levels: 4,
611            ..Default::default()
612        };
613        let mut idx = LSMIndex::new(cfg);
614        for i in 0..30 {
615            idx.insert(format!("v{i}"), v(4, i as f32), None);
616        }
617        let total_seg: usize = idx.stats().segments_per_level.iter().sum();
618        assert!(total_seg >= 1);
619    }
620
621    #[test]
622    fn lsm_write_amplification_increases() {
623        let cfg = CompactionConfig {
624            memtable_capacity: 5,
625            merge_threshold: 2,
626            max_levels: 3,
627            ..Default::default()
628        };
629        let mut idx = LSMIndex::new(cfg);
630        for i in 0..20 {
631            idx.insert(format!("v{i}"), v(4, i as f32), None);
632        }
633        assert!(idx.write_amplification() >= 1.0);
634    }
635
636    #[test]
637    fn lsm_empty_index() {
638        let idx = LSMIndex::new(CompactionConfig::default());
639        assert!(idx.search(&[0.0, 0.0], 10).is_empty());
640        let s = idx.stats();
641        assert_eq!(s.total_entries, 0);
642        assert!((s.write_amplification - 1.0).abs() < f64::EPSILON);
643    }
644
645    #[test]
646    fn lsm_large_batch_insert() {
647        let cfg = CompactionConfig {
648            memtable_capacity: 50,
649            merge_threshold: 4,
650            max_levels: 4,
651            ..Default::default()
652        };
653        let mut idx = LSMIndex::new(cfg);
654        for i in 0..500 {
655            idx.insert(format!("v{i}"), v(8, i as f32 * 0.01), None);
656        }
657        assert!(idx.stats().total_entries > 0);
658        let r = idx.search(&v(8, 0.0), 5);
659        assert_eq!(r.len(), 5);
660        assert_eq!(r[0].id, "v0");
661    }
662
663    #[test]
664    fn lsm_search_across_levels() {
665        let cfg = CompactionConfig {
666            memtable_capacity: 3,
667            merge_threshold: 3,
668            max_levels: 3,
669            ..Default::default()
670        };
671        let mut idx = LSMIndex::new(cfg);
672        for i in 0..9 {
673            idx.insert(format!("v{i}"), vec![i as f32, 0.0], None);
674        }
675        idx.insert("latest".into(), vec![0.0, 0.0], None);
676        let r = idx.search(&[0.0, 0.0], 3);
677        assert_eq!(r.len(), 3);
678        let ids: Vec<&str> = r.iter().map(|r| r.id.as_str()).collect();
679        assert!(ids.contains(&"latest"));
680        assert!(ids.contains(&"v0"));
681    }
682}