Skip to main content

sql_rs/storage/
transaction.rs

1use crate::{SqlRsError, Result};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone)]
5pub enum TransactionOp {
6    Put { key: Vec<u8>, value: Vec<u8> },
7    Delete { key: Vec<u8> },
8}
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum TransactionState {
12    Active,
13    Committed,
14    Aborted,
15}
16
17pub struct Transaction {
18    id: u64,
19    state: TransactionState,
20    operations: Vec<TransactionOp>,
21    undo_log: HashMap<Vec<u8>, Option<Vec<u8>>>,
22}
23
24impl Transaction {
25    pub fn new(id: u64) -> Self {
26        Self {
27            id,
28            state: TransactionState::Active,
29            operations: Vec::new(),
30            undo_log: HashMap::new(),
31        }
32    }
33
34    pub fn id(&self) -> u64 {
35        self.id
36    }
37
38    pub fn state(&self) -> &TransactionState {
39        &self.state
40    }
41
42    pub fn is_active(&self) -> bool {
43        self.state == TransactionState::Active
44    }
45
46    pub fn add_operation(&mut self, op: TransactionOp, old_value: Option<Vec<u8>>) -> Result<()> {
47        if !self.is_active() {
48            return Err(SqlRsError::InvalidOperation(
49                "Transaction is not active".to_string(),
50            ));
51        }
52
53        match &op {
54            TransactionOp::Put { key, .. } | TransactionOp::Delete { key } => {
55                if !self.undo_log.contains_key(key) {
56                    self.undo_log.insert(key.clone(), old_value);
57                }
58            }
59        }
60
61        self.operations.push(op);
62        Ok(())
63    }
64
65    pub fn commit(&mut self) -> Result<()> {
66        if !self.is_active() {
67            return Err(SqlRsError::InvalidOperation(
68                "Transaction is not active".to_string(),
69            ));
70        }
71
72        self.state = TransactionState::Committed;
73        self.undo_log.clear();
74        Ok(())
75    }
76
77    pub fn rollback(&mut self) -> Result<Vec<(Vec<u8>, Option<Vec<u8>>)>> {
78        if !self.is_active() {
79            return Err(SqlRsError::InvalidOperation(
80                "Transaction is not active".to_string(),
81            ));
82        }
83
84        self.state = TransactionState::Aborted;
85
86        let undo_ops: Vec<(Vec<u8>, Option<Vec<u8>>)> = self
87            .undo_log
88            .iter()
89            .map(|(k, v)| (k.clone(), v.clone()))
90            .collect();
91
92        self.undo_log.clear();
93        self.operations.clear();
94
95        Ok(undo_ops)
96    }
97
98    pub fn operations(&self) -> &[TransactionOp] {
99        &self.operations
100    }
101}
102
103pub struct TransactionManager {
104    next_id: u64,
105    active_transactions: HashMap<u64, Transaction>,
106}
107
108impl TransactionManager {
109    pub fn new() -> Self {
110        Self {
111            next_id: 1,
112            active_transactions: HashMap::new(),
113        }
114    }
115
116    pub fn begin(&mut self) -> u64 {
117        let id = self.next_id;
118        self.next_id += 1;
119
120        let tx = Transaction::new(id);
121        self.active_transactions.insert(id, tx);
122
123        id
124    }
125
126    pub fn get(&self, id: u64) -> Option<&Transaction> {
127        self.active_transactions.get(&id)
128    }
129
130    pub fn get_mut(&mut self, id: u64) -> Option<&mut Transaction> {
131        self.active_transactions.get_mut(&id)
132    }
133
134    pub fn commit(&mut self, id: u64) -> Result<()> {
135        let tx = self
136            .active_transactions
137            .get_mut(&id)
138            .ok_or_else(|| SqlRsError::NotFound(format!("Transaction {} not found", id)))?;
139
140        tx.commit()?;
141        self.active_transactions.remove(&id);
142
143        Ok(())
144    }
145
146    pub fn rollback(&mut self, id: u64) -> Result<Vec<(Vec<u8>, Option<Vec<u8>>)>> {
147        let tx = self
148            .active_transactions
149            .get_mut(&id)
150            .ok_or_else(|| SqlRsError::NotFound(format!("Transaction {} not found", id)))?;
151
152        let undo_ops = tx.rollback()?;
153        self.active_transactions.remove(&id);
154
155        Ok(undo_ops)
156    }
157
158    pub fn active_count(&self) -> usize {
159        self.active_transactions.len()
160    }
161}
162
163impl Default for TransactionManager {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_transaction_lifecycle() {
175        let mut tx = Transaction::new(1);
176        assert!(tx.is_active());
177
178        tx.add_operation(
179            TransactionOp::Put {
180                key: b"key1".to_vec(),
181                value: b"value1".to_vec(),
182            },
183            None,
184        )
185        .unwrap();
186
187        tx.commit().unwrap();
188        assert_eq!(tx.state(), &TransactionState::Committed);
189    }
190
191    #[test]
192    fn test_transaction_rollback() {
193        let mut tx = Transaction::new(1);
194
195        tx.add_operation(
196            TransactionOp::Put {
197                key: b"key1".to_vec(),
198                value: b"value1".to_vec(),
199            },
200            Some(b"old_value".to_vec()),
201        )
202        .unwrap();
203
204        let undo_ops = tx.rollback().unwrap();
205        assert_eq!(undo_ops.len(), 1);
206        assert_eq!(tx.state(), &TransactionState::Aborted);
207    }
208
209    #[test]
210    fn test_transaction_manager() {
211        let mut mgr = TransactionManager::new();
212
213        let tx_id = mgr.begin();
214        assert_eq!(mgr.active_count(), 1);
215
216        mgr.commit(tx_id).unwrap();
217        assert_eq!(mgr.active_count(), 0);
218    }
219
220    #[test]
221    fn test_multiple_transactions() {
222        let mut mgr = TransactionManager::new();
223
224        let tx1 = mgr.begin();
225        let tx2 = mgr.begin();
226
227        assert_eq!(mgr.active_count(), 2);
228
229        mgr.commit(tx1).unwrap();
230        assert_eq!(mgr.active_count(), 1);
231
232        mgr.rollback(tx2).unwrap();
233        assert_eq!(mgr.active_count(), 0);
234    }
235}