1use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
13use serde::{Deserialize, Serialize};
14use crate::types::{SearchResult, VectorId};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CompactionConfig {
19 pub memtable_capacity: usize,
21 pub level_size_ratio: usize,
23 pub max_levels: usize,
25 pub merge_threshold: usize,
27 pub bloom_fp_rate: f64,
29}
30
31impl Default for CompactionConfig {
32 fn default() -> Self {
33 Self { memtable_capacity: 1000, level_size_ratio: 10, max_levels: 4,
34 merge_threshold: 4, bloom_fp_rate: 0.01 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct BloomFilter { bits: Vec<bool>, num_hashes: usize }
41
42impl BloomFilter {
43 pub fn new(n: usize, fp_rate: f64) -> Self {
45 let n = n.max(1);
46 let fp = fp_rate.clamp(1e-10, 0.5);
47 let m = (-(n as f64) * fp.ln() / 2.0_f64.ln().powi(2)).ceil() as usize;
48 let m = m.max(8);
49 let k = ((m as f64 / n as f64) * 2.0_f64.ln()).ceil().max(1.0) as usize;
50 Self { bits: vec![false; m], num_hashes: k }
51 }
52
53 pub fn insert(&mut self, key: &str) {
55 let (h1, h2) = Self::hashes(key);
56 let m = self.bits.len();
57 for i in 0..self.num_hashes { self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m] = true; }
58 }
59
60 pub fn may_contain(&self, key: &str) -> bool {
62 let (h1, h2) = Self::hashes(key);
63 let m = self.bits.len();
64 (0..self.num_hashes).all(|i| self.bits[h1.wrapping_add(i.wrapping_mul(h2)) % m])
65 }
66
67 fn hashes(key: &str) -> (usize, usize) {
68 let (mut h1, mut h2): (u64, u64) = (0xcbf29ce484222325, 0x517cc1b727220a95);
69 for &b in key.as_bytes() {
70 h1 ^= b as u64; h1 = h1.wrapping_mul(0x100000001b3);
71 h2 = h2.wrapping_mul(31).wrapping_add(b as u64);
72 }
73 (h1 as usize, (h2 | 1) as usize)
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78struct LSMEntry {
79 id: VectorId,
80 vector: Option<Vec<f32>>, metadata: Option<HashMap<String, serde_json::Value>>,
82 seq: u64, }
84
85#[derive(Debug, Clone)]
87pub struct MemTable { entries: BTreeMap<VectorId, LSMEntry>, capacity: usize }
88
89impl MemTable {
90 pub fn new(capacity: usize) -> Self { Self { entries: BTreeMap::new(), capacity } }
91
92 pub fn insert(&mut self, id: VectorId, vector: Option<Vec<f32>>,
94 metadata: Option<HashMap<String, serde_json::Value>>, seq: u64) -> bool {
95 self.entries.insert(id.clone(), LSMEntry { id, vector, metadata, seq });
96 self.is_full()
97 }
98
99 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
101 let mut heap: BinaryHeap<(OrdF32, VectorId)> = BinaryHeap::new();
102 for e in self.entries.values() {
103 let v = match &e.vector { Some(v) => v, None => continue };
104 let d = OrdF32(euclid(query, v));
105 if heap.len() < top_k { heap.push((d, e.id.clone())); }
106 else if d < heap.peek().unwrap().0 { heap.pop(); heap.push((d, e.id.clone())); }
107 }
108 let mut r: Vec<SearchResult> = heap.into_sorted_vec().into_iter().filter_map(|(OrdF32(s), id)| {
109 self.entries.get(&id).map(|e| SearchResult { id: e.id.clone(), score: s,
110 vector: e.vector.clone(), metadata: e.metadata.clone() })
111 }).collect();
112 r.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)); r
113 }
114
115 pub fn flush(&mut self, level: usize, fp_rate: f64) -> Segment {
117 let entries: Vec<LSMEntry> = self.entries.values().cloned().collect();
118 self.entries.clear();
119 Segment::from_entries(entries, level, fp_rate)
120 }
121
122 pub fn len(&self) -> usize { self.entries.len() }
123 pub fn is_empty(&self) -> bool { self.entries.is_empty() }
124 pub fn is_full(&self) -> bool { self.entries.len() >= self.capacity }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct Segment { entries: Vec<LSMEntry>, bloom: BloomFilter, pub level: usize }
130
131impl Segment {
132 fn from_entries(entries: Vec<LSMEntry>, level: usize, fp_rate: f64) -> Self {
133 let mut bloom = BloomFilter::new(entries.len(), fp_rate);
134 for e in &entries { bloom.insert(&e.id); }
135 Self { entries, bloom, level }
136 }
137
138 pub fn size(&self) -> usize { self.entries.len() }
139 pub fn contains(&self, id: &str) -> bool { self.bloom.may_contain(id) }
140
141 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
143 let mut heap: BinaryHeap<(OrdF32, usize)> = BinaryHeap::new();
144 for (i, e) in self.entries.iter().enumerate() {
145 let v = match &e.vector { Some(v) => v, None => continue };
146 let d = OrdF32(euclid(query, v));
147 if heap.len() < top_k { heap.push((d, i)); }
148 else if d < heap.peek().unwrap().0 { heap.pop(); heap.push((d, i)); }
149 }
150 let mut r: Vec<SearchResult> = heap.into_sorted_vec().into_iter().map(|(OrdF32(s), i)| {
151 let e = &self.entries[i];
152 SearchResult { id: e.id.clone(), score: s, vector: e.vector.clone(), metadata: e.metadata.clone() }
153 }).collect();
154 r.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)); r
155 }
156
157 pub fn merge(segments: &[Segment], target_level: usize, fp_rate: f64) -> Segment {
159 let mut merged: BTreeMap<VectorId, LSMEntry> = BTreeMap::new();
160 for seg in segments {
161 for e in &seg.entries {
162 if merged.get(&e.id).map_or(true, |x| e.seq > x.seq) {
163 merged.insert(e.id.clone(), e.clone());
164 }
165 }
166 }
167 let entries: Vec<LSMEntry> = merged.into_values().filter(|e| e.vector.is_some()).collect();
168 Segment::from_entries(entries, target_level, fp_rate)
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct LSMStats {
175 pub num_levels: usize,
176 pub segments_per_level: Vec<usize>,
177 pub total_entries: usize,
178 pub write_amplification: f64,
179}
180
181#[derive(Debug, Clone)]
186pub struct LSMIndex {
187 config: CompactionConfig,
188 memtable: MemTable,
189 levels: Vec<Vec<Segment>>,
190 next_seq: u64,
191 bytes_written_user: u64,
192 bytes_written_total: u64,
193 deleted_ids: HashSet<VectorId>,
194}
195
196impl LSMIndex {
197 pub fn new(config: CompactionConfig) -> Self {
198 let cap = config.memtable_capacity;
199 let nl = config.max_levels;
200 Self { config, memtable: MemTable::new(cap), levels: vec![Vec::new(); nl],
201 next_seq: 0, bytes_written_user: 0, bytes_written_total: 0,
202 deleted_ids: HashSet::new() }
203 }
204
205 pub fn insert(&mut self, id: VectorId, vector: Vec<f32>,
207 metadata: Option<HashMap<String, serde_json::Value>>) {
208 let bytes = (vector.len() * 4 + id.len()) as u64;
209 self.bytes_written_user += bytes;
210 self.bytes_written_total += bytes;
211 self.deleted_ids.remove(&id);
212 let seq = self.next_seq; self.next_seq += 1;
213 if self.memtable.insert(id, Some(vector), metadata, seq) {
214 self.flush_memtable(); self.auto_compact();
215 }
216 }
217
218 pub fn delete(&mut self, id: VectorId) {
220 let bytes = id.len() as u64;
221 self.bytes_written_user += bytes;
222 self.bytes_written_total += bytes;
223 self.deleted_ids.insert(id.clone());
224 let seq = self.next_seq; self.next_seq += 1;
225 if self.memtable.insert(id, None, None, seq) {
226 self.flush_memtable(); self.auto_compact();
227 }
228 }
229
230 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
232 let mut seen = HashSet::new();
233 let mut all = Vec::new();
234 for r in self.memtable.search(query, top_k) {
235 if !self.deleted_ids.contains(&r.id) { seen.insert(r.id.clone()); all.push(r); }
236 }
237 for level in &self.levels {
238 for seg in level.iter().rev() {
239 for r in seg.search(query, top_k) {
240 if !seen.contains(&r.id) && !self.deleted_ids.contains(&r.id) {
241 seen.insert(r.id.clone()); all.push(r);
242 }
243 }
244 }
245 }
246 all.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal));
247 all.truncate(top_k); all
248 }
249
250 pub fn compact(&mut self) {
252 if !self.memtable.is_empty() { self.flush_memtable(); }
253 for l in 0..self.config.max_levels.saturating_sub(1) {
254 if self.levels[l].len() >= 2 { self.compact_level(l); }
255 }
256 }
257
258 pub fn auto_compact(&mut self) {
260 for l in 0..self.config.max_levels.saturating_sub(1) {
261 if self.levels[l].len() >= self.config.merge_threshold { self.compact_level(l); }
262 }
263 }
264
265 pub fn stats(&self) -> LSMStats {
266 let spl: Vec<usize> = self.levels.iter().map(|l| l.len()).collect();
267 let total = self.memtable.len()
268 + self.levels.iter().flat_map(|l| l.iter()).map(|s| s.size()).sum::<usize>();
269 LSMStats { num_levels: self.levels.len(), segments_per_level: spl,
270 total_entries: total, write_amplification: self.write_amplification() }
271 }
272
273 pub fn write_amplification(&self) -> f64 {
274 if self.bytes_written_user == 0 { 1.0 }
275 else { self.bytes_written_total as f64 / self.bytes_written_user as f64 }
276 }
277
278 fn flush_memtable(&mut self) {
279 let seg = self.memtable.flush(0, self.config.bloom_fp_rate);
280 self.bytes_written_total += entry_bytes(&seg.entries);
281 self.levels[0].push(seg);
282 }
283
284 fn compact_level(&mut self, level: usize) {
285 let target = level + 1;
286 if target >= self.config.max_levels { return; }
287 let segments = std::mem::take(&mut self.levels[level]);
288 let merged = Segment::merge(&segments, target, self.config.bloom_fp_rate);
289 self.bytes_written_total += entry_bytes(&merged.entries);
290 self.levels[target].push(merged);
291 }
292}
293
294fn entry_bytes(entries: &[LSMEntry]) -> u64 {
295 entries.iter().map(|e| {
296 (e.vector.as_ref().map_or(0, |v| v.len() * 4) + e.id.len()) as u64
297 }).sum()
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
301struct OrdF32(f32);
302impl Eq for OrdF32 {}
303impl PartialOrd for OrdF32 {
304 fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(o)) }
305}
306impl Ord for OrdF32 {
307 fn cmp(&self, o: &Self) -> std::cmp::Ordering {
308 self.0.partial_cmp(&o.0).unwrap_or(std::cmp::Ordering::Equal)
309 }
310}
311
312fn euclid(a: &[f32], b: &[f32]) -> f32 {
313 a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 fn v(dim: usize, val: f32) -> Vec<f32> { vec![val; dim] }
320 fn entry(id: &str, vec: Option<Vec<f32>>, seq: u64) -> LSMEntry {
321 LSMEntry { id: id.into(), vector: vec, metadata: None, seq }
322 }
323
324 #[test]
325 fn memtable_insert_and_len() {
326 let mut mt = MemTable::new(5);
327 assert!(mt.is_empty());
328 mt.insert("a".into(), Some(vec![1.0]), None, 0);
329 mt.insert("b".into(), Some(vec![2.0]), None, 1);
330 assert_eq!(mt.len(), 2);
331 assert!(!mt.is_full());
332 }
333
334 #[test]
335 fn memtable_is_full() {
336 let mut mt = MemTable::new(2);
337 mt.insert("a".into(), Some(vec![1.0]), None, 0);
338 assert!(mt.insert("b".into(), Some(vec![2.0]), None, 1));
339 }
340
341 #[test]
342 fn memtable_search_returns_closest() {
343 let mut mt = MemTable::new(100);
344 mt.insert("far".into(), Some(vec![10.0, 10.0]), None, 0);
345 mt.insert("close".into(), Some(vec![1.0, 0.0]), None, 1);
346 mt.insert("mid".into(), Some(vec![5.0, 5.0]), None, 2);
347 let r = mt.search(&[0.0, 0.0], 2);
348 assert_eq!(r.len(), 2);
349 assert_eq!(r[0].id, "close");
350 }
351
352 #[test]
353 fn memtable_flush_produces_segment() {
354 let mut mt = MemTable::new(10);
355 mt.insert("x".into(), Some(vec![1.0]), None, 0);
356 mt.insert("y".into(), Some(vec![2.0]), None, 1);
357 let seg = mt.flush(0, 0.01);
358 assert_eq!(seg.size(), 2);
359 assert_eq!(seg.level, 0);
360 assert!(mt.is_empty());
361 }
362
363 #[test]
364 fn segment_merge_dedup_keeps_latest() {
365 let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
366 let s2 = Segment::from_entries(vec![entry("a", Some(vec![9.0]), 5)], 0, 0.01);
367 let m = Segment::merge(&[s1, s2], 1, 0.01);
368 assert_eq!(m.size(), 1);
369 assert_eq!(m.entries[0].vector.as_ref().unwrap(), &vec![9.0]);
370 }
371
372 #[test]
373 fn segment_merge_drops_tombstones() {
374 let s1 = Segment::from_entries(vec![entry("a", Some(vec![1.0]), 1)], 0, 0.01);
375 let s2 = Segment::from_entries(vec![entry("a", None, 5)], 0, 0.01);
376 assert_eq!(Segment::merge(&[s1, s2], 1, 0.01).size(), 0);
377 }
378
379 #[test]
380 fn bloom_filter_no_false_negatives() {
381 let mut bf = BloomFilter::new(100, 0.01);
382 for i in 0..100 { bf.insert(&format!("key-{i}")); }
383 for i in 0..100 { assert!(bf.may_contain(&format!("key-{i}"))); }
384 }
385
386 #[test]
387 fn bloom_filter_low_false_positive_rate() {
388 let mut bf = BloomFilter::new(1000, 0.01);
389 for i in 0..1000 { bf.insert(&format!("present-{i}")); }
390 let fp: usize = (0..10_000).filter(|i| bf.may_contain(&format!("absent-{i}"))).count();
391 assert!((fp as f64 / 10_000.0) < 0.05, "FP rate too high: {fp}/10000");
392 }
393
394 #[test]
395 fn lsm_insert_and_search() {
396 let mut idx = LSMIndex::new(CompactionConfig { memtable_capacity: 10, ..Default::default() });
397 idx.insert("v1".into(), vec![1.0, 0.0], None);
398 idx.insert("v2".into(), vec![0.0, 1.0], None);
399 let r = idx.search(&[1.0, 0.0], 1);
400 assert_eq!(r.len(), 1);
401 assert_eq!(r[0].id, "v1");
402 }
403
404 #[test]
405 fn lsm_delete_with_tombstone() {
406 let mut idx = LSMIndex::new(CompactionConfig { memtable_capacity: 100, ..Default::default() });
407 idx.insert("v1".into(), vec![1.0, 0.0], None);
408 idx.insert("v2".into(), vec![0.0, 1.0], None);
409 idx.delete("v1".into());
410 let r = idx.search(&[1.0, 0.0], 2);
411 assert_eq!(r.len(), 1);
412 assert_eq!(r[0].id, "v2");
413 }
414
415 #[test]
416 fn lsm_auto_compaction_trigger() {
417 let cfg = CompactionConfig { memtable_capacity: 2, merge_threshold: 2, max_levels: 3, ..Default::default() };
418 let mut idx = LSMIndex::new(cfg);
419 for i in 0..10 { idx.insert(format!("v{i}"), vec![i as f32], None); }
420 assert!(idx.stats().segments_per_level[0] < 4, "L0 should compact");
421 }
422
423 #[test]
424 fn lsm_multi_level_compaction() {
425 let cfg = CompactionConfig { memtable_capacity: 2, merge_threshold: 2, max_levels: 4, ..Default::default() };
426 let mut idx = LSMIndex::new(cfg);
427 for i in 0..30 { idx.insert(format!("v{i}"), v(4, i as f32), None); }
428 let total_seg: usize = idx.stats().segments_per_level.iter().sum();
429 assert!(total_seg >= 1);
430 }
431
432 #[test]
433 fn lsm_write_amplification_increases() {
434 let cfg = CompactionConfig { memtable_capacity: 5, merge_threshold: 2, max_levels: 3, ..Default::default() };
435 let mut idx = LSMIndex::new(cfg);
436 for i in 0..20 { idx.insert(format!("v{i}"), v(4, i as f32), None); }
437 assert!(idx.write_amplification() >= 1.0);
438 }
439
440 #[test]
441 fn lsm_empty_index() {
442 let idx = LSMIndex::new(CompactionConfig::default());
443 assert!(idx.search(&[0.0, 0.0], 10).is_empty());
444 let s = idx.stats();
445 assert_eq!(s.total_entries, 0);
446 assert!((s.write_amplification - 1.0).abs() < f64::EPSILON);
447 }
448
449 #[test]
450 fn lsm_large_batch_insert() {
451 let cfg = CompactionConfig { memtable_capacity: 50, merge_threshold: 4, max_levels: 4, ..Default::default() };
452 let mut idx = LSMIndex::new(cfg);
453 for i in 0..500 { idx.insert(format!("v{i}"), v(8, i as f32 * 0.01), None); }
454 assert!(idx.stats().total_entries > 0);
455 let r = idx.search(&v(8, 0.0), 5);
456 assert_eq!(r.len(), 5);
457 assert_eq!(r[0].id, "v0");
458 }
459
460 #[test]
461 fn lsm_search_across_levels() {
462 let cfg = CompactionConfig { memtable_capacity: 3, merge_threshold: 3, max_levels: 3, ..Default::default() };
463 let mut idx = LSMIndex::new(cfg);
464 for i in 0..9 { idx.insert(format!("v{i}"), vec![i as f32, 0.0], None); }
465 idx.insert("latest".into(), vec![0.0, 0.0], None);
466 let r = idx.search(&[0.0, 0.0], 3);
467 assert_eq!(r.len(), 3);
468 let ids: Vec<&str> = r.iter().map(|r| r.id.as_str()).collect();
469 assert!(ids.contains(&"latest"));
470 assert!(ids.contains(&"v0"));
471 }
472}