unistore_sqlite/
transaction.rs

1//! 事务支持
2//!
3//! 职责:提供 RAII 风格的事务管理
4
5use crate::connection::Connection;
6use crate::error::SqliteError;
7use crate::types::Param;
8
9/// 事务隔离级别
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum IsolationLevel {
12    /// 默认(DEFERRED)
13    #[default]
14    Deferred,
15    /// 立即获取写锁
16    Immediate,
17    /// 独占锁
18    Exclusive,
19}
20
21impl IsolationLevel {
22    /// 转换为 SQL 关键字
23    pub fn as_sql(&self) -> &'static str {
24        match self {
25            Self::Deferred => "DEFERRED",
26            Self::Immediate => "IMMEDIATE",
27            Self::Exclusive => "EXCLUSIVE",
28        }
29    }
30}
31
32/// 事务状态
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TransactionState {
35    /// 活跃中
36    Active,
37    /// 已提交
38    Committed,
39    /// 已回滚
40    RolledBack,
41}
42
43/// 事务包装器
44///
45/// 使用 RAII 模式,离开作用域时自动回滚未提交的事务
46pub struct Transaction<'a> {
47    conn: &'a Connection,
48    state: TransactionState,
49}
50
51impl<'a> Transaction<'a> {
52    /// 开始新事务
53    pub fn begin(conn: &'a Connection) -> Result<Self, SqliteError> {
54        Self::begin_with_isolation(conn, IsolationLevel::Deferred)
55    }
56
57    /// 开始事务(指定隔离级别)
58    pub fn begin_with_isolation(
59        conn: &'a Connection,
60        isolation: IsolationLevel,
61    ) -> Result<Self, SqliteError> {
62        conn.execute_batch(&format!("BEGIN {} TRANSACTION", isolation.as_sql()))
63            .map_err(|e| SqliteError::TransactionFailed(format!("BEGIN failed: {}", e)))?;
64
65        Ok(Self {
66            conn,
67            state: TransactionState::Active,
68        })
69    }
70
71    /// 开始立即事务(推荐用于写操作)
72    pub fn begin_immediate(conn: &'a Connection) -> Result<Self, SqliteError> {
73        Self::begin_with_isolation(conn, IsolationLevel::Immediate)
74    }
75
76    /// 获取事务状态
77    pub fn state(&self) -> TransactionState {
78        self.state
79    }
80
81    /// 判断事务是否活跃
82    pub fn is_active(&self) -> bool {
83        self.state == TransactionState::Active
84    }
85
86    /// 执行 SQL
87    pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
88        if !self.is_active() {
89            return Err(SqliteError::TransactionFailed(
90                "Transaction is not active".to_string(),
91            ));
92        }
93        self.conn.execute(sql, params)
94    }
95
96    /// 执行批量 SQL
97    pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
98        if !self.is_active() {
99            return Err(SqliteError::TransactionFailed(
100                "Transaction is not active".to_string(),
101            ));
102        }
103        self.conn.execute_batch(sql)
104    }
105
106    /// 提交事务
107    pub fn commit(mut self) -> Result<(), SqliteError> {
108        if !self.is_active() {
109            return Err(SqliteError::TransactionFailed(
110                "Transaction is not active".to_string(),
111            ));
112        }
113
114        self.conn
115            .execute_batch("COMMIT")
116            .map_err(|e| SqliteError::TransactionFailed(format!("COMMIT failed: {}", e)))?;
117
118        self.state = TransactionState::Committed;
119        Ok(())
120    }
121
122    /// 回滚事务
123    pub fn rollback(mut self) -> Result<(), SqliteError> {
124        if !self.is_active() {
125            return Ok(()); // 已经不活跃,无需回滚
126        }
127
128        self.conn
129            .execute_batch("ROLLBACK")
130            .map_err(|e| SqliteError::TransactionFailed(format!("ROLLBACK failed: {}", e)))?;
131
132        self.state = TransactionState::RolledBack;
133        Ok(())
134    }
135
136    /// 创建保存点
137    pub fn savepoint(&self, name: &str) -> Result<Savepoint<'_, 'a>, SqliteError> {
138        if !self.is_active() {
139            return Err(SqliteError::TransactionFailed(
140                "Transaction is not active".to_string(),
141            ));
142        }
143        Savepoint::new(self, name)
144    }
145}
146
147impl<'a> Drop for Transaction<'a> {
148    fn drop(&mut self) {
149        // 如果事务仍然活跃,自动回滚
150        if self.is_active() {
151            let _ = self.conn.execute_batch("ROLLBACK");
152            self.state = TransactionState::RolledBack;
153        }
154    }
155}
156
157/// 保存点
158pub struct Savepoint<'t, 'c> {
159    tx: &'t Transaction<'c>,
160    name: String,
161    released: bool,
162}
163
164impl<'t, 'c> Savepoint<'t, 'c> {
165    /// 创建保存点
166    fn new(tx: &'t Transaction<'c>, name: &str) -> Result<Self, SqliteError> {
167        tx.execute_batch(&format!("SAVEPOINT {}", name))?;
168        Ok(Self {
169            tx,
170            name: name.to_string(),
171            released: false,
172        })
173    }
174
175    /// 释放保存点(提交到父事务)
176    pub fn release(mut self) -> Result<(), SqliteError> {
177        self.tx.execute_batch(&format!("RELEASE SAVEPOINT {}", self.name))?;
178        self.released = true;
179        Ok(())
180    }
181
182    /// 回滚到保存点
183    pub fn rollback(mut self) -> Result<(), SqliteError> {
184        self.tx
185            .execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name))?;
186        self.released = true;
187        Ok(())
188    }
189}
190
191impl<'t, 'c> Drop for Savepoint<'t, 'c> {
192    fn drop(&mut self) {
193        // 如果没有明确释放或回滚,自动回滚
194        if !self.released {
195            let _ = self
196                .tx
197                .conn
198                .execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name));
199        }
200    }
201}
202
203/// 便捷的事务执行函数
204///
205/// 自动处理提交和回滚
206pub fn with_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
207where
208    F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
209{
210    let tx = Transaction::begin(conn)?;
211    match f(&tx) {
212        Ok(result) => {
213            tx.commit()?;
214            Ok(result)
215        }
216        Err(e) => {
217            tx.rollback()?;
218            Err(e)
219        }
220    }
221}
222
223/// 便捷的立即事务执行函数
224pub fn with_immediate_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
225where
226    F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
227{
228    let tx = Transaction::begin_immediate(conn)?;
229    match f(&tx) {
230        Ok(result) => {
231            tx.commit()?;
232            Ok(result)
233        }
234        Err(e) => {
235            tx.rollback()?;
236            Err(e)
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    fn setup_test_db() -> Connection {
246        let conn = Connection::open_in_memory().unwrap();
247        conn.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
248            .unwrap();
249        conn
250    }
251
252    #[test]
253    fn test_commit() {
254        let conn = setup_test_db();
255
256        {
257            let tx = Transaction::begin(&conn).unwrap();
258            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
259                .unwrap();
260            tx.commit().unwrap();
261        }
262
263        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
264        assert_eq!(rows.len(), 1);
265    }
266
267    #[test]
268    fn test_rollback() {
269        let conn = setup_test_db();
270
271        {
272            let tx = Transaction::begin(&conn).unwrap();
273            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
274                .unwrap();
275            tx.rollback().unwrap();
276        }
277
278        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
279        assert_eq!(rows.len(), 0);
280    }
281
282    #[test]
283    fn test_auto_rollback_on_drop() {
284        let conn = setup_test_db();
285
286        {
287            let tx = Transaction::begin(&conn).unwrap();
288            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
289                .unwrap();
290            // 不调用 commit 或 rollback,让 Drop 处理
291        }
292
293        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
294        assert_eq!(rows.len(), 0); // 应该已回滚
295    }
296
297    #[test]
298    fn test_with_transaction() {
299        let conn = setup_test_db();
300
301        // 成功场景
302        let result = with_transaction(&conn, |tx| {
303            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])?;
304            Ok(42)
305        });
306
307        assert_eq!(result.unwrap(), 42);
308        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
309        assert_eq!(rows.len(), 1);
310
311        // 失败场景
312        let result: Result<i32, SqliteError> = with_transaction(&conn, |tx| {
313            tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])?;
314            Err(SqliteError::Internal("test error".to_string()))
315        });
316
317        assert!(result.is_err());
318        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
319        assert_eq!(rows.len(), 1); // 仍然只有 1 行,第二个被回滚了
320    }
321
322    #[test]
323    fn test_savepoint() {
324        let conn = setup_test_db();
325
326        let tx = Transaction::begin(&conn).unwrap();
327        tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
328            .unwrap();
329
330        {
331            let sp = tx.savepoint("sp1").unwrap();
332            tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])
333                .unwrap();
334            sp.rollback().unwrap(); // 回滚到保存点
335        }
336
337        tx.commit().unwrap();
338
339        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
340        assert_eq!(rows.len(), 1); // 只有第一个插入成功
341    }
342}