1use std::collections::HashSet;
55use std::ptr;
56use std::sync::atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering};
57
58use dashmap::DashMap;
59use parking_lot::Mutex;
60
61use sochdb_core::{Result, SochDBError};
62
63const HP_PER_THREAD: usize = 2;
65
66const MAX_THREADS: usize = 128;
68
69const RECLAMATION_THRESHOLD: usize = 64;
71
72pub const INLINE_VALUE_SIZE: usize = 56;
78
79#[repr(C)]
94pub enum ValueStorage {
95 Inline {
97 len: u8,
98 data: [u8; INLINE_VALUE_SIZE],
99 },
100 Heap(Box<[u8]>),
102 Tombstone,
104}
105
106impl std::fmt::Debug for ValueStorage {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 ValueStorage::Inline { len, .. } => write!(f, "Inline(len={})", len),
110 ValueStorage::Heap(data) => write!(f, "Heap(len={})", data.len()),
111 ValueStorage::Tombstone => write!(f, "Tombstone"),
112 }
113 }
114}
115
116impl ValueStorage {
117 #[inline]
119 pub fn new(value: Option<&[u8]>) -> Self {
120 match value {
121 None => ValueStorage::Tombstone,
122 Some(v) if v.len() <= INLINE_VALUE_SIZE => {
123 let mut data = [0u8; INLINE_VALUE_SIZE];
124 data[..v.len()].copy_from_slice(v);
125 ValueStorage::Inline {
126 len: v.len() as u8,
127 data,
128 }
129 }
130 Some(v) => ValueStorage::Heap(v.to_vec().into_boxed_slice()),
131 }
132 }
133
134 #[inline]
136 pub fn as_bytes(&self) -> Option<&[u8]> {
137 match self {
138 ValueStorage::Inline { len, data } => Some(&data[..*len as usize]),
139 ValueStorage::Heap(data) => Some(data),
140 ValueStorage::Tombstone => None,
141 }
142 }
143
144 #[inline]
146 pub fn is_tombstone(&self) -> bool {
147 matches!(self, ValueStorage::Tombstone)
148 }
149
150 #[inline]
152 pub fn is_inline(&self) -> bool {
153 matches!(self, ValueStorage::Inline { .. })
154 }
155
156 #[inline]
158 pub fn len(&self) -> usize {
159 match self {
160 ValueStorage::Inline { len, .. } => *len as usize,
161 ValueStorage::Heap(data) => data.len(),
162 ValueStorage::Tombstone => 0,
163 }
164 }
165
166 #[inline]
168 pub fn is_empty(&self) -> bool {
169 self.len() == 0
170 }
171}
172
173#[derive(Debug)]
178pub struct LockFreeVersion {
179 pub storage: ValueStorage,
181 pub txn_id: u64,
183 pub commit_ts: AtomicU64,
185 pub next: AtomicPtr<LockFreeVersion>,
187}
188
189impl LockFreeVersion {
190 #[inline]
192 pub fn new_from_slice(value: Option<&[u8]>, txn_id: u64) -> Self {
193 Self {
194 storage: ValueStorage::new(value),
195 txn_id,
196 commit_ts: AtomicU64::new(0),
197 next: AtomicPtr::new(ptr::null_mut()),
198 }
199 }
200
201 pub fn new(value: Option<Vec<u8>>, txn_id: u64) -> Self {
203 Self::new_from_slice(value.as_deref(), txn_id)
204 }
205
206 #[inline]
208 pub fn get_value(&self) -> Option<&[u8]> {
209 self.storage.as_bytes()
210 }
211
212 #[inline]
216 pub fn value_cloned(&self) -> Option<Vec<u8>> {
217 self.storage.as_bytes().map(|v| v.to_vec())
218 }
219
220 #[inline]
222 pub fn is_committed(&self) -> bool {
223 self.commit_ts.load(Ordering::Acquire) > 0
224 }
225
226 #[inline]
228 pub fn get_commit_ts(&self) -> u64 {
229 self.commit_ts.load(Ordering::Acquire)
230 }
231
232 #[inline]
234 pub fn set_commit_ts(&self, ts: u64) {
235 self.commit_ts.store(ts, Ordering::Release);
236 }
237
238 #[inline]
240 pub fn is_inline(&self) -> bool {
241 self.storage.is_inline()
242 }
243}
244
245pub struct LockFreeVersionChain {
247 head: AtomicPtr<LockFreeVersion>,
249}
250
251impl Default for LockFreeVersionChain {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257impl LockFreeVersionChain {
258 pub fn new() -> Self {
260 Self {
261 head: AtomicPtr::new(ptr::null_mut()),
262 }
263 }
264
265 pub fn add_uncommitted(&self, value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
269 let new_version = Box::into_raw(Box::new(LockFreeVersion::new(value, txn_id)));
270
271 loop {
272 let head = self.head.load(Ordering::Acquire);
273
274 if !head.is_null() {
276 let head_ref = unsafe { &*head };
277 if !head_ref.is_committed() && head_ref.txn_id != txn_id {
278 unsafe {
280 drop(Box::from_raw(new_version));
281 }
282 return Err(SochDBError::Internal("Write-write conflict".into()));
283 }
284 }
285
286 unsafe {
288 (*new_version).next.store(head, Ordering::Release);
289 }
290
291 match self
293 .head
294 .compare_exchange(head, new_version, Ordering::AcqRel, Ordering::Acquire)
295 {
296 Ok(_) => return Ok(()),
297 Err(_) => continue, }
299 }
300 }
301
302 pub fn commit(&self, txn_id: u64, commit_ts: u64) -> bool {
304 let mut current = self.head.load(Ordering::Acquire);
305
306 while !current.is_null() {
307 let version = unsafe { &*current };
308 if version.txn_id == txn_id && !version.is_committed() {
309 version.set_commit_ts(commit_ts);
310 return true;
311 }
312 current = version.next.load(Ordering::Acquire);
313 }
314
315 false
316 }
317
318 pub fn read_at(
323 &self,
324 snapshot_ts: u64,
325 current_txn_id: Option<u64>,
326 ) -> Option<&LockFreeVersion> {
327 let mut current = self.head.load(Ordering::Acquire);
328
329 while !current.is_null() {
330 let version = unsafe { &*current };
331
332 if let Some(txn_id) = current_txn_id
334 && version.txn_id == txn_id
335 && !version.is_committed()
336 {
337 return Some(version);
338 }
339
340 let commit_ts = version.get_commit_ts();
342 if commit_ts > 0 && commit_ts < snapshot_ts {
343 return Some(version);
344 }
345
346 current = version.next.load(Ordering::Acquire);
347 }
348
349 None
350 }
351
352 pub fn has_write_conflict(&self, my_txn_id: u64) -> bool {
354 let current = self.head.load(Ordering::Acquire);
355
356 if !current.is_null() {
357 let version = unsafe { &*current };
358 return !version.is_committed() && version.txn_id != my_txn_id;
359 }
360
361 false
362 }
363}
364
365#[repr(C, align(64))]
369struct HazardRecord {
370 hazard: [AtomicPtr<LockFreeVersion>; HP_PER_THREAD],
372 active: AtomicU64,
374}
375
376impl HazardRecord {
377 const fn new() -> Self {
378 Self {
379 hazard: [
380 AtomicPtr::new(ptr::null_mut()),
381 AtomicPtr::new(ptr::null_mut()),
382 ],
383 active: AtomicU64::new(0),
384 }
385 }
386
387 fn try_acquire(&self, thread_id: u64) -> bool {
389 self.active
390 .compare_exchange(0, thread_id, Ordering::AcqRel, Ordering::Acquire)
391 .is_ok()
392 }
393
394 #[allow(dead_code)]
396 fn release(&self) {
397 for hp in &self.hazard {
399 hp.store(ptr::null_mut(), Ordering::Release);
400 }
401 self.active.store(0, Ordering::Release);
402 }
403}
404
405pub struct HazardDomain {
407 records: Vec<HazardRecord>,
409 retired: Mutex<Vec<*mut LockFreeVersion>>,
411}
412
413impl HazardDomain {
414 pub fn new(max_threads: usize) -> Self {
416 let mut records = Vec::with_capacity(max_threads);
417 for _ in 0..max_threads {
418 records.push(HazardRecord::new());
419 }
420
421 Self {
422 records,
423 retired: Mutex::new(Vec::with_capacity(RECLAMATION_THRESHOLD * 2)),
424 }
425 }
426
427 fn get_record(&self) -> Option<&HazardRecord> {
429 let thread_id = thread_id::get() as u64;
430
431 for record in &self.records {
433 if record.active.load(Ordering::Acquire) == thread_id {
434 return Some(record);
435 }
436 }
437
438 self.records
440 .iter()
441 .find(|record| record.try_acquire(thread_id))
442 }
443
444 #[inline]
446 pub fn protect(&self, ptr: *mut LockFreeVersion, slot: usize) -> bool {
447 if let Some(record) = self.get_record()
448 && slot < HP_PER_THREAD
449 {
450 record.hazard[slot].store(ptr, Ordering::Release);
451 std::sync::atomic::fence(Ordering::SeqCst);
452 return true;
453 }
454 false
455 }
456
457 #[inline]
459 pub fn clear(&self, slot: usize) {
460 if let Some(record) = self.get_record()
461 && slot < HP_PER_THREAD
462 {
463 record.hazard[slot].store(ptr::null_mut(), Ordering::Release);
464 }
465 }
466
467 pub fn retire(&self, ptr: *mut LockFreeVersion) {
469 let mut retired = self.retired.lock();
470 retired.push(ptr);
471
472 if retired.len() >= RECLAMATION_THRESHOLD {
474 self.try_reclaim(&mut retired);
475 }
476 }
477
478 fn try_reclaim(&self, retired: &mut Vec<*mut LockFreeVersion>) {
480 let mut protected: HashSet<usize> = HashSet::new();
482
483 for record in &self.records {
484 if record.active.load(Ordering::Acquire) != 0 {
485 for hp in &record.hazard {
486 let ptr = hp.load(Ordering::Acquire);
487 if !ptr.is_null() {
488 protected.insert(ptr as usize);
489 }
490 }
491 }
492 }
493
494 let mut still_retired = Vec::new();
496 for ptr in retired.drain(..) {
497 if protected.contains(&(ptr as usize)) {
498 still_retired.push(ptr);
499 } else {
500 unsafe {
502 drop(Box::from_raw(ptr));
503 }
504 }
505 }
506
507 *retired = still_retired;
508 }
509}
510
511impl Drop for HazardDomain {
512 fn drop(&mut self) {
513 let mut retired = self.retired.lock();
515 for ptr in retired.drain(..) {
516 unsafe {
517 drop(Box::from_raw(ptr));
518 }
519 }
520 }
521}
522
523mod thread_id {
525 use std::sync::atomic::{AtomicUsize, Ordering};
526
527 static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
528
529 thread_local! {
530 static THREAD_ID: usize = NEXT_ID.fetch_add(1, Ordering::Relaxed);
531 }
532
533 pub fn get() -> usize {
534 THREAD_ID.with(|id| *id)
535 }
536}
537
538pub struct LockFreeMemTable {
540 data: DashMap<Vec<u8>, LockFreeVersionChain>,
542 hazard_domain: HazardDomain,
544 size_bytes: AtomicUsize,
546}
547
548impl LockFreeMemTable {
549 pub fn new() -> Self {
551 Self {
552 data: DashMap::new(),
553 hazard_domain: HazardDomain::new(MAX_THREADS),
554 size_bytes: AtomicUsize::new(0),
555 }
556 }
557
558 pub fn read(&self, key: &[u8], snapshot_ts: u64, txn_id: Option<u64>) -> Option<Vec<u8>> {
563 let chain = self.data.get(key)?;
564
565 if let Some(version) = chain.read_at(snapshot_ts, txn_id) {
567 let ptr = version as *const LockFreeVersion as *mut LockFreeVersion;
569 self.hazard_domain.protect(ptr, 0);
570
571 let result = version.value_cloned();
574
575 self.hazard_domain.clear(0);
577
578 result
579 } else {
580 None
581 }
582 }
583
584 #[inline]
598 pub fn read_with<F, R>(
599 &self,
600 key: &[u8],
601 snapshot_ts: u64,
602 txn_id: Option<u64>,
603 f: F,
604 ) -> Option<R>
605 where
606 F: FnOnce(&[u8]) -> R,
607 {
608 let chain = self.data.get(key)?;
609
610 if let Some(version) = chain.read_at(snapshot_ts, txn_id) {
611 let ptr = version as *const LockFreeVersion as *mut LockFreeVersion;
613 self.hazard_domain.protect(ptr, 0);
614
615 let result = version.get_value().map(f);
617
618 self.hazard_domain.clear(0);
620
621 result
622 } else {
623 None
624 }
625 }
626
627 pub fn write(&self, key: Vec<u8>, value: Option<Vec<u8>>, txn_id: u64) -> Result<()> {
629 let value_size = value.as_ref().map(|v| v.len()).unwrap_or(0);
630
631 let chain = self.data.entry(key.clone()).or_default();
633
634 chain.add_uncommitted(value, txn_id)?;
636
637 self.size_bytes
639 .fetch_add(key.len() + value_size + 64, Ordering::Relaxed);
640
641 Ok(())
642 }
643
644 pub fn commit(&self, txn_id: u64, commit_ts: u64, keys: &[Vec<u8>]) {
646 for key in keys {
647 if let Some(chain) = self.data.get(key) {
648 chain.commit(txn_id, commit_ts);
649 }
650 }
651 }
652
653 pub fn has_write_conflict(&self, key: &[u8], txn_id: u64) -> bool {
655 if let Some(chain) = self.data.get(key) {
656 chain.has_write_conflict(txn_id)
657 } else {
658 false
659 }
660 }
661
662 pub fn size_bytes(&self) -> usize {
664 self.size_bytes.load(Ordering::Relaxed)
665 }
666
667 pub fn len(&self) -> usize {
669 self.data.len()
670 }
671
672 pub fn is_empty(&self) -> bool {
674 self.data.is_empty()
675 }
676}
677
678unsafe impl Send for LockFreeMemTable {}
682unsafe impl Sync for LockFreeMemTable {}
683
684impl Default for LockFreeMemTable {
685 fn default() -> Self {
686 Self::new()
687 }
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693 use std::sync::Arc;
694 use std::thread;
695
696 #[test]
697 fn test_basic_write_read() {
698 let memtable = LockFreeMemTable::new();
699
700 memtable
702 .write(b"key1".to_vec(), Some(b"value1".to_vec()), 1)
703 .unwrap();
704
705 let val = memtable.read(b"key1", 100, Some(1));
707 assert_eq!(val, Some(b"value1".to_vec()));
708
709 let val = memtable.read(b"key1", 100, Some(2));
711 assert!(val.is_none());
712
713 memtable.commit(1, 50, &[b"key1".to_vec()]);
715 let val = memtable.read(b"key1", 100, None);
716 assert_eq!(val, Some(b"value1".to_vec()));
717 }
718
719 #[test]
720 fn test_snapshot_isolation() {
721 let memtable = LockFreeMemTable::new();
722
723 memtable
725 .write(b"key".to_vec(), Some(b"v1".to_vec()), 1)
726 .unwrap();
727 memtable.commit(1, 10, &[b"key".to_vec()]);
728
729 memtable
731 .write(b"key".to_vec(), Some(b"v2".to_vec()), 2)
732 .unwrap();
733 memtable.commit(2, 20, &[b"key".to_vec()]);
734
735 assert_eq!(memtable.read(b"key", 15, None), Some(b"v1".to_vec()));
737
738 assert_eq!(memtable.read(b"key", 25, None), Some(b"v2".to_vec()));
740 }
741
742 #[test]
743 fn test_write_conflict() {
744 let memtable = LockFreeMemTable::new();
745
746 memtable
748 .write(b"key".to_vec(), Some(b"v1".to_vec()), 1)
749 .unwrap();
750
751 let result = memtable.write(b"key".to_vec(), Some(b"v2".to_vec()), 2);
753 assert!(result.is_err());
754
755 let result = memtable.write(b"key".to_vec(), Some(b"v1_updated".to_vec()), 1);
757 assert!(result.is_ok());
758 }
759
760 #[test]
761 fn test_concurrent_reads() {
762 let memtable = Arc::new(LockFreeMemTable::new());
763
764 for i in 0..100 {
766 let key = format!("key{}", i).into_bytes();
767 let val = format!("value{}", i).into_bytes();
768 memtable.write(key.clone(), Some(val), 1).unwrap();
769 }
770 memtable.commit(
771 1,
772 10,
773 &(0..100)
774 .map(|i| format!("key{}", i).into_bytes())
775 .collect::<Vec<_>>(),
776 );
777
778 let handles: Vec<_> = (0..8)
780 .map(|t| {
781 let mt = Arc::clone(&memtable);
782 thread::spawn(move || {
783 for i in 0..100 {
784 let key = format!("key{}", i).into_bytes();
785 let expected = format!("value{}", i).into_bytes();
786 let val = mt.read(&key, 100, None);
787 assert_eq!(val, Some(expected), "Thread {} failed at key{}", t, i);
788 }
789 })
790 })
791 .collect();
792
793 for h in handles {
794 h.join().unwrap();
795 }
796 }
797
798 #[test]
799 fn test_inline_storage() {
800 let small_value = b"small".to_vec();
802 let version = LockFreeVersion::new(Some(small_value.clone()), 1);
803 assert!(version.is_inline(), "Small values should be inline");
804 assert_eq!(version.get_value(), Some(small_value.as_slice()));
805
806 let large_value = vec![42u8; 100]; let version = LockFreeVersion::new(Some(large_value.clone()), 2);
809 assert!(!version.is_inline(), "Large values should be on heap");
810 assert_eq!(version.get_value(), Some(large_value.as_slice()));
811
812 let version = LockFreeVersion::new(None, 3);
814 assert!(version.storage.is_tombstone());
815 assert_eq!(version.get_value(), None);
816 }
817
818 #[test]
819 fn test_inline_threshold() {
820 let value = vec![0u8; INLINE_VALUE_SIZE];
822 let version = LockFreeVersion::new(Some(value.clone()), 1);
823 assert!(version.is_inline(), "Values at threshold should be inline");
824
825 let value = vec![0u8; INLINE_VALUE_SIZE + 1];
827 let version = LockFreeVersion::new(Some(value), 2);
828 assert!(
829 !version.is_inline(),
830 "Values over threshold should be on heap"
831 );
832 }
833
834 #[test]
835 fn test_read_with_callback() {
836 let memtable = LockFreeMemTable::new();
837
838 memtable
839 .write(b"key1".to_vec(), Some(b"value1".to_vec()), 1)
840 .unwrap();
841 memtable.commit(1, 10, &[b"key1".to_vec()]);
842
843 let len = memtable.read_with(b"key1", 100, None, |v| v.len());
845 assert_eq!(len, Some(6)); let matches = memtable.read_with(b"key1", 100, None, |v| v == b"value1");
849 assert_eq!(matches, Some(true));
850 }
851}