1use serde::{Deserialize, Serialize};
24use std::collections::HashSet;
25use std::sync::atomic::{AtomicU64, Ordering};
26
27pub type TxnId = u64;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum TxnState {
33 Active,
34 Committed,
35 Aborted,
36}
37
38#[repr(u8)]
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41pub enum WalRecordType {
42 Data = 0x01,
44 TxnBegin = 0x10,
46 TxnCommit = 0x11,
48 TxnAbort = 0x12,
50 Checkpoint = 0x20,
52 SchemaChange = 0x30,
54 CompensationLogRecord = 0x40,
56 CheckpointEnd = 0x21,
58 PageUpdate = 0x02,
60}
61
62impl TryFrom<u8> for WalRecordType {
63 type Error = ();
64
65 fn try_from(value: u8) -> Result<Self, Self::Error> {
66 match value {
67 0x01 => Ok(WalRecordType::Data),
68 0x10 => Ok(WalRecordType::TxnBegin),
69 0x11 => Ok(WalRecordType::TxnCommit),
70 0x12 => Ok(WalRecordType::TxnAbort),
71 0x20 => Ok(WalRecordType::Checkpoint),
72 0x21 => Ok(WalRecordType::CheckpointEnd),
73 0x30 => Ok(WalRecordType::SchemaChange),
74 0x40 => Ok(WalRecordType::CompensationLogRecord),
75 0x02 => Ok(WalRecordType::PageUpdate),
76 _ => Err(()),
77 }
78 }
79}
80
81pub type Lsn = u64;
88
89pub type PageId = u64;
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct AriesTransactionEntry {
95 pub txn_id: TxnId,
97 pub state: TxnState,
99 pub last_lsn: Lsn,
101 pub undo_next_lsn: Option<Lsn>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct AriesDirtyPageEntry {
108 pub page_id: PageId,
110 pub rec_lsn: Lsn,
112}
113
114#[derive(Debug, Clone, Default, Serialize, Deserialize)]
116pub struct AriesCheckpointData {
117 pub active_transactions: Vec<AriesTransactionEntry>,
119 pub dirty_pages: Vec<AriesDirtyPageEntry>,
121 pub begin_checkpoint_lsn: Lsn,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct TxnWrite {
128 pub key: Vec<u8>,
130 pub value: Option<Vec<u8>>,
132 pub table: String,
134}
135
136#[derive(Debug, Clone, Hash, PartialEq, Eq)]
138pub struct TxnRead {
139 pub key: Vec<u8>,
140 pub table: String,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct TxnWalEntry {
152 pub record_type: WalRecordType,
154 pub txn_id: TxnId,
156 pub timestamp_us: u64,
158 pub key: Option<Vec<u8>>,
160 pub value: Option<Vec<u8>>,
162 pub table: Option<String>,
164 pub checksum: u32,
166 #[serde(default)]
168 pub lsn: Lsn,
169 #[serde(default)]
171 pub prev_lsn: Option<Lsn>,
172 #[serde(default)]
174 pub page_id: Option<PageId>,
175 #[serde(default)]
177 pub undo_info: Option<Vec<u8>>,
178 #[serde(default)]
180 pub undo_next_lsn: Option<Lsn>,
181}
182
183impl TxnWalEntry {
184 pub fn new_begin(txn_id: TxnId, timestamp_us: u64) -> Self {
185 Self {
186 record_type: WalRecordType::TxnBegin,
187 txn_id,
188 timestamp_us,
189 key: None,
190 value: None,
191 table: None,
192 checksum: 0,
193 lsn: 0,
194 prev_lsn: None,
195 page_id: None,
196 undo_info: None,
197 undo_next_lsn: None,
198 }
199 }
200
201 pub fn new_commit(txn_id: TxnId, timestamp_us: u64) -> Self {
202 Self {
203 record_type: WalRecordType::TxnCommit,
204 txn_id,
205 timestamp_us,
206 key: None,
207 value: None,
208 table: None,
209 checksum: 0,
210 lsn: 0,
211 prev_lsn: None,
212 page_id: None,
213 undo_info: None,
214 undo_next_lsn: None,
215 }
216 }
217
218 pub fn new_abort(txn_id: TxnId, timestamp_us: u64) -> Self {
219 Self {
220 record_type: WalRecordType::TxnAbort,
221 txn_id,
222 timestamp_us,
223 key: None,
224 value: None,
225 table: None,
226 checksum: 0,
227 lsn: 0,
228 prev_lsn: None,
229 page_id: None,
230 undo_info: None,
231 undo_next_lsn: None,
232 }
233 }
234
235 pub fn new_data(
236 txn_id: TxnId,
237 timestamp_us: u64,
238 table: String,
239 key: Vec<u8>,
240 value: Option<Vec<u8>>,
241 ) -> Self {
242 Self {
243 record_type: WalRecordType::Data,
244 txn_id,
245 timestamp_us,
246 key: Some(key),
247 value,
248 table: Some(table),
249 checksum: 0,
250 lsn: 0,
251 prev_lsn: None,
252 page_id: None,
253 undo_info: None,
254 undo_next_lsn: None,
255 }
256 }
257
258 #[allow(clippy::too_many_arguments)]
260 pub fn new_aries_data(
261 txn_id: TxnId,
262 timestamp_us: u64,
263 table: String,
264 key: Vec<u8>,
265 value: Option<Vec<u8>>,
266 page_id: PageId,
267 prev_lsn: Option<Lsn>,
268 undo_info: Option<Vec<u8>>,
269 ) -> Self {
270 Self {
271 record_type: WalRecordType::Data,
272 txn_id,
273 timestamp_us,
274 key: Some(key),
275 value,
276 table: Some(table),
277 checksum: 0,
278 lsn: 0, prev_lsn,
280 page_id: Some(page_id),
281 undo_info,
282 undo_next_lsn: None,
283 }
284 }
285
286 #[allow(clippy::too_many_arguments)]
292 pub fn new_clr(
293 txn_id: TxnId,
294 timestamp_us: u64,
295 table: String,
296 key: Vec<u8>,
297 value: Option<Vec<u8>>,
298 page_id: PageId,
299 prev_lsn: Lsn,
300 undo_next_lsn: Lsn,
301 ) -> Self {
302 Self {
303 record_type: WalRecordType::CompensationLogRecord,
304 txn_id,
305 timestamp_us,
306 key: Some(key),
307 value,
308 table: Some(table),
309 checksum: 0,
310 lsn: 0,
311 prev_lsn: Some(prev_lsn),
312 page_id: Some(page_id),
313 undo_info: None, undo_next_lsn: Some(undo_next_lsn),
315 }
316 }
317
318 pub fn new_checkpoint_end(
320 timestamp_us: u64,
321 checkpoint_data: AriesCheckpointData,
322 ) -> Result<Self, String> {
323 let data = bincode::serialize(&checkpoint_data)
324 .map_err(|e| format!("Failed to serialize checkpoint data: {}", e))?;
325 Ok(Self {
326 record_type: WalRecordType::CheckpointEnd,
327 txn_id: 0,
328 timestamp_us,
329 key: None,
330 value: Some(data),
331 table: None,
332 checksum: 0,
333 lsn: 0,
334 prev_lsn: None,
335 page_id: None,
336 undo_info: None,
337 undo_next_lsn: None,
338 })
339 }
340
341 pub fn get_checkpoint_data(&self) -> Option<AriesCheckpointData> {
343 if self.record_type != WalRecordType::CheckpointEnd {
344 return None;
345 }
346 self.value
347 .as_ref()
348 .and_then(|data| bincode::deserialize(data).ok())
349 }
350
351 pub fn compute_checksum(&mut self) {
353 let data = self.serialize_for_checksum();
354 self.checksum = crc32fast::hash(&data);
355 }
356
357 pub fn verify_checksum(&self) -> bool {
359 let data = self.serialize_for_checksum();
360 crc32fast::hash(&data) == self.checksum
361 }
362
363 fn serialize_for_checksum(&self) -> Vec<u8> {
364 let mut buf = Vec::new();
366 buf.push(self.record_type as u8);
367 buf.extend(&self.txn_id.to_le_bytes());
368 buf.extend(&self.timestamp_us.to_le_bytes());
369 if let Some(ref key) = self.key {
370 buf.extend(&(key.len() as u32).to_le_bytes());
371 buf.extend(key);
372 } else {
373 buf.extend(&0u32.to_le_bytes());
374 }
375 if let Some(ref value) = self.value {
376 buf.extend(&(value.len() as u32).to_le_bytes());
377 buf.extend(value);
378 } else {
379 buf.extend(&0u32.to_le_bytes());
380 }
381 if let Some(ref table) = self.table {
382 buf.extend(&(table.len() as u32).to_le_bytes());
383 buf.extend(table.as_bytes());
384 } else {
385 buf.extend(&0u32.to_le_bytes());
386 }
387 buf
388 }
389
390 pub fn to_bytes(&self) -> Vec<u8> {
392 let mut buf = self.serialize_for_checksum();
393 buf.extend(&self.checksum.to_le_bytes());
394 buf
395 }
396
397 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
406 if data.len() < 21 {
408 return Err(format!(
409 "WAL entry too short: {} bytes, need at least 21",
410 data.len()
411 ));
412 }
413
414 let record_type = WalRecordType::try_from(data[0])
415 .map_err(|_| format!("Invalid WAL record type: {}", data[0]))?;
416
417 let txn_id = u64::from_le_bytes(
418 data[1..9]
419 .try_into()
420 .map_err(|_| "Failed to parse txn_id: slice too short")?,
421 );
422 let timestamp_us = u64::from_le_bytes(
423 data[9..17]
424 .try_into()
425 .map_err(|_| "Failed to parse timestamp: slice too short")?,
426 );
427
428 let mut offset = 17;
429
430 if offset + 4 > data.len() {
432 return Err(format!(
433 "WAL entry truncated at key_len: offset {} + 4 > {}",
434 offset,
435 data.len()
436 ));
437 }
438 let key_len = u32::from_le_bytes(
439 data[offset..offset + 4]
440 .try_into()
441 .map_err(|_| "Failed to parse key_len")?,
442 ) as usize;
443 offset += 4;
444
445 if offset + key_len > data.len() {
446 return Err(format!(
447 "WAL entry truncated at key: need {} bytes at offset {}, have {}",
448 key_len,
449 offset,
450 data.len()
451 ));
452 }
453 let key = if key_len > 0 {
454 Some(data[offset..offset + key_len].to_vec())
455 } else {
456 None
457 };
458 offset += key_len;
459
460 if offset + 4 > data.len() {
462 return Err(format!(
463 "WAL entry truncated at value_len: offset {} + 4 > {}",
464 offset,
465 data.len()
466 ));
467 }
468 let value_len = u32::from_le_bytes(
469 data[offset..offset + 4]
470 .try_into()
471 .map_err(|_| "Failed to parse value_len")?,
472 ) as usize;
473 offset += 4;
474
475 if offset + value_len > data.len() {
476 return Err(format!(
477 "WAL entry truncated at value: need {} bytes at offset {}, have {}",
478 value_len,
479 offset,
480 data.len()
481 ));
482 }
483 let value = if value_len > 0 {
484 Some(data[offset..offset + value_len].to_vec())
485 } else {
486 None
487 };
488 offset += value_len;
489
490 if offset + 4 > data.len() {
492 return Err(format!(
493 "WAL entry truncated at table_len: offset {} + 4 > {}",
494 offset,
495 data.len()
496 ));
497 }
498 let table_len = u32::from_le_bytes(
499 data[offset..offset + 4]
500 .try_into()
501 .map_err(|_| "Failed to parse table_len")?,
502 ) as usize;
503 offset += 4;
504
505 if offset + table_len > data.len() {
506 return Err(format!(
507 "WAL entry truncated at table: need {} bytes at offset {}, have {}",
508 table_len,
509 offset,
510 data.len()
511 ));
512 }
513 let table = if table_len > 0 {
514 Some(
515 String::from_utf8(data[offset..offset + table_len].to_vec())
516 .map_err(|e| format!("Invalid UTF-8 in table name: {}", e))?,
517 )
518 } else {
519 None
520 };
521 offset += table_len;
522
523 if offset + 4 > data.len() {
525 return Err(format!(
526 "WAL entry truncated at checksum: offset {} + 4 > {}",
527 offset,
528 data.len()
529 ));
530 }
531 let checksum = u32::from_le_bytes(
532 data[offset..offset + 4]
533 .try_into()
534 .map_err(|_| "Failed to parse checksum")?,
535 );
536
537 let entry = Self {
538 record_type,
539 txn_id,
540 timestamp_us,
541 key,
542 value,
543 table,
544 checksum,
545 lsn: 0,
547 prev_lsn: None,
548 page_id: None,
549 undo_info: None,
550 undo_next_lsn: None,
551 };
552
553 if !entry.verify_checksum() {
555 return Err(format!(
556 "WAL entry checksum mismatch for txn_id {}: expected valid checksum, got {}",
557 entry.txn_id, entry.checksum
558 ));
559 }
560
561 Ok(entry)
562 }
563}
564
565#[derive(Debug)]
567pub struct Transaction {
568 pub id: TxnId,
570 pub state: TxnState,
572 pub start_ts: u64,
574 pub commit_ts: Option<u64>,
576 pub writes: Vec<TxnWrite>,
578 pub read_set: HashSet<TxnRead>,
580 pub isolation: IsolationLevel,
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
586pub enum IsolationLevel {
587 ReadCommitted,
589 #[default]
591 SnapshotIsolation,
592 Serializable,
594}
595
596impl Transaction {
597 pub fn new(id: TxnId, start_ts: u64, isolation: IsolationLevel) -> Self {
599 Self {
600 id,
601 state: TxnState::Active,
602 start_ts,
603 commit_ts: None,
604 writes: Vec::new(),
605 read_set: HashSet::new(),
606 isolation,
607 }
608 }
609
610 pub fn put(&mut self, table: &str, key: Vec<u8>, value: Vec<u8>) {
612 self.writes.push(TxnWrite {
613 key,
614 value: Some(value),
615 table: table.to_string(),
616 });
617 }
618
619 pub fn delete(&mut self, table: &str, key: Vec<u8>) {
621 self.writes.push(TxnWrite {
622 key,
623 value: None,
624 table: table.to_string(),
625 });
626 }
627
628 pub fn record_read(&mut self, table: &str, key: Vec<u8>) {
630 self.read_set.insert(TxnRead {
631 key,
632 table: table.to_string(),
633 });
634 }
635
636 pub fn get_local(&self, table: &str, key: &[u8]) -> Option<&TxnWrite> {
638 self.writes
639 .iter()
640 .rev()
641 .find(|w| w.table == table && w.key == key)
642 }
643
644 pub fn is_read_only(&self) -> bool {
646 self.writes.is_empty()
647 }
648}
649
650#[derive(Debug, Clone, Default)]
652pub struct TxnStats {
653 pub active_count: u64,
654 pub committed_count: u64,
655 pub aborted_count: u64,
656 pub conflict_aborts: u64,
657}
658
659pub struct TransactionManager {
671 next_txn_id: AtomicU64,
673 timestamp_counter: AtomicU64,
675 committed_watermark: AtomicU64,
677 stats: parking_lot::RwLock<TxnStats>,
679}
680
681impl TransactionManager {
682 pub fn new() -> Self {
683 Self {
684 next_txn_id: AtomicU64::new(1),
685 timestamp_counter: AtomicU64::new(1),
686 committed_watermark: AtomicU64::new(0),
687 stats: parking_lot::RwLock::new(TxnStats::default()),
688 }
689 }
690
691 pub fn begin(&self) -> Transaction {
693 self.begin_with_isolation(IsolationLevel::default())
694 }
695
696 pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> Transaction {
698 let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
699 let start_ts = self.timestamp_counter.fetch_add(1, Ordering::SeqCst);
700
701 {
702 let mut stats = self.stats.write();
703 stats.active_count += 1;
704 }
705
706 Transaction::new(txn_id, start_ts, isolation)
707 }
708
709 pub fn get_commit_ts(&self) -> u64 {
711 self.timestamp_counter.fetch_add(1, Ordering::SeqCst)
712 }
713
714 pub fn mark_committed(&self, txn: &mut Transaction) {
716 txn.state = TxnState::Committed;
717 txn.commit_ts = Some(self.get_commit_ts());
718
719 let mut stats = self.stats.write();
720 stats.active_count = stats.active_count.saturating_sub(1);
721 stats.committed_count += 1;
722 }
723
724 pub fn mark_aborted(&self, txn: &mut Transaction) {
726 txn.state = TxnState::Aborted;
727
728 let mut stats = self.stats.write();
729 stats.active_count = stats.active_count.saturating_sub(1);
730 stats.aborted_count += 1;
731 }
732
733 pub fn mark_conflict_abort(&self, txn: &mut Transaction) {
735 self.mark_aborted(txn);
736
737 let mut stats = self.stats.write();
738 stats.conflict_aborts += 1;
739 }
740
741 pub fn oldest_active_ts(&self) -> u64 {
743 self.committed_watermark.load(Ordering::SeqCst)
744 }
745
746 pub fn advance_watermark(&self, new_watermark: u64) {
748 self.committed_watermark
749 .fetch_max(new_watermark, Ordering::SeqCst);
750 }
751
752 pub fn stats(&self) -> TxnStats {
754 self.stats.read().clone()
755 }
756}
757
758impl Default for TransactionManager {
759 fn default() -> Self {
760 Self::new()
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767
768 #[test]
769 fn test_transaction_lifecycle() {
770 let mgr = TransactionManager::new();
771
772 let mut txn = mgr.begin();
773 assert_eq!(txn.state, TxnState::Active);
774 assert!(txn.is_read_only());
775
776 txn.put("users", vec![1], vec![2, 3, 4]);
777 assert!(!txn.is_read_only());
778
779 mgr.mark_committed(&mut txn);
780 assert_eq!(txn.state, TxnState::Committed);
781 assert!(txn.commit_ts.is_some());
782 }
783
784 #[test]
785 fn test_read_your_writes() {
786 let mgr = TransactionManager::new();
787 let mut txn = mgr.begin();
788
789 txn.put("users", vec![1], vec![10, 20]);
790 txn.put("users", vec![1], vec![30, 40]); let local = txn.get_local("users", &[1]);
793 assert!(local.is_some());
794 assert_eq!(local.unwrap().value, Some(vec![30, 40]));
795 }
796
797 #[test]
798 fn test_wal_entry_serialization() {
799 let mut entry = TxnWalEntry::new_data(
800 42,
801 1234567890,
802 "users".to_string(),
803 vec![1, 2, 3],
804 Some(vec![4, 5, 6]),
805 );
806 entry.compute_checksum();
807
808 let bytes = entry.to_bytes();
809 let parsed = TxnWalEntry::from_bytes(&bytes).unwrap();
810
811 assert_eq!(parsed.txn_id, 42);
812 assert_eq!(parsed.timestamp_us, 1234567890);
813 assert_eq!(parsed.table, Some("users".to_string()));
814 assert_eq!(parsed.key, Some(vec![1, 2, 3]));
815 assert_eq!(parsed.value, Some(vec![4, 5, 6]));
816 assert!(parsed.verify_checksum());
817 }
818
819 #[test]
820 fn test_transaction_stats() {
821 let mgr = TransactionManager::new();
822
823 let mut txn1 = mgr.begin();
824 let mut txn2 = mgr.begin();
825
826 assert_eq!(mgr.stats().active_count, 2);
827
828 mgr.mark_committed(&mut txn1);
829 assert_eq!(mgr.stats().committed_count, 1);
830
831 mgr.mark_aborted(&mut txn2);
832 assert_eq!(mgr.stats().aborted_count, 1);
833 assert_eq!(mgr.stats().active_count, 0);
834 }
835
836 #[test]
837 fn test_wal_entry_error_too_short() {
838 let short_data = vec![0u8; 10];
840 let result = TxnWalEntry::from_bytes(&short_data);
841 assert!(result.is_err());
842 assert!(result.unwrap_err().contains("too short"));
843 }
844
845 #[test]
846 fn test_wal_entry_error_invalid_record_type() {
847 let mut data = vec![0u8; 30];
849 data[0] = 255; let result = TxnWalEntry::from_bytes(&data);
851 assert!(result.is_err());
852 assert!(result.unwrap_err().contains("Invalid WAL record type"));
853 }
854
855 #[test]
856 fn test_wal_entry_error_truncated_key() {
857 let mut entry =
859 TxnWalEntry::new_data(1, 100, "test".to_string(), vec![1, 2], Some(vec![3, 4]));
860 entry.compute_checksum();
861 let mut bytes = entry.to_bytes();
862
863 let huge_len: u32 = 10000;
865 bytes[17..21].copy_from_slice(&huge_len.to_le_bytes());
866
867 let result = TxnWalEntry::from_bytes(&bytes);
868 assert!(result.is_err());
869 assert!(result.unwrap_err().contains("truncated at key"));
870 }
871
872 #[test]
873 fn test_wal_entry_error_corrupted_checksum() {
874 let mut entry = TxnWalEntry::new_data(
875 42,
876 1234567890,
877 "users".to_string(),
878 vec![1, 2, 3],
879 Some(vec![4, 5, 6]),
880 );
881 entry.compute_checksum();
882
883 let mut bytes = entry.to_bytes();
884 let len = bytes.len();
886 bytes[len - 1] ^= 0xFF; let result = TxnWalEntry::from_bytes(&bytes);
889 assert!(result.is_err());
890 assert!(result.unwrap_err().contains("checksum mismatch"));
891 }
892
893 #[test]
894 fn test_wal_entry_error_invalid_utf8_table() {
895 let mut entry = TxnWalEntry::new_data(1, 100, "test".to_string(), vec![1], Some(vec![2]));
896 entry.compute_checksum();
897 let mut bytes = entry.to_bytes();
898
899 let table_start = 17 + 4 + 1 + 4 + 1 + 4;
902 bytes[table_start] = 0xFF; let result = TxnWalEntry::from_bytes(&bytes);
905 assert!(result.is_err());
907 }
908}