1use crate::error::{KernelError, KernelResult, TransactionErrorKind};
24use crate::wal::LogSequenceNumber;
25use parking_lot::{Mutex, RwLock};
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::time::{Duration, Instant};
29
30pub type TransactionId = u64;
32
33pub type Timestamp = u64;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
38pub enum IsolationLevel {
39 ReadUncommitted,
41 ReadCommitted,
43 RepeatableRead,
45 #[default]
47 SnapshotIsolation,
48 Serializable,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum TransactionState {
55 Active,
57 Preparing,
59 Committed,
61 Aborted,
63}
64
65#[derive(Debug)]
67struct TransactionInfo {
68 id: TransactionId,
70 state: TransactionState,
72 snapshot_ts: Timestamp,
74 commit_ts: Option<Timestamp>,
76 isolation: IsolationLevel,
78 start_time: Instant,
80 last_lsn: Option<LogSequenceNumber>,
82 read_set: Vec<(u32, u64)>, write_set: Vec<(u32, u64)>, }
87
88pub struct TxnManager {
92 next_txn_id: AtomicU64,
94 current_ts: AtomicU64,
96 active_txns: RwLock<HashMap<TransactionId, TransactionInfo>>,
98 timeout: Duration,
100 commit_lock: Mutex<()>,
102}
103
104impl Default for TxnManager {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl TxnManager {
111 pub fn new() -> Self {
113 Self::with_timeout(Duration::from_secs(60))
114 }
115
116 pub fn with_timeout(timeout: Duration) -> Self {
118 Self {
119 next_txn_id: AtomicU64::new(1),
120 current_ts: AtomicU64::new(1),
121 active_txns: RwLock::new(HashMap::new()),
122 timeout,
123 commit_lock: Mutex::new(()),
124 }
125 }
126
127 pub fn begin(&self) -> TransactionId {
129 self.begin_with_isolation(IsolationLevel::default())
130 }
131
132 pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> TransactionId {
134 let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
135 let snapshot_ts = self.current_ts.load(Ordering::SeqCst);
136
137 let info = TransactionInfo {
138 id: txn_id,
139 state: TransactionState::Active,
140 snapshot_ts,
141 commit_ts: None,
142 isolation,
143 start_time: Instant::now(),
144 last_lsn: None,
145 read_set: Vec::new(),
146 write_set: Vec::new(),
147 };
148
149 self.active_txns.write().insert(txn_id, info);
150 txn_id
151 }
152
153 pub fn commit(&self, txn_id: TransactionId) -> KernelResult<Timestamp> {
155 let _guard = self.commit_lock.lock();
157
158 let mut txns = self.active_txns.write();
159
160 let (current_state, isolation, read_set, write_set) = {
162 let info = txns.get(&txn_id).ok_or(KernelError::Transaction {
163 kind: TransactionErrorKind::NotFound(txn_id),
164 })?;
165 (
166 info.state,
167 info.isolation,
168 info.read_set.clone(),
169 info.write_set.clone(),
170 )
171 };
172
173 match current_state {
174 TransactionState::Active | TransactionState::Preparing => {
175 if isolation == IsolationLevel::Serializable {
177 self.check_serialization_conflicts_cloned(&read_set, &write_set)?;
178 }
179
180 let info = txns.get_mut(&txn_id).unwrap();
182
183 let commit_ts = self.current_ts.fetch_add(1, Ordering::SeqCst);
185 info.commit_ts = Some(commit_ts);
186 info.state = TransactionState::Committed;
187
188 Ok(commit_ts)
189 }
190 TransactionState::Committed => Err(KernelError::Transaction {
191 kind: TransactionErrorKind::AlreadyCommitted,
192 }),
193 TransactionState::Aborted => Err(KernelError::Transaction {
194 kind: TransactionErrorKind::AlreadyAborted,
195 }),
196 }
197 }
198
199 pub fn abort(&self, txn_id: TransactionId) -> KernelResult<()> {
201 let mut txns = self.active_txns.write();
202 let info = txns.get_mut(&txn_id).ok_or(KernelError::Transaction {
203 kind: TransactionErrorKind::NotFound(txn_id),
204 })?;
205
206 match info.state {
207 TransactionState::Active | TransactionState::Preparing => {
208 info.state = TransactionState::Aborted;
209 Ok(())
210 }
211 TransactionState::Committed => Err(KernelError::Transaction {
212 kind: TransactionErrorKind::AlreadyCommitted,
213 }),
214 TransactionState::Aborted => Ok(()), }
216 }
217
218 pub fn is_active(&self, txn_id: TransactionId) -> bool {
220 self.active_txns
221 .read()
222 .get(&txn_id)
223 .map(|info| info.state == TransactionState::Active)
224 .unwrap_or(false)
225 }
226
227 pub fn snapshot_ts(&self, txn_id: TransactionId) -> KernelResult<Timestamp> {
229 self.active_txns
230 .read()
231 .get(&txn_id)
232 .map(|info| info.snapshot_ts)
233 .ok_or(KernelError::Transaction {
234 kind: TransactionErrorKind::NotFound(txn_id),
235 })
236 }
237
238 pub fn record_read(&self, txn_id: TransactionId, table_id: u32, row_id: u64) {
240 if let Some(info) = self.active_txns.write().get_mut(&txn_id)
241 && info.isolation == IsolationLevel::Serializable
242 {
243 info.read_set.push((table_id, row_id));
244 }
245 }
246
247 pub fn record_write(&self, txn_id: TransactionId, table_id: u32, row_id: u64) {
249 if let Some(info) = self.active_txns.write().get_mut(&txn_id) {
250 info.write_set.push((table_id, row_id));
251 }
252 }
253
254 pub fn set_last_lsn(&self, txn_id: TransactionId, lsn: LogSequenceNumber) {
256 if let Some(info) = self.active_txns.write().get_mut(&txn_id) {
257 info.last_lsn = Some(lsn);
258 }
259 }
260
261 pub fn min_active_snapshot(&self) -> Option<Timestamp> {
263 self.active_txns
264 .read()
265 .values()
266 .filter(|info| info.state == TransactionState::Active)
267 .map(|info| info.snapshot_ts)
268 .min()
269 }
270
271 pub fn active_count(&self) -> usize {
273 self.active_txns
274 .read()
275 .values()
276 .filter(|info| info.state == TransactionState::Active)
277 .count()
278 }
279
280 pub fn cleanup(&self, retention: Duration) {
282 let now = Instant::now();
283 self.active_txns.write().retain(|_, info| {
284 if info.state == TransactionState::Active {
286 return true;
287 }
288 now.duration_since(info.start_time) < retention
290 });
291 }
292
293 pub fn check_timeouts(&self) -> Vec<TransactionId> {
295 let now = Instant::now();
296 self.active_txns
297 .read()
298 .values()
299 .filter(|info| {
300 info.state == TransactionState::Active
301 && now.duration_since(info.start_time) > self.timeout
302 })
303 .map(|info| info.id)
304 .collect()
305 }
306
307 #[allow(dead_code)]
309 fn check_serialization_conflicts(
310 &self,
311 txn: &TransactionInfo,
312 _all_txns: &HashMap<TransactionId, TransactionInfo>,
313 ) -> KernelResult<()> {
314 let _ = txn;
320 Ok(())
321 }
322
323 fn check_serialization_conflicts_cloned(
325 &self,
326 _read_set: &[(u32, u64)],
327 _write_set: &[(u32, u64)],
328 ) -> KernelResult<()> {
329 Ok(())
335 }
336
337 pub fn current_timestamp(&self) -> Timestamp {
339 self.current_ts.load(Ordering::SeqCst)
340 }
341
342 pub fn restore(&self, next_txn_id: TransactionId, current_ts: Timestamp) {
344 self.next_txn_id.store(next_txn_id, Ordering::SeqCst);
345 self.current_ts.store(current_ts, Ordering::SeqCst);
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_begin_commit() {
355 let mgr = TxnManager::new();
356
357 let txn1 = mgr.begin();
358 assert!(mgr.is_active(txn1));
359
360 let commit_ts = mgr.commit(txn1).unwrap();
361 assert!(!mgr.is_active(txn1));
362 assert!(commit_ts > 0);
363 }
364
365 #[test]
366 fn test_begin_abort() {
367 let mgr = TxnManager::new();
368
369 let txn1 = mgr.begin();
370 assert!(mgr.is_active(txn1));
371
372 mgr.abort(txn1).unwrap();
373 assert!(!mgr.is_active(txn1));
374 }
375
376 #[test]
377 fn test_snapshot_isolation() {
378 let mgr = TxnManager::new();
379
380 let txn1 = mgr.begin();
381 let ts1 = mgr.snapshot_ts(txn1).unwrap();
382
383 mgr.commit(txn1).unwrap();
385
386 let txn2 = mgr.begin();
387 let ts2 = mgr.snapshot_ts(txn2).unwrap();
388
389 assert!(ts2 >= ts1);
391 }
392
393 #[test]
394 fn test_double_commit_fails() {
395 let mgr = TxnManager::new();
396 let txn1 = mgr.begin();
397
398 mgr.commit(txn1).unwrap();
399 assert!(mgr.commit(txn1).is_err());
400 }
401
402 #[test]
403 fn test_min_active_snapshot() {
404 let mgr = TxnManager::new();
405
406 let txn1 = mgr.begin();
407 let txn2 = mgr.begin();
408
409 let min = mgr.min_active_snapshot().unwrap();
410 assert_eq!(min, mgr.snapshot_ts(txn1).unwrap());
411
412 mgr.commit(txn1).unwrap();
413 let min = mgr.min_active_snapshot().unwrap();
414 assert_eq!(min, mgr.snapshot_ts(txn2).unwrap());
415 }
416}