1use crate::connection::Connection;
6use crate::error::SqliteError;
7use crate::types::Param;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum IsolationLevel {
12 #[default]
14 Deferred,
15 Immediate,
17 Exclusive,
19}
20
21impl IsolationLevel {
22 pub fn as_sql(&self) -> &'static str {
24 match self {
25 Self::Deferred => "DEFERRED",
26 Self::Immediate => "IMMEDIATE",
27 Self::Exclusive => "EXCLUSIVE",
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TransactionState {
35 Active,
37 Committed,
39 RolledBack,
41}
42
43pub struct Transaction<'a> {
47 conn: &'a Connection,
48 state: TransactionState,
49}
50
51impl<'a> Transaction<'a> {
52 pub fn begin(conn: &'a Connection) -> Result<Self, SqliteError> {
54 Self::begin_with_isolation(conn, IsolationLevel::Deferred)
55 }
56
57 pub fn begin_with_isolation(
59 conn: &'a Connection,
60 isolation: IsolationLevel,
61 ) -> Result<Self, SqliteError> {
62 conn.execute_batch(&format!("BEGIN {} TRANSACTION", isolation.as_sql()))
63 .map_err(|e| SqliteError::TransactionFailed(format!("BEGIN failed: {}", e)))?;
64
65 Ok(Self {
66 conn,
67 state: TransactionState::Active,
68 })
69 }
70
71 pub fn begin_immediate(conn: &'a Connection) -> Result<Self, SqliteError> {
73 Self::begin_with_isolation(conn, IsolationLevel::Immediate)
74 }
75
76 pub fn state(&self) -> TransactionState {
78 self.state
79 }
80
81 pub fn is_active(&self) -> bool {
83 self.state == TransactionState::Active
84 }
85
86 pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
88 if !self.is_active() {
89 return Err(SqliteError::TransactionFailed(
90 "Transaction is not active".to_string(),
91 ));
92 }
93 self.conn.execute(sql, params)
94 }
95
96 pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
98 if !self.is_active() {
99 return Err(SqliteError::TransactionFailed(
100 "Transaction is not active".to_string(),
101 ));
102 }
103 self.conn.execute_batch(sql)
104 }
105
106 pub fn commit(mut self) -> Result<(), SqliteError> {
108 if !self.is_active() {
109 return Err(SqliteError::TransactionFailed(
110 "Transaction is not active".to_string(),
111 ));
112 }
113
114 self.conn
115 .execute_batch("COMMIT")
116 .map_err(|e| SqliteError::TransactionFailed(format!("COMMIT failed: {}", e)))?;
117
118 self.state = TransactionState::Committed;
119 Ok(())
120 }
121
122 pub fn rollback(mut self) -> Result<(), SqliteError> {
124 if !self.is_active() {
125 return Ok(()); }
127
128 self.conn
129 .execute_batch("ROLLBACK")
130 .map_err(|e| SqliteError::TransactionFailed(format!("ROLLBACK failed: {}", e)))?;
131
132 self.state = TransactionState::RolledBack;
133 Ok(())
134 }
135
136 pub fn savepoint(&self, name: &str) -> Result<Savepoint<'_, 'a>, SqliteError> {
138 if !self.is_active() {
139 return Err(SqliteError::TransactionFailed(
140 "Transaction is not active".to_string(),
141 ));
142 }
143 Savepoint::new(self, name)
144 }
145}
146
147impl<'a> Drop for Transaction<'a> {
148 fn drop(&mut self) {
149 if self.is_active() {
151 let _ = self.conn.execute_batch("ROLLBACK");
152 self.state = TransactionState::RolledBack;
153 }
154 }
155}
156
157pub struct Savepoint<'t, 'c> {
159 tx: &'t Transaction<'c>,
160 name: String,
161 released: bool,
162}
163
164impl<'t, 'c> Savepoint<'t, 'c> {
165 fn new(tx: &'t Transaction<'c>, name: &str) -> Result<Self, SqliteError> {
167 tx.execute_batch(&format!("SAVEPOINT {}", name))?;
168 Ok(Self {
169 tx,
170 name: name.to_string(),
171 released: false,
172 })
173 }
174
175 pub fn release(mut self) -> Result<(), SqliteError> {
177 self.tx.execute_batch(&format!("RELEASE SAVEPOINT {}", self.name))?;
178 self.released = true;
179 Ok(())
180 }
181
182 pub fn rollback(mut self) -> Result<(), SqliteError> {
184 self.tx
185 .execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name))?;
186 self.released = true;
187 Ok(())
188 }
189}
190
191impl<'t, 'c> Drop for Savepoint<'t, 'c> {
192 fn drop(&mut self) {
193 if !self.released {
195 let _ = self
196 .tx
197 .conn
198 .execute_batch(&format!("ROLLBACK TO SAVEPOINT {}", self.name));
199 }
200 }
201}
202
203pub fn with_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
207where
208 F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
209{
210 let tx = Transaction::begin(conn)?;
211 match f(&tx) {
212 Ok(result) => {
213 tx.commit()?;
214 Ok(result)
215 }
216 Err(e) => {
217 tx.rollback()?;
218 Err(e)
219 }
220 }
221}
222
223pub fn with_immediate_transaction<F, T>(conn: &Connection, f: F) -> Result<T, SqliteError>
225where
226 F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
227{
228 let tx = Transaction::begin_immediate(conn)?;
229 match f(&tx) {
230 Ok(result) => {
231 tx.commit()?;
232 Ok(result)
233 }
234 Err(e) => {
235 tx.rollback()?;
236 Err(e)
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 fn setup_test_db() -> Connection {
246 let conn = Connection::open_in_memory().unwrap();
247 conn.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
248 .unwrap();
249 conn
250 }
251
252 #[test]
253 fn test_commit() {
254 let conn = setup_test_db();
255
256 {
257 let tx = Transaction::begin(&conn).unwrap();
258 tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
259 .unwrap();
260 tx.commit().unwrap();
261 }
262
263 let rows = conn.query("SELECT * FROM test", &[]).unwrap();
264 assert_eq!(rows.len(), 1);
265 }
266
267 #[test]
268 fn test_rollback() {
269 let conn = setup_test_db();
270
271 {
272 let tx = Transaction::begin(&conn).unwrap();
273 tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
274 .unwrap();
275 tx.rollback().unwrap();
276 }
277
278 let rows = conn.query("SELECT * FROM test", &[]).unwrap();
279 assert_eq!(rows.len(), 0);
280 }
281
282 #[test]
283 fn test_auto_rollback_on_drop() {
284 let conn = setup_test_db();
285
286 {
287 let tx = Transaction::begin(&conn).unwrap();
288 tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
289 .unwrap();
290 }
292
293 let rows = conn.query("SELECT * FROM test", &[]).unwrap();
294 assert_eq!(rows.len(), 0); }
296
297 #[test]
298 fn test_with_transaction() {
299 let conn = setup_test_db();
300
301 let result = with_transaction(&conn, |tx| {
303 tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])?;
304 Ok(42)
305 });
306
307 assert_eq!(result.unwrap(), 42);
308 let rows = conn.query("SELECT * FROM test", &[]).unwrap();
309 assert_eq!(rows.len(), 1);
310
311 let result: Result<i32, SqliteError> = with_transaction(&conn, |tx| {
313 tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])?;
314 Err(SqliteError::Internal("test error".to_string()))
315 });
316
317 assert!(result.is_err());
318 let rows = conn.query("SELECT * FROM test", &[]).unwrap();
319 assert_eq!(rows.len(), 1); }
321
322 #[test]
323 fn test_savepoint() {
324 let conn = setup_test_db();
325
326 let tx = Transaction::begin(&conn).unwrap();
327 tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])
328 .unwrap();
329
330 {
331 let sp = tx.savepoint("sp1").unwrap();
332 tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])
333 .unwrap();
334 sp.rollback().unwrap(); }
336
337 tx.commit().unwrap();
338
339 let rows = conn.query("SELECT * FROM test", &[]).unwrap();
340 assert_eq!(rows.len(), 1); }
342}