1use std::collections::{BTreeMap, HashMap, HashSet};
37use std::sync::atomic::{AtomicU64, Ordering};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum TxnOp {
42 Begin {
43 txn_id: u64,
44 },
45 Read {
46 key: Vec<u8>,
47 expected: Option<Vec<u8>>,
48 },
49 Write {
50 key: Vec<u8>,
51 value: Vec<u8>,
52 },
53 Commit {
54 txn_id: u64,
55 },
56 Abort {
57 txn_id: u64,
58 },
59}
60
61#[derive(Debug, Default)]
63pub struct TxnHistory {
64 operations: Vec<TxnOp>,
66 commit_order: Vec<u64>,
68 aborted: HashSet<u64>,
70}
71
72impl TxnHistory {
73 pub fn push(&mut self, op: TxnOp) {
75 match &op {
76 TxnOp::Commit { txn_id } => {
77 self.commit_order.push(*txn_id);
78 }
79 TxnOp::Abort { txn_id } => {
80 self.aborted.insert(*txn_id);
81 }
82 _ => {}
83 }
84 self.operations.push(op);
85 }
86
87 pub fn is_serializable(&self) -> Result<bool, SerializabilityError> {
94 let graph = self.build_dependency_graph()?;
95
96 Ok(!graph.has_cycle())
98 }
99
100 fn build_dependency_graph(&self) -> Result<DependencyGraph, SerializabilityError> {
102 let mut graph = DependencyGraph::new();
103
104 let mut txn_writes: HashMap<u64, HashSet<Vec<u8>>> = HashMap::new();
106 let mut txn_reads: HashMap<u64, HashSet<Vec<u8>>> = HashMap::new();
107 let mut current_txn: Option<u64> = None;
108
109 for op in &self.operations {
110 match op {
111 TxnOp::Begin { txn_id } => {
112 current_txn = Some(*txn_id);
113 graph.add_node(*txn_id);
114 }
115 TxnOp::Read { key, .. } => {
116 if let Some(txn_id) = current_txn {
117 txn_reads.entry(txn_id).or_default().insert(key.clone());
118 }
119 }
120 TxnOp::Write { key, .. } => {
121 if let Some(txn_id) = current_txn {
122 txn_writes.entry(txn_id).or_default().insert(key.clone());
123 }
124 }
125 TxnOp::Commit { .. } | TxnOp::Abort { .. } => {
126 current_txn = None;
127 }
128 }
129 }
130
131 let committed: Vec<_> = self.commit_order.iter().copied().collect();
133 let empty_set: HashSet<Vec<u8>> = HashSet::new();
134 for (i, &t1) in committed.iter().enumerate() {
135 for &t2 in &committed[i + 1..] {
136 let t1_writes = txn_writes.get(&t1).unwrap_or(&empty_set);
137 let t2_writes = txn_writes.get(&t2).unwrap_or(&empty_set);
138 let t1_reads = txn_reads.get(&t1).unwrap_or(&empty_set);
139 let t2_reads = txn_reads.get(&t2).unwrap_or(&empty_set);
140
141 if !t1_writes.is_disjoint(t2_writes) {
143 graph.add_edge(t1, t2);
144 }
145
146 if !t1_writes.is_disjoint(t2_reads) {
148 graph.add_edge(t1, t2);
149 }
150
151 if !t1_reads.is_disjoint(t2_writes) {
153 graph.add_edge(t1, t2);
154 }
155 }
156 }
157
158 Ok(graph)
159 }
160}
161
162#[derive(Debug)]
164pub enum SerializabilityError {
165 InvalidHistory(String),
166 CycleDetected(Vec<u64>),
167}
168
169#[derive(Debug, Default)]
171struct DependencyGraph {
172 edges: HashMap<u64, HashSet<u64>>,
174 nodes: HashSet<u64>,
176}
177
178impl DependencyGraph {
179 fn new() -> Self {
180 Self::default()
181 }
182
183 fn add_node(&mut self, node: u64) {
184 self.nodes.insert(node);
185 self.edges.entry(node).or_default();
186 }
187
188 fn add_edge(&mut self, from: u64, to: u64) {
189 self.edges.entry(from).or_default().insert(to);
190 }
191
192 fn has_cycle(&self) -> bool {
194 #[derive(Clone, Copy, PartialEq)]
195 enum Color {
196 White,
197 Gray,
198 Black,
199 }
200
201 let mut colors: HashMap<u64, Color> =
202 self.nodes.iter().map(|&n| (n, Color::White)).collect();
203
204 fn dfs(
205 node: u64,
206 edges: &HashMap<u64, HashSet<u64>>,
207 colors: &mut HashMap<u64, Color>,
208 ) -> bool {
209 colors.insert(node, Color::Gray);
210
211 if let Some(neighbors) = edges.get(&node) {
212 for &neighbor in neighbors {
213 match colors.get(&neighbor) {
214 Some(Color::Gray) => return true, Some(Color::White) => {
216 if dfs(neighbor, edges, colors) {
217 return true;
218 }
219 }
220 _ => {}
221 }
222 }
223 }
224
225 colors.insert(node, Color::Black);
226 false
227 }
228
229 for node in self.nodes.iter().copied() {
230 if colors.get(&node) == Some(&Color::White) {
231 if dfs(node, &self.edges, &mut colors) {
232 return true;
233 }
234 }
235 }
236
237 false
238 }
239}
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum CrashPoint {
244 BeforeWalWrite,
246 AfterWalWriteBeforeFsync,
248 AfterFsyncBeforeDataWrite,
250 AfterDataWrite,
252 DuringCheckpoint,
254}
255
256pub struct CrashSimulator {
258 crash_at: Option<CrashPoint>,
260 countdown: AtomicU64,
262 triggered: std::sync::Mutex<Vec<CrashPoint>>,
264}
265
266impl CrashSimulator {
267 pub fn new() -> Self {
269 Self {
270 crash_at: None,
271 countdown: AtomicU64::new(u64::MAX),
272 triggered: std::sync::Mutex::new(Vec::new()),
273 }
274 }
275
276 pub fn schedule_crash(&mut self, point: CrashPoint, after_ops: u64) {
278 self.crash_at = Some(point);
279 self.countdown.store(after_ops, Ordering::SeqCst);
280 }
281
282 pub fn maybe_crash(&self, point: CrashPoint) -> bool {
284 if self.crash_at != Some(point) {
285 return false;
286 }
287
288 let prev = self.countdown.fetch_sub(1, Ordering::SeqCst);
289 if prev == 1 {
290 self.triggered.lock().unwrap().push(point);
291 true
292 } else {
293 false
294 }
295 }
296
297 pub fn triggered_crashes(&self) -> Vec<CrashPoint> {
299 self.triggered.lock().unwrap().clone()
300 }
301}
302
303impl Default for CrashSimulator {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309#[derive(Debug, Default)]
311pub struct KvModel {
312 data: BTreeMap<Vec<u8>, Vec<u8>>,
314 next_txn: u64,
316 active_txns: HashMap<u64, HashMap<Vec<u8>, Vec<u8>>>,
318}
319
320impl KvModel {
321 pub fn new() -> Self {
323 Self::default()
324 }
325
326 pub fn begin(&mut self) -> u64 {
328 let txn_id = self.next_txn;
329 self.next_txn += 1;
330 self.active_txns.insert(txn_id, HashMap::new());
331 txn_id
332 }
333
334 pub fn read(&self, txn_id: u64, key: &[u8]) -> Option<Vec<u8>> {
336 if let Some(txn_writes) = self.active_txns.get(&txn_id) {
338 if let Some(value) = txn_writes.get(key) {
339 return Some(value.clone());
340 }
341 }
342 self.data.get(key).cloned()
344 }
345
346 pub fn write(&mut self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) {
348 if let Some(txn_writes) = self.active_txns.get_mut(&txn_id) {
349 txn_writes.insert(key, value);
350 }
351 }
352
353 pub fn commit(&mut self, txn_id: u64) -> bool {
355 if let Some(writes) = self.active_txns.remove(&txn_id) {
356 for (key, value) in writes {
357 self.data.insert(key, value);
358 }
359 true
360 } else {
361 false
362 }
363 }
364
365 pub fn abort(&mut self, txn_id: u64) -> bool {
367 self.active_txns.remove(&txn_id).is_some()
368 }
369
370 pub fn snapshot(&self) -> BTreeMap<Vec<u8>, Vec<u8>> {
372 self.data.clone()
373 }
374}
375
376pub struct TestOracle<T> {
378 model: KvModel,
380 sut: T,
382 discrepancies: Vec<Discrepancy>,
384}
385
386#[derive(Debug, Clone)]
388pub struct Discrepancy {
389 pub operation: String,
390 pub expected: Option<Vec<u8>>,
391 pub actual: Option<Vec<u8>>,
392 pub key: Vec<u8>,
393}
394
395impl<T> TestOracle<T> {
396 pub fn new(sut: T) -> Self {
398 Self {
399 model: KvModel::new(),
400 sut,
401 discrepancies: Vec::new(),
402 }
403 }
404
405 pub fn model(&mut self) -> &mut KvModel {
407 &mut self.model
408 }
409
410 pub fn sut(&mut self) -> &mut T {
412 &mut self.sut
413 }
414
415 pub fn record_discrepancy(&mut self, discrepancy: Discrepancy) {
417 self.discrepancies.push(discrepancy);
418 }
419
420 pub fn has_discrepancies(&self) -> bool {
422 !self.discrepancies.is_empty()
423 }
424
425 pub fn discrepancies(&self) -> &[Discrepancy] {
427 &self.discrepancies
428 }
429}
430
431#[derive(Debug)]
433pub struct LinearizabilityChecker {
434 history: Vec<LinearOp>,
436}
437
438#[derive(Debug, Clone)]
440pub struct LinearOp {
441 pub op_type: LinearOpType,
443 pub start: u64,
445 pub end: u64,
447 pub value: Option<Vec<u8>>,
449}
450
451#[derive(Debug, Clone, Copy)]
453pub enum LinearOpType {
454 Read,
455 Write,
456}
457
458impl LinearizabilityChecker {
459 pub fn new() -> Self {
461 Self {
462 history: Vec::new(),
463 }
464 }
465
466 pub fn add(&mut self, op: LinearOp) {
468 self.history.push(op);
469 }
470
471 pub fn is_linearizable(&self) -> bool {
476 let mut ops = self.history.clone();
478 ops.sort_by_key(|op| op.start);
479
480 let mut current_value: Option<Vec<u8>> = None;
482 let mut pending_writes: Vec<&LinearOp> = Vec::new();
483
484 for op in &ops {
485 pending_writes.retain(|w| w.end >= op.start);
487
488 match op.op_type {
489 LinearOpType::Write => {
490 pending_writes.push(op);
491 current_value = op.value.clone();
492 }
493 LinearOpType::Read => {
494 if op.value != current_value {
496 let matches_pending = pending_writes.iter().any(|w| w.value == op.value);
498 if !matches_pending && op.value != current_value {
499 return false;
500 }
501 }
502 }
503 }
504 }
505
506 true
507 }
508}
509
510impl Default for LinearizabilityChecker {
511 fn default() -> Self {
512 Self::new()
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_serializable_history_no_conflict() {
522 let mut history = TxnHistory::default();
523
524 history.push(TxnOp::Begin { txn_id: 1 });
526 history.push(TxnOp::Write {
527 key: b"x".to_vec(),
528 value: b"1".to_vec(),
529 });
530 history.push(TxnOp::Commit { txn_id: 1 });
531
532 history.push(TxnOp::Begin { txn_id: 2 });
534 history.push(TxnOp::Write {
535 key: b"y".to_vec(),
536 value: b"2".to_vec(),
537 });
538 history.push(TxnOp::Commit { txn_id: 2 });
539
540 assert!(history.is_serializable().unwrap());
541 }
542
543 #[test]
544 fn test_kv_model_basic() {
545 let mut model = KvModel::new();
546
547 let txn = model.begin();
548 model.write(txn, b"key1".to_vec(), b"value1".to_vec());
549
550 assert_eq!(model.read(txn, b"key1"), Some(b"value1".to_vec()));
552
553 assert!(model.commit(txn));
555
556 let txn2 = model.begin();
558 assert_eq!(model.read(txn2, b"key1"), Some(b"value1".to_vec()));
559 }
560
561 #[test]
562 fn test_crash_simulator() {
563 let mut sim = CrashSimulator::new();
564 sim.schedule_crash(CrashPoint::AfterWalWriteBeforeFsync, 3);
565
566 assert!(!sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
568 assert!(!sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
569
570 assert!(sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
572
573 assert!(!sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
575
576 assert_eq!(sim.triggered_crashes().len(), 1);
577 }
578
579 #[test]
580 fn test_linearizability_simple() {
581 let mut checker = LinearizabilityChecker::new();
582
583 checker.add(LinearOp {
585 op_type: LinearOpType::Write,
586 start: 0,
587 end: 2,
588 value: Some(b"1".to_vec()),
589 });
590
591 checker.add(LinearOp {
593 op_type: LinearOpType::Read,
594 start: 3,
595 end: 4,
596 value: Some(b"1".to_vec()),
597 });
598
599 assert!(checker.is_linearizable());
600 }
601
602 #[test]
603 fn test_dependency_graph_cycle_detection() {
604 let mut graph = DependencyGraph::new();
605 graph.add_node(1);
606 graph.add_node(2);
607 graph.add_node(3);
608
609 graph.add_edge(1, 2);
611 graph.add_edge(2, 3);
612 assert!(!graph.has_cycle());
613
614 graph.add_edge(3, 1);
616 assert!(graph.has_cycle());
617 }
618}
619
620#[cfg(test)]
621mod property_tests {
622 use super::*;
623 use proptest::prelude::*;
624
625 fn key_strategy() -> impl Strategy<Value = Vec<u8>> {
629 prop::sample::select(vec![
630 b"a".to_vec(),
631 b"b".to_vec(),
632 b"c".to_vec(),
633 b"d".to_vec(),
634 ])
635 }
636
637 fn writes_strategy() -> impl Strategy<Value = Vec<(Vec<u8>, Vec<u8>)>> {
638 prop::collection::vec(
639 (key_strategy(), prop::collection::vec(any::<u8>(), 0..4)),
640 0..8,
641 )
642 }
643
644 proptest! {
645 #[test]
649 fn prop_commit_is_all_or_nothing(writes in writes_strategy()) {
650 let mut model = KvModel::new();
651 let txn = model.begin();
652 for (k, v) in &writes {
653 model.write(txn, k.clone(), v.clone());
654 }
655 prop_assert!(model.commit(txn));
656
657 let mut expected: std::collections::BTreeMap<Vec<u8>, Vec<u8>> =
659 std::collections::BTreeMap::new();
660 for (k, v) in &writes {
661 expected.insert(k.clone(), v.clone());
662 }
663 prop_assert_eq!(model.snapshot(), expected.clone());
664
665 let reader = model.begin();
667 for (k, v) in &expected {
668 prop_assert_eq!(model.read(reader, k), Some(v.clone()));
669 }
670 }
671
672 #[test]
675 fn prop_abort_leaves_no_trace(writes in writes_strategy()) {
676 let mut model = KvModel::new();
677 let before = model.snapshot();
678
679 let txn = model.begin();
680 for (k, v) in &writes {
681 model.write(txn, k.clone(), v.clone());
682 }
683 prop_assert!(model.abort(txn));
684
685 prop_assert_eq!(model.snapshot(), before);
687
688 let reader = model.begin();
690 for (k, _v) in &writes {
691 prop_assert_eq!(model.read(reader, k), None);
692 }
693 }
694
695 #[test]
699 fn prop_no_dirty_reads(k in key_strategy(), v in prop::collection::vec(any::<u8>(), 1..4)) {
700 let mut model = KvModel::new();
701 let observer = model.begin(); let writer = model.begin();
703
704 model.write(writer, k.clone(), v.clone());
705
706 prop_assert_eq!(model.read(writer, &k), Some(v.clone()));
708 prop_assert_eq!(model.read(observer, &k), None);
710
711 prop_assert!(model.commit(writer));
713 let later = model.begin();
714 prop_assert_eq!(model.read(later, &k), Some(v));
715 }
716
717 #[test]
721 fn prop_crash_injection_is_exact(after_ops in 1u64..12, probes in 1u64..20) {
722 let mut sim = CrashSimulator::new();
723 sim.schedule_crash(CrashPoint::AfterWalWriteBeforeFsync, after_ops);
724
725 let mut fire_count = 0u64;
726 let mut fired_at = None;
727 for i in 1..=probes {
728 prop_assert!(!sim.maybe_crash(CrashPoint::BeforeWalWrite));
730 if sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync) {
731 fire_count += 1;
732 fired_at = Some(i);
733 }
734 }
735
736 if probes >= after_ops {
737 prop_assert_eq!(fire_count, 1);
738 prop_assert_eq!(fired_at, Some(after_ops));
739 prop_assert_eq!(sim.triggered_crashes().len(), 1);
740 } else {
741 prop_assert_eq!(fire_count, 0);
742 prop_assert!(sim.triggered_crashes().is_empty());
743 }
744 }
745 }
746
747 use crate::wal_segment::{SegmentConfig, WalSegmentManager};
754 use tempfile::tempdir;
755
756 fn wal_payloads_strategy() -> impl Strategy<Value = Vec<Vec<u8>>> {
759 prop::collection::vec(prop::collection::vec(any::<u8>(), 1..48), 1..40)
760 }
761
762 proptest! {
763 #![proptest_config(ProptestConfig { cases: 24, ..ProptestConfig::default() })]
766
767 #[test]
771 fn prop_wal_recovers_all_entries_in_order(payloads in wal_payloads_strategy()) {
772 let dir = tempdir().unwrap();
773 let config = SegmentConfig::default()
776 .with_wal_dir(dir.path())
777 .with_max_size(256);
778
779 {
781 let manager = WalSegmentManager::new(config.clone()).unwrap();
782 for (i, p) in payloads.iter().enumerate() {
783 let lsn = manager.append(p).unwrap();
784 prop_assert_eq!(lsn, i as u64);
786 }
787 manager.shutdown().unwrap();
788 }
789
790 {
792 let manager = WalSegmentManager::new(config).unwrap();
793 let mut iter = manager.recovery_iterator(0);
794 let mut recovered: Vec<Vec<u8>> = Vec::new();
795 while let Some(entry) = iter.next_entry().unwrap() {
796 recovered.push(entry.data);
797 }
798 prop_assert_eq!(recovered, payloads);
799 }
800 }
801 }
802}