1use serde::{Deserialize, Serialize};
27use std::collections::HashSet;
28use std::sync::atomic::{AtomicU64, Ordering};
29
30pub type TxnId = u64;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35pub enum TxnState {
36 Active,
37 Committed,
38 Aborted,
39}
40
41#[repr(u8)]
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum WalRecordType {
45 Data = 0x01,
47 TxnBegin = 0x10,
49 TxnCommit = 0x11,
51 TxnAbort = 0x12,
53 Checkpoint = 0x20,
55 SchemaChange = 0x30,
57 CompensationLogRecord = 0x40,
59 CheckpointEnd = 0x21,
61 PageUpdate = 0x02,
63}
64
65impl TryFrom<u8> for WalRecordType {
66 type Error = ();
67
68 fn try_from(value: u8) -> Result<Self, Self::Error> {
69 match value {
70 0x01 => Ok(WalRecordType::Data),
71 0x10 => Ok(WalRecordType::TxnBegin),
72 0x11 => Ok(WalRecordType::TxnCommit),
73 0x12 => Ok(WalRecordType::TxnAbort),
74 0x20 => Ok(WalRecordType::Checkpoint),
75 0x21 => Ok(WalRecordType::CheckpointEnd),
76 0x30 => Ok(WalRecordType::SchemaChange),
77 0x40 => Ok(WalRecordType::CompensationLogRecord),
78 0x02 => Ok(WalRecordType::PageUpdate),
79 _ => Err(()),
80 }
81 }
82}
83
84pub type Lsn = u64;
91
92pub type PageId = u64;
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct AriesTransactionEntry {
98 pub txn_id: TxnId,
100 pub state: TxnState,
102 pub last_lsn: Lsn,
104 pub undo_next_lsn: Option<Lsn>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct AriesDirtyPageEntry {
111 pub page_id: PageId,
113 pub rec_lsn: Lsn,
115}
116
117#[derive(Debug, Clone, Default, Serialize, Deserialize)]
119pub struct AriesCheckpointData {
120 pub active_transactions: Vec<AriesTransactionEntry>,
122 pub dirty_pages: Vec<AriesDirtyPageEntry>,
124 pub begin_checkpoint_lsn: Lsn,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct TxnWrite {
131 pub key: Vec<u8>,
133 pub value: Option<Vec<u8>>,
135 pub table: String,
137}
138
139#[derive(Debug, Clone, Hash, PartialEq, Eq)]
141pub struct TxnRead {
142 pub key: Vec<u8>,
143 pub table: String,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct TxnWalEntry {
155 pub record_type: WalRecordType,
157 pub txn_id: TxnId,
159 pub timestamp_us: u64,
161 pub key: Option<Vec<u8>>,
163 pub value: Option<Vec<u8>>,
165 pub table: Option<String>,
167 pub checksum: u32,
169 #[serde(default)]
171 pub lsn: Lsn,
172 #[serde(default)]
174 pub prev_lsn: Option<Lsn>,
175 #[serde(default)]
177 pub page_id: Option<PageId>,
178 #[serde(default)]
180 pub undo_info: Option<Vec<u8>>,
181 #[serde(default)]
183 pub undo_next_lsn: Option<Lsn>,
184}
185
186impl TxnWalEntry {
187 pub fn new_begin(txn_id: TxnId, timestamp_us: u64) -> Self {
188 Self {
189 record_type: WalRecordType::TxnBegin,
190 txn_id,
191 timestamp_us,
192 key: None,
193 value: None,
194 table: None,
195 checksum: 0,
196 lsn: 0,
197 prev_lsn: None,
198 page_id: None,
199 undo_info: None,
200 undo_next_lsn: None,
201 }
202 }
203
204 pub fn new_commit(txn_id: TxnId, timestamp_us: u64) -> Self {
205 Self {
206 record_type: WalRecordType::TxnCommit,
207 txn_id,
208 timestamp_us,
209 key: None,
210 value: None,
211 table: None,
212 checksum: 0,
213 lsn: 0,
214 prev_lsn: None,
215 page_id: None,
216 undo_info: None,
217 undo_next_lsn: None,
218 }
219 }
220
221 pub fn new_abort(txn_id: TxnId, timestamp_us: u64) -> Self {
222 Self {
223 record_type: WalRecordType::TxnAbort,
224 txn_id,
225 timestamp_us,
226 key: None,
227 value: None,
228 table: None,
229 checksum: 0,
230 lsn: 0,
231 prev_lsn: None,
232 page_id: None,
233 undo_info: None,
234 undo_next_lsn: None,
235 }
236 }
237
238 pub fn new_data(
239 txn_id: TxnId,
240 timestamp_us: u64,
241 table: String,
242 key: Vec<u8>,
243 value: Option<Vec<u8>>,
244 ) -> Self {
245 Self {
246 record_type: WalRecordType::Data,
247 txn_id,
248 timestamp_us,
249 key: Some(key),
250 value,
251 table: Some(table),
252 checksum: 0,
253 lsn: 0,
254 prev_lsn: None,
255 page_id: None,
256 undo_info: None,
257 undo_next_lsn: None,
258 }
259 }
260
261 #[allow(clippy::too_many_arguments)]
263 pub fn new_aries_data(
264 txn_id: TxnId,
265 timestamp_us: u64,
266 table: String,
267 key: Vec<u8>,
268 value: Option<Vec<u8>>,
269 page_id: PageId,
270 prev_lsn: Option<Lsn>,
271 undo_info: Option<Vec<u8>>,
272 ) -> Self {
273 Self {
274 record_type: WalRecordType::Data,
275 txn_id,
276 timestamp_us,
277 key: Some(key),
278 value,
279 table: Some(table),
280 checksum: 0,
281 lsn: 0, prev_lsn,
283 page_id: Some(page_id),
284 undo_info,
285 undo_next_lsn: None,
286 }
287 }
288
289 #[allow(clippy::too_many_arguments)]
295 pub fn new_clr(
296 txn_id: TxnId,
297 timestamp_us: u64,
298 table: String,
299 key: Vec<u8>,
300 value: Option<Vec<u8>>,
301 page_id: PageId,
302 prev_lsn: Lsn,
303 undo_next_lsn: Lsn,
304 ) -> Self {
305 Self {
306 record_type: WalRecordType::CompensationLogRecord,
307 txn_id,
308 timestamp_us,
309 key: Some(key),
310 value,
311 table: Some(table),
312 checksum: 0,
313 lsn: 0,
314 prev_lsn: Some(prev_lsn),
315 page_id: Some(page_id),
316 undo_info: None, undo_next_lsn: Some(undo_next_lsn),
318 }
319 }
320
321 pub fn new_checkpoint_end(
323 timestamp_us: u64,
324 checkpoint_data: AriesCheckpointData,
325 ) -> Result<Self, String> {
326 let data = bincode::serialize(&checkpoint_data)
327 .map_err(|e| format!("Failed to serialize checkpoint data: {}", e))?;
328 Ok(Self {
329 record_type: WalRecordType::CheckpointEnd,
330 txn_id: 0,
331 timestamp_us,
332 key: None,
333 value: Some(data),
334 table: None,
335 checksum: 0,
336 lsn: 0,
337 prev_lsn: None,
338 page_id: None,
339 undo_info: None,
340 undo_next_lsn: None,
341 })
342 }
343
344 pub fn get_checkpoint_data(&self) -> Option<AriesCheckpointData> {
346 if self.record_type != WalRecordType::CheckpointEnd {
347 return None;
348 }
349 self.value
350 .as_ref()
351 .and_then(|data| bincode::deserialize(data).ok())
352 }
353
354 pub fn compute_checksum(&mut self) {
356 let data = self.serialize_for_checksum();
357 self.checksum = crc32fast::hash(&data);
358 }
359
360 pub fn verify_checksum(&self) -> bool {
362 let data = self.serialize_for_checksum();
363 crc32fast::hash(&data) == self.checksum
364 }
365
366 fn serialize_for_checksum(&self) -> Vec<u8> {
367 let mut buf = Vec::new();
369 buf.push(self.record_type as u8);
370 buf.extend(&self.txn_id.to_le_bytes());
371 buf.extend(&self.timestamp_us.to_le_bytes());
372 if let Some(ref key) = self.key {
373 buf.extend(&(key.len() as u32).to_le_bytes());
374 buf.extend(key);
375 } else {
376 buf.extend(&0u32.to_le_bytes());
377 }
378 if let Some(ref value) = self.value {
379 buf.extend(&(value.len() as u32).to_le_bytes());
380 buf.extend(value);
381 } else {
382 buf.extend(&0u32.to_le_bytes());
383 }
384 if let Some(ref table) = self.table {
385 buf.extend(&(table.len() as u32).to_le_bytes());
386 buf.extend(table.as_bytes());
387 } else {
388 buf.extend(&0u32.to_le_bytes());
389 }
390 buf
391 }
392
393 pub fn to_bytes(&self) -> Vec<u8> {
395 let mut buf = self.serialize_for_checksum();
396 buf.extend(&self.checksum.to_le_bytes());
397 buf
398 }
399
400 pub fn from_bytes(data: &[u8]) -> Result<Self, String> {
409 if data.len() < 21 {
411 return Err(format!(
412 "WAL entry too short: {} bytes, need at least 21",
413 data.len()
414 ));
415 }
416
417 let record_type = WalRecordType::try_from(data[0])
418 .map_err(|_| format!("Invalid WAL record type: {}", data[0]))?;
419
420 let txn_id = u64::from_le_bytes(
421 data[1..9]
422 .try_into()
423 .map_err(|_| "Failed to parse txn_id: slice too short")?,
424 );
425 let timestamp_us = u64::from_le_bytes(
426 data[9..17]
427 .try_into()
428 .map_err(|_| "Failed to parse timestamp: slice too short")?,
429 );
430
431 let mut offset = 17;
432
433 if offset + 4 > data.len() {
435 return Err(format!(
436 "WAL entry truncated at key_len: offset {} + 4 > {}",
437 offset,
438 data.len()
439 ));
440 }
441 let key_len = u32::from_le_bytes(
442 data[offset..offset + 4]
443 .try_into()
444 .map_err(|_| "Failed to parse key_len")?,
445 ) as usize;
446 offset += 4;
447
448 if offset + key_len > data.len() {
449 return Err(format!(
450 "WAL entry truncated at key: need {} bytes at offset {}, have {}",
451 key_len,
452 offset,
453 data.len()
454 ));
455 }
456 let key = if key_len > 0 {
457 Some(data[offset..offset + key_len].to_vec())
458 } else {
459 None
460 };
461 offset += key_len;
462
463 if offset + 4 > data.len() {
465 return Err(format!(
466 "WAL entry truncated at value_len: offset {} + 4 > {}",
467 offset,
468 data.len()
469 ));
470 }
471 let value_len = u32::from_le_bytes(
472 data[offset..offset + 4]
473 .try_into()
474 .map_err(|_| "Failed to parse value_len")?,
475 ) as usize;
476 offset += 4;
477
478 if offset + value_len > data.len() {
479 return Err(format!(
480 "WAL entry truncated at value: need {} bytes at offset {}, have {}",
481 value_len,
482 offset,
483 data.len()
484 ));
485 }
486 let value = if value_len > 0 {
487 Some(data[offset..offset + value_len].to_vec())
488 } else {
489 None
490 };
491 offset += value_len;
492
493 if offset + 4 > data.len() {
495 return Err(format!(
496 "WAL entry truncated at table_len: offset {} + 4 > {}",
497 offset,
498 data.len()
499 ));
500 }
501 let table_len = u32::from_le_bytes(
502 data[offset..offset + 4]
503 .try_into()
504 .map_err(|_| "Failed to parse table_len")?,
505 ) as usize;
506 offset += 4;
507
508 if offset + table_len > data.len() {
509 return Err(format!(
510 "WAL entry truncated at table: need {} bytes at offset {}, have {}",
511 table_len,
512 offset,
513 data.len()
514 ));
515 }
516 let table = if table_len > 0 {
517 Some(
518 String::from_utf8(data[offset..offset + table_len].to_vec())
519 .map_err(|e| format!("Invalid UTF-8 in table name: {}", e))?,
520 )
521 } else {
522 None
523 };
524 offset += table_len;
525
526 if offset + 4 > data.len() {
528 return Err(format!(
529 "WAL entry truncated at checksum: offset {} + 4 > {}",
530 offset,
531 data.len()
532 ));
533 }
534 let checksum = u32::from_le_bytes(
535 data[offset..offset + 4]
536 .try_into()
537 .map_err(|_| "Failed to parse checksum")?,
538 );
539
540 let entry = Self {
541 record_type,
542 txn_id,
543 timestamp_us,
544 key,
545 value,
546 table,
547 checksum,
548 lsn: 0,
550 prev_lsn: None,
551 page_id: None,
552 undo_info: None,
553 undo_next_lsn: None,
554 };
555
556 if !entry.verify_checksum() {
558 return Err(format!(
559 "WAL entry checksum mismatch for txn_id {}: expected valid checksum, got {}",
560 entry.txn_id, entry.checksum
561 ));
562 }
563
564 Ok(entry)
565 }
566}
567
568#[derive(Debug)]
570pub struct Transaction {
571 pub id: TxnId,
573 pub state: TxnState,
575 pub start_ts: u64,
577 pub commit_ts: Option<u64>,
579 pub writes: Vec<TxnWrite>,
581 pub read_set: HashSet<TxnRead>,
583 pub isolation: IsolationLevel,
585}
586
587#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
589pub enum IsolationLevel {
590 ReadCommitted,
592 #[default]
594 SnapshotIsolation,
595 Serializable,
597}
598
599impl Transaction {
600 pub fn new(id: TxnId, start_ts: u64, isolation: IsolationLevel) -> Self {
602 Self {
603 id,
604 state: TxnState::Active,
605 start_ts,
606 commit_ts: None,
607 writes: Vec::new(),
608 read_set: HashSet::new(),
609 isolation,
610 }
611 }
612
613 pub fn put(&mut self, table: &str, key: Vec<u8>, value: Vec<u8>) {
615 self.writes.push(TxnWrite {
616 key,
617 value: Some(value),
618 table: table.to_string(),
619 });
620 }
621
622 pub fn delete(&mut self, table: &str, key: Vec<u8>) {
624 self.writes.push(TxnWrite {
625 key,
626 value: None,
627 table: table.to_string(),
628 });
629 }
630
631 pub fn record_read(&mut self, table: &str, key: Vec<u8>) {
633 self.read_set.insert(TxnRead {
634 key,
635 table: table.to_string(),
636 });
637 }
638
639 pub fn get_local(&self, table: &str, key: &[u8]) -> Option<&TxnWrite> {
641 self.writes
642 .iter()
643 .rev()
644 .find(|w| w.table == table && w.key == key)
645 }
646
647 pub fn is_read_only(&self) -> bool {
649 self.writes.is_empty()
650 }
651}
652
653#[derive(Debug, Clone, Default)]
655pub struct TxnStats {
656 pub active_count: u64,
657 pub committed_count: u64,
658 pub aborted_count: u64,
659 pub conflict_aborts: u64,
660}
661
662pub struct TransactionManager {
674 next_txn_id: AtomicU64,
676 timestamp_counter: AtomicU64,
678 committed_watermark: AtomicU64,
680 stats: parking_lot::RwLock<TxnStats>,
682}
683
684impl TransactionManager {
685 pub fn new() -> Self {
686 Self {
687 next_txn_id: AtomicU64::new(1),
688 timestamp_counter: AtomicU64::new(1),
689 committed_watermark: AtomicU64::new(0),
690 stats: parking_lot::RwLock::new(TxnStats::default()),
691 }
692 }
693
694 pub fn begin(&self) -> Transaction {
696 self.begin_with_isolation(IsolationLevel::default())
697 }
698
699 pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> Transaction {
701 let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
702 let start_ts = self.timestamp_counter.fetch_add(1, Ordering::SeqCst);
703
704 {
705 let mut stats = self.stats.write();
706 stats.active_count += 1;
707 }
708
709 Transaction::new(txn_id, start_ts, isolation)
710 }
711
712 pub fn get_commit_ts(&self) -> u64 {
714 self.timestamp_counter.fetch_add(1, Ordering::SeqCst)
715 }
716
717 pub fn mark_committed(&self, txn: &mut Transaction) {
719 txn.state = TxnState::Committed;
720 txn.commit_ts = Some(self.get_commit_ts());
721
722 let mut stats = self.stats.write();
723 stats.active_count = stats.active_count.saturating_sub(1);
724 stats.committed_count += 1;
725 }
726
727 pub fn mark_aborted(&self, txn: &mut Transaction) {
729 txn.state = TxnState::Aborted;
730
731 let mut stats = self.stats.write();
732 stats.active_count = stats.active_count.saturating_sub(1);
733 stats.aborted_count += 1;
734 }
735
736 pub fn mark_conflict_abort(&self, txn: &mut Transaction) {
738 self.mark_aborted(txn);
739
740 let mut stats = self.stats.write();
741 stats.conflict_aborts += 1;
742 }
743
744 pub fn oldest_active_ts(&self) -> u64 {
746 self.committed_watermark.load(Ordering::SeqCst)
747 }
748
749 pub fn advance_watermark(&self, new_watermark: u64) {
751 self.committed_watermark
752 .fetch_max(new_watermark, Ordering::SeqCst);
753 }
754
755 pub fn stats(&self) -> TxnStats {
757 self.stats.read().clone()
758 }
759}
760
761impl Default for TransactionManager {
762 fn default() -> Self {
763 Self::new()
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770
771 #[test]
772 fn test_transaction_lifecycle() {
773 let mgr = TransactionManager::new();
774
775 let mut txn = mgr.begin();
776 assert_eq!(txn.state, TxnState::Active);
777 assert!(txn.is_read_only());
778
779 txn.put("users", vec![1], vec![2, 3, 4]);
780 assert!(!txn.is_read_only());
781
782 mgr.mark_committed(&mut txn);
783 assert_eq!(txn.state, TxnState::Committed);
784 assert!(txn.commit_ts.is_some());
785 }
786
787 #[test]
788 fn test_read_your_writes() {
789 let mgr = TransactionManager::new();
790 let mut txn = mgr.begin();
791
792 txn.put("users", vec![1], vec![10, 20]);
793 txn.put("users", vec![1], vec![30, 40]); let local = txn.get_local("users", &[1]);
796 assert!(local.is_some());
797 assert_eq!(local.unwrap().value, Some(vec![30, 40]));
798 }
799
800 #[test]
801 fn test_wal_entry_serialization() {
802 let mut entry = TxnWalEntry::new_data(
803 42,
804 1234567890,
805 "users".to_string(),
806 vec![1, 2, 3],
807 Some(vec![4, 5, 6]),
808 );
809 entry.compute_checksum();
810
811 let bytes = entry.to_bytes();
812 let parsed = TxnWalEntry::from_bytes(&bytes).unwrap();
813
814 assert_eq!(parsed.txn_id, 42);
815 assert_eq!(parsed.timestamp_us, 1234567890);
816 assert_eq!(parsed.table, Some("users".to_string()));
817 assert_eq!(parsed.key, Some(vec![1, 2, 3]));
818 assert_eq!(parsed.value, Some(vec![4, 5, 6]));
819 assert!(parsed.verify_checksum());
820 }
821
822 #[test]
823 fn test_transaction_stats() {
824 let mgr = TransactionManager::new();
825
826 let mut txn1 = mgr.begin();
827 let mut txn2 = mgr.begin();
828
829 assert_eq!(mgr.stats().active_count, 2);
830
831 mgr.mark_committed(&mut txn1);
832 assert_eq!(mgr.stats().committed_count, 1);
833
834 mgr.mark_aborted(&mut txn2);
835 assert_eq!(mgr.stats().aborted_count, 1);
836 assert_eq!(mgr.stats().active_count, 0);
837 }
838
839 #[test]
840 fn test_wal_entry_error_too_short() {
841 let short_data = vec![0u8; 10];
843 let result = TxnWalEntry::from_bytes(&short_data);
844 assert!(result.is_err());
845 assert!(result.unwrap_err().contains("too short"));
846 }
847
848 #[test]
849 fn test_wal_entry_error_invalid_record_type() {
850 let mut data = vec![0u8; 30];
852 data[0] = 255; let result = TxnWalEntry::from_bytes(&data);
854 assert!(result.is_err());
855 assert!(result.unwrap_err().contains("Invalid WAL record type"));
856 }
857
858 #[test]
859 fn test_wal_entry_error_truncated_key() {
860 let mut entry =
862 TxnWalEntry::new_data(1, 100, "test".to_string(), vec![1, 2], Some(vec![3, 4]));
863 entry.compute_checksum();
864 let mut bytes = entry.to_bytes();
865
866 let huge_len: u32 = 10000;
868 bytes[17..21].copy_from_slice(&huge_len.to_le_bytes());
869
870 let result = TxnWalEntry::from_bytes(&bytes);
871 assert!(result.is_err());
872 assert!(result.unwrap_err().contains("truncated at key"));
873 }
874
875 #[test]
876 fn test_wal_entry_error_corrupted_checksum() {
877 let mut entry = TxnWalEntry::new_data(
878 42,
879 1234567890,
880 "users".to_string(),
881 vec![1, 2, 3],
882 Some(vec![4, 5, 6]),
883 );
884 entry.compute_checksum();
885
886 let mut bytes = entry.to_bytes();
887 let len = bytes.len();
889 bytes[len - 1] ^= 0xFF; let result = TxnWalEntry::from_bytes(&bytes);
892 assert!(result.is_err());
893 assert!(result.unwrap_err().contains("checksum mismatch"));
894 }
895
896 #[test]
897 fn test_wal_entry_error_invalid_utf8_table() {
898 let mut entry = TxnWalEntry::new_data(1, 100, "test".to_string(), vec![1], Some(vec![2]));
899 entry.compute_checksum();
900 let mut bytes = entry.to_bytes();
901
902 let table_start = 17 + 4 + 1 + 4 + 1 + 4;
905 bytes[table_start] = 0xFF; let result = TxnWalEntry::from_bytes(&bytes);
908 assert!(result.is_err());
910 }
911}