vapor_cli/
transactions.rs

1//! # Explicit Transaction Management
2//!
3//! This module provides a stateful manager for handling database transactions explicitly.
4//! It is designed to be used in interactive contexts like a REPL or shell, where users
5//! can manually begin, commit, or roll back transactions.
6//!
7//! ## Core Components:
8//! - `TransactionManager`: A thread-safe struct that tracks the current transaction state.
9//! - `TransactionState`: An enum representing whether a transaction is `Active` or `None`.
10//!
11//! The manager ensures that users cannot start a new transaction while one is already
12//! active and provides clear feedback about the transaction status. It also intercepts
13//! transaction-related SQL keywords (`BEGIN`, `COMMIT`, `ROLLBACK`) to manage state correctly.
14
15use anyhow::{Context, Result};
16use rusqlite::Connection;
17use std::sync::{Arc, Mutex};
18
19/// Represents the current state of a database transaction.
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum TransactionState {
22    /// No transaction is currently active.
23    None,
24    /// A transaction is active and awaiting a `COMMIT` or `ROLLBACK`.
25    Active,
26}
27
28/// Manages the state of database transactions in a thread-safe manner.
29///
30/// This struct wraps the `TransactionState` in an `Arc<Mutex<>>` to allow it to be
31/// shared across different parts of the application, such as between the REPL and
32/// other command handlers, while preventing race conditions.
33pub struct TransactionManager {
34    state: Arc<Mutex<TransactionState>>,
35}
36
37impl TransactionManager {
38    /// Creates a new `TransactionManager` with an initial state of `None`.
39    pub fn new() -> Self {
40        Self {
41            state: Arc::new(Mutex::new(TransactionState::None)),
42        }
43    }
44
45    /// Begins a new database transaction.
46    ///
47    /// If a transaction is already active, it prints a warning and does nothing.
48    /// Otherwise, it executes a `BEGIN` statement and sets the state to `Active`.
49    ///
50    /// # Arguments
51    /// * `conn` - A reference to the `rusqlite::Connection`.
52    pub fn begin_transaction(&self, conn: &Connection) -> Result<()> {
53        let mut state = self.state.lock().unwrap();
54
55        match *state {
56            TransactionState::Active => {
57                println!("Warning: Transaction already active. Use COMMIT or ROLLBACK first.");
58                return Ok(());
59            }
60            TransactionState::None => {
61                conn.execute("BEGIN", [])?;
62                *state = TransactionState::Active;
63                println!("Transaction started.");
64            }
65        }
66
67        Ok(())
68    }
69
70    /// Commits the active database transaction.
71    ///
72    /// If no transaction is active, it prints a message and does nothing.
73    /// Otherwise, it executes a `COMMIT` statement and resets the state to `None`.
74    ///
75    /// # Arguments
76    /// * `conn` - A reference to the `rusqlite::Connection`.
77    pub fn commit_transaction(&self, conn: &Connection) -> Result<()> {
78        let mut state = self.state.lock().unwrap();
79
80        match *state {
81            TransactionState::None => {
82                println!("No active transaction to commit.");
83                return Ok(());
84            }
85            TransactionState::Active => {
86                conn.execute("COMMIT", [])?;
87                *state = TransactionState::None;
88                println!("Transaction committed.");
89            }
90        }
91
92        Ok(())
93    }
94
95    /// Rolls back the active database transaction.
96    ///
97    /// If no transaction is active, it prints a message and does nothing.
98    /// Otherwise, it executes a `ROLLBACK` statement and resets the state to `None`.
99    ///
100    /// # Arguments
101    /// * `conn` - A reference to the `rusqlite::Connection`.
102    pub fn rollback_transaction(&self, conn: &Connection) -> Result<()> {
103        let mut state = self.state.lock().unwrap();
104
105        match *state {
106            TransactionState::None => {
107                println!("No active transaction to rollback.");
108                return Ok(());
109            }
110            TransactionState::Active => {
111                conn.execute("ROLLBACK", [])?;
112                *state = TransactionState::None;
113                println!("Transaction rolled back.");
114            }
115        }
116
117        Ok(())
118    }
119
120    /// Checks if a transaction is currently active.
121    ///
122    /// # Returns
123    /// `true` if the transaction state is `Active`, `false` otherwise.
124    pub fn is_active(&self) -> bool {
125        matches!(*self.state.lock().unwrap(), TransactionState::Active)
126    }
127
128    /// Prints the current transaction status to the console.
129    pub fn show_status(&self) {
130        let state = self.state.lock().unwrap();
131        match *state {
132            TransactionState::None => println!("No active transaction."),
133            TransactionState::Active => println!("Transaction is active."),
134        }
135    }
136
137    /// Intercepts and handles transaction-related SQL commands.
138    ///
139    /// This method checks if the input SQL string matches known transaction control
140    /// statements (`BEGIN`, `COMMIT`, `ROLLBACK`) or a `DROP` command. If a match is found,
141    /// it calls the appropriate `TransactionManager` method and returns `Ok(true)`.
142    /// For `DROP`, it adds extra validation.
143    ///
144    /// If the command is not a recognized transaction command, it returns `Ok(false)`,
145    /// indicating that the command should be executed as a standard SQL query.
146    ///
147    /// # Arguments
148    /// * `conn` - A reference to the `rusqlite::Connection`.
149    /// * `sql` - The SQL command string to be processed.
150    ///
151    /// # Returns
152    /// A `Result<bool>` which is `Ok(true)` if the command was handled, or `Ok(false)` if not.
153    pub fn handle_sql_command(&self, conn: &Connection, sql: &str) -> Result<bool> {
154        let sql_lower = sql.to_lowercase().trim().to_string();
155
156        match sql_lower.as_str() {
157            "begin" | "begin transaction" => {
158                self.begin_transaction(conn)?;
159                Ok(true) // Command was handled
160            }
161            "commit" | "commit transaction" => {
162                self.commit_transaction(conn)?;
163                Ok(true) // Command was handled
164            }
165            "rollback" | "rollback transaction" => {
166                self.rollback_transaction(conn)?;
167                Ok(true) // Command was handled
168            }
169            _ => {
170                // Handle DROP commands
171                if sql_lower.starts_with("drop") {
172                    let parts: Vec<&str> = sql_lower.split_whitespace().collect();
173                    if parts.len() < 2 {
174                        println!("Usage: DROP TABLE table_name; or DROP table_name;");
175                        return Ok(true);
176                    }
177
178                    let table_name = if parts[1] == "table" {
179                        if parts.len() < 3 {
180                            println!("Usage: DROP TABLE table_name;");
181                            return Ok(true);
182                        }
183                        parts[2].trim_end_matches(';')
184                    } else {
185                        parts[1].trim_end_matches(';')
186                    };
187
188                    // Verify table exists before dropping
189                    let mut stmt = conn
190                        .prepare(
191                            "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1",
192                        )
193                        .context("Failed to prepare table existence check")?;
194
195                    let count: i64 = stmt
196                        .query_row(rusqlite::params![table_name], |row| row.get(0))
197                        .with_context(||
198                            format!("Failed to check if table '{}' exists", table_name)
199                        )?;
200
201                    if count == 0 {
202                        println!("Table '{}' does not exist", table_name);
203                        return Ok(true);
204                    }
205
206                    // Execute the DROP command
207                    conn.execute(&format!("DROP TABLE {}", table_name), [])
208                        .with_context(|| format!("Failed to drop table '{}'", table_name))?;
209
210                    println!("Table '{}' dropped successfully", table_name);
211                    return Ok(true);
212                }
213                Ok(false) // Command was not handled
214            }
215        }
216    }
217}