1use std::collections::{BTreeMap, HashMap, HashSet};
37use std::sync::atomic::{AtomicU64, Ordering};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum TxnOp {
42 Begin { txn_id: u64 },
43 Read { key: Vec<u8>, expected: Option<Vec<u8>> },
44 Write { key: Vec<u8>, value: Vec<u8> },
45 Commit { txn_id: u64 },
46 Abort { txn_id: u64 },
47}
48
49#[derive(Debug, Default)]
51pub struct TxnHistory {
52 operations: Vec<TxnOp>,
54 commit_order: Vec<u64>,
56 aborted: HashSet<u64>,
58}
59
60impl TxnHistory {
61 pub fn push(&mut self, op: TxnOp) {
63 match &op {
64 TxnOp::Commit { txn_id } => {
65 self.commit_order.push(*txn_id);
66 }
67 TxnOp::Abort { txn_id } => {
68 self.aborted.insert(*txn_id);
69 }
70 _ => {}
71 }
72 self.operations.push(op);
73 }
74
75 pub fn is_serializable(&self) -> Result<bool, SerializabilityError> {
82 let graph = self.build_dependency_graph()?;
83
84 Ok(!graph.has_cycle())
86 }
87
88 fn build_dependency_graph(&self) -> Result<DependencyGraph, SerializabilityError> {
90 let mut graph = DependencyGraph::new();
91
92 let mut txn_writes: HashMap<u64, HashSet<Vec<u8>>> = HashMap::new();
94 let mut txn_reads: HashMap<u64, HashSet<Vec<u8>>> = HashMap::new();
95 let mut current_txn: Option<u64> = None;
96
97 for op in &self.operations {
98 match op {
99 TxnOp::Begin { txn_id } => {
100 current_txn = Some(*txn_id);
101 graph.add_node(*txn_id);
102 }
103 TxnOp::Read { key, .. } => {
104 if let Some(txn_id) = current_txn {
105 txn_reads.entry(txn_id).or_default().insert(key.clone());
106 }
107 }
108 TxnOp::Write { key, .. } => {
109 if let Some(txn_id) = current_txn {
110 txn_writes.entry(txn_id).or_default().insert(key.clone());
111 }
112 }
113 TxnOp::Commit { .. } | TxnOp::Abort { .. } => {
114 current_txn = None;
115 }
116 }
117 }
118
119 let committed: Vec<_> = self.commit_order.iter().copied().collect();
121 let empty_set: HashSet<Vec<u8>> = HashSet::new();
122 for (i, &t1) in committed.iter().enumerate() {
123 for &t2 in &committed[i+1..] {
124 let t1_writes = txn_writes.get(&t1).unwrap_or(&empty_set);
125 let t2_writes = txn_writes.get(&t2).unwrap_or(&empty_set);
126 let t1_reads = txn_reads.get(&t1).unwrap_or(&empty_set);
127 let t2_reads = txn_reads.get(&t2).unwrap_or(&empty_set);
128
129 if !t1_writes.is_disjoint(t2_writes) {
131 graph.add_edge(t1, t2);
132 }
133
134 if !t1_writes.is_disjoint(t2_reads) {
136 graph.add_edge(t1, t2);
137 }
138
139 if !t1_reads.is_disjoint(t2_writes) {
141 graph.add_edge(t1, t2);
142 }
143 }
144 }
145
146 Ok(graph)
147 }
148}
149
150#[derive(Debug)]
152pub enum SerializabilityError {
153 InvalidHistory(String),
154 CycleDetected(Vec<u64>),
155}
156
157#[derive(Debug, Default)]
159struct DependencyGraph {
160 edges: HashMap<u64, HashSet<u64>>,
162 nodes: HashSet<u64>,
164}
165
166impl DependencyGraph {
167 fn new() -> Self {
168 Self::default()
169 }
170
171 fn add_node(&mut self, node: u64) {
172 self.nodes.insert(node);
173 self.edges.entry(node).or_default();
174 }
175
176 fn add_edge(&mut self, from: u64, to: u64) {
177 self.edges.entry(from).or_default().insert(to);
178 }
179
180 fn has_cycle(&self) -> bool {
182 #[derive(Clone, Copy, PartialEq)]
183 enum Color { White, Gray, Black }
184
185 let mut colors: HashMap<u64, Color> = self.nodes.iter()
186 .map(|&n| (n, Color::White))
187 .collect();
188
189 fn dfs(
190 node: u64,
191 edges: &HashMap<u64, HashSet<u64>>,
192 colors: &mut HashMap<u64, Color>,
193 ) -> bool {
194 colors.insert(node, Color::Gray);
195
196 if let Some(neighbors) = edges.get(&node) {
197 for &neighbor in neighbors {
198 match colors.get(&neighbor) {
199 Some(Color::Gray) => return true, Some(Color::White) => {
201 if dfs(neighbor, edges, colors) {
202 return true;
203 }
204 }
205 _ => {}
206 }
207 }
208 }
209
210 colors.insert(node, Color::Black);
211 false
212 }
213
214 for node in self.nodes.iter().copied() {
215 if colors.get(&node) == Some(&Color::White) {
216 if dfs(node, &self.edges, &mut colors) {
217 return true;
218 }
219 }
220 }
221
222 false
223 }
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub enum CrashPoint {
229 BeforeWalWrite,
231 AfterWalWriteBeforeFsync,
233 AfterFsyncBeforeDataWrite,
235 AfterDataWrite,
237 DuringCheckpoint,
239}
240
241pub struct CrashSimulator {
243 crash_at: Option<CrashPoint>,
245 countdown: AtomicU64,
247 triggered: std::sync::Mutex<Vec<CrashPoint>>,
249}
250
251impl CrashSimulator {
252 pub fn new() -> Self {
254 Self {
255 crash_at: None,
256 countdown: AtomicU64::new(u64::MAX),
257 triggered: std::sync::Mutex::new(Vec::new()),
258 }
259 }
260
261 pub fn schedule_crash(&mut self, point: CrashPoint, after_ops: u64) {
263 self.crash_at = Some(point);
264 self.countdown.store(after_ops, Ordering::SeqCst);
265 }
266
267 pub fn maybe_crash(&self, point: CrashPoint) -> bool {
269 if self.crash_at != Some(point) {
270 return false;
271 }
272
273 let prev = self.countdown.fetch_sub(1, Ordering::SeqCst);
274 if prev == 1 {
275 self.triggered.lock().unwrap().push(point);
276 true
277 } else {
278 false
279 }
280 }
281
282 pub fn triggered_crashes(&self) -> Vec<CrashPoint> {
284 self.triggered.lock().unwrap().clone()
285 }
286}
287
288impl Default for CrashSimulator {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294#[derive(Debug, Default)]
296pub struct KvModel {
297 data: BTreeMap<Vec<u8>, Vec<u8>>,
299 next_txn: u64,
301 active_txns: HashMap<u64, HashMap<Vec<u8>, Vec<u8>>>,
303}
304
305impl KvModel {
306 pub fn new() -> Self {
308 Self::default()
309 }
310
311 pub fn begin(&mut self) -> u64 {
313 let txn_id = self.next_txn;
314 self.next_txn += 1;
315 self.active_txns.insert(txn_id, HashMap::new());
316 txn_id
317 }
318
319 pub fn read(&self, txn_id: u64, key: &[u8]) -> Option<Vec<u8>> {
321 if let Some(txn_writes) = self.active_txns.get(&txn_id) {
323 if let Some(value) = txn_writes.get(key) {
324 return Some(value.clone());
325 }
326 }
327 self.data.get(key).cloned()
329 }
330
331 pub fn write(&mut self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) {
333 if let Some(txn_writes) = self.active_txns.get_mut(&txn_id) {
334 txn_writes.insert(key, value);
335 }
336 }
337
338 pub fn commit(&mut self, txn_id: u64) -> bool {
340 if let Some(writes) = self.active_txns.remove(&txn_id) {
341 for (key, value) in writes {
342 self.data.insert(key, value);
343 }
344 true
345 } else {
346 false
347 }
348 }
349
350 pub fn abort(&mut self, txn_id: u64) -> bool {
352 self.active_txns.remove(&txn_id).is_some()
353 }
354
355 pub fn snapshot(&self) -> BTreeMap<Vec<u8>, Vec<u8>> {
357 self.data.clone()
358 }
359}
360
361pub struct TestOracle<T> {
363 model: KvModel,
365 sut: T,
367 discrepancies: Vec<Discrepancy>,
369}
370
371#[derive(Debug, Clone)]
373pub struct Discrepancy {
374 pub operation: String,
375 pub expected: Option<Vec<u8>>,
376 pub actual: Option<Vec<u8>>,
377 pub key: Vec<u8>,
378}
379
380impl<T> TestOracle<T> {
381 pub fn new(sut: T) -> Self {
383 Self {
384 model: KvModel::new(),
385 sut,
386 discrepancies: Vec::new(),
387 }
388 }
389
390 pub fn model(&mut self) -> &mut KvModel {
392 &mut self.model
393 }
394
395 pub fn sut(&mut self) -> &mut T {
397 &mut self.sut
398 }
399
400 pub fn record_discrepancy(&mut self, discrepancy: Discrepancy) {
402 self.discrepancies.push(discrepancy);
403 }
404
405 pub fn has_discrepancies(&self) -> bool {
407 !self.discrepancies.is_empty()
408 }
409
410 pub fn discrepancies(&self) -> &[Discrepancy] {
412 &self.discrepancies
413 }
414}
415
416#[derive(Debug)]
418pub struct LinearizabilityChecker {
419 history: Vec<LinearOp>,
421}
422
423#[derive(Debug, Clone)]
425pub struct LinearOp {
426 pub op_type: LinearOpType,
428 pub start: u64,
430 pub end: u64,
432 pub value: Option<Vec<u8>>,
434}
435
436#[derive(Debug, Clone, Copy)]
438pub enum LinearOpType {
439 Read,
440 Write,
441}
442
443impl LinearizabilityChecker {
444 pub fn new() -> Self {
446 Self { history: Vec::new() }
447 }
448
449 pub fn add(&mut self, op: LinearOp) {
451 self.history.push(op);
452 }
453
454 pub fn is_linearizable(&self) -> bool {
459 let mut ops = self.history.clone();
461 ops.sort_by_key(|op| op.start);
462
463 let mut current_value: Option<Vec<u8>> = None;
465 let mut pending_writes: Vec<&LinearOp> = Vec::new();
466
467 for op in &ops {
468 pending_writes.retain(|w| w.end >= op.start);
470
471 match op.op_type {
472 LinearOpType::Write => {
473 pending_writes.push(op);
474 current_value = op.value.clone();
475 }
476 LinearOpType::Read => {
477 if op.value != current_value {
479 let matches_pending = pending_writes.iter()
481 .any(|w| w.value == op.value);
482 if !matches_pending && op.value != current_value {
483 return false;
484 }
485 }
486 }
487 }
488 }
489
490 true
491 }
492}
493
494impl Default for LinearizabilityChecker {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_serializable_history_no_conflict() {
506 let mut history = TxnHistory::default();
507
508 history.push(TxnOp::Begin { txn_id: 1 });
510 history.push(TxnOp::Write { key: b"x".to_vec(), value: b"1".to_vec() });
511 history.push(TxnOp::Commit { txn_id: 1 });
512
513 history.push(TxnOp::Begin { txn_id: 2 });
515 history.push(TxnOp::Write { key: b"y".to_vec(), value: b"2".to_vec() });
516 history.push(TxnOp::Commit { txn_id: 2 });
517
518 assert!(history.is_serializable().unwrap());
519 }
520
521 #[test]
522 fn test_kv_model_basic() {
523 let mut model = KvModel::new();
524
525 let txn = model.begin();
526 model.write(txn, b"key1".to_vec(), b"value1".to_vec());
527
528 assert_eq!(model.read(txn, b"key1"), Some(b"value1".to_vec()));
530
531 assert!(model.commit(txn));
533
534 let txn2 = model.begin();
536 assert_eq!(model.read(txn2, b"key1"), Some(b"value1".to_vec()));
537 }
538
539 #[test]
540 fn test_crash_simulator() {
541 let mut sim = CrashSimulator::new();
542 sim.schedule_crash(CrashPoint::AfterWalWriteBeforeFsync, 3);
543
544 assert!(!sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
546 assert!(!sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
547
548 assert!(sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
550
551 assert!(!sim.maybe_crash(CrashPoint::AfterWalWriteBeforeFsync));
553
554 assert_eq!(sim.triggered_crashes().len(), 1);
555 }
556
557 #[test]
558 fn test_linearizability_simple() {
559 let mut checker = LinearizabilityChecker::new();
560
561 checker.add(LinearOp {
563 op_type: LinearOpType::Write,
564 start: 0,
565 end: 2,
566 value: Some(b"1".to_vec()),
567 });
568
569 checker.add(LinearOp {
571 op_type: LinearOpType::Read,
572 start: 3,
573 end: 4,
574 value: Some(b"1".to_vec()),
575 });
576
577 assert!(checker.is_linearizable());
578 }
579
580 #[test]
581 fn test_dependency_graph_cycle_detection() {
582 let mut graph = DependencyGraph::new();
583 graph.add_node(1);
584 graph.add_node(2);
585 graph.add_node(3);
586
587 graph.add_edge(1, 2);
589 graph.add_edge(2, 3);
590 assert!(!graph.has_cycle());
591
592 graph.add_edge(3, 1);
594 assert!(graph.has_cycle());
595 }
596}