sqlrite/sql/db/
database.rs1use crate::error::{Result, SQLRiteError};
2use crate::sql::db::table::Table;
3use crate::sql::pager::pager::{AccessMode, Pager};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7#[derive(Debug)]
13pub struct TxnSnapshot {
14 pub(crate) tables: HashMap<String, Table>,
15}
16
17#[derive(Debug)]
19pub struct Database {
20 pub db_name: String,
22 pub tables: HashMap<String, Table>,
24 pub source_path: Option<PathBuf>,
28 pub pager: Option<Pager>,
33 pub txn: Option<TxnSnapshot>,
39}
40
41impl Database {
42 pub fn new(db_name: String) -> Self {
51 Database {
52 db_name,
53 tables: HashMap::new(),
54 source_path: None,
55 pager: None,
56 txn: None,
57 }
58 }
59
60 pub fn contains_table(&self, table_name: String) -> bool {
63 self.tables.contains_key(&table_name)
64 }
65
66 pub fn get_table(&self, table_name: String) -> Result<&Table> {
70 if let Some(table) = self.tables.get(&table_name) {
71 Ok(table)
72 } else {
73 Err(SQLRiteError::General(String::from("Table not found.")))
74 }
75 }
76
77 pub fn get_table_mut(&mut self, table_name: String) -> Result<&mut Table> {
81 if let Some(table) = self.tables.get_mut(&table_name) {
82 Ok(table)
83 } else {
84 Err(SQLRiteError::General(String::from("Table not found.")))
85 }
86 }
87
88 pub fn is_read_only(&self) -> bool {
95 self.pager
96 .as_ref()
97 .is_some_and(|p| p.access_mode() == AccessMode::ReadOnly)
98 }
99
100 pub fn in_transaction(&self) -> bool {
102 self.txn.is_some()
103 }
104
105 pub fn begin_transaction(&mut self) -> Result<()> {
110 if self.in_transaction() {
111 return Err(SQLRiteError::General(
112 "cannot BEGIN: a transaction is already open".to_string(),
113 ));
114 }
115 if self.is_read_only() {
116 return Err(SQLRiteError::General(
117 "cannot BEGIN: database is opened read-only".to_string(),
118 ));
119 }
120 let snapshot = TxnSnapshot {
121 tables: self
122 .tables
123 .iter()
124 .map(|(k, v)| (k.clone(), v.deep_clone()))
125 .collect(),
126 };
127 self.txn = Some(snapshot);
128 Ok(())
129 }
130
131 pub fn commit_transaction(&mut self) -> Result<()> {
136 if self.txn.is_none() {
137 return Err(SQLRiteError::General(
138 "cannot COMMIT: no transaction is open".to_string(),
139 ));
140 }
141 self.txn = None;
142 Ok(())
143 }
144
145 pub fn rollback_transaction(&mut self) -> Result<()> {
148 let Some(snapshot) = self.txn.take() else {
149 return Err(SQLRiteError::General(
150 "cannot ROLLBACK: no transaction is open".to_string(),
151 ));
152 };
153 self.tables = snapshot.tables;
154 Ok(())
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::sql::parser::create::CreateQuery;
162 use sqlparser::dialect::SQLiteDialect;
163 use sqlparser::parser::Parser;
164
165 #[test]
166 fn new_database_create_test() {
167 let db_name = String::from("my_db");
168 let db = Database::new(db_name.to_string());
169 assert_eq!(db.db_name, db_name);
170 }
171
172 #[test]
173 fn contains_table_test() {
174 let db_name = String::from("my_db");
175 let mut db = Database::new(db_name.to_string());
176
177 let query_statement = "CREATE TABLE contacts (
178 id INTEGER PRIMARY KEY,
179 first_name TEXT NOT NULL,
180 last_name TEXT NOT NULl,
181 email TEXT NOT NULL UNIQUE
182 );";
183 let dialect = SQLiteDialect {};
184 let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
185 if ast.len() > 1 {
186 panic!("Expected a single query statement, but there are more then 1.")
187 }
188 let query = ast.pop().unwrap();
189
190 let create_query = CreateQuery::new(&query).unwrap();
191 let table_name = &create_query.table_name;
192 db.tables
193 .insert(table_name.to_string(), Table::new(create_query));
194
195 assert!(db.contains_table("contacts".to_string()));
196 }
197
198 #[test]
199 fn get_table_test() {
200 let db_name = String::from("my_db");
201 let mut db = Database::new(db_name.to_string());
202
203 let query_statement = "CREATE TABLE contacts (
204 id INTEGER PRIMARY KEY,
205 first_name TEXT NOT NULL,
206 last_name TEXT NOT NULl,
207 email TEXT NOT NULL UNIQUE
208 );";
209 let dialect = SQLiteDialect {};
210 let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
211 if ast.len() > 1 {
212 panic!("Expected a single query statement, but there are more then 1.")
213 }
214 let query = ast.pop().unwrap();
215
216 let create_query = CreateQuery::new(&query).unwrap();
217 let table_name = &create_query.table_name;
218 db.tables
219 .insert(table_name.to_string(), Table::new(create_query));
220
221 let table = db.get_table(String::from("contacts")).unwrap();
222 assert_eq!(table.columns.len(), 4);
223
224 let table = db.get_table_mut(String::from("contacts")).unwrap();
225 table.last_rowid += 1;
226 assert_eq!(table.columns.len(), 4);
227 assert_eq!(table.last_rowid, 1);
228 }
229}