1use std::collections::HashSet;
52use std::ptr;
53use std::sync::atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering};
54
55use dashmap::DashMap;
56use parking_lot::Mutex;
57
58use sochdb_core::{Result, SochDBError};
59
60const HP_PER_THREAD: usize = 2;
62
63const MAX_THREADS: usize = 128;
65
66const RECLAMATION_THRESHOLD: usize = 64;
68
69pub const INLINE_VALUE_SIZE: usize = 56;
75
76#[repr(C)]
91pub enum ValueStorage {
92 Inline {
94 len: u8,
95 data: [u8; INLINE_VALUE_SIZE],
96 },
97 Heap(Box<[u8]>),
99 Tombstone,
101}
102
103impl std::fmt::Debug for ValueStorage {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 ValueStorage::Inline { len, .. } => write!(f, "Inline(len={})", len),
107 ValueStorage::Heap(data) => write!(f, "Heap(len={})", data.len()),
108 ValueStorage::Tombstone => write!(f, "Tombstone"),
109 }
110 }
111}
112
113impl ValueStorage {
114 #[inline]
116 pub fn new(value: Option<&[u8]>) -> Self {
117 match value {
118 None => ValueStorage::Tombstone,
119 Some(v) if v.len() <= INLINE_VALUE_SIZE => {
120 let mut data = [0u8; INLINE_VALUE_SIZE];
121 data[..v.len()].copy_from_slice(v);
122 ValueStorage::Inline {
123 len: v.len() as u8,
124 data,
125 }
126 }
127 Some(v) => ValueStorage::Heap(v.to_vec().into_boxed_slice()),
128 }
129 }
130
131 #[inline]
133 pub fn as_bytes(&self) -> Option<&[u8]> {
134 match self {
135 ValueStorage::Inline { len, data } => Some(&data[..*len as usize]),
136 ValueStorage::Heap(data) => Some(data),
137 ValueStorage::Tombstone => None,
138 }
139 }
140
141 #[inline]
143 pub fn is_tombstone(&self) -> bool {
144 matches!(self, ValueStorage::Tombstone)
145 }
146
147 #[inline]
149 pub fn is_inline(&self) -> bool {
150 matches!(self, ValueStorage::Inline { .. })
151 }
152
153 #[inline]
155 pub fn len(&self) -> usize {
156 match self {
157 ValueStorage::Inline { len, .. } => *len as usize,
158 ValueStorage::Heap(data) => data.len(),
159 ValueStorage::Tombstone => 0,
160 }
161 }
162
163 #[inline]
165 pub fn is_empty(&self) -> bool {
166 self.len() == 0
167 }
168}
169
170#[derive(Debug)]
175pub struct LockFreeVersion {
176 pub storage: ValueStorage,
178 pub txn_id: u64,
180 pub commit_ts: AtomicU64,
182 pub next: AtomicPtr<LockFreeVersion>,
184}
185
186impl LockFreeVersion {
187 #[inline]
189 pub fn new_from_slice(value: Option<&[u8]>, txn_id: u64) -> Self {
190 Self {
191 storage: ValueStorage::new(value),
192 txn_id,
193 commit_ts: AtomicU64::new(0),
194 next: AtomicPtr::new(ptr::null_mut()),
195 }
196 }
197
198 pub fn new(value: Option<Vec<u8>>, txn_id: u64) -> Self {
200 Self::new_from_slice(value.as_deref(), txn_id)
201 }
202
203 #[inline]
205 pub fn get_value(&self) -> Option<&[u8]> {
206 self.storage.as_bytes()
207 }
208
209 #[inline]
213 pub fn value_cloned(&self) -> Option<Vec<u8>> {
214 self.storage.as_bytes().map(|v| v.to_vec())
215 }
216
217 #[inline]
219 pub fn is_committed(&self) -> bool {
220 self.commit_ts.load(Ordering::Acquire) > 0
221 }
222
223 #[inline]
225 pub fn get_commit_ts(&self) -> u64 {
226 self.commit_ts.load(Ordering::Acquire)
227 }
228
229 #[inline]
231 pub fn set_commit_ts(&self, ts: u64) {
232 self.commit_ts.store(ts, Ordering::Release);
233 }
234
235 #[inline]
237 pub fn is_inline(&self) -> bool {
238 self.storage.is_inline()
239 }
240}
241
242pub struct LockFreeVersionChain {
244 head: AtomicPtr<LockFreeVersion>,
246}
247
248impl Default for LockFreeVersionChain {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl LockFreeVersionChain {
255 pub fn new() -> Self {
257 Self {
258 head: AtomicPtr::new(ptr::null_mut()),
259 }
260 }
261
262 pub fn add_uncommitted(&self, value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
266 let new_version = Box::into_raw(Box::new(LockFreeVersion::new(value, txn_id)));
267
268 loop {
269 let head = self.head.load(Ordering::Acquire);
270
271 if !head.is_null() {
273 let head_ref = unsafe { &*head };
274 if !head_ref.is_committed() && head_ref.txn_id != txn_id {
275 unsafe {
277 drop(Box::from_raw(new_version));
278 }
279 return Err(SochDBError::Internal("Write-write conflict".into()));
280 }
281 }
282
283 unsafe {
285 (*new_version).next.store(head, Ordering::Release);
286 }
287
288 match self
290 .head
291 .compare_exchange(head, new_version, Ordering::AcqRel, Ordering::Acquire)
292 {
293 Ok(_) => return Ok(()),
294 Err(_) => continue, }
296 }
297 }
298
299 pub fn commit(&self, txn_id: u64, commit_ts: u64) -> bool {
301 let mut current = self.head.load(Ordering::Acquire);
302
303 while !current.is_null() {
304 let version = unsafe { &*current };
305 if version.txn_id == txn_id && !version.is_committed() {
306 version.set_commit_ts(commit_ts);
307 return true;
308 }
309 current = version.next.load(Ordering::Acquire);
310 }
311
312 false
313 }
314
315 pub fn read_at(
320 &self,
321 snapshot_ts: u64,
322 current_txn_id: Option<u64>,
323 ) -> Option<&LockFreeVersion> {
324 let mut current = self.head.load(Ordering::Acquire);
325
326 while !current.is_null() {
327 let version = unsafe { &*current };
328
329 if let Some(txn_id) = current_txn_id
331 && version.txn_id == txn_id
332 && !version.is_committed()
333 {
334 return Some(version);
335 }
336
337 let commit_ts = version.get_commit_ts();
339 if commit_ts > 0 && commit_ts < snapshot_ts {
340 return Some(version);
341 }
342
343 current = version.next.load(Ordering::Acquire);
344 }
345
346 None
347 }
348
349 pub fn has_write_conflict(&self, my_txn_id: u64) -> bool {
351 let current = self.head.load(Ordering::Acquire);
352
353 if !current.is_null() {
354 let version = unsafe { &*current };
355 return !version.is_committed() && version.txn_id != my_txn_id;
356 }
357
358 false
359 }
360}
361
362#[repr(C, align(64))]
366struct HazardRecord {
367 hazard: [AtomicPtr<LockFreeVersion>; HP_PER_THREAD],
369 active: AtomicU64,
371}
372
373impl HazardRecord {
374 const fn new() -> Self {
375 Self {
376 hazard: [
377 AtomicPtr::new(ptr::null_mut()),
378 AtomicPtr::new(ptr::null_mut()),
379 ],
380 active: AtomicU64::new(0),
381 }
382 }
383
384 fn try_acquire(&self, thread_id: u64) -> bool {
386 self.active
387 .compare_exchange(0, thread_id, Ordering::AcqRel, Ordering::Acquire)
388 .is_ok()
389 }
390
391 #[allow(dead_code)]
393 fn release(&self) {
394 for hp in &self.hazard {
396 hp.store(ptr::null_mut(), Ordering::Release);
397 }
398 self.active.store(0, Ordering::Release);
399 }
400}
401
402pub struct HazardDomain {
404 records: Vec<HazardRecord>,
406 retired: Mutex<Vec<*mut LockFreeVersion>>,
408}
409
410impl HazardDomain {
411 pub fn new(max_threads: usize) -> Self {
413 let mut records = Vec::with_capacity(max_threads);
414 for _ in 0..max_threads {
415 records.push(HazardRecord::new());
416 }
417
418 Self {
419 records,
420 retired: Mutex::new(Vec::with_capacity(RECLAMATION_THRESHOLD * 2)),
421 }
422 }
423
424 fn get_record(&self) -> Option<&HazardRecord> {
426 let thread_id = thread_id::get() as u64;
427
428 for record in &self.records {
430 if record.active.load(Ordering::Acquire) == thread_id {
431 return Some(record);
432 }
433 }
434
435 self.records
437 .iter()
438 .find(|record| record.try_acquire(thread_id))
439 }
440
441 #[inline]
443 pub fn protect(&self, ptr: *mut LockFreeVersion, slot: usize) -> bool {
444 if let Some(record) = self.get_record()
445 && slot < HP_PER_THREAD
446 {
447 record.hazard[slot].store(ptr, Ordering::Release);
448 std::sync::atomic::fence(Ordering::SeqCst);
449 return true;
450 }
451 false
452 }
453
454 #[inline]
456 pub fn clear(&self, slot: usize) {
457 if let Some(record) = self.get_record()
458 && slot < HP_PER_THREAD
459 {
460 record.hazard[slot].store(ptr::null_mut(), Ordering::Release);
461 }
462 }
463
464 pub fn retire(&self, ptr: *mut LockFreeVersion) {
466 let mut retired = self.retired.lock();
467 retired.push(ptr);
468
469 if retired.len() >= RECLAMATION_THRESHOLD {
471 self.try_reclaim(&mut retired);
472 }
473 }
474
475 fn try_reclaim(&self, retired: &mut Vec<*mut LockFreeVersion>) {
477 let mut protected: HashSet<usize> = HashSet::new();
479
480 for record in &self.records {
481 if record.active.load(Ordering::Acquire) != 0 {
482 for hp in &record.hazard {
483 let ptr = hp.load(Ordering::Acquire);
484 if !ptr.is_null() {
485 protected.insert(ptr as usize);
486 }
487 }
488 }
489 }
490
491 let mut still_retired = Vec::new();
493 for ptr in retired.drain(..) {
494 if protected.contains(&(ptr as usize)) {
495 still_retired.push(ptr);
496 } else {
497 unsafe {
499 drop(Box::from_raw(ptr));
500 }
501 }
502 }
503
504 *retired = still_retired;
505 }
506}
507
508impl Drop for HazardDomain {
509 fn drop(&mut self) {
510 let mut retired = self.retired.lock();
512 for ptr in retired.drain(..) {
513 unsafe {
514 drop(Box::from_raw(ptr));
515 }
516 }
517 }
518}
519
520mod thread_id {
522 use std::sync::atomic::{AtomicUsize, Ordering};
523
524 static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
525
526 thread_local! {
527 static THREAD_ID: usize = NEXT_ID.fetch_add(1, Ordering::Relaxed);
528 }
529
530 pub fn get() -> usize {
531 THREAD_ID.with(|id| *id)
532 }
533}
534
535pub struct LockFreeMemTable {
537 data: DashMap<Vec<u8>, LockFreeVersionChain>,
539 hazard_domain: HazardDomain,
541 size_bytes: AtomicUsize,
543}
544
545impl LockFreeMemTable {
546 pub fn new() -> Self {
548 Self {
549 data: DashMap::new(),
550 hazard_domain: HazardDomain::new(MAX_THREADS),
551 size_bytes: AtomicUsize::new(0),
552 }
553 }
554
555 pub fn read(&self, key: &[u8], snapshot_ts: u64, txn_id: Option<u64>) -> Option<Vec<u8>> {
560 let chain = self.data.get(key)?;
561
562 if let Some(version) = chain.read_at(snapshot_ts, txn_id) {
564 let ptr = version as *const LockFreeVersion as *mut LockFreeVersion;
566 self.hazard_domain.protect(ptr, 0);
567
568 let result = version.value_cloned();
571
572 self.hazard_domain.clear(0);
574
575 result
576 } else {
577 None
578 }
579 }
580
581 #[inline]
595 pub fn read_with<F, R>(
596 &self,
597 key: &[u8],
598 snapshot_ts: u64,
599 txn_id: Option<u64>,
600 f: F,
601 ) -> Option<R>
602 where
603 F: FnOnce(&[u8]) -> R,
604 {
605 let chain = self.data.get(key)?;
606
607 if let Some(version) = chain.read_at(snapshot_ts, txn_id) {
608 let ptr = version as *const LockFreeVersion as *mut LockFreeVersion;
610 self.hazard_domain.protect(ptr, 0);
611
612 let result = version.get_value().map(f);
614
615 self.hazard_domain.clear(0);
617
618 result
619 } else {
620 None
621 }
622 }
623
624 pub fn write(&self, key: Vec<u8>, value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
626 let value_size = value.as_ref().map(|v| v.len()).unwrap_or(0);
627
628 let chain = self.data.entry(key.clone()).or_default();
630
631 chain.add_uncommitted(value, txn_id)?;
633
634 self.size_bytes
636 .fetch_add(key.len() + value_size + 64, Ordering::Relaxed);
637
638 Ok(())
639 }
640
641 pub fn commit(&self, txn_id: u64, commit_ts: u64, keys: &[Vec<u8>]) {
643 for key in keys {
644 if let Some(chain) = self.data.get(key) {
645 chain.commit(txn_id, commit_ts);
646 }
647 }
648 }
649
650 pub fn has_write_conflict(&self, key: &[u8], txn_id: u64) -> bool {
652 if let Some(chain) = self.data.get(key) {
653 chain.has_write_conflict(txn_id)
654 } else {
655 false
656 }
657 }
658
659 pub fn size_bytes(&self) -> usize {
661 self.size_bytes.load(Ordering::Relaxed)
662 }
663
664 pub fn len(&self) -> usize {
666 self.data.len()
667 }
668
669 pub fn is_empty(&self) -> bool {
671 self.data.is_empty()
672 }
673}
674
675unsafe impl Send for LockFreeMemTable {}
679unsafe impl Sync for LockFreeMemTable {}
680
681impl Default for LockFreeMemTable {
682 fn default() -> Self {
683 Self::new()
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use std::sync::Arc;
691 use std::thread;
692
693 #[test]
694 fn test_basic_write_read() {
695 let memtable = LockFreeMemTable::new();
696
697 memtable
699 .write(b"key1".to_vec(), Some(b"value1".to_vec()), 1)
700 .unwrap();
701
702 let val = memtable.read(b"key1", 100, Some(1));
704 assert_eq!(val, Some(b"value1".to_vec()));
705
706 let val = memtable.read(b"key1", 100, Some(2));
708 assert!(val.is_none());
709
710 memtable.commit(1, 50, &[b"key1".to_vec()]);
712 let val = memtable.read(b"key1", 100, None);
713 assert_eq!(val, Some(b"value1".to_vec()));
714 }
715
716 #[test]
717 fn test_snapshot_isolation() {
718 let memtable = LockFreeMemTable::new();
719
720 memtable
722 .write(b"key".to_vec(), Some(b"v1".to_vec()), 1)
723 .unwrap();
724 memtable.commit(1, 10, &[b"key".to_vec()]);
725
726 memtable
728 .write(b"key".to_vec(), Some(b"v2".to_vec()), 2)
729 .unwrap();
730 memtable.commit(2, 20, &[b"key".to_vec()]);
731
732 assert_eq!(memtable.read(b"key", 15, None), Some(b"v1".to_vec()));
734
735 assert_eq!(memtable.read(b"key", 25, None), Some(b"v2".to_vec()));
737 }
738
739 #[test]
740 fn test_write_conflict() {
741 let memtable = LockFreeMemTable::new();
742
743 memtable
745 .write(b"key".to_vec(), Some(b"v1".to_vec()), 1)
746 .unwrap();
747
748 let result = memtable.write(b"key".to_vec(), Some(b"v2".to_vec()), 2);
750 assert!(result.is_err());
751
752 let result = memtable.write(b"key".to_vec(), Some(b"v1_updated".to_vec()), 1);
754 assert!(result.is_ok());
755 }
756
757 #[test]
758 fn test_concurrent_reads() {
759 let memtable = Arc::new(LockFreeMemTable::new());
760
761 for i in 0..100 {
763 let key = format!("key{}", i).into_bytes();
764 let val = format!("value{}", i).into_bytes();
765 memtable.write(key.clone(), Some(val), 1).unwrap();
766 }
767 memtable.commit(
768 1,
769 10,
770 &(0..100)
771 .map(|i| format!("key{}", i).into_bytes())
772 .collect::<Vec<_>>(),
773 );
774
775 let handles: Vec<_> = (0..8)
777 .map(|t| {
778 let mt = Arc::clone(&memtable);
779 thread::spawn(move || {
780 for i in 0..100 {
781 let key = format!("key{}", i).into_bytes();
782 let expected = format!("value{}", i).into_bytes();
783 let val = mt.read(&key, 100, None);
784 assert_eq!(val, Some(expected), "Thread {} failed at key{}", t, i);
785 }
786 })
787 })
788 .collect();
789
790 for h in handles {
791 h.join().unwrap();
792 }
793 }
794
795 #[test]
796 fn test_inline_storage() {
797 let small_value = b"small".to_vec();
799 let version = LockFreeVersion::new(Some(small_value.clone()), 1);
800 assert!(version.is_inline(), "Small values should be inline");
801 assert_eq!(version.get_value(), Some(small_value.as_slice()));
802
803 let large_value = vec![42u8; 100]; let version = LockFreeVersion::new(Some(large_value.clone()), 2);
806 assert!(!version.is_inline(), "Large values should be on heap");
807 assert_eq!(version.get_value(), Some(large_value.as_slice()));
808
809 let version = LockFreeVersion::new(None, 3);
811 assert!(version.storage.is_tombstone());
812 assert_eq!(version.get_value(), None);
813 }
814
815 #[test]
816 fn test_inline_threshold() {
817 let value = vec![0u8; INLINE_VALUE_SIZE];
819 let version = LockFreeVersion::new(Some(value.clone()), 1);
820 assert!(version.is_inline(), "Values at threshold should be inline");
821
822 let value = vec![0u8; INLINE_VALUE_SIZE + 1];
824 let version = LockFreeVersion::new(Some(value), 2);
825 assert!(
826 !version.is_inline(),
827 "Values over threshold should be on heap"
828 );
829 }
830
831 #[test]
832 fn test_read_with_callback() {
833 let memtable = LockFreeMemTable::new();
834
835 memtable
836 .write(b"key1".to_vec(), Some(b"value1".to_vec()), 1)
837 .unwrap();
838 memtable.commit(1, 10, &[b"key1".to_vec()]);
839
840 let len = memtable.read_with(b"key1", 100, None, |v| v.len());
842 assert_eq!(len, Some(6)); let matches = memtable.read_with(b"key1", 100, None, |v| v == b"value1");
846 assert_eq!(matches, Some(true));
847 }
848}