1use byteorder::{LittleEndian, ReadBytesExt};
52use parking_lot::Mutex;
53use sochdb_core::{Result, SochDBError, WalRecordType};
54use std::cell::Cell;
55use std::collections::HashSet;
56use std::fs::{File, OpenOptions};
57use std::io::{BufReader, BufWriter, Read, Write};
58use std::path::{Path, PathBuf};
59use std::sync::atomic::{AtomicU64, Ordering};
60use std::time::{Instant, SystemTime, UNIX_EPOCH};
61
62const CACHE_VALIDITY_NS: u64 = 1_000_000;
68
69thread_local! {
70 static TS_CACHE: Cell<(Instant, u64)> = Cell::new((Instant::now(), 0));
73}
74
75#[inline(always)]
92pub fn cached_timestamp_us() -> u64 {
93 TS_CACHE.with(|cache| {
94 let (instant, ts) = cache.get();
95 let elapsed_ns = instant.elapsed().as_nanos() as u64;
96
97 if elapsed_ns < CACHE_VALIDITY_NS {
98 ts + elapsed_ns / 1000
101 } else {
102 let new_ts = SystemTime::now()
104 .duration_since(UNIX_EPOCH)
105 .expect("system clock set before UNIX epoch (1970-01-01)")
106 .as_micros() as u64;
107 cache.set((Instant::now(), new_ts));
108 new_ts
109 }
110 })
111}
112
113const RECORD_HEADER_SIZE: usize = 4 + 1 + 8 + 8 + 4 + 4; const CHECKSUM_SIZE: usize = 4;
118
119const DEFAULT_TXN_BUFFER_CAPACITY: usize = 32 * 1024;
121
122#[derive(Debug)]
159pub struct TxnWalBuffer {
160 txn_id: u64,
162 buffer: Vec<u8>,
164 entry_count: usize,
166}
167
168impl TxnWalBuffer {
169 #[inline]
171 pub fn new(txn_id: u64) -> Self {
172 Self {
173 txn_id,
174 buffer: Vec::with_capacity(DEFAULT_TXN_BUFFER_CAPACITY),
175 entry_count: 0,
176 }
177 }
178
179 #[inline]
181 pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
182 Self {
183 txn_id,
184 buffer: Vec::with_capacity(capacity),
185 entry_count: 0,
186 }
187 }
188
189 #[inline]
196 pub fn append(&mut self, key: &[u8], value: &[u8]) {
197 let timestamp_us = cached_timestamp_us();
199
200 let total_len = RECORD_HEADER_SIZE + key.len() + value.len() + CHECKSUM_SIZE;
201 let entry_start = self.buffer.len();
202
203 self.buffer.extend_from_slice(&[0u8; 4]);
205
206 let mut hasher = crc32fast::Hasher::new();
207
208 let record_type_byte = WalRecordType::Data as u8;
210 self.buffer.push(record_type_byte);
211 hasher.update(&[record_type_byte]);
212
213 let txn_bytes = self.txn_id.to_le_bytes();
215 self.buffer.extend_from_slice(&txn_bytes);
216 hasher.update(&txn_bytes);
217
218 let ts_bytes = timestamp_us.to_le_bytes();
220 self.buffer.extend_from_slice(&ts_bytes);
221 hasher.update(&ts_bytes);
222
223 let key_len_bytes = (key.len() as u32).to_le_bytes();
225 self.buffer.extend_from_slice(&key_len_bytes);
226 hasher.update(&key_len_bytes);
227
228 let val_len_bytes = (value.len() as u32).to_le_bytes();
230 self.buffer.extend_from_slice(&val_len_bytes);
231 hasher.update(&val_len_bytes);
232
233 self.buffer.extend_from_slice(key);
235 hasher.update(key);
236
237 self.buffer.extend_from_slice(value);
239 hasher.update(value);
240
241 self.buffer
243 .extend_from_slice(&hasher.finalize().to_le_bytes());
244
245 let content_len = (total_len - 4) as u32;
247 self.buffer[entry_start..entry_start + 4].copy_from_slice(&content_len.to_le_bytes());
248
249 self.entry_count += 1;
250 }
251
252 #[inline]
260 pub fn flush_to_wal(&self, wal: &TxnWal) -> Result<u64> {
261 wal.flush_buffer(self)
262 }
263
264 #[inline]
266 pub fn clear(&mut self) {
267 self.buffer.clear();
268 self.entry_count = 0;
269 }
270
271 #[inline]
273 pub fn entry_count(&self) -> usize {
274 self.entry_count
275 }
276
277 #[inline]
279 pub fn bytes_buffered(&self) -> usize {
280 self.buffer.len()
281 }
282
283 #[inline]
285 pub fn is_empty(&self) -> bool {
286 self.buffer.is_empty()
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct TxnWalEntry {
293 pub record_type: WalRecordType,
295 pub txn_id: u64,
297 pub timestamp_us: u64,
299 pub key: Vec<u8>,
301 pub value: Vec<u8>,
303}
304
305impl TxnWalEntry {
306 pub fn data(txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Self {
308 Self {
309 record_type: WalRecordType::Data,
310 txn_id,
311 timestamp_us: Self::now_us(),
312 key,
313 value,
314 }
315 }
316
317 pub fn txn_begin(txn_id: u64) -> Self {
319 Self {
320 record_type: WalRecordType::TxnBegin,
321 txn_id,
322 timestamp_us: Self::now_us(),
323 key: Vec::new(),
324 value: Vec::new(),
325 }
326 }
327
328 pub fn txn_commit(txn_id: u64) -> Self {
330 Self {
331 record_type: WalRecordType::TxnCommit,
332 txn_id,
333 timestamp_us: Self::now_us(),
334 key: Vec::new(),
335 value: Vec::new(),
336 }
337 }
338
339 pub fn txn_abort(txn_id: u64) -> Self {
341 Self {
342 record_type: WalRecordType::TxnAbort,
343 txn_id,
344 timestamp_us: Self::now_us(),
345 key: Vec::new(),
346 value: Vec::new(),
347 }
348 }
349
350 pub fn checkpoint(txn_id: u64) -> Self {
352 Self {
353 record_type: WalRecordType::Checkpoint,
354 txn_id,
355 timestamp_us: Self::now_us(),
356 key: Vec::new(),
357 value: Vec::new(),
358 }
359 }
360
361 pub fn schema_change(txn_id: u64, schema_data: Vec<u8>) -> Self {
363 Self {
364 record_type: WalRecordType::SchemaChange,
365 txn_id,
366 timestamp_us: Self::now_us(),
367 key: Vec::new(),
368 value: schema_data,
369 }
370 }
371
372 #[inline]
374 fn now_us() -> u64 {
375 cached_timestamp_us()
376 }
377
378 pub fn checksum(&self) -> u32 {
384 let mut hasher = crc32fast::Hasher::new();
385 hasher.update(&[self.record_type as u8]);
386 hasher.update(&self.txn_id.to_le_bytes());
387 hasher.update(&self.timestamp_us.to_le_bytes());
388 hasher.update(&(self.key.len() as u32).to_le_bytes());
389 hasher.update(&(self.value.len() as u32).to_le_bytes());
390 hasher.update(&self.key);
391 hasher.update(&self.value);
392 hasher.finalize()
393 }
394
395 pub fn to_bytes(&self) -> Vec<u8> {
400 let total_len = RECORD_HEADER_SIZE + self.key.len() + self.value.len() + CHECKSUM_SIZE;
401 let mut buf = Vec::with_capacity(total_len);
402 let mut hasher = crc32fast::Hasher::new();
403
404 let content_len = (total_len - 4) as u32;
406 buf.extend_from_slice(&content_len.to_le_bytes());
407
408 let record_type_byte = self.record_type as u8;
410 buf.push(record_type_byte);
411 hasher.update(&[record_type_byte]);
412
413 let txn_bytes = self.txn_id.to_le_bytes();
415 buf.extend_from_slice(&txn_bytes);
416 hasher.update(&txn_bytes);
417
418 let ts_bytes = self.timestamp_us.to_le_bytes();
420 buf.extend_from_slice(&ts_bytes);
421 hasher.update(&ts_bytes);
422
423 let key_len_bytes = (self.key.len() as u32).to_le_bytes();
425 buf.extend_from_slice(&key_len_bytes);
426 hasher.update(&key_len_bytes);
427
428 let val_len_bytes = (self.value.len() as u32).to_le_bytes();
430 buf.extend_from_slice(&val_len_bytes);
431 hasher.update(&val_len_bytes);
432
433 buf.extend_from_slice(&self.key);
435 hasher.update(&self.key);
436
437 buf.extend_from_slice(&self.value);
439 hasher.update(&self.value);
440
441 buf.extend_from_slice(&hasher.finalize().to_le_bytes());
443
444 buf
445 }
446
447 pub fn from_reader<R: Read>(reader: &mut R) -> Result<Self> {
454 let content_len = reader.read_u32::<LittleEndian>()?;
456 if content_len < (RECORD_HEADER_SIZE - 4 + CHECKSUM_SIZE) as u32 {
457 return Err(SochDBError::Corruption("WAL entry too short".into()));
458 }
459
460 let record_type_byte = reader.read_u8()?;
462 let record_type = WalRecordType::try_from(record_type_byte).map_err(|_| {
463 SochDBError::Corruption(format!("Invalid record type: {}", record_type_byte))
464 })?;
465
466 let txn_id = reader.read_u64::<LittleEndian>()?;
468
469 let timestamp_us = reader.read_u64::<LittleEndian>()?;
471
472 let key_len = reader.read_u32::<LittleEndian>()? as usize;
474
475 let value_len = reader.read_u32::<LittleEndian>()? as usize;
477
478 let mut key = vec![0u8; key_len];
480 reader.read_exact(&mut key)?;
481
482 let mut value = vec![0u8; value_len];
484 reader.read_exact(&mut value)?;
485
486 let stored_checksum = reader.read_u32::<LittleEndian>()?;
488
489 let entry = Self {
490 record_type,
491 txn_id,
492 timestamp_us,
493 key,
494 value,
495 };
496
497 if entry.checksum() != stored_checksum {
499 return Err(SochDBError::Corruption(format!(
500 "WAL checksum mismatch for txn_id {}: expected {}, got {}",
501 txn_id,
502 entry.checksum(),
503 stored_checksum
504 )));
505 }
506
507 Ok(entry)
508 }
509}
510
511pub struct TxnWal {
513 path: PathBuf,
515 writer: Mutex<BufWriter<File>>,
517 next_txn_id: AtomicU64,
519 sequence: AtomicU64,
521 bytes_since_sync: AtomicU64,
523 cached_timestamp_us: AtomicU64,
526}
527
528impl TxnWal {
529 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
531 let path = path.as_ref().to_path_buf();
532
533 if let Some(parent) = path.parent() {
535 std::fs::create_dir_all(parent)?;
536 }
537
538 let file = OpenOptions::new()
539 .create(true)
540 .append(true)
541 .read(true)
542 .open(&path)?;
543
544 let now_us = cached_timestamp_us();
547
548 let wal = Self {
549 path,
550 writer: Mutex::new(BufWriter::with_capacity(256 * 1024, file)),
551 next_txn_id: AtomicU64::new(1),
552 sequence: AtomicU64::new(0),
553 bytes_since_sync: AtomicU64::new(0),
554 cached_timestamp_us: AtomicU64::new(now_us),
555 };
556
557 wal.recover_state()?;
559
560 Ok(wal)
561 }
562
563 fn recover_state(&self) -> Result<()> {
570 let file = File::open(&self.path)?;
571 let mut reader = BufReader::new(file);
572 let mut count: u64 = 0;
573
574 let our_pid = std::process::id() as u64;
579 let pid_base = our_pid << 32;
580 let mut max_our_counter: u64 = 0;
581
582 loop {
583 match TxnWalEntry::from_reader(&mut reader) {
584 Ok(entry) => {
585 count += 1;
586 let entry_pid = entry.txn_id >> 32;
588 if entry_pid == our_pid {
589 let entry_counter = entry.txn_id & 0xFFFF_FFFF;
590 if entry_counter > max_our_counter {
591 max_our_counter = entry_counter;
592 }
593 }
594 }
595 Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
596 break;
597 }
598 Err(_) => {
599 break;
601 }
602 }
603 }
604
605 let next_id = pid_base + max_our_counter + 1;
611
612 self.next_txn_id.store(next_id, Ordering::SeqCst);
613 self.sequence.store(count, Ordering::SeqCst);
614
615 Ok(())
616 }
617
618 #[inline]
623 fn get_cached_timestamp(&self) -> u64 {
624 let cached = self.cached_timestamp_us.load(Ordering::Relaxed);
626
627 let seq = self.sequence.load(Ordering::Relaxed);
630 if seq & 0x3FF == 0 {
631 let now_us = cached_timestamp_us();
633 self.cached_timestamp_us.store(now_us, Ordering::Relaxed);
634 return now_us;
635 }
636
637 cached
638 }
639
640 pub fn append(&self, entry: &TxnWalEntry) -> Result<u64> {
656 let bytes = entry.to_bytes();
657 let mut writer = self.writer.lock();
658
659 writer.write_all(&bytes)?;
660 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
664 self.bytes_since_sync
665 .fetch_add(bytes.len() as u64, Ordering::Relaxed);
666
667 Ok(seq)
668 }
669
670 #[inline]
674 pub fn append_no_flush(&self, entry: &TxnWalEntry) -> Result<u64> {
675 let bytes = entry.to_bytes();
676 let mut writer = self.writer.lock();
677
678 writer.write_all(&bytes)?;
679 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
682 self.bytes_since_sync
683 .fetch_add(bytes.len() as u64, Ordering::Relaxed);
684
685 Ok(seq)
686 }
687
688 #[inline]
690 pub fn write_no_flush(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<u64> {
691 let entry = TxnWalEntry::data(txn_id, key, value);
692 self.append_no_flush(&entry)
693 }
694
695 #[inline]
700 pub fn write_no_flush_refs(&self, txn_id: u64, key: &[u8], value: &[u8]) -> Result<u64> {
701 let timestamp_us = self.get_cached_timestamp();
703
704 let total_len = RECORD_HEADER_SIZE + key.len() + value.len() + CHECKSUM_SIZE;
705 let mut hasher = crc32fast::Hasher::new();
706
707 let mut writer = self.writer.lock();
708
709 let content_len = (total_len - 4) as u32;
711 writer.write_all(&content_len.to_le_bytes())?;
712
713 let record_type_byte = WalRecordType::Data as u8;
715 writer.write_all(&[record_type_byte])?;
716 hasher.update(&[record_type_byte]);
717
718 let txn_bytes = txn_id.to_le_bytes();
720 writer.write_all(&txn_bytes)?;
721 hasher.update(&txn_bytes);
722
723 let ts_bytes = timestamp_us.to_le_bytes();
725 writer.write_all(&ts_bytes)?;
726 hasher.update(&ts_bytes);
727
728 let key_len_bytes = (key.len() as u32).to_le_bytes();
730 writer.write_all(&key_len_bytes)?;
731 hasher.update(&key_len_bytes);
732
733 let val_len_bytes = (value.len() as u32).to_le_bytes();
735 writer.write_all(&val_len_bytes)?;
736 hasher.update(&val_len_bytes);
737
738 writer.write_all(key)?;
740 hasher.update(key);
741
742 writer.write_all(value)?;
744 hasher.update(value);
745
746 writer.write_all(&hasher.finalize().to_le_bytes())?;
748
749 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
750 self.bytes_since_sync
751 .fetch_add(total_len as u64, Ordering::Relaxed);
752
753 Ok(seq)
754 }
755
756 pub fn flush(&self) -> Result<()> {
758 let mut writer = self.writer.lock();
759 writer.flush()?;
760 Ok(())
761 }
762
763 pub fn append_sync(&self, entry: &TxnWalEntry) -> Result<u64> {
765 let seq = self.append(entry)?;
766 self.sync()?;
767 Ok(seq)
768 }
769
770 pub fn sync(&self) -> Result<()> {
778 let mut writer = self.writer.lock();
779 writer.flush()?;
780 writer.get_ref().sync_all()?;
781 self.bytes_since_sync.store(0, Ordering::Relaxed);
782 Ok(())
783 }
784
785 #[inline]
799 pub fn flush_buffer(&self, buffer: &TxnWalBuffer) -> Result<u64> {
800 if buffer.is_empty() {
801 return Ok(0);
802 }
803
804 let mut writer = self.writer.lock();
805 writer.write_all(&buffer.buffer)?;
806
807 let seq = self
808 .sequence
809 .fetch_add(buffer.entry_count as u64, Ordering::SeqCst);
810 self.bytes_since_sync
811 .fetch_add(buffer.buffer.len() as u64, Ordering::Relaxed);
812
813 Ok(seq)
814 }
815
816 pub fn size_bytes(&self) -> u64 {
818 std::fs::metadata(&self.path).map(|m| m.len()).unwrap_or(0)
819 }
820
821 pub fn alloc_txn_id(&self) -> u64 {
823 self.next_txn_id.fetch_add(1, Ordering::SeqCst)
824 }
825
826 pub fn begin_transaction(&self) -> Result<u64> {
828 let txn_id = self.alloc_txn_id();
829 let entry = TxnWalEntry::txn_begin(txn_id);
830 self.append(&entry)?;
831 Ok(txn_id)
832 }
833
834 pub fn commit_transaction(&self, txn_id: u64) -> Result<()> {
850 self.flush()?;
852
853 let entry = TxnWalEntry::txn_commit(txn_id);
855 self.append_sync(&entry)?;
856 Ok(())
857 }
858
859 pub fn commit_durable_batch(&self, txn_ids: &[u64]) -> Result<()> {
877 for &txn_id in txn_ids {
879 let entry = TxnWalEntry::txn_commit(txn_id);
880 self.append_no_flush(&entry)?;
881 }
882
883 self.flush()?;
885 self.sync()?;
886 Ok(())
887 }
888
889 pub fn abort_transaction(&self, txn_id: u64) -> Result<()> {
891 let entry = TxnWalEntry::txn_abort(txn_id);
892 self.append(&entry)?;
893 Ok(())
894 }
895
896 pub fn write(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<u64> {
898 let entry = TxnWalEntry::data(txn_id, key, value);
899 self.append(&entry)
900 }
901
902 #[allow(clippy::type_complexity)]
906 pub fn replay_for_recovery(&self) -> Result<(Vec<(Vec<u8>, Vec<u8>)>, usize)> {
907 let file = File::open(&self.path)?;
908 let mut reader = BufReader::new(file);
909
910 let mut pending_writes: std::collections::HashMap<u64, Vec<(Vec<u8>, Vec<u8>)>> =
911 std::collections::HashMap::new();
912 let mut result = Vec::new();
913 let mut txn_count = 0;
914
915 loop {
920 match TxnWalEntry::from_reader(&mut reader) {
921 Ok(entry) => match entry.record_type {
922 WalRecordType::TxnBegin => {
923 pending_writes.insert(entry.txn_id, Vec::new());
924 }
925 WalRecordType::Data => {
926 pending_writes
930 .entry(entry.txn_id)
931 .or_insert_with(Vec::new)
932 .push((entry.key, entry.value));
933 }
934 WalRecordType::TxnCommit => {
935 if let Some(writes) = pending_writes.remove(&entry.txn_id) {
936 result.extend(writes);
937 txn_count += 1;
938 }
939 }
940 WalRecordType::TxnAbort => {
941 pending_writes.remove(&entry.txn_id);
942 }
943 _ => {}
944 },
945 Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
946 break;
947 }
948 Err(_) => {
949 break;
950 }
951 }
952 }
953
954 Ok((result, txn_count))
955 }
956
957 pub fn replay<F>(&self, mut callback: F) -> Result<u64>
959 where
960 F: FnMut(TxnWalEntry) -> Result<()>,
961 {
962 let file = File::open(&self.path)?;
963 let mut reader = BufReader::new(file);
964 let mut count = 0u64;
965
966 loop {
967 match TxnWalEntry::from_reader(&mut reader) {
968 Ok(entry) => {
969 callback(entry)?;
970 count += 1;
971 }
972 Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
973 break;
974 }
975 Err(e) => {
976 eprintln!("WAL replay warning: {:?}", e);
978 break;
979 }
980 }
981 }
982
983 Ok(count)
984 }
985
986 pub fn truncate(&self) -> Result<()> {
997 let mut writer = self.writer.lock();
998 writer.flush()?;
1000 let file = writer.get_ref();
1001 file.set_len(0)?;
1002 file.sync_all()?;
1003 self.sequence.store(0, Ordering::SeqCst);
1004 self.bytes_since_sync.store(0, Ordering::Relaxed);
1005 Ok(())
1006 }
1007
1008 pub fn write_checkpoint(&self) -> Result<u64> {
1010 let entry = TxnWalEntry::checkpoint(0);
1011 self.append_sync(&entry)
1012 }
1013
1014 pub fn append_clr(
1020 &self,
1021 txn_id: u64,
1022 _original_lsn: u64,
1023 undo_next_lsn: Option<u64>,
1024 undo_data: &[u8],
1025 ) -> Result<u64> {
1026 let key = undo_next_lsn.unwrap_or(0).to_le_bytes().to_vec();
1028 let entry = TxnWalEntry {
1029 record_type: WalRecordType::CompensationLogRecord,
1030 txn_id,
1031 timestamp_us: TxnWalEntry::now_us(),
1032 key, value: undo_data.to_vec(),
1034 };
1035 self.append(&entry)
1036 }
1037
1038 pub fn write_checkpoint_with_data(&self, checkpoint_data: &[u8]) -> Result<u64> {
1040 let entry = TxnWalEntry {
1041 record_type: WalRecordType::Checkpoint,
1042 txn_id: 0,
1043 timestamp_us: TxnWalEntry::now_us(),
1044 key: Vec::new(),
1045 value: checkpoint_data.to_vec(),
1046 };
1047 self.append_sync(&entry)
1048 }
1049
1050 pub fn write_checkpoint_end(&self, checkpoint_data: &[u8]) -> Result<u64> {
1052 let entry = TxnWalEntry {
1053 record_type: WalRecordType::CheckpointEnd,
1054 txn_id: 0,
1055 timestamp_us: TxnWalEntry::now_us(),
1056 key: Vec::new(),
1057 value: checkpoint_data.to_vec(),
1058 };
1059 self.append_sync(&entry)
1060 }
1061
1062 pub fn sequence(&self) -> u64 {
1064 self.sequence.load(Ordering::SeqCst)
1065 }
1066
1067 pub fn bytes_since_sync(&self) -> u64 {
1069 self.bytes_since_sync.load(Ordering::Relaxed)
1070 }
1071
1072 pub fn path(&self) -> &Path {
1074 &self.path
1075 }
1076}
1077
1078#[derive(Debug, Clone, Default)]
1080pub struct TxnWalStats {
1081 pub entries_written: u64,
1083 pub bytes_since_sync: u64,
1085 pub next_txn_id: u64,
1087}
1088
1089#[allow(dead_code)]
1102pub struct ShardedWal {
1103 shards: Vec<parking_lot::Mutex<WalShard>>,
1105 num_shards: usize,
1107 central_writer: parking_lot::Mutex<BufWriter<File>>,
1109 next_txn_id: AtomicU64,
1111 sequence: AtomicU64,
1113 path: PathBuf,
1115}
1116
1117struct WalShard {
1119 buffer: Vec<u8>,
1121 entry_count: usize,
1123}
1124
1125impl WalShard {
1126 fn new() -> Self {
1127 Self {
1128 buffer: Vec::with_capacity(64 * 1024), entry_count: 0,
1130 }
1131 }
1132
1133 fn append(&mut self, entry: &TxnWalEntry) {
1134 let bytes = entry.to_bytes();
1135 self.buffer.extend_from_slice(&bytes);
1136 self.entry_count += 1;
1137 }
1138
1139 fn is_empty(&self) -> bool {
1140 self.buffer.is_empty()
1141 }
1142
1143 fn drain(&mut self) -> Vec<u8> {
1144 self.entry_count = 0;
1145 std::mem::take(&mut self.buffer)
1146 }
1147}
1148
1149impl ShardedWal {
1150 pub fn new<P: AsRef<Path>>(path: P, num_shards: usize) -> Result<Self> {
1154 let path = path.as_ref().to_path_buf();
1155
1156 if let Some(parent) = path.parent() {
1157 std::fs::create_dir_all(parent)?;
1158 }
1159
1160 let file = std::fs::OpenOptions::new()
1161 .create(true)
1162 .append(true)
1163 .read(true)
1164 .open(&path)?;
1165
1166 let num_shards = num_shards.next_power_of_two();
1168 let shards: Vec<_> = (0..num_shards)
1169 .map(|_| parking_lot::Mutex::new(WalShard::new()))
1170 .collect();
1171
1172 Ok(Self {
1173 shards,
1174 num_shards,
1175 central_writer: parking_lot::Mutex::new(BufWriter::with_capacity(256 * 1024, file)),
1176 next_txn_id: AtomicU64::new(1),
1177 sequence: AtomicU64::new(0),
1178 path,
1179 })
1180 }
1181
1182 #[inline]
1184 fn shard_idx(&self, txn_id: u64) -> usize {
1185 (txn_id as usize) & (self.num_shards - 1)
1186 }
1187
1188 pub fn append(&self, entry: &TxnWalEntry) -> u64 {
1190 let shard_idx = self.shard_idx(entry.txn_id);
1191 let mut shard = self.shards[shard_idx].lock();
1192 shard.append(entry);
1193 self.sequence.fetch_add(1, Ordering::SeqCst)
1194 }
1195
1196 pub fn alloc_txn_id(&self) -> u64 {
1198 self.next_txn_id.fetch_add(1, Ordering::SeqCst)
1199 }
1200
1201 pub fn flush(&self) -> Result<()> {
1203 let mut central = self.central_writer.lock();
1204
1205 for shard in &self.shards {
1207 let mut shard_guard = shard.lock();
1208 if !shard_guard.is_empty() {
1209 let data = shard_guard.drain();
1210 central.write_all(&data)?;
1211 }
1212 }
1213
1214 central.flush()?;
1215 Ok(())
1216 }
1217
1218 pub fn sync(&self) -> Result<()> {
1220 self.flush()?;
1221 let central = self.central_writer.lock();
1222 central.get_ref().sync_all()?;
1223 Ok(())
1224 }
1225
1226 pub fn begin_transaction(&self) -> Result<u64> {
1228 let txn_id = self.alloc_txn_id();
1229 let entry = TxnWalEntry::txn_begin(txn_id);
1230 self.append(&entry);
1231 Ok(txn_id)
1232 }
1233
1234 pub fn write(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<u64> {
1236 let entry = TxnWalEntry::data(txn_id, key, value);
1237 Ok(self.append(&entry))
1238 }
1239
1240 pub fn commit_transaction(&self, txn_id: u64) -> Result<u64> {
1242 let entry = TxnWalEntry::txn_commit(txn_id);
1243 let seq = self.append(&entry);
1244 self.sync()?; Ok(seq)
1246 }
1247
1248 pub fn stats(&self) -> ShardedWalStats {
1250 let mut shard_entry_counts = Vec::with_capacity(self.num_shards);
1251 for shard in &self.shards {
1252 shard_entry_counts.push(shard.lock().entry_count);
1253 }
1254
1255 ShardedWalStats {
1256 num_shards: self.num_shards,
1257 total_entries: self.sequence.load(Ordering::SeqCst),
1258 next_txn_id: self.next_txn_id.load(Ordering::SeqCst),
1259 shard_entry_counts,
1260 }
1261 }
1262}
1263
1264#[derive(Debug, Clone)]
1266pub struct ShardedWalStats {
1267 pub num_shards: usize,
1268 pub total_entries: u64,
1269 pub next_txn_id: u64,
1270 pub shard_entry_counts: Vec<usize>,
1271}
1272
1273#[derive(Debug, Clone, Default)]
1275pub struct CrashRecoveryStats {
1276 pub total_records: u64,
1278 pub committed_txns: u64,
1280 pub rolled_back_txns: u64,
1282 pub aborted_txns: u64,
1284 pub recovered_writes: u64,
1286 pub torn_records: u64,
1288 pub bytes_read: u64,
1290 pub recovery_duration_us: u64,
1292 pub max_txn_id: u64,
1294}
1295
1296impl TxnWal {
1297 pub fn stats(&self) -> TxnWalStats {
1299 TxnWalStats {
1300 entries_written: self.sequence.load(Ordering::SeqCst),
1301 bytes_since_sync: self.bytes_since_sync.load(Ordering::Relaxed),
1302 next_txn_id: self.next_txn_id.load(Ordering::SeqCst),
1303 }
1304 }
1305
1306 #[allow(clippy::type_complexity)]
1316 pub fn crash_recovery(&self) -> Result<(Vec<(Vec<u8>, Vec<u8>)>, CrashRecoveryStats)> {
1317 let start_time = std::time::Instant::now();
1318 let file = File::open(&self.path)?;
1319 let file_size = file.metadata()?.len();
1320 let mut reader = BufReader::new(file);
1321
1322 let mut stats = CrashRecoveryStats {
1323 bytes_read: file_size,
1324 ..Default::default()
1325 };
1326
1327 let mut committed_txns: HashSet<u64> = HashSet::new();
1328 let mut aborted_txns: HashSet<u64> = HashSet::new();
1329 let mut pending_writes: std::collections::HashMap<u64, Vec<(Vec<u8>, Vec<u8>)>> =
1330 std::collections::HashMap::new();
1331 let mut all_txns: HashSet<u64> = HashSet::new();
1332
1333 loop {
1335 match TxnWalEntry::from_reader(&mut reader) {
1336 Ok(entry) => {
1337 stats.total_records += 1;
1338 if entry.txn_id > stats.max_txn_id {
1339 stats.max_txn_id = entry.txn_id;
1340 }
1341
1342 match entry.record_type {
1343 WalRecordType::TxnBegin => {
1344 pending_writes.insert(entry.txn_id, Vec::new());
1345 all_txns.insert(entry.txn_id);
1346 }
1347 WalRecordType::Data => {
1348 if let Some(writes) = pending_writes.get_mut(&entry.txn_id) {
1349 writes.push((entry.key, entry.value));
1350 }
1351 }
1352 WalRecordType::TxnCommit => {
1353 committed_txns.insert(entry.txn_id);
1354 }
1355 WalRecordType::TxnAbort => {
1356 pending_writes.remove(&entry.txn_id);
1357 aborted_txns.insert(entry.txn_id);
1358 }
1359 _ => {}
1360 }
1361 }
1362 Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1363 break;
1365 }
1366 Err(_) => {
1367 stats.torn_records += 1;
1369 break;
1370 }
1371 }
1372 }
1373
1374 let mut result = Vec::new();
1376 for (txn_id, writes) in &pending_writes {
1377 if committed_txns.contains(txn_id) {
1378 stats.committed_txns += 1;
1379 stats.recovered_writes += writes.len() as u64;
1380 result.extend(writes.clone());
1381 }
1382 }
1383
1384 stats.aborted_txns = aborted_txns.len() as u64;
1386 stats.rolled_back_txns = all_txns.len() as u64 - stats.committed_txns - stats.aborted_txns;
1387
1388 stats.recovery_duration_us = start_time.elapsed().as_micros() as u64;
1389
1390 Ok((result, stats))
1391 }
1392}
1393
1394#[cfg(test)]
1395mod tests {
1396 use super::*;
1397 use tempfile::tempdir;
1398
1399 #[test]
1400 fn test_wal_entry_roundtrip() {
1401 let entry = TxnWalEntry::data(42, b"key".to_vec(), b"value".to_vec());
1402 let bytes = entry.to_bytes();
1403
1404 let mut cursor = std::io::Cursor::new(bytes);
1405 let recovered = TxnWalEntry::from_reader(&mut cursor).unwrap();
1406
1407 assert_eq!(recovered.record_type, WalRecordType::Data);
1408 assert_eq!(recovered.txn_id, 42);
1409 assert_eq!(recovered.key, b"key");
1410 assert_eq!(recovered.value, b"value");
1411 }
1412
1413 #[test]
1414 fn test_wal_append_and_replay() {
1415 let dir = tempdir().unwrap();
1416 let wal_path = dir.path().join("test.wal");
1417
1418 {
1420 let wal = TxnWal::new(&wal_path).unwrap();
1421 let txn_id = wal.begin_transaction().unwrap();
1422 wal.write(txn_id, b"k1".to_vec(), b"v1".to_vec()).unwrap();
1423 wal.write(txn_id, b"k2".to_vec(), b"v2".to_vec()).unwrap();
1424 wal.commit_transaction(txn_id).unwrap();
1425 }
1426
1427 {
1429 let wal = TxnWal::new(&wal_path).unwrap();
1430 let (writes, txn_count) = wal.replay_for_recovery().unwrap();
1431
1432 assert_eq!(txn_count, 1);
1433 assert_eq!(writes.len(), 2);
1434 assert_eq!(writes[0], (b"k1".to_vec(), b"v1".to_vec()));
1435 assert_eq!(writes[1], (b"k2".to_vec(), b"v2".to_vec()));
1436 }
1437 }
1438
1439 #[test]
1440 fn test_uncommitted_transaction_rollback() {
1441 let dir = tempdir().unwrap();
1442 let wal_path = dir.path().join("test.wal");
1443
1444 {
1446 let wal = TxnWal::new(&wal_path).unwrap();
1447
1448 let txn1 = wal.begin_transaction().unwrap();
1450 wal.write(txn1, b"committed".to_vec(), b"yes".to_vec())
1451 .unwrap();
1452 wal.commit_transaction(txn1).unwrap();
1453
1454 let txn2 = wal.begin_transaction().unwrap();
1456 wal.write(txn2, b"uncommitted".to_vec(), b"no".to_vec())
1457 .unwrap();
1458 }
1460
1461 {
1463 let wal = TxnWal::new(&wal_path).unwrap();
1464 let (writes, txn_count) = wal.replay_for_recovery().unwrap();
1465
1466 assert_eq!(txn_count, 1); assert_eq!(writes.len(), 1);
1468 assert_eq!(writes[0], (b"committed".to_vec(), b"yes".to_vec()));
1469 }
1470 }
1471
1472 #[test]
1473 fn test_aborted_transaction() {
1474 let dir = tempdir().unwrap();
1475 let wal_path = dir.path().join("test.wal");
1476
1477 {
1478 let wal = TxnWal::new(&wal_path).unwrap();
1479
1480 let txn = wal.begin_transaction().unwrap();
1481 wal.write(txn, b"aborted".to_vec(), b"data".to_vec())
1482 .unwrap();
1483 wal.abort_transaction(txn).unwrap();
1484 }
1485
1486 {
1487 let wal = TxnWal::new(&wal_path).unwrap();
1488 let (writes, txn_count) = wal.replay_for_recovery().unwrap();
1489
1490 assert_eq!(txn_count, 0);
1491 assert!(writes.is_empty());
1492 }
1493 }
1494
1495 #[test]
1496 fn test_checksum_validation() {
1497 let entry = TxnWalEntry::data(1, b"key".to_vec(), b"value".to_vec());
1498 let mut bytes = entry.to_bytes();
1499
1500 let len = bytes.len();
1502 bytes[len - 1] ^= 0xFF;
1503
1504 let mut cursor = std::io::Cursor::new(bytes);
1505 let result = TxnWalEntry::from_reader(&mut cursor);
1506
1507 assert!(result.is_err());
1508 }
1509
1510 #[test]
1511 fn test_crash_recovery_with_stats() {
1512 let dir = tempdir().unwrap();
1513 let wal_path = dir.path().join("test.wal");
1514
1515 {
1517 let wal = TxnWal::new(&wal_path).unwrap();
1518
1519 let txn1 = wal.begin_transaction().unwrap();
1521 wal.write(txn1, b"k1".to_vec(), b"v1".to_vec()).unwrap();
1522 wal.write(txn1, b"k2".to_vec(), b"v2".to_vec()).unwrap();
1523 wal.commit_transaction(txn1).unwrap();
1524
1525 let txn2 = wal.begin_transaction().unwrap();
1527 wal.write(txn2, b"aborted_key".to_vec(), b"aborted_val".to_vec())
1528 .unwrap();
1529 wal.abort_transaction(txn2).unwrap();
1530
1531 let txn3 = wal.begin_transaction().unwrap();
1533 wal.write(txn3, b"k3".to_vec(), b"v3".to_vec()).unwrap();
1534 wal.commit_transaction(txn3).unwrap();
1535
1536 let txn4 = wal.begin_transaction().unwrap();
1538 wal.write(txn4, b"uncommitted".to_vec(), b"data".to_vec())
1539 .unwrap();
1540 }
1542
1543 {
1545 let wal = TxnWal::new(&wal_path).unwrap();
1546 let (writes, stats) = wal.crash_recovery().unwrap();
1547
1548 assert_eq!(writes.len(), 3);
1550 assert_eq!(stats.committed_txns, 2);
1551 assert_eq!(stats.aborted_txns, 1);
1552 assert_eq!(stats.rolled_back_txns, 1); assert_eq!(stats.recovered_writes, 3);
1554 assert!(stats.recovery_duration_us > 0);
1555 }
1556 }
1557
1558 #[test]
1559 fn test_torn_write_detection() {
1560 let dir = tempdir().unwrap();
1561 let wal_path = dir.path().join("test.wal");
1562
1563 {
1565 let wal = TxnWal::new(&wal_path).unwrap();
1566 let txn = wal.begin_transaction().unwrap();
1567 wal.write(txn, b"key".to_vec(), b"value".to_vec()).unwrap();
1568 wal.commit_transaction(txn).unwrap();
1569 }
1570
1571 {
1573 use std::io::Write;
1574 let mut file = std::fs::OpenOptions::new()
1575 .append(true)
1576 .open(&wal_path)
1577 .unwrap();
1578 file.write_all(&[0x10, 0x00, 0x00, 0x00, 0xFF, 0xFF])
1580 .unwrap();
1581 }
1582
1583 {
1585 let wal = TxnWal::new(&wal_path).unwrap();
1586 let (writes, stats) = wal.crash_recovery().unwrap();
1587
1588 assert_eq!(writes.len(), 1);
1590 assert_eq!(stats.committed_txns, 1);
1591 assert_eq!(stats.torn_records, 1);
1592 }
1593 }
1594
1595 #[test]
1596 fn test_crc32_determinism() {
1597 let mut entry1 = TxnWalEntry::data(42, b"key".to_vec(), b"value".to_vec());
1599 entry1.timestamp_us = 12345; let mut entry2 = TxnWalEntry::data(42, b"key".to_vec(), b"value".to_vec());
1602 entry2.timestamp_us = 12345; assert_eq!(entry1.checksum(), entry2.checksum());
1605
1606 let mut entry3 = TxnWalEntry::data(42, b"key".to_vec(), b"different".to_vec());
1608 entry3.timestamp_us = 12345;
1609 assert_ne!(entry1.checksum(), entry3.checksum());
1610
1611 let bytes = entry1.to_bytes();
1613 let mut cursor = std::io::Cursor::new(bytes);
1614 let recovered = TxnWalEntry::from_reader(&mut cursor).unwrap();
1615 assert_eq!(recovered.checksum(), entry1.checksum());
1616 }
1617}