Skip to main content

sqlrite/sql/db/
database.rs

1use crate::error::{Result, SQLRiteError};
2use crate::sql::db::table::Table;
3use crate::sql::pager::pager::{AccessMode, Pager};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Snapshot of the mutable in-memory state taken at `BEGIN` time so
8/// `ROLLBACK` can restore it. See `begin_transaction`, `rollback_transaction`.
9/// `tables` is deep-cloned (the `Table::deep_clone` helper reallocates
10/// the `Arc<Mutex<_>>` row storage so snapshot and live state don't
11/// share a map).
12#[derive(Debug)]
13pub struct TxnSnapshot {
14    pub(crate) tables: HashMap<String, Table>,
15}
16
17/// The database is represented by this structure.assert_eq!
18#[derive(Debug)]
19pub struct Database {
20    /// Name of this database. (schema name, not filename)
21    pub db_name: String,
22    /// HashMap of tables in this database
23    pub tables: HashMap<String, Table>,
24    /// If `Some`, every committing SQL statement auto-flushes the DB to
25    /// this path. `None` → transient in-memory mode (the default; the
26    /// REPL only enters persistent mode after `.open FILE`).
27    pub source_path: Option<PathBuf>,
28    /// Long-lived pager attached when the database is file-backed. Keeps
29    /// an in-memory snapshot of every page so auto-saves can diff
30    /// against the last-committed state and skip rewriting unchanged
31    /// pages. `None` means "in-memory only" or "not yet opened".
32    pub pager: Option<Pager>,
33    /// Active transaction state (Phase 4f). `Some` between `BEGIN` and
34    /// the matching `COMMIT` / `ROLLBACK`. While set:
35    /// - auto-save is suppressed (mutations stay in-memory)
36    /// - nested `BEGIN` is rejected
37    /// - `ROLLBACK` restores `tables` from the snapshot
38    pub txn: Option<TxnSnapshot>,
39}
40
41impl Database {
42    /// Creates an empty in-memory `Database`.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// use sqlrite::Database;
48    /// let mut db = Database::new("my_db".to_string());
49    /// ```
50    pub fn new(db_name: String) -> Self {
51        Database {
52            db_name,
53            tables: HashMap::new(),
54            source_path: None,
55            pager: None,
56            txn: None,
57        }
58    }
59
60    /// Returns true if the database contains a table with the specified key as a table name.
61    ///
62    pub fn contains_table(&self, table_name: String) -> bool {
63        self.tables.contains_key(&table_name)
64    }
65
66    /// Returns an immutable reference of `sql::db::table::Table` if the database contains a
67    /// table with the specified key as a table name.
68    ///
69    pub fn get_table(&self, table_name: String) -> Result<&Table> {
70        if let Some(table) = self.tables.get(&table_name) {
71            Ok(table)
72        } else {
73            Err(SQLRiteError::General(String::from("Table not found.")))
74        }
75    }
76
77    /// Returns an mutable reference of `sql::db::table::Table` if the database contains a
78    /// table with the specified key as a table name.
79    ///
80    pub fn get_table_mut(&mut self, table_name: String) -> Result<&mut Table> {
81        if let Some(table) = self.tables.get_mut(&table_name) {
82            Ok(table)
83        } else {
84            Err(SQLRiteError::General(String::from("Table not found.")))
85        }
86    }
87
88    /// Returns `true` if this database is attached to a file and that
89    /// file was opened in [`AccessMode::ReadOnly`]. In-memory databases
90    /// (no pager) and read-write file-backed databases both return
91    /// `false`. Callers use this to reject mutating SQL at the
92    /// dispatcher level so the in-memory tables don't drift away from
93    /// disk on a would-be INSERT / UPDATE / DELETE.
94    pub fn is_read_only(&self) -> bool {
95        self.pager
96            .as_ref()
97            .is_some_and(|p| p.access_mode() == AccessMode::ReadOnly)
98    }
99
100    /// Returns `true` while a `BEGIN … COMMIT`/`ROLLBACK` block is open.
101    pub fn in_transaction(&self) -> bool {
102        self.txn.is_some()
103    }
104
105    /// Starts a transaction: snapshots every table deep-cloned so that
106    /// a later `rollback_transaction` can restore the pre-BEGIN state.
107    /// Nested transactions are rejected — explicit savepoints are not
108    /// on this phase's roadmap. Errors on a read-only database.
109    pub fn begin_transaction(&mut self) -> Result<()> {
110        if self.in_transaction() {
111            return Err(SQLRiteError::General(
112                "cannot BEGIN: a transaction is already open".to_string(),
113            ));
114        }
115        if self.is_read_only() {
116            return Err(SQLRiteError::General(
117                "cannot BEGIN: database is opened read-only".to_string(),
118            ));
119        }
120        let snapshot = TxnSnapshot {
121            tables: self
122                .tables
123                .iter()
124                .map(|(k, v)| (k.clone(), v.deep_clone()))
125                .collect(),
126        };
127        self.txn = Some(snapshot);
128        Ok(())
129    }
130
131    /// Drops the transaction snapshot and returns it for the caller to
132    /// discard. The in-memory `tables` state is the new committed state;
133    /// the caller is responsible for flushing to disk via the pager.
134    /// Errors if no transaction is open.
135    pub fn commit_transaction(&mut self) -> Result<()> {
136        if self.txn.is_none() {
137            return Err(SQLRiteError::General(
138                "cannot COMMIT: no transaction is open".to_string(),
139            ));
140        }
141        self.txn = None;
142        Ok(())
143    }
144
145    /// Restores `tables` from the transaction snapshot and clears it.
146    /// Errors if no transaction is open.
147    pub fn rollback_transaction(&mut self) -> Result<()> {
148        let Some(snapshot) = self.txn.take() else {
149            return Err(SQLRiteError::General(
150                "cannot ROLLBACK: no transaction is open".to_string(),
151            ));
152        };
153        self.tables = snapshot.tables;
154        Ok(())
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::sql::parser::create::CreateQuery;
162    use sqlparser::dialect::SQLiteDialect;
163    use sqlparser::parser::Parser;
164
165    #[test]
166    fn new_database_create_test() {
167        let db_name = String::from("my_db");
168        let db = Database::new(db_name.to_string());
169        assert_eq!(db.db_name, db_name);
170    }
171
172    #[test]
173    fn contains_table_test() {
174        let db_name = String::from("my_db");
175        let mut db = Database::new(db_name.to_string());
176
177        let query_statement = "CREATE TABLE contacts (
178            id INTEGER PRIMARY KEY,
179            first_name TEXT NOT NULL,
180            last_name TEXT NOT NULl,
181            email TEXT NOT NULL UNIQUE
182        );";
183        let dialect = SQLiteDialect {};
184        let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
185        if ast.len() > 1 {
186            panic!("Expected a single query statement, but there are more then 1.")
187        }
188        let query = ast.pop().unwrap();
189
190        let create_query = CreateQuery::new(&query).unwrap();
191        let table_name = &create_query.table_name;
192        db.tables
193            .insert(table_name.to_string(), Table::new(create_query));
194
195        assert!(db.contains_table("contacts".to_string()));
196    }
197
198    #[test]
199    fn get_table_test() {
200        let db_name = String::from("my_db");
201        let mut db = Database::new(db_name.to_string());
202
203        let query_statement = "CREATE TABLE contacts (
204            id INTEGER PRIMARY KEY,
205            first_name TEXT NOT NULL,
206            last_name TEXT NOT NULl,
207            email TEXT NOT NULL UNIQUE
208        );";
209        let dialect = SQLiteDialect {};
210        let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
211        if ast.len() > 1 {
212            panic!("Expected a single query statement, but there are more then 1.")
213        }
214        let query = ast.pop().unwrap();
215
216        let create_query = CreateQuery::new(&query).unwrap();
217        let table_name = &create_query.table_name;
218        db.tables
219            .insert(table_name.to_string(), Table::new(create_query));
220
221        let table = db.get_table(String::from("contacts")).unwrap();
222        assert_eq!(table.columns.len(), 4);
223
224        let table = db.get_table_mut(String::from("contacts")).unwrap();
225        table.last_rowid += 1;
226        assert_eq!(table.columns.len(), 4);
227        assert_eq!(table.last_rowid, 1);
228    }
229}