Skip to main content

sochdb_kernel/
transaction.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Transaction Management
19//!
20//! Core transaction manager with MVCC support.
21//! This is the minimal ACID transaction implementation for the kernel.
22
23use crate::error::{KernelError, KernelResult, TransactionErrorKind};
24use crate::wal::LogSequenceNumber;
25use parking_lot::{Mutex, RwLock};
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::time::{Duration, Instant};
29
30/// Transaction identifier
31pub type TransactionId = u64;
32
33/// Timestamp for MVCC
34pub type Timestamp = u64;
35
36/// Isolation level
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
38pub enum IsolationLevel {
39    /// Read uncommitted - sees uncommitted changes (rarely used)
40    ReadUncommitted,
41    /// Read committed - only sees committed changes
42    ReadCommitted,
43    /// Repeatable read - snapshot at first read
44    RepeatableRead,
45    /// Snapshot isolation - snapshot at transaction start
46    #[default]
47    SnapshotIsolation,
48    /// Serializable - full serializability via SSI
49    Serializable,
50}
51
52/// Transaction state
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum TransactionState {
55    /// Transaction is active
56    Active,
57    /// Transaction is preparing to commit
58    Preparing,
59    /// Transaction is committed
60    Committed,
61    /// Transaction is aborted
62    Aborted,
63}
64
65/// Transaction metadata
66#[derive(Debug)]
67struct TransactionInfo {
68    /// Transaction ID
69    id: TransactionId,
70    /// Transaction state
71    state: TransactionState,
72    /// Snapshot timestamp (for MVCC visibility)
73    snapshot_ts: Timestamp,
74    /// Commit timestamp (set on commit)
75    commit_ts: Option<Timestamp>,
76    /// Isolation level
77    isolation: IsolationLevel,
78    /// Start time (for timeout detection)
79    start_time: Instant,
80    /// Last LSN written by this transaction
81    last_lsn: Option<LogSequenceNumber>,
82    /// Read set (for SSI conflict detection)
83    read_set: Vec<(u32, u64)>, // (table_id, row_id)
84    /// Write set (for conflict detection)
85    write_set: Vec<(u32, u64)>, // (table_id, row_id)
86}
87
88/// Transaction manager
89///
90/// Manages transaction lifecycle and MVCC timestamps.
91pub struct TxnManager {
92    /// Next transaction ID
93    next_txn_id: AtomicU64,
94    /// Current timestamp (logical clock)
95    current_ts: AtomicU64,
96    /// Active transactions
97    active_txns: RwLock<HashMap<TransactionId, TransactionInfo>>,
98    /// Transaction timeout
99    timeout: Duration,
100    /// Lock for commit ordering
101    commit_lock: Mutex<()>,
102}
103
104impl Default for TxnManager {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl TxnManager {
111    /// Create a new transaction manager
112    pub fn new() -> Self {
113        Self::with_timeout(Duration::from_secs(60))
114    }
115
116    /// Create with custom timeout
117    pub fn with_timeout(timeout: Duration) -> Self {
118        Self {
119            next_txn_id: AtomicU64::new(1),
120            current_ts: AtomicU64::new(1),
121            active_txns: RwLock::new(HashMap::new()),
122            timeout,
123            commit_lock: Mutex::new(()),
124        }
125    }
126
127    /// Begin a new transaction with default isolation
128    pub fn begin(&self) -> TransactionId {
129        self.begin_with_isolation(IsolationLevel::default())
130    }
131
132    /// Begin a new transaction with specific isolation level
133    pub fn begin_with_isolation(&self, isolation: IsolationLevel) -> TransactionId {
134        let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
135        let snapshot_ts = self.current_ts.load(Ordering::SeqCst);
136
137        let info = TransactionInfo {
138            id: txn_id,
139            state: TransactionState::Active,
140            snapshot_ts,
141            commit_ts: None,
142            isolation,
143            start_time: Instant::now(),
144            last_lsn: None,
145            read_set: Vec::new(),
146            write_set: Vec::new(),
147        };
148
149        self.active_txns.write().insert(txn_id, info);
150        txn_id
151    }
152
153    /// Commit a transaction
154    pub fn commit(&self, txn_id: TransactionId) -> KernelResult<Timestamp> {
155        // Acquire commit lock for ordering
156        let _guard = self.commit_lock.lock();
157
158        let mut txns = self.active_txns.write();
159
160        // First check state and get necessary info
161        let (current_state, isolation, read_set, write_set) = {
162            let info = txns.get(&txn_id).ok_or(KernelError::Transaction {
163                kind: TransactionErrorKind::NotFound(txn_id),
164            })?;
165            (
166                info.state,
167                info.isolation,
168                info.read_set.clone(),
169                info.write_set.clone(),
170            )
171        };
172
173        match current_state {
174            TransactionState::Active | TransactionState::Preparing => {
175                // Check for SSI conflicts if serializable (using cloned data)
176                if isolation == IsolationLevel::Serializable {
177                    self.check_serialization_conflicts_cloned(&read_set, &write_set)?;
178                }
179
180                // Now get mutable reference and update
181                let info = txns.get_mut(&txn_id).unwrap();
182
183                // Allocate commit timestamp
184                let commit_ts = self.current_ts.fetch_add(1, Ordering::SeqCst);
185                info.commit_ts = Some(commit_ts);
186                info.state = TransactionState::Committed;
187
188                Ok(commit_ts)
189            }
190            TransactionState::Committed => Err(KernelError::Transaction {
191                kind: TransactionErrorKind::AlreadyCommitted,
192            }),
193            TransactionState::Aborted => Err(KernelError::Transaction {
194                kind: TransactionErrorKind::AlreadyAborted,
195            }),
196        }
197    }
198
199    /// Abort a transaction
200    pub fn abort(&self, txn_id: TransactionId) -> KernelResult<()> {
201        let mut txns = self.active_txns.write();
202        let info = txns.get_mut(&txn_id).ok_or(KernelError::Transaction {
203            kind: TransactionErrorKind::NotFound(txn_id),
204        })?;
205
206        match info.state {
207            TransactionState::Active | TransactionState::Preparing => {
208                info.state = TransactionState::Aborted;
209                Ok(())
210            }
211            TransactionState::Committed => Err(KernelError::Transaction {
212                kind: TransactionErrorKind::AlreadyCommitted,
213            }),
214            TransactionState::Aborted => Ok(()), // Idempotent
215        }
216    }
217
218    /// Check if a transaction is active
219    pub fn is_active(&self, txn_id: TransactionId) -> bool {
220        self.active_txns
221            .read()
222            .get(&txn_id)
223            .map(|info| info.state == TransactionState::Active)
224            .unwrap_or(false)
225    }
226
227    /// Get snapshot timestamp for a transaction
228    pub fn snapshot_ts(&self, txn_id: TransactionId) -> KernelResult<Timestamp> {
229        self.active_txns
230            .read()
231            .get(&txn_id)
232            .map(|info| info.snapshot_ts)
233            .ok_or(KernelError::Transaction {
234                kind: TransactionErrorKind::NotFound(txn_id),
235            })
236    }
237
238    /// Record a read operation (for SSI)
239    pub fn record_read(&self, txn_id: TransactionId, table_id: u32, row_id: u64) {
240        if let Some(info) = self.active_txns.write().get_mut(&txn_id)
241            && info.isolation == IsolationLevel::Serializable
242        {
243            info.read_set.push((table_id, row_id));
244        }
245    }
246
247    /// Record a write operation
248    pub fn record_write(&self, txn_id: TransactionId, table_id: u32, row_id: u64) {
249        if let Some(info) = self.active_txns.write().get_mut(&txn_id) {
250            info.write_set.push((table_id, row_id));
251        }
252    }
253
254    /// Update last LSN for a transaction
255    pub fn set_last_lsn(&self, txn_id: TransactionId, lsn: LogSequenceNumber) {
256        if let Some(info) = self.active_txns.write().get_mut(&txn_id) {
257            info.last_lsn = Some(lsn);
258        }
259    }
260
261    /// Get minimum active snapshot timestamp (for GC)
262    pub fn min_active_snapshot(&self) -> Option<Timestamp> {
263        self.active_txns
264            .read()
265            .values()
266            .filter(|info| info.state == TransactionState::Active)
267            .map(|info| info.snapshot_ts)
268            .min()
269    }
270
271    /// Get active transaction count
272    pub fn active_count(&self) -> usize {
273        self.active_txns
274            .read()
275            .values()
276            .filter(|info| info.state == TransactionState::Active)
277            .count()
278    }
279
280    /// Clean up completed transactions older than retention period
281    pub fn cleanup(&self, retention: Duration) {
282        let now = Instant::now();
283        self.active_txns.write().retain(|_, info| {
284            // Keep active transactions
285            if info.state == TransactionState::Active {
286                return true;
287            }
288            // Keep recently completed transactions
289            now.duration_since(info.start_time) < retention
290        });
291    }
292
293    /// Check for transactions that have timed out
294    pub fn check_timeouts(&self) -> Vec<TransactionId> {
295        let now = Instant::now();
296        self.active_txns
297            .read()
298            .values()
299            .filter(|info| {
300                info.state == TransactionState::Active
301                    && now.duration_since(info.start_time) > self.timeout
302            })
303            .map(|info| info.id)
304            .collect()
305    }
306
307    /// Check serialization conflicts for SSI
308    #[allow(dead_code)]
309    fn check_serialization_conflicts(
310        &self,
311        txn: &TransactionInfo,
312        _all_txns: &HashMap<TransactionId, TransactionInfo>,
313    ) -> KernelResult<()> {
314        // Simplified SSI check - in production this would track rw-dependencies
315        // and detect dangerous structures (two consecutive rw-antidependencies)
316        //
317        // For now, we just check for write-write conflicts
318        // Full SSI implementation is in sochdb-storage/src/ssi.rs
319        let _ = txn;
320        Ok(())
321    }
322
323    /// Check serialization conflicts for SSI (using cloned data to avoid borrow issues)
324    fn check_serialization_conflicts_cloned(
325        &self,
326        _read_set: &[(u32, u64)],
327        _write_set: &[(u32, u64)],
328    ) -> KernelResult<()> {
329        // Simplified SSI check - in production this would track rw-dependencies
330        // and detect dangerous structures (two consecutive rw-antidependencies)
331        //
332        // For now, we just check for write-write conflicts
333        // Full SSI implementation is in sochdb-storage/src/ssi.rs
334        Ok(())
335    }
336
337    /// Get current timestamp
338    pub fn current_timestamp(&self) -> Timestamp {
339        self.current_ts.load(Ordering::SeqCst)
340    }
341
342    /// Restore state after recovery
343    pub fn restore(&self, next_txn_id: TransactionId, current_ts: Timestamp) {
344        self.next_txn_id.store(next_txn_id, Ordering::SeqCst);
345        self.current_ts.store(current_ts, Ordering::SeqCst);
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_begin_commit() {
355        let mgr = TxnManager::new();
356
357        let txn1 = mgr.begin();
358        assert!(mgr.is_active(txn1));
359
360        let commit_ts = mgr.commit(txn1).unwrap();
361        assert!(!mgr.is_active(txn1));
362        assert!(commit_ts > 0);
363    }
364
365    #[test]
366    fn test_begin_abort() {
367        let mgr = TxnManager::new();
368
369        let txn1 = mgr.begin();
370        assert!(mgr.is_active(txn1));
371
372        mgr.abort(txn1).unwrap();
373        assert!(!mgr.is_active(txn1));
374    }
375
376    #[test]
377    fn test_snapshot_isolation() {
378        let mgr = TxnManager::new();
379
380        let txn1 = mgr.begin();
381        let ts1 = mgr.snapshot_ts(txn1).unwrap();
382
383        // Commit txn1 to advance timestamp
384        mgr.commit(txn1).unwrap();
385
386        let txn2 = mgr.begin();
387        let ts2 = mgr.snapshot_ts(txn2).unwrap();
388
389        // txn2 should have later snapshot
390        assert!(ts2 >= ts1);
391    }
392
393    #[test]
394    fn test_double_commit_fails() {
395        let mgr = TxnManager::new();
396        let txn1 = mgr.begin();
397
398        mgr.commit(txn1).unwrap();
399        assert!(mgr.commit(txn1).is_err());
400    }
401
402    #[test]
403    fn test_min_active_snapshot() {
404        let mgr = TxnManager::new();
405
406        let txn1 = mgr.begin();
407        let txn2 = mgr.begin();
408
409        let min = mgr.min_active_snapshot().unwrap();
410        assert_eq!(min, mgr.snapshot_ts(txn1).unwrap());
411
412        mgr.commit(txn1).unwrap();
413        let min = mgr.min_active_snapshot().unwrap();
414        assert_eq!(min, mgr.snapshot_ts(txn2).unwrap());
415    }
416}