Skip to main content

sqlrite/sql/
mod.rs

1pub mod db;
2pub mod executor;
3pub mod pager;
4pub mod parser;
5// pub mod tokenizer;
6
7use parser::create::CreateQuery;
8use parser::insert::InsertQuery;
9use parser::select::SelectQuery;
10
11use sqlparser::ast::Statement;
12use sqlparser::dialect::SQLiteDialect;
13use sqlparser::parser::{Parser, ParserError};
14
15use crate::error::{Result, SQLRiteError};
16use crate::sql::db::database::Database;
17use crate::sql::db::table::Table;
18
19#[derive(Debug, PartialEq)]
20pub enum SQLCommand {
21    Insert(String),
22    Delete(String),
23    Update(String),
24    CreateTable(String),
25    Select(String),
26    Unknown(String),
27}
28
29impl SQLCommand {
30    pub fn new(command: String) -> SQLCommand {
31        let v = command.split(" ").collect::<Vec<&str>>();
32        match v[0] {
33            "insert" => SQLCommand::Insert(command),
34            "update" => SQLCommand::Update(command),
35            "delete" => SQLCommand::Delete(command),
36            "create" => SQLCommand::CreateTable(command),
37            "select" => SQLCommand::Select(command),
38            _ => SQLCommand::Unknown(command),
39        }
40    }
41}
42
43/// Performs initial parsing of SQL Statement using sqlparser-rs
44pub fn process_command(query: &str, db: &mut Database) -> Result<String> {
45    let dialect = SQLiteDialect {};
46    let message: String;
47    let mut ast = Parser::parse_sql(&dialect, query).map_err(SQLRiteError::from)?;
48
49    if ast.len() > 1 {
50        return Err(SQLRiteError::SqlError(ParserError::ParserError(format!(
51            "Expected a single query statement, but there are {}",
52            ast.len()
53        ))));
54    }
55
56    // Comment-only or whitespace-only input parses to an empty Vec<Statement>.
57    // Return a benign status rather than panicking on `pop().unwrap()`. Callers
58    // (REPL, Tauri app) treat this as a no-op with no disk write triggered.
59    let Some(query) = ast.pop() else {
60        return Ok("No statement to execute.".to_string());
61    };
62
63    // Transaction boundary statements are routed to Database-level
64    // handlers before we even inspect the rest of the AST. They don't
65    // mutate table data directly, so they short-circuit the
66    // is_write_statement / auto-save path.
67    match &query {
68        Statement::StartTransaction { .. } => {
69            db.begin_transaction()?;
70            return Ok(String::from("BEGIN"));
71        }
72        Statement::Commit { .. } => {
73            if !db.in_transaction() {
74                return Err(SQLRiteError::General(
75                    "cannot COMMIT: no transaction is open".to_string(),
76                ));
77            }
78            // Flush accumulated in-memory changes to disk. If the save
79            // fails we auto-rollback the in-memory state to the
80            // pre-BEGIN snapshot and surface a combined error. Leaving
81            // the transaction open after a failed COMMIT would be
82            // unsafe: auto-save on any subsequent non-transactional
83            // statement would silently publish partial mid-transaction
84            // work. Auto-rollback keeps the disk-plus-memory pair
85            // coherent — the user loses their in-flight work on a disk
86            // error, but that's the only safe outcome.
87            if let Some(path) = db.source_path.clone() {
88                if let Err(save_err) = pager::save_database(db, &path) {
89                    let _ = db.rollback_transaction();
90                    return Err(SQLRiteError::General(format!(
91                        "COMMIT failed — transaction rolled back: {save_err}"
92                    )));
93                }
94            }
95            db.commit_transaction()?;
96            return Ok(String::from("COMMIT"));
97        }
98        Statement::Rollback { .. } => {
99            db.rollback_transaction()?;
100            return Ok(String::from("ROLLBACK"));
101        }
102        _ => {}
103    }
104
105    // Statements that mutate state — trigger auto-save on success. Read-only
106    // SELECTs skip the save entirely to avoid pointless file writes.
107    let is_write_statement = matches!(
108        &query,
109        Statement::CreateTable(_)
110            | Statement::CreateIndex(_)
111            | Statement::Insert(_)
112            | Statement::Update(_)
113            | Statement::Delete(_)
114    );
115
116    // Early-reject mutations on a read-only database before they touch
117    // in-memory state. Phase 4e: without this, a user running INSERT
118    // on a `--readonly` REPL would see the row appear in the printed
119    // table, and then the auto-save would fail — leaving the in-memory
120    // Database visibly diverged from disk.
121    if is_write_statement && db.is_read_only() {
122        return Err(SQLRiteError::General(
123            "cannot execute: database is opened read-only".to_string(),
124        ));
125    }
126
127    // Initialy only implementing some basic SQL Statements
128    match query {
129        Statement::CreateTable(_) => {
130            let create_query = CreateQuery::new(&query);
131            match create_query {
132                Ok(payload) => {
133                    let table_name = payload.table_name.clone();
134                    if table_name == pager::MASTER_TABLE_NAME {
135                        return Err(SQLRiteError::General(format!(
136                            "'{}' is a reserved name used by the internal schema catalog",
137                            pager::MASTER_TABLE_NAME
138                        )));
139                    }
140                    // Checking if table already exists, after parsing CREATE TABLE query
141                    match db.contains_table(table_name.to_string()) {
142                        true => {
143                            return Err(SQLRiteError::Internal(
144                                "Cannot create, table already exists.".to_string(),
145                            ));
146                        }
147                        false => {
148                            let table = Table::new(payload);
149                            let _ = table.print_table_schema();
150                            db.tables.insert(table_name.to_string(), table);
151                            // Iterate over everything.
152                            // for (table_name, _) in &db.tables {
153                            //     println!("{}" , table_name);
154                            // }
155                            message = String::from("CREATE TABLE Statement executed.");
156                        }
157                    }
158                }
159                Err(err) => return Err(err),
160            }
161        }
162        Statement::Insert(_) => {
163            let insert_query = InsertQuery::new(&query);
164            match insert_query {
165                Ok(payload) => {
166                    let table_name = payload.table_name;
167                    let columns = payload.columns;
168                    let values = payload.rows;
169
170                    // println!("table_name = {:?}\n cols = {:?}\n vals = {:?}", table_name, columns, values);
171                    // Checking if Table exists in Database
172                    match db.contains_table(table_name.to_string()) {
173                        true => {
174                            let db_table = db.get_table_mut(table_name.to_string()).unwrap();
175                            // Checking if columns on INSERT query exist on Table
176                            match columns
177                                .iter()
178                                .all(|column| db_table.contains_column(column.to_string()))
179                            {
180                                true => {
181                                    for value in &values {
182                                        // Checking if number of columns in query are the same as number of values
183                                        if columns.len() != value.len() {
184                                            return Err(SQLRiteError::Internal(format!(
185                                                "{} values for {} columns",
186                                                value.len(),
187                                                columns.len()
188                                            )));
189                                        }
190                                        db_table
191                                            .validate_unique_constraint(&columns, value)
192                                            .map_err(|err| {
193                                                SQLRiteError::Internal(format!(
194                                                    "Unique key constraint violation: {err}"
195                                                ))
196                                            })?;
197                                        db_table.insert_row(&columns, value)?;
198                                    }
199                                }
200                                false => {
201                                    return Err(SQLRiteError::Internal(
202                                        "Cannot insert, some of the columns do not exist"
203                                            .to_string(),
204                                    ));
205                                }
206                            }
207                            db_table.print_table_data();
208                        }
209                        false => {
210                            return Err(SQLRiteError::Internal("Table doesn't exist".to_string()));
211                        }
212                    }
213                }
214                Err(err) => return Err(err),
215            }
216
217            message = String::from("INSERT Statement executed.")
218        }
219        Statement::Query(_) => {
220            let select_query = SelectQuery::new(&query)?;
221            let (rendered, rows) = executor::execute_select(select_query, db)?;
222            // Print the result table above the status message so the REPL shows both.
223            print!("{rendered}");
224            message = format!(
225                "SELECT Statement executed. {rows} row{s} returned.",
226                s = if rows == 1 { "" } else { "s" }
227            );
228        }
229        Statement::Delete(_) => {
230            let rows = executor::execute_delete(&query, db)?;
231            message = format!(
232                "DELETE Statement executed. {rows} row{s} deleted.",
233                s = if rows == 1 { "" } else { "s" }
234            );
235        }
236        Statement::Update(_) => {
237            let rows = executor::execute_update(&query, db)?;
238            message = format!(
239                "UPDATE Statement executed. {rows} row{s} updated.",
240                s = if rows == 1 { "" } else { "s" }
241            );
242        }
243        Statement::CreateIndex(_) => {
244            let name = executor::execute_create_index(&query, db)?;
245            message = format!("CREATE INDEX '{name}' executed.");
246        }
247        _ => {
248            return Err(SQLRiteError::NotImplemented(
249                "SQL Statement not supported yet.".to_string(),
250            ));
251        }
252    };
253
254    // Auto-save: if the database is backed by a file AND no explicit
255    // transaction is open AND the statement changed state, flush to
256    // disk before returning. Inside a `BEGIN … COMMIT` block the
257    // mutations accumulate in memory (protected by the ROLLBACK
258    // snapshot) and land on disk in one shot when COMMIT runs.
259    //
260    // A failed save surfaces as an error — the in-memory state already
261    // mutated, so the caller should know disk is out of sync. The
262    // Pager held on `db` diffs against its last-committed snapshot,
263    // so only pages whose bytes actually changed are written.
264    if is_write_statement && db.source_path.is_some() && !db.in_transaction() {
265        let path = db.source_path.clone().unwrap();
266        pager::save_database(db, &path)?;
267    }
268
269    Ok(message)
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    /// Builds a `users(id INTEGER PK, name TEXT, age INTEGER)` table populated
277    /// with three rows, for use in executor-level tests.
278    fn seed_users_table() -> Database {
279        let mut db = Database::new("tempdb".to_string());
280        process_command(
281            "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER);",
282            &mut db,
283        )
284        .expect("create table");
285        process_command(
286            "INSERT INTO users (name, age) VALUES ('alice', 30);",
287            &mut db,
288        )
289        .expect("insert alice");
290        process_command("INSERT INTO users (name, age) VALUES ('bob', 25);", &mut db)
291            .expect("insert bob");
292        process_command(
293            "INSERT INTO users (name, age) VALUES ('carol', 40);",
294            &mut db,
295        )
296        .expect("insert carol");
297        db
298    }
299
300    #[test]
301    fn process_command_select_all_test() {
302        let mut db = seed_users_table();
303        let response = process_command("SELECT * FROM users;", &mut db).expect("select");
304        assert!(response.contains("3 rows returned"));
305    }
306
307    #[test]
308    fn process_command_select_where_test() {
309        let mut db = seed_users_table();
310        let response =
311            process_command("SELECT name FROM users WHERE age > 25;", &mut db).expect("select");
312        assert!(response.contains("2 rows returned"));
313    }
314
315    #[test]
316    fn process_command_select_eq_string_test() {
317        let mut db = seed_users_table();
318        let response =
319            process_command("SELECT name FROM users WHERE name = 'bob';", &mut db).expect("select");
320        assert!(response.contains("1 row returned"));
321    }
322
323    #[test]
324    fn process_command_select_limit_test() {
325        let mut db = seed_users_table();
326        let response = process_command("SELECT * FROM users ORDER BY age ASC LIMIT 2;", &mut db)
327            .expect("select");
328        assert!(response.contains("2 rows returned"));
329    }
330
331    #[test]
332    fn process_command_select_unknown_table_test() {
333        let mut db = Database::new("tempdb".to_string());
334        let result = process_command("SELECT * FROM nope;", &mut db);
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn process_command_select_unknown_column_test() {
340        let mut db = seed_users_table();
341        let result = process_command("SELECT height FROM users;", &mut db);
342        assert!(result.is_err());
343    }
344
345    #[test]
346    fn process_command_insert_test() {
347        // Creating temporary database
348        let mut db = Database::new("tempdb".to_string());
349
350        // Creating temporary table for testing purposes
351        let query_statement = "CREATE TABLE users (
352            id INTEGER PRIMARY KEY,
353            name TEXT
354        );";
355        let dialect = SQLiteDialect {};
356        let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
357        if ast.len() > 1 {
358            panic!("Expected a single query statement, but there are more then 1.")
359        }
360        let query = ast.pop().unwrap();
361        let create_query = CreateQuery::new(&query).unwrap();
362
363        // Inserting table into database
364        db.tables.insert(
365            create_query.table_name.to_string(),
366            Table::new(create_query),
367        );
368
369        // Inserting data into table
370        let insert_query = String::from("INSERT INTO users (name) Values ('josh');");
371        match process_command(&insert_query, &mut db) {
372            Ok(response) => assert_eq!(response, "INSERT Statement executed."),
373            Err(err) => {
374                eprintln!("Error: {}", err);
375                assert!(false)
376            }
377        };
378    }
379
380    #[test]
381    fn process_command_insert_no_pk_test() {
382        // Creating temporary database
383        let mut db = Database::new("tempdb".to_string());
384
385        // Creating temporary table for testing purposes
386        let query_statement = "CREATE TABLE users (
387            name TEXT
388        );";
389        let dialect = SQLiteDialect {};
390        let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
391        if ast.len() > 1 {
392            panic!("Expected a single query statement, but there are more then 1.")
393        }
394        let query = ast.pop().unwrap();
395        let create_query = CreateQuery::new(&query).unwrap();
396
397        // Inserting table into database
398        db.tables.insert(
399            create_query.table_name.to_string(),
400            Table::new(create_query),
401        );
402
403        // Inserting data into table
404        let insert_query = String::from("INSERT INTO users (name) Values ('josh');");
405        match process_command(&insert_query, &mut db) {
406            Ok(response) => assert_eq!(response, "INSERT Statement executed."),
407            Err(err) => {
408                eprintln!("Error: {}", err);
409                assert!(false)
410            }
411        };
412    }
413
414    #[test]
415    fn process_command_delete_where_test() {
416        let mut db = seed_users_table();
417        let response =
418            process_command("DELETE FROM users WHERE name = 'bob';", &mut db).expect("delete");
419        assert!(response.contains("1 row deleted"));
420
421        let remaining = process_command("SELECT * FROM users;", &mut db).expect("select");
422        assert!(remaining.contains("2 rows returned"));
423    }
424
425    #[test]
426    fn process_command_delete_all_test() {
427        let mut db = seed_users_table();
428        let response = process_command("DELETE FROM users;", &mut db).expect("delete");
429        assert!(response.contains("3 rows deleted"));
430    }
431
432    #[test]
433    fn process_command_update_where_test() {
434        use crate::sql::db::table::Value;
435
436        let mut db = seed_users_table();
437        let response = process_command("UPDATE users SET age = 99 WHERE name = 'bob';", &mut db)
438            .expect("update");
439        assert!(response.contains("1 row updated"));
440
441        // Confirm the cell was actually rewritten.
442        let users = db.get_table("users".to_string()).unwrap();
443        let bob_rowid = users
444            .rowids()
445            .into_iter()
446            .find(|r| users.get_value("name", *r) == Some(Value::Text("bob".to_string())))
447            .expect("bob row must exist");
448        assert_eq!(users.get_value("age", bob_rowid), Some(Value::Integer(99)));
449    }
450
451    #[test]
452    fn process_command_update_unique_violation_test() {
453        let mut db = seed_users_table();
454        // `name` is not UNIQUE in the seed — reinforce with an explicit unique column.
455        process_command(
456            "CREATE TABLE tags (id INTEGER PRIMARY KEY, label TEXT UNIQUE);",
457            &mut db,
458        )
459        .unwrap();
460        process_command("INSERT INTO tags (label) VALUES ('a');", &mut db).unwrap();
461        process_command("INSERT INTO tags (label) VALUES ('b');", &mut db).unwrap();
462
463        let result = process_command("UPDATE tags SET label = 'a' WHERE label = 'b';", &mut db);
464        assert!(result.is_err(), "expected UNIQUE violation, got {result:?}");
465    }
466
467    #[test]
468    fn process_command_insert_type_mismatch_returns_error_test() {
469        // Previously this panicked in parse::<i32>().unwrap(); now it should return an error cleanly.
470        let mut db = Database::new("tempdb".to_string());
471        process_command(
472            "CREATE TABLE items (id INTEGER PRIMARY KEY, qty INTEGER);",
473            &mut db,
474        )
475        .unwrap();
476        let result = process_command("INSERT INTO items (qty) VALUES ('not a number');", &mut db);
477        assert!(result.is_err(), "expected error, got {result:?}");
478    }
479
480    #[test]
481    fn process_command_insert_missing_integer_returns_error_test() {
482        // Non-PK INTEGER without a value should error (not panic on "Null".parse()).
483        let mut db = Database::new("tempdb".to_string());
484        process_command(
485            "CREATE TABLE items (id INTEGER PRIMARY KEY, qty INTEGER);",
486            &mut db,
487        )
488        .unwrap();
489        let result = process_command("INSERT INTO items (id) VALUES (1);", &mut db);
490        assert!(result.is_err(), "expected error, got {result:?}");
491    }
492
493    #[test]
494    fn process_command_update_arith_test() {
495        use crate::sql::db::table::Value;
496
497        let mut db = seed_users_table();
498        process_command("UPDATE users SET age = age + 1;", &mut db).expect("update +1");
499
500        let users = db.get_table("users".to_string()).unwrap();
501        let mut ages: Vec<i64> = users
502            .rowids()
503            .into_iter()
504            .filter_map(|r| match users.get_value("age", r) {
505                Some(Value::Integer(n)) => Some(n),
506                _ => None,
507            })
508            .collect();
509        ages.sort();
510        assert_eq!(ages, vec![26, 31, 41]); // 25+1, 30+1, 40+1
511    }
512
513    #[test]
514    fn process_command_select_arithmetic_where_test() {
515        let mut db = seed_users_table();
516        // age * 2 > 55  →  only ages > 27.5  →  alice(30) + carol(40)
517        let response =
518            process_command("SELECT name FROM users WHERE age * 2 > 55;", &mut db).expect("select");
519        assert!(response.contains("2 rows returned"));
520    }
521
522    #[test]
523    fn process_command_divide_by_zero_test() {
524        let mut db = seed_users_table();
525        let result = process_command("SELECT age / 0 FROM users;", &mut db);
526        // Projection only supports bare columns, so this errors earlier; still shouldn't panic.
527        assert!(result.is_err());
528    }
529
530    #[test]
531    fn process_command_unsupported_statement_test() {
532        let mut db = Database::new("tempdb".to_string());
533        // Nothing in Phase 1 handles DROP.
534        let result = process_command("DROP TABLE users;", &mut db);
535        assert!(result.is_err());
536    }
537
538    #[test]
539    fn empty_input_is_a_noop_not_a_panic() {
540        // Regression for: desktop app pre-fills the textarea with a
541        // comment-only placeholder, and hitting Run used to panic because
542        // sqlparser produced zero statements and pop().unwrap() exploded.
543        let mut db = Database::new("t".to_string());
544        for input in ["", "   ", "-- just a comment", "-- comment\n-- another"] {
545            let result = process_command(input, &mut db);
546            assert!(result.is_ok(), "input {input:?} should not error");
547            let msg = result.unwrap();
548            assert!(msg.contains("No statement"), "got: {msg:?}");
549        }
550    }
551
552    #[test]
553    fn create_index_adds_explicit_index() {
554        let mut db = seed_users_table();
555        let response = process_command("CREATE INDEX users_age_idx ON users (age);", &mut db)
556            .expect("create index");
557        assert!(response.contains("users_age_idx"));
558
559        // The index should now be attached to the users table.
560        let users = db.get_table("users".to_string()).unwrap();
561        let idx = users
562            .index_by_name("users_age_idx")
563            .expect("index should exist after CREATE INDEX");
564        assert_eq!(idx.column_name, "age");
565        assert!(!idx.is_unique);
566    }
567
568    #[test]
569    fn create_unique_index_rejects_duplicate_existing_values() {
570        let mut db = seed_users_table();
571        // `name` is already UNIQUE (auto-indexed); insert a duplicate-age row
572        // first so CREATE UNIQUE INDEX on age catches the conflict.
573        process_command("INSERT INTO users (name, age) VALUES ('dan', 30);", &mut db).unwrap();
574        let result = process_command(
575            "CREATE UNIQUE INDEX users_age_unique ON users (age);",
576            &mut db,
577        );
578        assert!(
579            result.is_err(),
580            "expected unique-index failure, got {result:?}"
581        );
582    }
583
584    #[test]
585    fn where_eq_on_indexed_column_uses_index_probe() {
586        // Build a table big enough that a full scan would be expensive,
587        // then rely on the index-probe fast path. This test verifies
588        // correctness (right rows returned); the perf win is implicit.
589        let mut db = Database::new("t".to_string());
590        process_command(
591            "CREATE TABLE big (id INTEGER PRIMARY KEY, tag TEXT);",
592            &mut db,
593        )
594        .unwrap();
595        process_command("CREATE INDEX big_tag_idx ON big (tag);", &mut db).unwrap();
596        for i in 1..=100 {
597            let tag = if i % 3 == 0 { "hot" } else { "cold" };
598            process_command(&format!("INSERT INTO big (tag) VALUES ('{tag}');"), &mut db).unwrap();
599        }
600        let response =
601            process_command("SELECT id FROM big WHERE tag = 'hot';", &mut db).expect("select");
602        // 1..=100 has 33 multiples of 3.
603        assert!(
604            response.contains("33 rows returned"),
605            "response was {response:?}"
606        );
607    }
608
609    #[test]
610    fn where_eq_on_indexed_column_inside_parens_uses_index_probe() {
611        let mut db = seed_users_table();
612        let response = process_command("SELECT name FROM users WHERE (name = 'bob');", &mut db)
613            .expect("select");
614        assert!(response.contains("1 row returned"));
615    }
616
617    #[test]
618    fn where_eq_literal_first_side_uses_index_probe() {
619        let mut db = seed_users_table();
620        // `'bob' = name` should hit the same path as `name = 'bob'`.
621        let response =
622            process_command("SELECT name FROM users WHERE 'bob' = name;", &mut db).expect("select");
623        assert!(response.contains("1 row returned"));
624    }
625
626    #[test]
627    fn non_equality_where_still_falls_back_to_full_scan() {
628        // Sanity: range predicates bypass the optimizer and the full-scan
629        // path still returns correct results.
630        let mut db = seed_users_table();
631        let response =
632            process_command("SELECT name FROM users WHERE age > 28;", &mut db).expect("select");
633        assert!(response.contains("2 rows returned"));
634    }
635
636    // -------------------------------------------------------------------
637    // Phase 4f — Transactions (BEGIN / COMMIT / ROLLBACK)
638    // -------------------------------------------------------------------
639
640    #[test]
641    fn rollback_restores_pre_begin_in_memory_state() {
642        // In-memory DB (no pager): BEGIN, insert a row, ROLLBACK.
643        // The row must disappear from the live tables HashMap.
644        let mut db = seed_users_table();
645        let before = db.get_table("users".to_string()).unwrap().rowids().len();
646        assert_eq!(before, 3);
647
648        process_command("BEGIN;", &mut db).expect("BEGIN");
649        assert!(db.in_transaction());
650        process_command("INSERT INTO users (name, age) VALUES ('dan', 50);", &mut db)
651            .expect("INSERT inside txn");
652        // Mid-transaction read sees the new row.
653        let mid = db.get_table("users".to_string()).unwrap().rowids().len();
654        assert_eq!(mid, 4);
655
656        process_command("ROLLBACK;", &mut db).expect("ROLLBACK");
657        assert!(!db.in_transaction());
658        let after = db.get_table("users".to_string()).unwrap().rowids().len();
659        assert_eq!(
660            after, 3,
661            "ROLLBACK should have restored the pre-BEGIN state"
662        );
663    }
664
665    #[test]
666    fn commit_keeps_mutations_and_clears_txn_flag() {
667        let mut db = seed_users_table();
668        process_command("BEGIN;", &mut db).expect("BEGIN");
669        process_command("INSERT INTO users (name, age) VALUES ('dan', 50);", &mut db)
670            .expect("INSERT inside txn");
671        process_command("COMMIT;", &mut db).expect("COMMIT");
672        assert!(!db.in_transaction());
673        let after = db.get_table("users".to_string()).unwrap().rowids().len();
674        assert_eq!(after, 4);
675    }
676
677    #[test]
678    fn rollback_undoes_update_and_delete_side_by_side() {
679        use crate::sql::db::table::Value;
680
681        let mut db = seed_users_table();
682        process_command("BEGIN;", &mut db).unwrap();
683        process_command("UPDATE users SET age = 999;", &mut db).unwrap();
684        process_command("DELETE FROM users WHERE name = 'bob';", &mut db).unwrap();
685        // Mid-txn: one row gone, others have age=999.
686        let users = db.get_table("users".to_string()).unwrap();
687        assert_eq!(users.rowids().len(), 2);
688        for r in users.rowids() {
689            assert_eq!(users.get_value("age", r), Some(Value::Integer(999)));
690        }
691
692        process_command("ROLLBACK;", &mut db).unwrap();
693        let users = db.get_table("users".to_string()).unwrap();
694        assert_eq!(users.rowids().len(), 3);
695        // Original ages {30, 25, 40} — none should be 999.
696        for r in users.rowids() {
697            assert_ne!(users.get_value("age", r), Some(Value::Integer(999)));
698        }
699    }
700
701    #[test]
702    fn nested_begin_is_rejected() {
703        let mut db = seed_users_table();
704        process_command("BEGIN;", &mut db).unwrap();
705        let err = process_command("BEGIN;", &mut db).unwrap_err();
706        assert!(
707            format!("{err}").contains("already open"),
708            "nested BEGIN should error; got: {err}"
709        );
710        // Still in the original transaction; a ROLLBACK clears it.
711        assert!(db.in_transaction());
712        process_command("ROLLBACK;", &mut db).unwrap();
713    }
714
715    #[test]
716    fn orphan_commit_and_rollback_are_rejected() {
717        let mut db = seed_users_table();
718        let commit_err = process_command("COMMIT;", &mut db).unwrap_err();
719        assert!(format!("{commit_err}").contains("no transaction"));
720        let rollback_err = process_command("ROLLBACK;", &mut db).unwrap_err();
721        assert!(format!("{rollback_err}").contains("no transaction"));
722    }
723
724    #[test]
725    fn error_inside_transaction_keeps_txn_open() {
726        // A bad INSERT inside a txn doesn't commit or abort automatically —
727        // the user can still ROLLBACK. SQLite's implicit-rollback behavior
728        // isn't modeled here.
729        let mut db = seed_users_table();
730        process_command("BEGIN;", &mut db).unwrap();
731        let err = process_command("INSERT INTO nope (x) VALUES (1);", &mut db);
732        assert!(err.is_err());
733        assert!(db.in_transaction(), "txn should stay open after error");
734        process_command("ROLLBACK;", &mut db).unwrap();
735    }
736
737    /// Builds a file-backed Database at a unique temp path, with the
738    /// schema seeded and `source_path` set so subsequent process_command
739    /// calls auto-save. Returns (path, db). Drop the db before deleting
740    /// the files.
741    fn seed_file_backed(name: &str, schema: &str) -> (std::path::PathBuf, Database) {
742        use crate::sql::pager::{open_database, save_database};
743        let mut p = std::env::temp_dir();
744        let pid = std::process::id();
745        let nanos = std::time::SystemTime::now()
746            .duration_since(std::time::UNIX_EPOCH)
747            .map(|d| d.as_nanos())
748            .unwrap_or(0);
749        p.push(format!("sqlrite-txn-{name}-{pid}-{nanos}.sqlrite"));
750
751        // Seed the file, then reopen to get a source_path-attached db
752        // (save_database alone doesn't attach a fresh pager to a db
753        // whose source_path was None before the call).
754        {
755            let mut seed = Database::new("t".to_string());
756            process_command(schema, &mut seed).unwrap();
757            save_database(&mut seed, &p).unwrap();
758        }
759        let db = open_database(&p, "t".to_string()).unwrap();
760        (p, db)
761    }
762
763    fn cleanup_file(path: &std::path::Path) {
764        let _ = std::fs::remove_file(path);
765        let mut wal = path.as_os_str().to_owned();
766        wal.push("-wal");
767        let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
768    }
769
770    #[test]
771    fn begin_commit_rollback_round_trip_through_disk() {
772        // File-backed DB: commit inside a transaction must actually
773        // persist. ROLLBACK inside a *later* transaction must not
774        // un-do the previously-committed changes.
775        use crate::sql::pager::open_database;
776
777        let (path, mut db) = seed_file_backed(
778            "roundtrip",
779            "CREATE TABLE notes (id INTEGER PRIMARY KEY, body TEXT);",
780        );
781
782        // Transaction 1: insert two rows, commit.
783        process_command("BEGIN;", &mut db).unwrap();
784        process_command("INSERT INTO notes (body) VALUES ('a');", &mut db).unwrap();
785        process_command("INSERT INTO notes (body) VALUES ('b');", &mut db).unwrap();
786        process_command("COMMIT;", &mut db).unwrap();
787
788        // Transaction 2: insert another, roll back.
789        process_command("BEGIN;", &mut db).unwrap();
790        process_command("INSERT INTO notes (body) VALUES ('c');", &mut db).unwrap();
791        process_command("ROLLBACK;", &mut db).unwrap();
792
793        drop(db); // release pager lock
794
795        let reopened = open_database(&path, "t".to_string()).unwrap();
796        let notes = reopened.get_table("notes".to_string()).unwrap();
797        assert_eq!(notes.rowids().len(), 2, "committed rows should survive");
798
799        drop(reopened);
800        cleanup_file(&path);
801    }
802
803    #[test]
804    fn write_inside_transaction_does_not_autosave() {
805        // File-backed DB: writes inside BEGIN/…/COMMIT must NOT hit
806        // the WAL until COMMIT. We prove it by checking the WAL file
807        // size before vs during the transaction.
808        let (path, mut db) =
809            seed_file_backed("noas", "CREATE TABLE t (id INTEGER PRIMARY KEY, x TEXT);");
810
811        let mut wal_path = path.as_os_str().to_owned();
812        wal_path.push("-wal");
813        let wal_path = std::path::PathBuf::from(wal_path);
814        let frames_before = std::fs::metadata(&wal_path).unwrap().len();
815
816        process_command("BEGIN;", &mut db).unwrap();
817        process_command("INSERT INTO t (x) VALUES ('a');", &mut db).unwrap();
818        process_command("INSERT INTO t (x) VALUES ('b');", &mut db).unwrap();
819
820        // Mid-transaction: WAL must be unchanged — no auto-save fired.
821        let frames_mid = std::fs::metadata(&wal_path).unwrap().len();
822        assert_eq!(
823            frames_before, frames_mid,
824            "WAL should not grow during an open transaction"
825        );
826
827        process_command("COMMIT;", &mut db).unwrap();
828
829        drop(db); // release pager lock
830        let fresh = crate::sql::pager::open_database(&path, "t".to_string()).unwrap();
831        assert_eq!(
832            fresh.get_table("t".to_string()).unwrap().rowids().len(),
833            2,
834            "COMMIT should have persisted both inserted rows"
835        );
836        drop(fresh);
837        cleanup_file(&path);
838    }
839
840    #[test]
841    fn rollback_undoes_create_table() {
842        // Schema DDL inside a txn: ROLLBACK must make the new table
843        // disappear. The txn snapshot captures db.tables as of BEGIN,
844        // and ROLLBACK reassigns tables from that snapshot, so a table
845        // created mid-transaction has no entry in the snapshot.
846        let mut db = seed_users_table();
847        assert_eq!(db.tables.len(), 1);
848
849        process_command("BEGIN;", &mut db).unwrap();
850        process_command(
851            "CREATE TABLE dropme (id INTEGER PRIMARY KEY, x TEXT);",
852            &mut db,
853        )
854        .unwrap();
855        process_command("INSERT INTO dropme (x) VALUES ('stuff');", &mut db).unwrap();
856        assert_eq!(db.tables.len(), 2);
857
858        process_command("ROLLBACK;", &mut db).unwrap();
859        assert_eq!(
860            db.tables.len(),
861            1,
862            "CREATE TABLE should have been rolled back"
863        );
864        assert!(db.get_table("dropme".to_string()).is_err());
865    }
866
867    #[test]
868    fn rollback_restores_secondary_index_state() {
869        // Phase 4f edge case: rolling back an INSERT on a UNIQUE-indexed
870        // column must also clean up the index, otherwise a re-insert of
871        // the same value would spuriously collide.
872        let mut db = Database::new("t".to_string());
873        process_command(
874            "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT UNIQUE);",
875            &mut db,
876        )
877        .unwrap();
878        process_command("INSERT INTO users (email) VALUES ('a@x');", &mut db).unwrap();
879
880        process_command("BEGIN;", &mut db).unwrap();
881        process_command("INSERT INTO users (email) VALUES ('b@x');", &mut db).unwrap();
882        // Inside the txn: the index now contains both 'a@x' and 'b@x'.
883        process_command("ROLLBACK;", &mut db).unwrap();
884
885        // Re-inserting 'b@x' after rollback must succeed — if the index
886        // wasn't properly restored, it would think 'b@x' is still a
887        // collision and fail with a UNIQUE violation.
888        let reinsert = process_command("INSERT INTO users (email) VALUES ('b@x');", &mut db);
889        assert!(
890            reinsert.is_ok(),
891            "re-insert after rollback should succeed, got {reinsert:?}"
892        );
893    }
894
895    #[test]
896    fn rollback_restores_last_rowid_counter() {
897        // Rowids allocated inside a rolled-back transaction should be
898        // reusable. The snapshot restores Table::last_rowid, so the
899        // next insert picks up where the pre-BEGIN state left off.
900        use crate::sql::db::table::Value;
901
902        let mut db = seed_users_table(); // 3 rows, last_rowid = 3
903        let pre = db.get_table("users".to_string()).unwrap().last_rowid;
904
905        process_command("BEGIN;", &mut db).unwrap();
906        process_command("INSERT INTO users (name, age) VALUES ('d', 50);", &mut db).unwrap(); // would be rowid 4
907        process_command("INSERT INTO users (name, age) VALUES ('e', 60);", &mut db).unwrap(); // would be rowid 5
908        process_command("ROLLBACK;", &mut db).unwrap();
909
910        let post = db.get_table("users".to_string()).unwrap().last_rowid;
911        assert_eq!(pre, post, "last_rowid must roll back with the snapshot");
912
913        // Confirm: the next insert reuses rowid pre+1.
914        process_command("INSERT INTO users (name, age) VALUES ('d', 50);", &mut db).unwrap();
915        let users = db.get_table("users".to_string()).unwrap();
916        let d_rowid = users
917            .rowids()
918            .into_iter()
919            .find(|r| users.get_value("name", *r) == Some(Value::Text("d".into())))
920            .expect("d row must exist");
921        assert_eq!(d_rowid, pre + 1);
922    }
923
924    #[test]
925    fn commit_on_in_memory_db_clears_txn_without_pager_call() {
926        // In-memory DB (no source_path): COMMIT must still work — just
927        // no disk flush. Covers the `if let Some(path) = …` branch
928        // where the guard falls through without calling save_database.
929        let mut db = seed_users_table(); // no source_path
930        assert!(db.source_path.is_none());
931
932        process_command("BEGIN;", &mut db).unwrap();
933        process_command("INSERT INTO users (name, age) VALUES ('z', 99);", &mut db).unwrap();
934        process_command("COMMIT;", &mut db).unwrap();
935
936        assert!(!db.in_transaction());
937        assert_eq!(db.get_table("users".to_string()).unwrap().rowids().len(), 4);
938    }
939
940    #[test]
941    fn failed_commit_auto_rolls_back_in_memory_state() {
942        // Data-safety regression: on COMMIT save failure we must auto-
943        // rollback the in-memory state. Otherwise, any subsequent
944        // non-transactional statement would auto-save the partial
945        // mid-transaction work, silently publishing uncommitted
946        // changes to disk.
947        //
948        // We simulate a save failure by making the WAL sidecar path
949        // unavailable mid-transaction: after BEGIN, we take an
950        // exclusive OS lock on the WAL via a second File handle,
951        // forcing the next save to fail when it tries to append.
952        //
953        // Simpler repro: point source_path at a directory (not a file).
954        // `OpenOptions::open` will fail with EISDIR on save.
955        use crate::sql::pager::save_database;
956
957        // Seed a file-backed db.
958        let (path, mut db) = seed_file_backed(
959            "failcommit",
960            "CREATE TABLE notes (id INTEGER PRIMARY KEY, body TEXT);",
961        );
962
963        // Prime one committed row so we have a baseline.
964        process_command("INSERT INTO notes (body) VALUES ('before');", &mut db).unwrap();
965
966        // Open a new txn and add a row.
967        process_command("BEGIN;", &mut db).unwrap();
968        process_command("INSERT INTO notes (body) VALUES ('inflight');", &mut db).unwrap();
969        assert_eq!(
970            db.get_table("notes".to_string()).unwrap().rowids().len(),
971            2,
972            "inflight row visible mid-txn"
973        );
974
975        // Swap source_path to a path that will fail on open. A
976        // directory is a reliable failure mode — Pager::open on a
977        // directory errors with an I/O error.
978        let orig_source = db.source_path.clone();
979        let orig_pager = db.pager.take();
980        db.source_path = Some(std::env::temp_dir());
981
982        let commit_result = process_command("COMMIT;", &mut db);
983        assert!(commit_result.is_err(), "commit must fail");
984        let err_str = format!("{}", commit_result.unwrap_err());
985        assert!(
986            err_str.contains("COMMIT failed") && err_str.contains("rolled back"),
987            "error must surface auto-rollback; got: {err_str}"
988        );
989
990        // Auto-rollback fired: the inflight row is gone, the txn flag
991        // is cleared, and a follow-up non-txn statement won't leak
992        // stale state.
993        assert!(
994            !db.in_transaction(),
995            "txn must be cleared after auto-rollback"
996        );
997        assert_eq!(
998            db.get_table("notes".to_string()).unwrap().rowids().len(),
999            1,
1000            "inflight row must be rolled back"
1001        );
1002
1003        // Restore the real source_path + pager and verify a clean
1004        // subsequent write goes through.
1005        db.source_path = orig_source;
1006        db.pager = orig_pager;
1007        process_command("INSERT INTO notes (body) VALUES ('after');", &mut db).unwrap();
1008        drop(db);
1009
1010        // Reopen and assert only 'before' + 'after' landed on disk.
1011        let reopened = crate::sql::pager::open_database(&path, "t".to_string()).unwrap();
1012        let notes = reopened.get_table("notes".to_string()).unwrap();
1013        assert_eq!(notes.rowids().len(), 2);
1014        // Ensure no leaked save_database partial happened.
1015        let _ = save_database; // silence unused-import lint if any
1016        drop(reopened);
1017        cleanup_file(&path);
1018    }
1019
1020    #[test]
1021    fn begin_on_read_only_is_rejected() {
1022        use crate::sql::pager::{open_database_read_only, save_database};
1023
1024        let path = {
1025            let mut p = std::env::temp_dir();
1026            let pid = std::process::id();
1027            let nanos = std::time::SystemTime::now()
1028                .duration_since(std::time::UNIX_EPOCH)
1029                .map(|d| d.as_nanos())
1030                .unwrap_or(0);
1031            p.push(format!("sqlrite-txn-ro-{pid}-{nanos}.sqlrite"));
1032            p
1033        };
1034        {
1035            let mut seed = Database::new("t".to_string());
1036            process_command("CREATE TABLE t (id INTEGER PRIMARY KEY);", &mut seed).unwrap();
1037            save_database(&mut seed, &path).unwrap();
1038        }
1039
1040        let mut ro = open_database_read_only(&path, "t".to_string()).unwrap();
1041        let err = process_command("BEGIN;", &mut ro).unwrap_err();
1042        assert!(
1043            format!("{err}").contains("read-only"),
1044            "BEGIN on RO db should surface read-only; got: {err}"
1045        );
1046        assert!(!ro.in_transaction());
1047
1048        let _ = std::fs::remove_file(&path);
1049        let mut wal = path.as_os_str().to_owned();
1050        wal.push("-wal");
1051        let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
1052    }
1053
1054    #[test]
1055    fn read_only_database_rejects_mutations_before_touching_state() {
1056        // Phase 4e end-to-end: a `--readonly` caller that runs INSERT
1057        // must error *before* the row is added to the in-memory table.
1058        // Otherwise the user sees a rendered result table with the
1059        // phantom row, followed by the auto-save error — UX rot and a
1060        // state-drift risk.
1061        use crate::sql::pager::open_database_read_only;
1062
1063        let mut seed = Database::new("t".to_string());
1064        process_command(
1065            "CREATE TABLE notes (id INTEGER PRIMARY KEY, body TEXT);",
1066            &mut seed,
1067        )
1068        .unwrap();
1069        process_command("INSERT INTO notes (body) VALUES ('alpha');", &mut seed).unwrap();
1070
1071        let path = {
1072            let mut p = std::env::temp_dir();
1073            let pid = std::process::id();
1074            let nanos = std::time::SystemTime::now()
1075                .duration_since(std::time::UNIX_EPOCH)
1076                .map(|d| d.as_nanos())
1077                .unwrap_or(0);
1078            p.push(format!("sqlrite-ro-reject-{pid}-{nanos}.sqlrite"));
1079            p
1080        };
1081        crate::sql::pager::save_database(&mut seed, &path).unwrap();
1082        drop(seed);
1083
1084        let mut ro = open_database_read_only(&path, "t".to_string()).unwrap();
1085        let notes_before = ro.get_table("notes".to_string()).unwrap().rowids().len();
1086
1087        for stmt in [
1088            "INSERT INTO notes (body) VALUES ('beta');",
1089            "UPDATE notes SET body = 'x';",
1090            "DELETE FROM notes;",
1091            "CREATE TABLE more (id INTEGER PRIMARY KEY);",
1092            "CREATE INDEX notes_body ON notes (body);",
1093        ] {
1094            let err = process_command(stmt, &mut ro).unwrap_err();
1095            assert!(
1096                format!("{err}").contains("read-only"),
1097                "stmt {stmt:?} should surface a read-only error; got: {err}"
1098            );
1099        }
1100
1101        // Nothing mutated: same row count as before, and SELECTs still work.
1102        let notes_after = ro.get_table("notes".to_string()).unwrap().rowids().len();
1103        assert_eq!(notes_before, notes_after);
1104        let sel = process_command("SELECT * FROM notes;", &mut ro).expect("select on RO must work");
1105        assert!(sel.contains("1 row returned"));
1106
1107        // Cleanup.
1108        drop(ro);
1109        let _ = std::fs::remove_file(&path);
1110        let mut wal = path.as_os_str().to_owned();
1111        wal.push("-wal");
1112        let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
1113    }
1114}