1use std::sync::Arc;
52
53use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
54use vibesql_ast::Statement;
55use vibesql_storage::Database;
56
57use crate::{
58 errors::ExecutorError,
59 select::{SelectExecutor, SelectResult},
60};
61
62#[derive(Debug)]
64pub enum ReadOnlyError {
65 NotReadOnly { statement_type: String },
67 ParseError(String),
69 ExecutionError(ExecutorError),
71}
72
73impl std::fmt::Display for ReadOnlyError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 ReadOnlyError::NotReadOnly { statement_type } => {
77 write!(
78 f,
79 "{} is not allowed in read-only mode. Only SELECT queries are permitted.",
80 statement_type
81 )
82 }
83 ReadOnlyError::ParseError(msg) => write!(f, "SQL parse error: {}", msg),
84 ReadOnlyError::ExecutionError(e) => write!(f, "Execution error: {:?}", e),
85 }
86 }
87}
88
89impl std::error::Error for ReadOnlyError {}
90
91impl From<ExecutorError> for ReadOnlyError {
92 fn from(e: ExecutorError) -> Self {
93 ReadOnlyError::ExecutionError(e)
94 }
95}
96
97pub trait ReadOnlyQuery {
103 fn query(&self, sql: &str) -> Result<SelectResult, ReadOnlyError>;
138}
139
140impl ReadOnlyQuery for Database {
141 fn query(&self, sql: &str) -> Result<SelectResult, ReadOnlyError> {
142 let statement = vibesql_parser::Parser::parse_sql(sql)
144 .map_err(|e| ReadOnlyError::ParseError(format!("{:?}", e)))?;
145
146 match &statement {
148 Statement::Select(select_stmt) => {
149 let executor = SelectExecutor::new(self);
150 executor.execute_with_columns(select_stmt.as_ref()).map_err(ReadOnlyError::from)
151 }
152 Statement::Insert(_) => {
153 Err(ReadOnlyError::NotReadOnly { statement_type: "INSERT".to_string() })
154 }
155 Statement::Update(_) => {
156 Err(ReadOnlyError::NotReadOnly { statement_type: "UPDATE".to_string() })
157 }
158 Statement::Delete(_) => {
159 Err(ReadOnlyError::NotReadOnly { statement_type: "DELETE".to_string() })
160 }
161 Statement::CreateTable(_) => {
162 Err(ReadOnlyError::NotReadOnly { statement_type: "CREATE TABLE".to_string() })
163 }
164 Statement::DropTable(_) => {
165 Err(ReadOnlyError::NotReadOnly { statement_type: "DROP TABLE".to_string() })
166 }
167 Statement::CreateIndex(_) => {
168 Err(ReadOnlyError::NotReadOnly { statement_type: "CREATE INDEX".to_string() })
169 }
170 Statement::DropIndex(_) => {
171 Err(ReadOnlyError::NotReadOnly { statement_type: "DROP INDEX".to_string() })
172 }
173 Statement::CreateView(_) => {
174 Err(ReadOnlyError::NotReadOnly { statement_type: "CREATE VIEW".to_string() })
175 }
176 Statement::DropView(_) => {
177 Err(ReadOnlyError::NotReadOnly { statement_type: "DROP VIEW".to_string() })
178 }
179 Statement::AlterTable(_) => {
180 Err(ReadOnlyError::NotReadOnly { statement_type: "ALTER TABLE".to_string() })
181 }
182 Statement::TruncateTable(_) => {
183 Err(ReadOnlyError::NotReadOnly { statement_type: "TRUNCATE".to_string() })
184 }
185 Statement::BeginTransaction(_) => {
186 Err(ReadOnlyError::NotReadOnly { statement_type: "BEGIN TRANSACTION".to_string() })
187 }
188 Statement::Commit(_) => {
189 Err(ReadOnlyError::NotReadOnly { statement_type: "COMMIT".to_string() })
190 }
191 Statement::Rollback(_) => {
192 Err(ReadOnlyError::NotReadOnly { statement_type: "ROLLBACK".to_string() })
193 }
194 _ => {
195 Err(ReadOnlyError::NotReadOnly {
197 statement_type: format!("{:?}", std::mem::discriminant(&statement)),
198 })
199 }
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use vibesql_catalog::{ColumnSchema, TableSchema};
207 use vibesql_storage::Row;
208 use vibesql_types::{DataType, SqlValue};
209
210 use super::*;
211
212 fn create_test_db() -> Database {
213 let mut db = Database::new();
214 db.catalog.set_case_sensitive_identifiers(false);
215
216 let columns = vec![
218 ColumnSchema::new("id".to_string(), DataType::Integer, false),
219 ColumnSchema::new(
220 "name".to_string(),
221 DataType::Varchar { max_length: Some(100) },
222 true,
223 ),
224 ];
225 let schema =
226 TableSchema::with_primary_key("users".to_string(), columns, vec!["id".to_string()]);
227 db.create_table(schema).unwrap();
228
229 let row1 =
231 Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))]);
232 let row2 =
233 Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))]);
234 let row3 = Row::new(vec![
235 SqlValue::Integer(3),
236 SqlValue::Varchar(arcstr::ArcStr::from("Charlie")),
237 ]);
238
239 db.insert_row("users", row1).unwrap();
240 db.insert_row("users", row2).unwrap();
241 db.insert_row("users", row3).unwrap();
242
243 db
244 }
245
246 #[test]
247 fn test_query_select_all() {
248 let db = create_test_db();
249
250 let result = db.query("SELECT * FROM users").unwrap();
251 assert_eq!(result.rows.len(), 3);
252 assert_eq!(result.columns.len(), 2);
253 }
254
255 #[test]
256 fn test_query_select_with_where() {
257 let db = create_test_db();
258
259 let result = db.query("SELECT * FROM users WHERE id = 1").unwrap();
260 assert_eq!(result.rows.len(), 1);
261 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
262 assert_eq!(result.rows[0].values[1], SqlValue::Varchar(arcstr::ArcStr::from("Alice")));
263 }
264
265 #[test]
266 fn test_query_select_specific_columns() {
267 let db = create_test_db();
268
269 let result = db.query("SELECT name FROM users WHERE id = 2").unwrap();
270 assert_eq!(result.rows.len(), 1);
271 assert_eq!(result.columns.len(), 1);
272 assert_eq!(result.columns[0].to_lowercase(), "name");
274 assert_eq!(result.rows[0].values[0], SqlValue::Varchar(arcstr::ArcStr::from("Bob")));
275 }
276
277 #[test]
278 fn test_query_select_count() {
279 let db = create_test_db();
280
281 let result = db.query("SELECT COUNT(*) FROM users").unwrap();
282 assert_eq!(result.rows.len(), 1);
283 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
285 }
286
287 #[test]
288 fn test_query_rejects_insert() {
289 let db = create_test_db();
290
291 let result = db.query("INSERT INTO users (id, name) VALUES (4, 'David')");
292 assert!(matches!(
293 result,
294 Err(ReadOnlyError::NotReadOnly { statement_type }) if statement_type == "INSERT"
295 ));
296 }
297
298 #[test]
299 fn test_query_rejects_update() {
300 let db = create_test_db();
301
302 let result = db.query("UPDATE users SET name = 'Alicia' WHERE id = 1");
303 assert!(matches!(
304 result,
305 Err(ReadOnlyError::NotReadOnly { statement_type }) if statement_type == "UPDATE"
306 ));
307 }
308
309 #[test]
310 fn test_query_rejects_delete() {
311 let db = create_test_db();
312
313 let result = db.query("DELETE FROM users WHERE id = 1");
314 assert!(matches!(
315 result,
316 Err(ReadOnlyError::NotReadOnly { statement_type }) if statement_type == "DELETE"
317 ));
318 }
319
320 #[test]
321 fn test_query_rejects_create_table() {
322 let db = create_test_db();
323
324 let result = db.query("CREATE TABLE test (id INT)");
325 assert!(matches!(
326 result,
327 Err(ReadOnlyError::NotReadOnly { statement_type }) if statement_type == "CREATE TABLE"
328 ));
329 }
330
331 #[test]
332 fn test_query_rejects_drop_table() {
333 let db = create_test_db();
334
335 let result = db.query("DROP TABLE users");
336 assert!(matches!(
337 result,
338 Err(ReadOnlyError::NotReadOnly { statement_type }) if statement_type == "DROP TABLE"
339 ));
340 }
341
342 #[test]
343 fn test_query_rejects_truncate() {
344 let db = create_test_db();
345
346 let result = db.query("TRUNCATE TABLE users");
347 assert!(matches!(
348 result,
349 Err(ReadOnlyError::NotReadOnly { statement_type }) if statement_type == "TRUNCATE"
350 ));
351 }
352
353 #[test]
354 fn test_query_parse_error() {
355 let db = create_test_db();
356
357 let result = db.query("SELEKT * FROM users");
358 assert!(matches!(result, Err(ReadOnlyError::ParseError(_))));
359 }
360
361 #[test]
362 fn test_query_execution_error_table_not_found() {
363 let db = create_test_db();
364
365 let result = db.query("SELECT * FROM nonexistent");
366 assert!(matches!(result, Err(ReadOnlyError::ExecutionError(_))));
367 }
368
369 #[test]
370 fn test_query_with_order_by() {
371 let db = create_test_db();
372
373 let result = db.query("SELECT * FROM users ORDER BY id DESC").unwrap();
374 assert_eq!(result.rows.len(), 3);
375 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
377 assert_eq!(result.rows[2].values[0], SqlValue::Integer(1));
379 }
380
381 #[test]
382 fn test_query_with_limit() {
383 let db = create_test_db();
384
385 let result = db.query("SELECT * FROM users LIMIT 2").unwrap();
386 assert_eq!(result.rows.len(), 2);
387 }
388
389 #[test]
390 fn test_query_with_aggregation() {
391 let db = create_test_db();
392
393 let result = db.query("SELECT COUNT(*), MAX(id), MIN(id) FROM users").unwrap();
394 assert_eq!(result.rows.len(), 1);
395 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3)); assert_eq!(result.rows[0].values[1], SqlValue::Integer(3)); assert_eq!(result.rows[0].values[2], SqlValue::Integer(1)); }
400
401 #[test]
402 fn test_query_immutability() {
403 let db = create_test_db();
404
405 let result1 = db.query("SELECT COUNT(*) FROM users").unwrap();
407 let result2 = db.query("SELECT * FROM users WHERE id = 1").unwrap();
408 let result3 = db.query("SELECT name FROM users").unwrap();
409
410 assert_eq!(result1.rows[0].values[0], SqlValue::Integer(3));
413 assert_eq!(result2.rows.len(), 1);
414 assert_eq!(result3.rows.len(), 3);
415 }
416}
417
418#[derive(Clone)]
457pub struct SharedDatabase {
458 inner: Arc<RwLock<Database>>,
459}
460
461impl SharedDatabase {
462 pub fn new(db: Database) -> Self {
464 Self { inner: Arc::new(RwLock::new(db)) }
465 }
466
467 pub fn from_arc(inner: Arc<RwLock<Database>>) -> Self {
472 Self { inner }
473 }
474
475 pub fn into_inner(self) -> Arc<RwLock<Database>> {
480 self.inner
481 }
482
483 pub fn as_arc(&self) -> &Arc<RwLock<Database>> {
485 &self.inner
486 }
487
488 pub async fn read(&self) -> RwLockReadGuard<'_, Database> {
493 self.inner.read().await
494 }
495
496 pub async fn write(&self) -> RwLockWriteGuard<'_, Database> {
501 self.inner.write().await
502 }
503
504 pub async fn query(&self, sql: &str) -> Result<SelectResult, ReadOnlyError> {
524 let guard = self.read().await;
525 guard.query(sql)
526 }
527}
528
529impl Default for SharedDatabase {
530 fn default() -> Self {
531 Self::new(Database::new())
532 }
533}
534
535impl std::fmt::Debug for SharedDatabase {
536 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
537 f.debug_struct("SharedDatabase").field("inner", &"Arc<RwLock<Database>>").finish()
538 }
539}
540
541#[cfg(test)]
542mod shared_database_tests {
543 use super::*;
544 use vibesql_catalog::{ColumnSchema, TableSchema};
545 use vibesql_storage::Row;
546 use vibesql_types::{DataType, SqlValue};
547
548 async fn create_shared_test_db() -> SharedDatabase {
549 let mut db = Database::new();
550 db.catalog.set_case_sensitive_identifiers(false);
551
552 let columns = vec![
554 ColumnSchema::new("id".to_string(), DataType::Integer, false),
555 ColumnSchema::new(
556 "name".to_string(),
557 DataType::Varchar { max_length: Some(100) },
558 true,
559 ),
560 ];
561 let schema =
562 TableSchema::with_primary_key("users".to_string(), columns, vec!["id".to_string()]);
563 db.create_table(schema).unwrap();
564
565 let row1 =
567 Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("Alice"))]);
568 let row2 =
569 Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("Bob"))]);
570
571 db.insert_row("users", row1).unwrap();
572 db.insert_row("users", row2).unwrap();
573
574 SharedDatabase::new(db)
575 }
576
577 #[tokio::test]
578 async fn test_shared_query() {
579 let db = create_shared_test_db().await;
580
581 let result = db.query("SELECT * FROM users").await.unwrap();
582 assert_eq!(result.rows.len(), 2);
583 }
584
585 #[tokio::test]
586 async fn test_shared_query_with_filter() {
587 let db = create_shared_test_db().await;
588
589 let result = db.query("SELECT * FROM users WHERE id = 1").await.unwrap();
590 assert_eq!(result.rows.len(), 1);
591 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
592 }
593
594 #[tokio::test]
595 async fn test_shared_query_rejects_mutations() {
596 let db = create_shared_test_db().await;
597
598 let result = db.query("INSERT INTO users VALUES (3, 'Charlie')").await;
599 assert!(matches!(result, Err(ReadOnlyError::NotReadOnly { .. })));
600 }
601
602 #[tokio::test]
603 async fn test_concurrent_reads() {
604 let db = create_shared_test_db().await;
605
606 let mut handles = Vec::new();
608 for i in 0..10 {
609 let db_clone = db.clone();
610 handles.push(tokio::spawn(async move {
611 let result = db_clone.query("SELECT COUNT(*) FROM users").await.unwrap();
612 (i, result.rows[0].values[0].clone())
613 }));
614 }
615
616 for handle in handles {
618 let (_, count) = handle.await.unwrap();
619 assert_eq!(count, SqlValue::Integer(2));
620 }
621 }
622
623 #[tokio::test]
624 async fn test_read_write_isolation() {
625 let db = create_shared_test_db().await;
626
627 let result_before = db.query("SELECT COUNT(*) FROM users").await.unwrap();
629 assert_eq!(result_before.rows[0].values[0], SqlValue::Integer(2));
630
631 {
633 let mut guard = db.write().await;
634 let row = Row::new(vec![
635 SqlValue::Integer(3),
636 SqlValue::Varchar(arcstr::ArcStr::from("Charlie")),
637 ]);
638 guard.insert_row("users", row).unwrap();
639 }
640
641 let result_after = db.query("SELECT COUNT(*) FROM users").await.unwrap();
643 assert_eq!(result_after.rows[0].values[0], SqlValue::Integer(3));
644 }
645
646 #[tokio::test]
647 async fn test_from_arc() {
648 let inner = Arc::new(RwLock::new(Database::new()));
649 let db = SharedDatabase::from_arc(inner.clone());
650
651 assert!(Arc::ptr_eq(db.as_arc(), &inner));
653 }
654}