1use crate::types::{SearchResult, VectorId};
13use serde::{Deserialize, Serialize};
14use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CompactionConfig {
19 pub memtable_capacity: usize,
21 pub level_size_ratio: usize,
23 pub max_levels: usize,
25 pub merge_threshold: usize,
27 pub bloom_fp_rate: f64,
29}
30
31impl Default for CompactionConfig {
32 fn default() -> Self {
33 Self {
34 memtable_capacity: 1000,
35 level_size_ratio: 10,
36 max_levels: 4,
37 merge_threshold: 4,
38 bloom_fp_rate: 0.01,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct BloomFilter {
46 bits: Vec<bool>,
47 num_hashes: usize,
48}
49
50impl BloomFilter {
51 pub fn new(n: usize, fp_rate: f64) -> Self {
53 let n = n.max(1);
54 let fp = fp_rate.clamp(1e-10, 0.5);
55 let m = (-(n as f64) * fp.ln() / 2.0_f64.ln().powi(2)).ceil() as usize;
56 let m = m.max(8);
57 let k = ((m as f64 / n as f64) * 2.0_f64.ln()).ceil().max(1.0) as usize;
58 Self {
59 bits: vec![false; m],
60 num_hashes: k,
61 }
62 }
63
64 pub fn insert(&mut self, key: &str) {
66 let (h1, h2) = Self::hashes(key);
67 let m = self.bits.len();
68 for i in 0..self.num_hashes {
69 self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m] = true;
70 }
71 }
72
73 pub fn may_contain(&self, key: &str) -> bool {
75 let (h1, h2) = Self::hashes(key);
76 let m = self.bits.len();
77 (0..self.num_hashes).all(|i| self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m])
78 }
79
80 fn hashes(key: &str) -> (usize, usize) {
81 let (mut h1, mut h2): (u64, u64) = (0xcbf29ce484222325, 0x517cc1b727220a95);
82 for &b in key.as_bytes() {
83 h1 ^= b as u64;
84 h1 = h1.wrapping_mul(0x100000001b3);
85 h2 = h2.wrapping_mul(31).wrapping_add(b as u64);
86 }
87 (h1 as usize, (h2 | 1) as usize)
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92struct LSMEntry {
93 id: VectorId,
94 vector: Option<Vec<f32>>, metadata: Option<HashMap<String, serde_json::Value>>,
96 seq: u64, }
98
99#[derive(Debug, Clone)]
101pub struct MemTable {
102 entries: BTreeMap<VectorId, LSMEntry>,
103 capacity: usize,
104}
105
106impl MemTable {
107 pub fn new(capacity: usize) -> Self {
108 Self {
109 entries: BTreeMap::new(),
110 capacity,
111 }
112 }
113
114 pub fn insert(
116 &mut self,
117 id: VectorId,
118 vector: Option<Vec<f32>>,
119 metadata: Option<HashMap<String, serde_json::Value>>,
120 seq: u64,
121 ) -> bool {
122 self.entries.insert(
123 id.clone(),
124 LSMEntry {
125 id,
126 vector,
127 metadata,
128 seq,
129 },
130 );
131 self.is_full()
132 }
133
134 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
136 let mut heap: BinaryHeap<(OrdF32, VectorId)> = BinaryHeap::new();
137 for e in self.entries.values() {
138 let v = match &e.vector {
139 Some(v) => v,
140 None => continue,
141 };
142 let d = OrdF32(euclid(query, v));
143 if heap.len() < top_k {
144 heap.push((d, e.id.clone()));
145 } else if d < heap.peek().unwrap().0 {
146 heap.pop();
147 heap.push((d, e.id.clone()));
148 }
149 }
150 let mut r: Vec<SearchResult> = heap
151 .into_sorted_vec()
152 .into_iter()
153 .filter_map(|(OrdF32(s), id)| {
154 self.entries.get(&id).map(|e| SearchResult {
155 id: e.id.clone(),
156 score: s,
157 vector: e.vector.clone(),
158 metadata: e.metadata.clone(),
159 })
160 })
161 .collect();
162 r.sort_by(|a, b| {
163 a.score
164 .partial_cmp(&b.score)
165 .unwrap_or(std::cmp::Ordering::Equal)
166 });
167 r
168 }
169
170 pub fn flush(&mut self, level: usize, fp_rate: f64) -> Segment {
172 let entries: Vec<LSMEntry> = self.entries.values().cloned().collect();
173 self.entries.clear();
174 Segment::from_entries(entries, level, fp_rate)
175 }
176
177 pub fn len(&self) -> usize {
178 self.entries.len()
179 }
180 pub fn is_empty(&self) -> bool {
181 self.entries.is_empty()
182 }
183 pub fn is_full(&self) -> bool {
184 self.entries.len() >= self.capacity
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct Segment {
191 entries: Vec<LSMEntry>,
192 bloom: BloomFilter,
193 pub level: usize,
194}
195
196impl Segment {
197 fn from_entries(entries: Vec<LSMEntry>, level: usize, fp_rate: f64) -> Self {
198 let mut bloom = BloomFilter::new(entries.len(), fp_rate);
199 for e in &entries {
200 bloom.insert(&e.id);
201 }
202 Self {
203 entries,
204 bloom,
205 level,
206 }
207 }
208
209 pub fn size(&self) -> usize {
210 self.entries.len()
211 }
212 pub fn contains(&self, id: &str) -> bool {
213 self.bloom.may_contain(id)
214 }
215
216 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
218 let mut heap: BinaryHeap<(OrdF32, usize)> = BinaryHeap::new();
219 for (i, e) in self.entries.iter().enumerate() {
220 let v = match &e.vector {
221 Some(v) => v,
222 None => continue,
223 };
224 let d = OrdF32(euclid(query, v));
225 if heap.len() < top_k {
226 heap.push((d, i));
227 } else if d < heap.peek().unwrap().0 {
228 heap.pop();
229 heap.push((d, i));
230 }
231 }
232 let mut r: Vec<SearchResult> = heap
233 .into_sorted_vec()
234 .into_iter()
235 .map(|(OrdF32(s), i)| {
236 let e = &self.entries[i];
237 SearchResult {
238 id: e.id.clone(),
239 score: s,
240 vector: e.vector.clone(),
241 metadata: e.metadata.clone(),
242 }
243 })
244 .collect();
245 r.sort_by(|a, b| {
246 a.score
247 .partial_cmp(&b.score)
248 .unwrap_or(std::cmp::Ordering::Equal)
249 });
250 r
251 }
252
253 pub fn merge(segments: &[Segment], target_level: usize, fp_rate: f64) -> Segment {
255 let mut merged: BTreeMap<VectorId, LSMEntry> = BTreeMap::new();
256 for seg in segments {
257 for e in &seg.entries {
258 if merged.get(&e.id).map_or(true, |x| e.seq > x.seq) {
259 merged.insert(e.id.clone(), e.clone());
260 }
261 }
262 }
263 let entries: Vec<LSMEntry> = merged
264 .into_values()
265 .filter(|e| e.vector.is_some())
266 .collect();
267 Segment::from_entries(entries, target_level, fp_rate)
268 }
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct LSMStats {
274 pub num_levels: usize,
275 pub segments_per_level: Vec<usize>,
276 pub total_entries: usize,
277 pub write_amplification: f64,
278}
279
280#[derive(Debug, Clone)]
285pub struct LSMIndex {
286 config: CompactionConfig,
287 memtable: MemTable,
288 levels: Vec<Vec<Segment>>,
289 next_seq: u64,
290 bytes_written_user: u64,
291 bytes_written_total: u64,
292 deleted_ids: HashSet<VectorId>,
293}
294
295impl LSMIndex {
296 pub fn new(config: CompactionConfig) -> Self {
297 let cap = config.memtable_capacity;
298 let nl = config.max_levels;
299 Self {
300 config,
301 memtable: MemTable::new(cap),
302 levels: vec![Vec::new(); nl],
303 next_seq: 0,
304 bytes_written_user: 0,
305 bytes_written_total: 0,
306 deleted_ids: HashSet::new(),
307 }
308 }
309
310 pub fn insert(
312 &mut self,
313 id: VectorId,
314 vector: Vec<f32>,
315 metadata: Option<HashMap<String, serde_json::Value>>,
316 ) {
317 let bytes = (vector.len() * 4 + id.len()) as u64;
318 self.bytes_written_user += bytes;
319 self.bytes_written_total += bytes;
320 self.deleted_ids.remove(&id);
321 let seq = self.next_seq;
322 self.next_seq += 1;
323 if self.memtable.insert(id, Some(vector), metadata, seq) {
324 self.flush_memtable();
325 self.auto_compact();
326 }
327 }
328
329 pub fn delete(&mut self, id: VectorId) {
331 let bytes = id.len() as u64;
332 self.bytes_written_user += bytes;
333 self.bytes_written_total += bytes;
334 self.deleted_ids.insert(id.clone());
335 let seq = self.next_seq;
336 self.next_seq += 1;
337 if self.memtable.insert(id, None, None, seq) {
338 self.flush_memtable();
339 self.auto_compact();
340 }
341 }
342
343 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
345 let mut seen = HashSet::new();
346 let mut all = Vec::new();
347 for r in self.memtable.search(query, top_k) {
348 if !self.deleted_ids.contains(&r.id) {
349 seen.insert(r.id.clone());
350 all.push(r);
351 }
352 }
353 for level in &self.levels {
354 for seg in level.iter().rev() {
355 for r in seg.search(query, top_k) {
356 if !seen.contains(&r.id) && !self.deleted_ids.contains(&r.id) {
357 seen.insert(r.id.clone());
358 all.push(r);
359 }
360 }
361 }
362 }
363 all.sort_by(|a, b| {
364 a.score
365 .partial_cmp(&b.score)
366 .unwrap_or(std::cmp::Ordering::Equal)
367 });
368 all.truncate(top_k);
369 all
370 }
371
372 pub fn compact(&mut self) {
374 if !self.memtable.is_empty() {
375 self.flush_memtable();
376 }
377 for l in 0..self.config.max_levels.saturating_sub(1) {
378 if self.levels[l].len() >= 2 {
379 self.compact_level(l);
380 }
381 }
382 }
383
384 pub fn auto_compact(&mut self) {
386 for l in 0..self.config.max_levels.saturating_sub(1) {
387 if self.levels[l].len() >= self.config.merge_threshold {
388 self.compact_level(l);
389 }
390 }
391 }
392
393 pub fn stats(&self) -> LSMStats {
394 let spl: Vec<usize> = self.levels.iter().map(|l| l.len()).collect();
395 let total = self.memtable.len()
396 + self
397 .levels
398 .iter()
399 .flat_map(|l| l.iter())
400 .map(|s| s.size())
401 .sum::<usize>();
402 LSMStats {
403 num_levels: self.levels.len(),
404 segments_per_level: spl,
405 total_entries: total,
406 write_amplification: self.write_amplification(),
407 }
408 }
409
410 pub fn write_amplification(&self) -> f64 {
411 if self.bytes_written_user == 0 {
412 1.0
413 } else {
414 self.bytes_written_total as f64 / self.bytes_written_user as f64
415 }
416 }
417
418 fn flush_memtable(&mut self) {
419 let seg = self.memtable.flush(0, self.config.bloom_fp_rate);
420 self.bytes_written_total += entry_bytes(&seg.entries);
421 self.levels[0].push(seg);
422 }
423
424 fn compact_level(&mut self, level: usize) {
425 let target = level + 1;
426 if target >= self.config.max_levels {
427 return;
428 }
429 let segments = std::mem::take(&mut self.levels[level]);
430 let merged = Segment::merge(&segments, target, self.config.bloom_fp_rate);
431 self.bytes_written_total += entry_bytes(&merged.entries);
432 self.levels[target].push(merged);
433 }
434}
435
436fn entry_bytes(entries: &[LSMEntry]) -> u64 {
437 entries
438 .iter()
439 .map(|e| (e.vector.as_ref().map_or(0, |v| v.len() * 4) + e.id.len()) as u64)
440 .sum()
441}
442
443#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
444struct OrdF32(f32);
445impl Eq for OrdF32 {}
446impl PartialOrd for OrdF32 {
447 fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
448 Some(self.cmp(o))
449 }
450}
451impl Ord for OrdF32 {
452 fn cmp(&self, o: &Self) -> std::cmp::Ordering {
453 self.0
454 .partial_cmp(&o.0)
455 .unwrap_or(std::cmp::Ordering::Equal)
456 }
457}
458
459fn euclid(a: &[f32], b: &[f32]) -> f32 {
460 a.iter()
461 .zip(b)
462 .map(|(x, y)| (x - y).powi(2))
463 .sum::<f32>()
464 .sqrt()
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 fn v(dim: usize, val: f32) -> Vec<f32> {
471 vec![val; dim]
472 }
473 fn entry(id: &str, vec: Option<Vec<f32>>, seq: u64) -> LSMEntry {
474 LSMEntry {
475 id: id.into(),
476 vector: vec,
477 metadata: None,
478 seq,
479 }
480 }
481
482 #[test]
483 fn memtable_insert_and_len() {
484 let mut mt = MemTable::new(5);
485 assert!(mt.is_empty());
486 mt.insert("a".into(), Some(vec![1.0]), None, 0);
487 mt.insert("b".into(), Some(vec![2.0]), None, 1);
488 assert_eq!(mt.len(), 2);
489 assert!(!mt.is_full());
490 }
491
492 #[test]
493 fn memtable_is_full() {
494 let mut mt = MemTable::new(2);
495 mt.insert("a".into(), Some(vec![1.0]), None, 0);
496 assert!(mt.insert("b".into(), Some(vec![2.0]), None, 1));
497 }
498
499 #[test]
500 fn memtable_search_returns_closest() {
501 let mut mt = MemTable::new(100);
502 mt.insert("far".into(), Some(vec![10.0, 10.0]), None, 0);
503 mt.insert("close".into(), Some(vec![1.0, 0.0]), None, 1);
504 mt.insert("mid".into(), Some(vec![5.0, 5.0]), None, 2);
505 let r = mt.search(&[0.0, 0.0], 2);
506 assert_eq!(r.len(), 2);
507 assert_eq!(r[0].id, "close");
508 }
509
510 #[test]
511 fn memtable_flush_produces_segment() {
512 let mut mt = MemTable::new(10);
513 mt.insert("x".into(), Some(vec![1.0]), None, 0);
514 mt.insert("y".into(), Some(vec![2.0]), None, 1);
515 let seg = mt.flush(0, 0.01);
516 assert_eq!(seg.size(), 2);
517 assert_eq!(seg.level, 0);
518 assert!(mt.is_empty());
519 }
520
521 #[test]
522 fn segment_merge_dedup_keeps_latest() {
523 let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
524 let s2 = Segment::from_entries(vec![entry("a", Some(vec![9.0]), 5)], 0, 0.01);
525 let m = Segment::merge(&[s1, s2], 1, 0.01);
526 assert_eq!(m.size(), 1);
527 assert_eq!(m.entries[0].vector.as_ref().unwrap(), &vec![9.0]);
528 }
529
530 #[test]
531 fn segment_merge_drops_tombstones() {
532 let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
533 let s2 = Segment::from_entries(vec![entry("a", None, 5)], 0, 0.01);
534 assert_eq!(Segment::merge(&[s1, s2], 1, 0.01).size(), 0);
535 }
536
537 #[test]
538 fn bloom_filter_no_false_negatives() {
539 let mut bf = BloomFilter::new(100, 0.01);
540 for i in 0..100 {
541 bf.insert(&format!("key-{i}"));
542 }
543 for i in 0..100 {
544 assert!(bf.may_contain(&format!("key-{i}")));
545 }
546 }
547
548 #[test]
549 fn bloom_filter_low_false_positive_rate() {
550 let mut bf = BloomFilter::new(1000, 0.01);
551 for i in 0..1000 {
552 bf.insert(&format!("present-{i}"));
553 }
554 let fp: usize = (0..10_000)
555 .filter(|i| bf.may_contain(&format!("absent-{i}")))
556 .count();
557 assert!(
558 (fp as f64 / 10_000.0) < 0.05,
559 "FP rate too high: {fp}/10000"
560 );
561 }
562
563 #[test]
564 fn lsm_insert_and_search() {
565 let mut idx = LSMIndex::new(CompactionConfig {
566 memtable_capacity: 10,
567 ..Default::default()
568 });
569 idx.insert("v1".into(), vec![1.0, 0.0], None);
570 idx.insert("v2".into(), vec![0.0, 1.0], None);
571 let r = idx.search(&[1.0, 0.0], 1);
572 assert_eq!(r.len(), 1);
573 assert_eq!(r[0].id, "v1");
574 }
575
576 #[test]
577 fn lsm_delete_with_tombstone() {
578 let mut idx = LSMIndex::new(CompactionConfig {
579 memtable_capacity: 100,
580 ..Default::default()
581 });
582 idx.insert("v1".into(), vec![1.0, 0.0], None);
583 idx.insert("v2".into(), vec![0.0, 1.0], None);
584 idx.delete("v1".into());
585 let r = idx.search(&[1.0, 0.0], 2);
586 assert_eq!(r.len(), 1);
587 assert_eq!(r[0].id, "v2");
588 }
589
590 #[test]
591 fn lsm_auto_compaction_trigger() {
592 let cfg = CompactionConfig {
593 memtable_capacity: 2,
594 merge_threshold: 2,
595 max_levels: 3,
596 ..Default::default()
597 };
598 let mut idx = LSMIndex::new(cfg);
599 for i in 0..10 {
600 idx.insert(format!("v{i}"), vec![i as f32], None);
601 }
602 assert!(idx.stats().segments_per_level[0] < 4, "L0 should compact");
603 }
604
605 #[test]
606 fn lsm_multi_level_compaction() {
607 let cfg = CompactionConfig {
608 memtable_capacity: 2,
609 merge_threshold: 2,
610 max_levels: 4,
611 ..Default::default()
612 };
613 let mut idx = LSMIndex::new(cfg);
614 for i in 0..30 {
615 idx.insert(format!("v{i}"), v(4, i as f32), None);
616 }
617 let total_seg: usize = idx.stats().segments_per_level.iter().sum();
618 assert!(total_seg >= 1);
619 }
620
621 #[test]
622 fn lsm_write_amplification_increases() {
623 let cfg = CompactionConfig {
624 memtable_capacity: 5,
625 merge_threshold: 2,
626 max_levels: 3,
627 ..Default::default()
628 };
629 let mut idx = LSMIndex::new(cfg);
630 for i in 0..20 {
631 idx.insert(format!("v{i}"), v(4, i as f32), None);
632 }
633 assert!(idx.write_amplification() >= 1.0);
634 }
635
636 #[test]
637 fn lsm_empty_index() {
638 let idx = LSMIndex::new(CompactionConfig::default());
639 assert!(idx.search(&[0.0, 0.0], 10).is_empty());
640 let s = idx.stats();
641 assert_eq!(s.total_entries, 0);
642 assert!((s.write_amplification - 1.0).abs() < f64::EPSILON);
643 }
644
645 #[test]
646 fn lsm_large_batch_insert() {
647 let cfg = CompactionConfig {
648 memtable_capacity: 50,
649 merge_threshold: 4,
650 max_levels: 4,
651 ..Default::default()
652 };
653 let mut idx = LSMIndex::new(cfg);
654 for i in 0..500 {
655 idx.insert(format!("v{i}"), v(8, i as f32 * 0.01), None);
656 }
657 assert!(idx.stats().total_entries > 0);
658 let r = idx.search(&v(8, 0.0), 5);
659 assert_eq!(r.len(), 5);
660 assert_eq!(r[0].id, "v0");
661 }
662
663 #[test]
664 fn lsm_search_across_levels() {
665 let cfg = CompactionConfig {
666 memtable_capacity: 3,
667 merge_threshold: 3,
668 max_levels: 3,
669 ..Default::default()
670 };
671 let mut idx = LSMIndex::new(cfg);
672 for i in 0..9 {
673 idx.insert(format!("v{i}"), vec![i as f32, 0.0], None);
674 }
675 idx.insert("latest".into(), vec![0.0, 0.0], None);
676 let r = idx.search(&[0.0, 0.0], 3);
677 assert_eq!(r.len(), 3);
678 let ids: Vec<&str> = r.iter().map(|r| r.id.as_str()).collect();
679 assert!(ids.contains(&"latest"));
680 assert!(ids.contains(&"v0"));
681 }
682}