vibesql_storage/database/
session.rs

1// ============================================================================
2// Database Session Management (SQL Mode, Session Variables, Security)
3// ============================================================================
4
5use super::core::Database;
6
7impl Database {
8    // ============================================================================
9    // Security and Role Management
10    // ============================================================================
11
12    /// Set the current session role for privilege checks
13    pub fn set_role(&mut self, role: Option<String>) {
14        self.lifecycle.set_role(role);
15    }
16
17    /// Get the current session role (defaults to "PUBLIC" if not set)
18    pub fn get_current_role(&self) -> String {
19        self.lifecycle.current_role().map(|s| s.to_string()).unwrap_or_else(|| "PUBLIC".to_string())
20    }
21
22    /// Check if security enforcement is enabled
23    pub fn is_security_enabled(&self) -> bool {
24        self.lifecycle.is_security_enabled()
25    }
26
27    /// Disable security checks (for testing)
28    pub fn disable_security(&mut self) {
29        self.lifecycle.disable_security();
30    }
31
32    /// Enable security checks
33    pub fn enable_security(&mut self) {
34        self.lifecycle.enable_security();
35    }
36
37    // ============================================================================
38    // Session Variables
39    // ============================================================================
40
41    /// Set a session variable (MySQL-style @variable)
42    pub fn set_session_variable(&mut self, name: &str, value: vibesql_types::SqlValue) {
43        self.metadata.set_session_variable(name, value);
44    }
45
46    /// Get a session variable value
47    pub fn get_session_variable(&self, name: &str) -> Option<&vibesql_types::SqlValue> {
48        self.metadata.get_session_variable(name)
49    }
50
51    /// Clear all session variables
52    pub fn clear_session_variables(&mut self) {
53        self.metadata.clear_session_variables();
54    }
55
56    // ============================================================================
57    // SQL Mode
58    // ============================================================================
59
60    /// Get the current SQL compatibility mode
61    pub fn sql_mode(&self) -> vibesql_types::SqlMode {
62        self.sql_mode.clone()
63    }
64
65    // ============================================================================
66    // PRAGMA Settings (SQLite compatibility)
67    // ============================================================================
68
69    /// Get the full_column_names PRAGMA setting
70    ///
71    /// When ON, column names in result sets use "table.column" format
72    pub fn full_column_names(&self) -> bool {
73        match self.get_session_variable("FULL_COLUMN_NAMES") {
74            Some(vibesql_types::SqlValue::Integer(n)) => *n != 0,
75            _ => false, // Default: OFF
76        }
77    }
78
79    /// Set the full_column_names PRAGMA setting
80    pub fn set_full_column_names(&mut self, value: bool) {
81        self.set_session_variable(
82            "FULL_COLUMN_NAMES",
83            vibesql_types::SqlValue::Integer(if value { 1 } else { 0 }),
84        );
85    }
86
87    /// Get the short_column_names PRAGMA setting
88    ///
89    /// When ON (default), column names use just the column name (e.g., "f1")
90    /// When OFF, column names may include expression text
91    pub fn short_column_names(&self) -> bool {
92        match self.get_session_variable("SHORT_COLUMN_NAMES") {
93            Some(vibesql_types::SqlValue::Integer(n)) => *n != 0,
94            _ => true, // Default: ON
95        }
96    }
97
98    /// Set the short_column_names PRAGMA setting
99    pub fn set_short_column_names(&mut self, value: bool) {
100        self.set_session_variable(
101            "SHORT_COLUMN_NAMES",
102            vibesql_types::SqlValue::Integer(if value { 1 } else { 0 }),
103        );
104    }
105
106    /// Get the case_sensitive_like PRAGMA setting
107    ///
108    /// When OFF (default), LIKE comparisons are case-insensitive for ASCII letters (A-Z = a-z).
109    /// When ON, LIKE comparisons are case-sensitive (strict byte-for-byte matching).
110    ///
111    /// This matches SQLite's default behavior where LIKE is case-insensitive for ASCII.
112    pub fn case_sensitive_like(&self) -> bool {
113        match self.get_session_variable("CASE_SENSITIVE_LIKE") {
114            Some(vibesql_types::SqlValue::Integer(n)) => *n != 0,
115            _ => false, // Default: OFF (case-insensitive LIKE)
116        }
117    }
118
119    /// Set the case_sensitive_like PRAGMA setting
120    pub fn set_case_sensitive_like(&mut self, value: bool) {
121        self.set_session_variable(
122            "CASE_SENSITIVE_LIKE",
123            vibesql_types::SqlValue::Integer(if value { 1 } else { 0 }),
124        );
125    }
126
127    /// Get the reverse_unordered_selects PRAGMA setting
128    ///
129    /// When ON, the order of output rows from SELECT statements that do not have
130    /// an ORDER BY clause is reversed. This is useful for testing to ensure that
131    /// applications do not depend on an implicit row ordering.
132    pub fn reverse_unordered_selects(&self) -> bool {
133        match self.get_session_variable("REVERSE_UNORDERED_SELECTS") {
134            Some(vibesql_types::SqlValue::Integer(n)) => *n != 0,
135            _ => false, // Default: OFF
136        }
137    }
138
139    /// Set the reverse_unordered_selects PRAGMA setting
140    pub fn set_reverse_unordered_selects(&mut self, value: bool) {
141        self.set_session_variable(
142            "REVERSE_UNORDERED_SELECTS",
143            vibesql_types::SqlValue::Integer(if value { 1 } else { 0 }),
144        );
145    }
146
147    // ============================================================================
148    // SQLite stat1 Storage (SQLite Compatibility)
149    // ============================================================================
150
151    /// Insert a sqlite_stat1 entry
152    ///
153    /// This allows manual insertion of statistics for query optimizer tuning,
154    /// matching SQLite's behavior where users can INSERT INTO sqlite_stat1.
155    pub fn insert_sqlite_stat1(
156        &mut self,
157        table_name: String,
158        index_name: Option<String>,
159        stat: String,
160    ) {
161        self.metadata.insert_sqlite_stat1(table_name, index_name, stat);
162    }
163
164    /// Get a sqlite_stat1 entry
165    pub fn get_sqlite_stat1(
166        &self,
167        table_name: &str,
168        index_name: Option<&str>,
169    ) -> Option<&String> {
170        self.metadata.get_sqlite_stat1(table_name, index_name)
171    }
172
173    /// Get all sqlite_stat1 entries
174    pub fn get_all_sqlite_stat1(
175        &self,
176    ) -> &std::collections::HashMap<(String, Option<String>), String> {
177        self.metadata.get_all_sqlite_stat1()
178    }
179
180    /// Delete a sqlite_stat1 entry
181    pub fn delete_sqlite_stat1(&mut self, table_name: &str, index_name: Option<&str>) {
182        self.metadata.delete_sqlite_stat1(table_name, index_name);
183    }
184
185    /// Clear all sqlite_stat1 entries
186    pub fn clear_sqlite_stat1(&mut self) {
187        self.metadata.clear_sqlite_stat1();
188    }
189
190    // ============================================================================
191    // Reserved Rowids (SQLite REPLACE semantics)
192    // ============================================================================
193
194    /// Reserve a rowid for a table during REPLACE operations
195    ///
196    /// During REPLACE INTO, SQLite allocates the rowid for the new row BEFORE
197    /// firing BEFORE DELETE triggers. Any INSERT within those triggers that
198    /// tries to allocate the same rowid will fail with a UNIQUE constraint
199    /// violation on rowid.
200    ///
201    /// # Arguments
202    /// * `table_name` - The table name (case-insensitive)
203    /// * `rowid` - The rowid to reserve
204    /// * `is_explicit` - True if the rowid comes from an explicit INTEGER PRIMARY KEY
205    ///   value, false if it's auto-allocated. This affects how conflicts are handled
206    ///   in AFTER DELETE triggers.
207    pub fn reserve_rowid(&mut self, table_name: &str, rowid: u64, is_explicit: bool) {
208        self.reserved_rowids.insert(table_name.to_lowercase(), (rowid, is_explicit));
209    }
210
211    /// Release a reserved rowid after REPLACE completes
212    pub fn release_reserved_rowid(&mut self, table_name: &str) {
213        self.reserved_rowids.remove(&table_name.to_lowercase());
214    }
215
216    /// Check if a rowid is reserved for a table and get the reservation details
217    ///
218    /// Returns Some((rowid, is_explicit)) if a rowid is reserved, None otherwise.
219    pub fn get_reserved_rowid_info(&self, table_name: &str) -> Option<(u64, bool)> {
220        self.reserved_rowids.get(&table_name.to_lowercase()).copied()
221    }
222
223    /// Check if a rowid is reserved for a table
224    pub fn is_rowid_reserved(&self, table_name: &str, rowid: u64) -> bool {
225        self.reserved_rowids
226            .get(&table_name.to_lowercase())
227            .map(|(r, _)| *r == rowid)
228            .unwrap_or(false)
229    }
230
231    /// Get the reserved rowid for a table, if any
232    pub fn get_reserved_rowid(&self, table_name: &str) -> Option<u64> {
233        self.reserved_rowids.get(&table_name.to_lowercase()).map(|(r, _)| *r)
234    }
235
236    // ============================================================================
237    // SQL Mode
238    // ============================================================================
239
240    /// Set the SQL compatibility mode at runtime
241    ///
242    /// This allows changing the SQL dialect (MySQL, SQLite, etc.) during a session.
243    /// The `@@sql_mode` session variable is automatically updated to reflect the change.
244    ///
245    /// # Example
246    /// ```rust
247    /// use vibesql_storage::Database;
248    /// use vibesql_types::{MySqlModeFlags, SqlMode};
249    ///
250    /// let mut db = Database::new();
251    /// // Default is MySQL (for SQLLogicTest compatibility)
252    /// assert!(matches!(db.sql_mode(), SqlMode::MySQL { .. }));
253    ///
254    /// db.set_sql_mode(SqlMode::SQLite);
255    /// assert!(matches!(db.sql_mode(), SqlMode::SQLite));
256    /// ```
257    pub fn set_sql_mode(&mut self, mode: vibesql_types::SqlMode) {
258        self.sql_mode = mode.clone();
259
260        // Update the @@sql_mode session variable to reflect the new mode
261        let mode_string = match &mode {
262            vibesql_types::SqlMode::MySQL { flags } => {
263                // Build MySQL mode string from flags
264                let mut modes = Vec::new();
265                if flags.strict_mode {
266                    modes.push("STRICT_TRANS_TABLES");
267                }
268                if flags.pipes_as_concat {
269                    modes.push("PIPES_AS_CONCAT");
270                }
271                if flags.ansi_quotes {
272                    modes.push("ANSI_QUOTES");
273                }
274                // Add common MySQL defaults if no specific flags are set
275                if modes.is_empty() {
276                    "NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
277                        .to_string()
278                } else {
279                    modes.join(",")
280                }
281            }
282            vibesql_types::SqlMode::SQLite => "SQLITE".to_string(),
283        };
284
285        self.metadata.set_session_variable(
286            "SQL_MODE",
287            vibesql_types::SqlValue::Varchar(arcstr::ArcStr::from(mode_string.as_str())),
288        );
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use vibesql_types::{MySqlModeFlags, SqlMode, SqlValue};
295
296    use super::*;
297
298    #[test]
299    fn test_set_sql_mode_changes_mode() {
300        let mut db = Database::new();
301
302        // Default is MySQL (for SQLLogicTest compatibility - dolthub corpus was regenerated against
303        // MySQL 8.x)
304        assert!(matches!(db.sql_mode(), SqlMode::MySQL { .. }));
305
306        // Change to SQLite
307        db.set_sql_mode(SqlMode::SQLite);
308        assert!(matches!(db.sql_mode(), SqlMode::SQLite));
309
310        // Change back to MySQL
311        db.set_sql_mode(SqlMode::MySQL { flags: MySqlModeFlags::default() });
312        assert!(matches!(db.sql_mode(), SqlMode::MySQL { .. }));
313    }
314
315    #[test]
316    fn test_set_sql_mode_updates_session_variable() {
317        let mut db = Database::new();
318
319        // Set to MySQL mode
320        db.set_sql_mode(SqlMode::MySQL { flags: MySqlModeFlags::default() });
321
322        // Check session variable reflects the change
323        let sql_mode_var = db.get_session_variable("SQL_MODE");
324        assert!(sql_mode_var.is_some());
325        if let Some(SqlValue::Varchar(mode_str)) = sql_mode_var {
326            // Default MySQL flags should include common MySQL defaults
327            assert!(
328                mode_str.contains("NO_ZERO_IN_DATE") || mode_str.contains("NO_ENGINE_SUBSTITUTION")
329            );
330        } else {
331            panic!("Expected SQL_MODE to be a Varchar");
332        }
333    }
334
335    #[test]
336    fn test_set_sql_mode_mysql_with_flags() {
337        let mut db = Database::new();
338
339        // Set MySQL with specific flags
340        db.set_sql_mode(SqlMode::MySQL {
341            flags: MySqlModeFlags {
342                pipes_as_concat: true,
343                ansi_quotes: true,
344                strict_mode: true,
345                ..Default::default()
346            },
347        });
348
349        // Check session variable contains the flags
350        let sql_mode_var = db.get_session_variable("SQL_MODE");
351        assert!(sql_mode_var.is_some());
352        if let Some(SqlValue::Varchar(mode_str)) = sql_mode_var {
353            assert!(mode_str.contains("STRICT_TRANS_TABLES"));
354            assert!(mode_str.contains("PIPES_AS_CONCAT"));
355            assert!(mode_str.contains("ANSI_QUOTES"));
356        } else {
357            panic!("Expected SQL_MODE to be a Varchar");
358        }
359    }
360
361    #[test]
362    fn test_set_sql_mode_mysql_default_flags() {
363        let mut db = Database::new();
364
365        // Set MySQL with default flags (all false)
366        db.set_sql_mode(SqlMode::MySQL { flags: MySqlModeFlags::default() });
367
368        // Check session variable has default MySQL modes
369        let sql_mode_var = db.get_session_variable("SQL_MODE");
370        assert!(sql_mode_var.is_some());
371        if let Some(SqlValue::Varchar(mode_str)) = sql_mode_var {
372            // Default should include common MySQL defaults
373            assert!(
374                mode_str.contains("NO_ZERO_IN_DATE") || mode_str.contains("NO_ENGINE_SUBSTITUTION")
375            );
376        } else {
377            panic!("Expected SQL_MODE to be a Varchar");
378        }
379    }
380
381    #[test]
382    fn test_sql_mode_affects_subsequent_queries() {
383        let mut db = Database::new();
384
385        // Start in MySQL mode (default)
386        assert!(matches!(db.sql_mode(), SqlMode::MySQL { .. }));
387
388        // Switch to SQLite
389        db.set_sql_mode(SqlMode::SQLite);
390
391        // Verify the mode changed
392        let mode = db.sql_mode();
393        assert!(matches!(mode, SqlMode::SQLite));
394    }
395}