1use std::collections::HashMap;
89use std::fs::{File, OpenOptions};
90use std::io::{BufWriter, Write};
91use std::path::{Path, PathBuf};
92use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
93use std::sync::{Arc, Mutex, RwLock};
94use std::thread::JoinHandle;
95
96#[derive(Debug, Clone)]
98pub struct LsmConfig {
99 pub dim: usize,
101
102 pub mutable_capacity: usize,
104
105 pub wal_path: PathBuf,
107
108 pub sync_wal: bool,
110
111 pub wal_batch_size: usize,
113
114 pub build_threads: usize,
116
117 pub auto_compact: bool,
119
120 pub compact_threshold: usize,
122}
123
124impl Default for LsmConfig {
125 fn default() -> Self {
126 Self {
127 dim: 768,
128 mutable_capacity: 10_000,
129 wal_path: PathBuf::from("./wal"),
130 sync_wal: true,
131 wal_batch_size: 100,
132 build_threads: 2,
133 auto_compact: true,
134 compact_threshold: 4,
135 }
136 }
137}
138
139pub type VectorKey = u64;
141
142mod wal_record_compat {
154 use sochdb_core::txn::WalRecordType;
155
156 pub(super) fn to_disk_byte(rt: WalRecordType) -> u8 {
158 match rt {
159 WalRecordType::Data => 1, WalRecordType::Delete => 2, WalRecordType::Flush => 3, WalRecordType::Compaction => 4, _ => 0xFF, }
165 }
166}
167use sochdb_core::txn::WalRecordType;
168
169#[repr(C, packed)]
171struct WalHeader {
172 record_type: u8,
173 key: u64,
174 dim: u32,
175 checksum: u32,
176}
177
178pub struct WriteAheadLog {
180 writer: BufWriter<File>,
182
183 position: u64,
185
186 pending: usize,
188
189 sync_interval: usize,
191
192 writes: AtomicU64,
194 syncs: AtomicU64,
195}
196
197impl WriteAheadLog {
198 pub fn open(path: &Path, sync_interval: usize) -> std::io::Result<Self> {
200 let file = OpenOptions::new()
201 .create(true)
202 .write(true)
203 .append(true)
204 .open(path)?;
205
206 let position = file.metadata()?.len();
207 let writer = BufWriter::with_capacity(64 * 1024, file);
208
209 Ok(Self {
210 writer,
211 position,
212 pending: 0,
213 sync_interval,
214 writes: AtomicU64::new(0),
215 syncs: AtomicU64::new(0),
216 })
217 }
218
219 pub fn write_insert(&mut self, key: VectorKey, vector: &[f32]) -> std::io::Result<()> {
221 let dim = vector.len() as u32;
222 let checksum = self.compute_checksum(key, vector);
223
224 let header = WalHeader {
226 record_type: wal_record_compat::to_disk_byte(WalRecordType::Data),
227 key,
228 dim,
229 checksum,
230 };
231
232 let header_bytes = unsafe {
233 std::slice::from_raw_parts(
234 &header as *const WalHeader as *const u8,
235 std::mem::size_of::<WalHeader>(),
236 )
237 };
238 self.writer.write_all(header_bytes)?;
239
240 let vector_bytes = unsafe {
242 std::slice::from_raw_parts(
243 vector.as_ptr() as *const u8,
244 vector.len() * std::mem::size_of::<f32>(),
245 )
246 };
247 self.writer.write_all(vector_bytes)?;
248
249 self.position += header_bytes.len() as u64 + vector_bytes.len() as u64;
250 self.pending += 1;
251 self.writes.fetch_add(1, Ordering::Relaxed);
252
253 if self.pending >= self.sync_interval {
255 self.sync()?;
256 }
257
258 Ok(())
259 }
260
261 pub fn write_seal_start(&mut self, segment_id: u64) -> std::io::Result<()> {
263 let header = WalHeader {
264 record_type: wal_record_compat::to_disk_byte(WalRecordType::Flush),
265 key: segment_id,
266 dim: 0,
267 checksum: 0,
268 };
269
270 let header_bytes = unsafe {
271 std::slice::from_raw_parts(
272 &header as *const WalHeader as *const u8,
273 std::mem::size_of::<WalHeader>(),
274 )
275 };
276 self.writer.write_all(header_bytes)?;
277 self.sync()?;
278
279 Ok(())
280 }
281
282 pub fn write_seal_complete(&mut self, segment_id: u64) -> std::io::Result<()> {
284 let header = WalHeader {
285 record_type: wal_record_compat::to_disk_byte(WalRecordType::Compaction),
286 key: segment_id,
287 dim: 0,
288 checksum: 0,
289 };
290
291 let header_bytes = unsafe {
292 std::slice::from_raw_parts(
293 &header as *const WalHeader as *const u8,
294 std::mem::size_of::<WalHeader>(),
295 )
296 };
297 self.writer.write_all(header_bytes)?;
298 self.sync()?;
299
300 Ok(())
301 }
302
303 pub fn sync(&mut self) -> std::io::Result<()> {
305 self.writer.flush()?;
306 self.writer.get_ref().sync_all()?;
307 self.pending = 0;
308 self.syncs.fetch_add(1, Ordering::Relaxed);
309 Ok(())
310 }
311
312 fn compute_checksum(&self, key: VectorKey, vector: &[f32]) -> u32 {
313 let mut hasher = crc32fast::Hasher::new();
314 hasher.update(&(key as u32).to_le_bytes());
315 for &x in vector {
316 hasher.update(&x.to_le_bytes());
317 }
318 hasher.finalize()
319 }
320
321 pub fn stats(&self) -> WalStats {
323 WalStats {
324 writes: self.writes.load(Ordering::Relaxed),
325 syncs: self.syncs.load(Ordering::Relaxed),
326 position: self.position,
327 }
328 }
329}
330
331#[derive(Debug, Clone)]
333pub struct WalStats {
334 pub writes: u64,
335 pub syncs: u64,
336 pub position: u64,
337}
338
339pub struct MutableSegment {
345 vectors: HashMap<VectorKey, (u32, Vec<f32>)>,
347
348 keys: Vec<VectorKey>,
350
351 #[allow(dead_code)]
353 dim: usize,
354
355 capacity: usize,
357}
358
359impl MutableSegment {
360 pub fn new(dim: usize, capacity: usize) -> Self {
362 Self {
363 vectors: HashMap::with_capacity(capacity),
364 keys: Vec::with_capacity(capacity),
365 dim,
366 capacity,
367 }
368 }
369
370 pub fn insert(&mut self, key: VectorKey, vector: Vec<f32>) -> bool {
372 if self.vectors.len() >= self.capacity {
373 return false;
374 }
375
376 let index = self.keys.len() as u32;
377 self.vectors.insert(key, (index, vector));
378 self.keys.push(key);
379 true
380 }
381
382 pub fn is_full(&self) -> bool {
384 self.vectors.len() >= self.capacity
385 }
386
387 pub fn len(&self) -> usize {
389 self.vectors.len()
390 }
391
392 pub fn is_empty(&self) -> bool {
394 self.vectors.is_empty()
395 }
396
397 pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
399 self.vectors.get(&key).map(|(_, v)| v.as_slice())
400 }
401
402 pub fn drain(&mut self) -> Vec<(VectorKey, Vec<f32>)> {
404 let result: Vec<_> = self
405 .keys
406 .drain(..)
407 .filter_map(|k| self.vectors.remove(&k).map(|(_, v)| (k, v)))
408 .collect();
409 result
410 }
411}
412
413pub struct SealedSegment {
419 pub id: u64,
421
422 pub data: Vec<f32>,
424
425 pub key_to_index: HashMap<VectorKey, u32>,
427
428 pub index_to_key: Vec<VectorKey>,
430
431 pub dim: usize,
433
434 pub build_time_ns: u64,
436}
437
438impl SealedSegment {
439 pub fn len(&self) -> usize {
441 self.index_to_key.len()
442 }
443
444 pub fn is_empty(&self) -> bool {
446 self.index_to_key.is_empty()
447 }
448
449 pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
451 self.key_to_index.get(&key).map(|&idx| {
452 let start = idx as usize * self.dim;
453 &self.data[start..start + self.dim]
454 })
455 }
456
457 pub fn get_by_index(&self, index: u32) -> Option<&[f32]> {
459 if (index as usize) < self.index_to_key.len() {
460 let start = index as usize * self.dim;
461 Some(&self.data[start..start + self.dim])
462 } else {
463 None
464 }
465 }
466}
467
468struct BuildTask {
474 segment_id: u64,
476
477 vectors: Vec<(VectorKey, Vec<f32>)>,
479
480 #[allow(dead_code)]
482 dim: usize,
483}
484
485#[allow(dead_code)]
487struct BuildResult {
488 segment: SealedSegment,
489}
490
491pub struct AsyncLsmManager {
497 config: LsmConfig,
499
500 wal: Mutex<WriteAheadLog>,
502
503 mutable: RwLock<MutableSegment>,
505
506 sealed: RwLock<Vec<Arc<SealedSegment>>>,
508
509 pending_builds: Mutex<Vec<BuildTask>>,
511
512 workers: Mutex<Vec<JoinHandle<()>>>,
514
515 shutdown: Arc<AtomicBool>,
517
518 segment_counter: AtomicU64,
520
521 stats: LsmStats,
523}
524
525impl AsyncLsmManager {
526 pub fn new(config: LsmConfig) -> std::io::Result<Self> {
528 std::fs::create_dir_all(&config.wal_path)?;
530
531 let wal_file = config.wal_path.join("wal.log");
532 let wal = WriteAheadLog::open(&wal_file, config.wal_batch_size)?;
533
534 let mutable = MutableSegment::new(config.dim, config.mutable_capacity);
535
536 let shutdown = Arc::new(AtomicBool::new(false));
537
538 Ok(Self {
539 config,
540 wal: Mutex::new(wal),
541 mutable: RwLock::new(mutable),
542 sealed: RwLock::new(Vec::new()),
543 pending_builds: Mutex::new(Vec::new()),
544 workers: Mutex::new(Vec::new()),
545 shutdown,
546 segment_counter: AtomicU64::new(0),
547 stats: LsmStats::default(),
548 })
549 }
550
551 pub fn insert(&self, key: VectorKey, vector: Vec<f32>) -> Result<(), LsmError> {
553 {
555 let mut wal = self.wal.lock().unwrap();
556 wal.write_insert(key, &vector)?;
557 }
558
559 {
561 let mut mutable = self.mutable.write().unwrap();
562
563 if mutable.is_full() {
564 drop(mutable);
566 self.seal_async()?;
567 mutable = self.mutable.write().unwrap();
568 }
569
570 if !mutable.insert(key, vector) {
571 return Err(LsmError::SegmentFull);
572 }
573 }
574
575 self.stats.inserts.fetch_add(1, Ordering::Relaxed);
576
577 Ok(())
578 }
579
580 pub fn insert_batch(&self, items: Vec<(VectorKey, Vec<f32>)>) -> Result<(), LsmError> {
582 {
584 let mut wal = self.wal.lock().unwrap();
585 for (key, vector) in &items {
586 wal.write_insert(*key, vector)?;
587 }
588 wal.sync()?;
589 }
590
591 let mut mutable = self.mutable.write().unwrap();
593
594 for (key, vector) in items {
595 if mutable.is_full() {
596 drop(mutable);
598 self.seal_async()?;
599 mutable = self.mutable.write().unwrap();
600 }
601
602 mutable.insert(key, vector);
603 self.stats.inserts.fetch_add(1, Ordering::Relaxed);
604 }
605
606 Ok(())
607 }
608
609 pub fn seal_async(&self) -> Result<u64, LsmError> {
611 let segment_id = self.segment_counter.fetch_add(1, Ordering::Relaxed);
612
613 {
615 let mut wal = self.wal.lock().unwrap();
616 wal.write_seal_start(segment_id)?;
617 }
618
619 let vectors = {
621 let mut mutable = self.mutable.write().unwrap();
622 let vectors = mutable.drain();
623
624 *mutable = MutableSegment::new(self.config.dim, self.config.mutable_capacity);
626
627 vectors
628 };
629
630 if vectors.is_empty() {
631 return Ok(segment_id);
632 }
633
634 let task = BuildTask {
636 segment_id,
637 vectors,
638 dim: self.config.dim,
639 };
640
641 {
642 let mut pending = self.pending_builds.lock().unwrap();
643 pending.push(task);
644 }
645
646 self.ensure_worker_running();
648
649 self.stats.seals.fetch_add(1, Ordering::Relaxed);
650
651 Ok(segment_id)
652 }
653
654 pub fn seal_blocking(&self) -> Result<Arc<SealedSegment>, LsmError> {
656 let segment_id = self.segment_counter.fetch_add(1, Ordering::Relaxed);
657
658 let vectors = {
660 let mut mutable = self.mutable.write().unwrap();
661 let vectors = mutable.drain();
662
663 *mutable = MutableSegment::new(self.config.dim, self.config.mutable_capacity);
665
666 vectors
667 };
668
669 if vectors.is_empty() {
670 return Err(LsmError::EmptySegment);
671 }
672
673 let segment = self.build_segment(segment_id, vectors);
675 let segment = Arc::new(segment);
676
677 {
679 let mut sealed = self.sealed.write().unwrap();
680 sealed.push(Arc::clone(&segment));
681 }
682
683 {
685 let mut wal = self.wal.lock().unwrap();
686 wal.write_seal_complete(segment_id)?;
687 }
688
689 self.stats.seals.fetch_add(1, Ordering::Relaxed);
690
691 Ok(segment)
692 }
693
694 pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorKey, f32)> {
696 let mut results = Vec::new();
697
698 {
700 let mutable = self.mutable.read().unwrap();
701 for &key in &mutable.keys {
702 if let Some(vector) = mutable.get(key) {
703 let dist = l2_squared(query, vector);
704 results.push((key, dist));
705 }
706 }
707 }
708
709 {
711 let sealed = self.sealed.read().unwrap();
712 for segment in sealed.iter() {
713 for (i, &key) in segment.index_to_key.iter().enumerate() {
714 if let Some(vector) = segment.get_by_index(i as u32) {
715 let dist = l2_squared(query, vector);
716 results.push((key, dist));
717 }
718 }
719 }
720 }
721
722 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
724 results.truncate(k);
725
726 results
727 }
728
729 pub fn get(&self, key: VectorKey) -> Option<Vec<f32>> {
731 {
733 let mutable = self.mutable.read().unwrap();
734 if let Some(v) = mutable.get(key) {
735 return Some(v.to_vec());
736 }
737 }
738
739 {
741 let sealed = self.sealed.read().unwrap();
742 for segment in sealed.iter().rev() {
743 if let Some(v) = segment.get(key) {
744 return Some(v.to_vec());
745 }
746 }
747 }
748
749 None
750 }
751
752 pub fn flush(&self) -> Result<(), LsmError> {
754 loop {
756 let task = {
757 let mut pending = self.pending_builds.lock().unwrap();
758 pending.pop()
759 };
760
761 match task {
762 Some(task) => {
763 let segment = self.build_segment(task.segment_id, task.vectors);
764 let segment = Arc::new(segment);
765
766 let mut sealed = self.sealed.write().unwrap();
767 sealed.push(segment);
768
769 let mut wal = self.wal.lock().unwrap();
770 wal.write_seal_complete(task.segment_id)?;
771 }
772 None => break,
773 }
774 }
775
776 let mut wal = self.wal.lock().unwrap();
778 wal.sync()?;
779
780 Ok(())
781 }
782
783 pub fn stats(&self) -> LsmManagerStats {
785 let mutable_len = self.mutable.read().unwrap().len();
786 let sealed_count = self.sealed.read().unwrap().len();
787 let pending = self.pending_builds.lock().unwrap().len();
788
789 LsmManagerStats {
790 inserts: self.stats.inserts.load(Ordering::Relaxed),
791 seals: self.stats.seals.load(Ordering::Relaxed),
792 mutable_vectors: mutable_len,
793 sealed_segments: sealed_count,
794 pending_builds: pending,
795 }
796 }
797
798 fn build_segment(&self, segment_id: u64, vectors: Vec<(VectorKey, Vec<f32>)>) -> SealedSegment {
799 let start = std::time::Instant::now();
800 let dim = self.config.dim;
801
802 let mut data = Vec::with_capacity(vectors.len() * dim);
803 let mut key_to_index = HashMap::with_capacity(vectors.len());
804 let mut index_to_key = Vec::with_capacity(vectors.len());
805
806 for (i, (key, vector)) in vectors.into_iter().enumerate() {
807 data.extend_from_slice(&vector);
808 key_to_index.insert(key, i as u32);
809 index_to_key.push(key);
810 }
811
812 SealedSegment {
813 id: segment_id,
814 data,
815 key_to_index,
816 index_to_key,
817 dim,
818 build_time_ns: start.elapsed().as_nanos() as u64,
819 }
820 }
821
822 fn ensure_worker_running(&self) {
823 }
826}
827
828impl Drop for AsyncLsmManager {
829 fn drop(&mut self) {
830 self.shutdown.store(true, Ordering::Release);
831
832 let _ = self.flush();
834
835 let mut workers = self.workers.lock().unwrap();
837 for handle in workers.drain(..) {
838 let _ = handle.join();
839 }
840 }
841}
842
843#[derive(Default)]
845struct LsmStats {
846 inserts: AtomicU64,
847 seals: AtomicU64,
848}
849
850#[derive(Debug, Clone)]
852pub struct LsmManagerStats {
853 pub inserts: u64,
854 pub seals: u64,
855 pub mutable_vectors: usize,
856 pub sealed_segments: usize,
857 pub pending_builds: usize,
858}
859
860#[derive(Debug)]
862pub enum LsmError {
863 Io(std::io::Error),
864 SegmentFull,
865 EmptySegment,
866 KeyNotFound,
867}
868
869impl From<std::io::Error> for LsmError {
870 fn from(e: std::io::Error) -> Self {
871 LsmError::Io(e)
872 }
873}
874
875impl std::fmt::Display for LsmError {
876 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
877 match self {
878 LsmError::Io(e) => write!(f, "IO error: {}", e),
879 LsmError::SegmentFull => write!(f, "segment full"),
880 LsmError::EmptySegment => write!(f, "empty segment"),
881 LsmError::KeyNotFound => write!(f, "key not found"),
882 }
883 }
884}
885
886impl std::error::Error for LsmError {}
887
888fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
894 a.iter()
895 .zip(b.iter())
896 .map(|(x, y)| {
897 let d = x - y;
898 d * d
899 })
900 .sum()
901}
902
903#[cfg(test)]
908mod tests {
909 use super::*;
910 use tempfile::tempdir;
911
912 #[test]
913 fn test_wal_basic() {
914 let dir = tempdir().unwrap();
915 let wal_path = dir.path().join("wal.log");
916
917 let mut wal = WriteAheadLog::open(&wal_path, 10).unwrap();
918
919 let vector = vec![1.0, 2.0, 3.0, 4.0];
920 wal.write_insert(42, &vector).unwrap();
921 wal.sync().unwrap();
922
923 let stats = wal.stats();
924 assert_eq!(stats.writes, 1);
925 assert!(stats.position > 0);
926 }
927
928 #[test]
929 fn test_mutable_segment() {
930 let mut segment = MutableSegment::new(4, 10);
931
932 segment.insert(1, vec![1.0, 2.0, 3.0, 4.0]);
933 segment.insert(2, vec![5.0, 6.0, 7.0, 8.0]);
934
935 assert_eq!(segment.len(), 2);
936 assert_eq!(segment.get(1).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
937
938 let drained = segment.drain();
939 assert_eq!(drained.len(), 2);
940 assert!(segment.is_empty());
941 }
942
943 #[test]
944 fn test_lsm_manager_basic() {
945 let dir = tempdir().unwrap();
946
947 let config = LsmConfig {
948 dim: 4,
949 mutable_capacity: 10,
950 wal_path: dir.path().to_path_buf(),
951 ..Default::default()
952 };
953
954 let manager = AsyncLsmManager::new(config).unwrap();
955
956 manager.insert(1, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
957 manager.insert(2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
958
959 let v1 = manager.get(1).unwrap();
960 assert_eq!(v1, vec![1.0, 2.0, 3.0, 4.0]);
961
962 let stats = manager.stats();
963 assert_eq!(stats.inserts, 2);
964 assert_eq!(stats.mutable_vectors, 2);
965 }
966
967 #[test]
968 fn test_lsm_seal_blocking() {
969 let dir = tempdir().unwrap();
970
971 let config = LsmConfig {
972 dim: 4,
973 mutable_capacity: 10,
974 wal_path: dir.path().to_path_buf(),
975 ..Default::default()
976 };
977
978 let manager = AsyncLsmManager::new(config).unwrap();
979
980 manager.insert(1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
981 manager.insert(2, vec![0.0, 1.0, 0.0, 0.0]).unwrap();
982
983 let segment = manager.seal_blocking().unwrap();
984
985 assert_eq!(segment.len(), 2);
986 assert!(manager.get(1).is_some());
987
988 let stats = manager.stats();
989 assert_eq!(stats.seals, 1);
990 assert_eq!(stats.sealed_segments, 1);
991 assert_eq!(stats.mutable_vectors, 0);
992 }
993
994 #[test]
995 fn test_lsm_search() {
996 let dir = tempdir().unwrap();
997
998 let config = LsmConfig {
999 dim: 4,
1000 mutable_capacity: 100,
1001 wal_path: dir.path().to_path_buf(),
1002 ..Default::default()
1003 };
1004
1005 let manager = AsyncLsmManager::new(config).unwrap();
1006
1007 manager.insert(1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
1009 manager.insert(2, vec![0.0, 1.0, 0.0, 0.0]).unwrap();
1010 manager.insert(3, vec![0.5, 0.5, 0.0, 0.0]).unwrap();
1011
1012 let query = vec![1.0, 0.0, 0.0, 0.0];
1014 let results = manager.search(&query, 2);
1015
1016 assert_eq!(results.len(), 2);
1017 assert_eq!(results[0].0, 1); assert!(results[0].1 < 0.01); }
1020
1021 #[test]
1022 fn test_lsm_batch_insert() {
1023 let dir = tempdir().unwrap();
1024
1025 let config = LsmConfig {
1026 dim: 4,
1027 mutable_capacity: 100,
1028 wal_path: dir.path().to_path_buf(),
1029 ..Default::default()
1030 };
1031
1032 let manager = AsyncLsmManager::new(config).unwrap();
1033
1034 let batch: Vec<_> = (0..10).map(|i| (i as u64, vec![i as f32; 4])).collect();
1035
1036 manager.insert_batch(batch).unwrap();
1037
1038 let stats = manager.stats();
1039 assert_eq!(stats.inserts, 10);
1040
1041 for i in 0..10 {
1043 let v = manager.get(i as u64).unwrap();
1044 assert_eq!(v[0], i as f32);
1045 }
1046 }
1047}