1use crate::Result;
4use crate::database::SqliteDatabase;
5use crate::error::Error;
6use crate::write_guard::WriteGuard;
7use sqlx::Sqlite;
8use sqlx::pool::PoolConnection;
9use sqlx::sqlite::SqliteConnection;
10use std::ops::{Deref, DerefMut};
11use std::sync::Arc;
12
13#[derive(Clone)]
15pub struct AttachedSpec {
16 pub database: Arc<SqliteDatabase>,
18 pub schema_name: String,
20 pub mode: AttachedMode,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26pub enum AttachedMode {
27 ReadOnly,
29 ReadWrite,
31}
32
33#[must_use = "if unused, the attached connection and locks are immediately dropped"]
39#[derive(Debug)]
40pub struct AttachedReadConnection {
41 conn: PoolConnection<Sqlite>,
42 #[allow(dead_code)]
47 held_writers: Vec<WriteGuard>,
48 #[allow(dead_code)]
50 schema_names: Vec<String>,
51}
52
53impl AttachedReadConnection {
54 pub(crate) fn new(
55 conn: PoolConnection<Sqlite>,
56 held_writers: Vec<WriteGuard>,
57 schema_names: Vec<String>,
58 ) -> Self {
59 Self {
60 conn,
61 held_writers,
62 schema_names,
63 }
64 }
65
66 pub async fn detach_all(mut self) -> Result<()> {
72 for schema_name in &self.schema_names {
73 let detach_sql = format!("DETACH DATABASE {}", schema_name);
74 sqlx::query(&detach_sql).execute(&mut *self.conn).await?;
75 }
76 Ok(())
77 }
78}
79
80impl Deref for AttachedReadConnection {
81 type Target = SqliteConnection;
82
83 fn deref(&self) -> &Self::Target {
84 &self.conn
85 }
86}
87
88impl DerefMut for AttachedReadConnection {
89 fn deref_mut(&mut self) -> &mut Self::Target {
90 &mut self.conn
91 }
92}
93
94impl Drop for AttachedReadConnection {
95 fn drop(&mut self) {
96 }
101}
102
103#[must_use = "if unused, the write guard and locks are immediately dropped"]
109#[derive(Debug)]
110pub struct AttachedWriteGuard {
111 writer: WriteGuard,
112 #[allow(dead_code)]
117 held_writers: Vec<WriteGuard>,
118 #[allow(dead_code)]
120 schema_names: Vec<String>,
121}
122
123impl AttachedWriteGuard {
124 pub(crate) fn new(
125 writer: WriteGuard,
126 held_writers: Vec<WriteGuard>,
127 schema_names: Vec<String>,
128 ) -> Self {
129 Self {
130 writer,
131 held_writers,
132 schema_names,
133 }
134 }
135
136 pub async fn detach_all(mut self) -> Result<()> {
142 for schema_name in &self.schema_names {
143 let detach_sql = format!("DETACH DATABASE {}", schema_name);
144 sqlx::query(&detach_sql).execute(&mut *self.writer).await?;
145 }
146 Ok(())
147 }
148}
149
150impl Deref for AttachedWriteGuard {
151 type Target = SqliteConnection;
152
153 fn deref(&self) -> &Self::Target {
154 &self.writer
155 }
156}
157
158impl DerefMut for AttachedWriteGuard {
159 fn deref_mut(&mut self) -> &mut Self::Target {
160 &mut self.writer
161 }
162}
163
164impl Drop for AttachedWriteGuard {
165 fn drop(&mut self) {
166 }
171}
172
173fn is_valid_schema_name(name: &str) -> bool {
187 !name.is_empty()
188 && name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
189 && !name.chars().next().unwrap().is_ascii_digit()
190}
191
192pub async fn acquire_reader_with_attached(
214 main_db: &SqliteDatabase,
215 mut specs: Vec<AttachedSpec>,
216) -> Result<AttachedReadConnection> {
217 let mut conn = main_db.read_pool()?.acquire().await?;
219
220 specs.sort_by(|a, b| a.database.path_str().cmp(&b.database.path_str()));
225
226 use std::collections::HashSet;
230 let mut seen_paths = HashSet::new();
231 for spec in &specs {
232 let path = spec.database.path_str();
233 if !seen_paths.insert(path.clone()) {
234 return Err(Error::DuplicateAttachedDatabase(path));
235 }
236 }
237
238 let mut schema_names = Vec::new();
239
240 for spec in specs {
241 if !is_valid_schema_name(&spec.schema_name) {
243 return Err(Error::InvalidSchemaName(spec.schema_name.clone()));
244 }
245
246 if spec.mode == AttachedMode::ReadWrite {
248 return Err(Error::CannotAttachReadWriteToReader);
249 }
250
251 let path = spec.database.path_str();
254 let escaped_path = path.replace("'", "''");
255 let attach_sql = format!("ATTACH DATABASE '{}' AS {}", escaped_path, spec.schema_name);
256 sqlx::query(&attach_sql).execute(&mut *conn).await?;
257
258 schema_names.push(spec.schema_name);
259 }
260
261 Ok(AttachedReadConnection::new(conn, Vec::new(), schema_names))
262}
263
264pub async fn acquire_writer_with_attached(
289 main_db: &SqliteDatabase,
290 specs: Vec<AttachedSpec>,
291) -> Result<AttachedWriteGuard> {
292 for spec in &specs {
294 if !is_valid_schema_name(&spec.schema_name) {
295 return Err(Error::InvalidSchemaName(spec.schema_name.clone()));
296 }
297 }
298
299 let main_path = main_db.path_str();
306
307 let mut db_entries: Vec<(String, &SqliteDatabase)> = vec![(main_path.clone(), main_db)];
309
310 for spec in &specs {
311 if spec.mode == AttachedMode::ReadWrite {
312 db_entries.push((spec.database.path_str(), &*spec.database));
313 }
314 }
315
316 use std::collections::HashSet;
320 let mut seen_paths = HashSet::new();
321 for (path, _) in &db_entries {
322 if !seen_paths.insert(path.as_str()) {
323 return Err(Error::DuplicateAttachedDatabase(path.clone()));
324 }
325 }
326
327 db_entries.sort_by(|a, b| a.0.cmp(&b.0));
329
330 let main_writer_idx = db_entries
332 .iter()
333 .position(|(path, _)| path == &main_path)
334 .expect("main database must be in the list");
335
336 let mut all_writers = Vec::new();
338 for (_, db) in &db_entries {
339 all_writers.push(db.acquire_writer().await?);
340 }
341
342 let mut writer = all_writers.remove(main_writer_idx);
344 let held_writers = all_writers;
345
346 let mut schema_names = Vec::new();
348
349 for spec in specs {
350 let path = spec.database.path_str();
351 let escaped_path = path.replace("'", "''");
352 let attach_sql = format!("ATTACH DATABASE '{}' AS {}", escaped_path, spec.schema_name);
353 sqlx::query(&attach_sql).execute(&mut *writer).await?;
354
355 schema_names.push(spec.schema_name);
356 }
357
358 Ok(AttachedWriteGuard::new(writer, held_writers, schema_names))
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::SqliteDatabase;
365 use sqlx::Row;
366 use std::sync::Arc;
367 use tempfile::TempDir;
368
369 async fn create_test_db(name: &str, temp_dir: &TempDir) -> Arc<SqliteDatabase> {
370 let path = temp_dir.path().join(name);
371 let db = SqliteDatabase::connect(&path, None).await.unwrap();
372
373 let mut writer = db.acquire_writer().await.unwrap();
375 sqlx::query(&format!(
376 "CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY, value TEXT)",
377 name.replace(".db", "")
378 ))
379 .execute(&mut *writer)
380 .await
381 .unwrap();
382
383 sqlx::query(&format!(
385 "INSERT INTO {} (value) VALUES ('test_data')",
386 name.replace(".db", "")
387 ))
388 .execute(&mut *writer)
389 .await
390 .unwrap();
391
392 db
393 }
394
395 #[tokio::test]
396 async fn test_attach_readonly_to_reader() {
397 let temp_dir = TempDir::new().unwrap();
398 let main_db = create_test_db("main.db", &temp_dir).await;
399 let other_db = create_test_db("other.db", &temp_dir).await;
400
401 let specs = vec![AttachedSpec {
402 database: other_db.clone(),
403 schema_name: "other".to_string(),
404 mode: AttachedMode::ReadOnly,
405 }];
406
407 let mut conn = acquire_reader_with_attached(&main_db, specs).await.unwrap();
408
409 let row = sqlx::query("SELECT value FROM other.other LIMIT 1")
411 .fetch_one(&mut *conn)
412 .await
413 .unwrap();
414
415 let value: String = row.get(0);
416 assert_eq!(value, "test_data");
417 }
418
419 #[tokio::test]
420 async fn test_attach_readonly_to_writer() {
421 let temp_dir = TempDir::new().unwrap();
422 let main_db = create_test_db("main.db", &temp_dir).await;
423 let other_db = create_test_db("other.db", &temp_dir).await;
424
425 let specs = vec![AttachedSpec {
426 database: other_db.clone(),
427 schema_name: "other".to_string(),
428 mode: AttachedMode::ReadOnly,
429 }];
430
431 let mut conn = acquire_writer_with_attached(&main_db, specs).await.unwrap();
432
433 let row = sqlx::query("SELECT value FROM other.other LIMIT 1")
435 .fetch_one(&mut *conn)
436 .await
437 .unwrap();
438
439 let value: String = row.get(0);
440 assert_eq!(value, "test_data");
441 }
442
443 #[tokio::test]
444 async fn test_attach_readwrite_to_writer() {
445 let temp_dir = TempDir::new().unwrap();
446 let main_db = create_test_db("main.db", &temp_dir).await;
447 let other_db = create_test_db("other.db", &temp_dir).await;
448
449 let specs = vec![AttachedSpec {
450 database: other_db.clone(),
451 schema_name: "other".to_string(),
452 mode: AttachedMode::ReadWrite,
453 }];
454
455 let mut conn = acquire_writer_with_attached(&main_db, specs).await.unwrap();
456
457 sqlx::query("INSERT INTO other.other (value) VALUES ('new_data')")
459 .execute(&mut *conn)
460 .await
461 .unwrap();
462
463 let row = sqlx::query("SELECT value FROM other.other WHERE value = 'new_data'")
465 .fetch_one(&mut *conn)
466 .await
467 .unwrap();
468
469 let value: String = row.get(0);
470 assert_eq!(value, "new_data");
471 }
472
473 #[tokio::test]
474 async fn test_attach_readwrite_to_reader_fails() {
475 let temp_dir = TempDir::new().unwrap();
476 let main_db = create_test_db("main.db", &temp_dir).await;
477 let other_db = create_test_db("other.db", &temp_dir).await;
478
479 let specs = vec![AttachedSpec {
480 database: other_db.clone(),
481 schema_name: "other".to_string(),
482 mode: AttachedMode::ReadWrite,
483 }];
484
485 let result = acquire_reader_with_attached(&main_db, specs).await;
486 assert!(result.is_err());
487 assert!(matches!(
488 result.unwrap_err(),
489 Error::CannotAttachReadWriteToReader
490 ));
491 }
492
493 #[tokio::test]
494 async fn test_attach_multiple_databases() {
495 let temp_dir = TempDir::new().unwrap();
496 let main_db = create_test_db("main.db", &temp_dir).await;
497 let db1 = create_test_db("db1.db", &temp_dir).await;
498 let db2 = create_test_db("db2.db", &temp_dir).await;
499
500 let specs = vec![
501 AttachedSpec {
502 database: db1.clone(),
503 schema_name: "db1".to_string(),
504 mode: AttachedMode::ReadOnly,
505 },
506 AttachedSpec {
507 database: db2.clone(),
508 schema_name: "db2".to_string(),
509 mode: AttachedMode::ReadOnly,
510 },
511 ];
512
513 let mut conn = acquire_reader_with_attached(&main_db, specs).await.unwrap();
514
515 let row1 = sqlx::query("SELECT value FROM db1.db1 LIMIT 1")
517 .fetch_one(&mut *conn)
518 .await
519 .unwrap();
520
521 let value1: String = row1.get(0);
522 assert_eq!(value1, "test_data");
523
524 let row2 = sqlx::query("SELECT value FROM db2.db2 LIMIT 1")
525 .fetch_one(&mut *conn)
526 .await
527 .unwrap();
528
529 let value2: String = row2.get(0);
530 assert_eq!(value2, "test_data");
531 }
532
533 #[tokio::test]
534 async fn test_attached_database_in_readwrite_mode_holds_writer_lock() {
535 let temp_dir = TempDir::new().unwrap();
536 let main_db = create_test_db("main.db", &temp_dir).await;
537 let other_db = create_test_db("other.db", &temp_dir).await;
538
539 let specs = vec![AttachedSpec {
540 database: other_db.clone(),
541 schema_name: "other".to_string(),
542 mode: AttachedMode::ReadWrite,
543 }];
544
545 let _guard = acquire_writer_with_attached(&main_db, specs).await.unwrap();
547
548 let acquire_result = tokio::time::timeout(
550 std::time::Duration::from_millis(100),
551 other_db.acquire_writer(),
552 )
553 .await;
554
555 assert!(
557 acquire_result.is_err(),
558 "Expected timeout acquiring writer that's already held"
559 );
560 }
561
562 #[tokio::test]
563 async fn test_locks_released_on_drop() {
564 let temp_dir = TempDir::new().unwrap();
565 let main_db = create_test_db("main.db", &temp_dir).await;
566 let other_db = create_test_db("other.db", &temp_dir).await;
567
568 let specs = vec![AttachedSpec {
569 database: other_db.clone(),
570 schema_name: "other".to_string(),
571 mode: AttachedMode::ReadWrite,
572 }];
573
574 {
576 let _ = acquire_writer_with_attached(&main_db, specs).await.unwrap();
577 }
579
580 let writer = other_db.acquire_writer().await;
582 assert!(
583 writer.is_ok(),
584 "Writer should be available after attached connection dropped"
585 );
586 }
587
588 #[tokio::test]
589 async fn test_cross_database_join_query() {
590 let temp_dir = TempDir::new().unwrap();
591
592 let main_db = SqliteDatabase::connect(temp_dir.path().join("main.db"), None)
594 .await
595 .unwrap();
596
597 let mut writer = main_db.acquire_writer().await.unwrap();
598 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
599 .execute(&mut *writer)
600 .await
601 .unwrap();
602
603 sqlx::query("INSERT INTO users (id, name) VALUES (1, 'Alice')")
604 .execute(&mut *writer)
605 .await
606 .unwrap();
607
608 drop(writer);
609
610 let orders_db = SqliteDatabase::connect(temp_dir.path().join("orders.db"), None)
612 .await
613 .unwrap();
614
615 let mut writer = orders_db.acquire_writer().await.unwrap();
616 sqlx::query("CREATE TABLE orders (id INTEGER PRIMARY KEY, user_id INTEGER, total REAL)")
617 .execute(&mut *writer)
618 .await
619 .unwrap();
620
621 sqlx::query("INSERT INTO orders (id, user_id, total) VALUES (100, 1, 99.99)")
622 .execute(&mut *writer)
623 .await
624 .unwrap();
625
626 drop(writer);
627
628 let specs = vec![AttachedSpec {
630 database: orders_db,
631 schema_name: "orders".to_string(),
632 mode: AttachedMode::ReadOnly,
633 }];
634
635 let mut conn = acquire_reader_with_attached(&main_db, specs).await.unwrap();
636
637 let row = sqlx::query(
638 "SELECT u.name, o.total FROM main.users u JOIN orders.orders o ON u.id = o.user_id",
639 )
640 .fetch_one(&mut *conn)
641 .await
642 .unwrap();
643
644 let name: String = row.get(0);
645 let total: f64 = row.get(1);
646 assert_eq!(name, "Alice");
647 assert_eq!(total, 99.99);
648 }
649
650 #[tokio::test]
651 async fn test_sorting_attached_databases_prevents_deadlock() {
652 let temp_dir = TempDir::new().unwrap();
653 let main_db = create_test_db("main.db", &temp_dir).await;
654 let db_a = create_test_db("a.db", &temp_dir).await;
655 let db_z = create_test_db("z.db", &temp_dir).await;
656
657 let specs = vec![
659 AttachedSpec {
660 database: db_z.clone(),
661 schema_name: "z".to_string(),
662 mode: AttachedMode::ReadWrite,
663 },
664 AttachedSpec {
665 database: db_a.clone(),
666 schema_name: "a".to_string(),
667 mode: AttachedMode::ReadWrite,
668 },
669 ];
670
671 let result = acquire_writer_with_attached(&main_db, specs).await;
673 assert!(
674 result.is_ok(),
675 "Attachment should succeed with sorted acquisition order"
676 );
677 }
678
679 #[tokio::test]
680 async fn test_attaching_same_databases_in_different_order_concurrently_no_deadlock() {
681 let temp_dir = TempDir::new().unwrap();
688 let db_a = create_test_db("a.db", &temp_dir).await;
689 let db_b = create_test_db("b.db", &temp_dir).await;
690
691 let db_a_clone = db_a.clone();
692 let db_b_clone = db_b.clone();
693
694 let task1 = tokio::spawn(async move {
695 let specs = vec![AttachedSpec {
697 database: db_b_clone,
698 schema_name: "b_schema".to_string(),
699 mode: AttachedMode::ReadWrite,
700 }];
701 let guard = acquire_writer_with_attached(&db_a_clone, specs).await?;
702 drop(guard);
704 Ok::<_, crate::Error>(())
705 });
706
707 let task2 = tokio::spawn(async move {
708 let specs = vec![AttachedSpec {
710 database: db_a,
711 schema_name: "a_schema".to_string(),
712 mode: AttachedMode::ReadWrite,
713 }];
714 let guard = acquire_writer_with_attached(&db_b, specs).await?;
715 drop(guard);
716 Ok::<_, crate::Error>(())
717 });
718
719 let timeout_duration = std::time::Duration::from_secs(5);
721 let result =
722 tokio::time::timeout(timeout_duration, async { tokio::try_join!(task1, task2) }).await;
723
724 assert!(
726 result.is_ok(),
727 "Should complete without deadlock within {} seconds",
728 timeout_duration.as_secs()
729 );
730
731 let (res1, res2) = result.unwrap().unwrap();
733 assert!(res1.is_ok() && res2.is_ok(), "Both tasks should succeed");
734 }
735
736 #[tokio::test]
737 async fn test_invalid_schema_names_rejected() {
738 let temp_dir = TempDir::new().unwrap();
739 let main_db = create_test_db("main.db", &temp_dir).await;
740 let other_db = create_test_db("other.db", &temp_dir).await;
741
742 let invalid_names = vec![
744 "", "123invalid", "schema-name", "schema name", "schema;DROP TABLE users", "schema'--", "schema/*comment*/", ];
752
753 for invalid_name in invalid_names {
754 let specs = vec![AttachedSpec {
755 database: other_db.clone(),
756 schema_name: invalid_name.to_string(),
757 mode: AttachedMode::ReadOnly,
758 }];
759
760 let result = acquire_reader_with_attached(&main_db, specs).await;
761 assert!(
762 matches!(result, Err(Error::InvalidSchemaName(_))),
763 "Expected InvalidSchemaName error for '{}'",
764 invalid_name
765 );
766 }
767 }
768
769 #[tokio::test]
770 async fn test_duplicate_attached_database_rejected() {
771 let temp_dir = TempDir::new().unwrap();
772 let main_db = create_test_db("main.db", &temp_dir).await;
773 let other_db = create_test_db("other.db", &temp_dir).await;
774
775 let specs = vec![
777 AttachedSpec {
778 database: other_db.clone(),
779 schema_name: "other1".to_string(),
780 mode: AttachedMode::ReadWrite,
781 },
782 AttachedSpec {
783 database: other_db.clone(),
784 schema_name: "other2".to_string(),
785 mode: AttachedMode::ReadWrite,
786 },
787 ];
788
789 let result = acquire_writer_with_attached(&main_db, specs).await;
790 assert!(
791 matches!(result, Err(Error::DuplicateAttachedDatabase(_))),
792 "Should reject duplicate attached database"
793 );
794 }
795
796 #[tokio::test]
797 async fn test_main_db_in_attached_list_rejected() {
798 let temp_dir = TempDir::new().unwrap();
799 let main_db = create_test_db("main.db", &temp_dir).await;
800
801 let specs = vec![AttachedSpec {
803 database: main_db.clone(),
804 schema_name: "main_copy".to_string(),
805 mode: AttachedMode::ReadWrite,
806 }];
807
808 let result = acquire_writer_with_attached(&main_db, specs).await;
809 assert!(
810 matches!(result, Err(Error::DuplicateAttachedDatabase(_))),
811 "Should reject attaching main database to itself"
812 );
813 }
814
815 #[tokio::test]
816 async fn test_path_with_single_quotes() {
817 let temp_dir = TempDir::new().unwrap();
818
819 let quoted_dir = temp_dir.path().join("user's_data");
821 std::fs::create_dir("ed_dir).unwrap();
822
823 let main_db = SqliteDatabase::connect(temp_dir.path().join("main.db"), None)
824 .await
825 .unwrap();
826
827 let other_path = quoted_dir.join("other.db");
829 let other_db = SqliteDatabase::connect(&other_path, None).await.unwrap();
830
831 let specs = vec![AttachedSpec {
833 database: other_db,
834 schema_name: "other".to_string(),
835 mode: AttachedMode::ReadOnly,
836 }];
837
838 let result = acquire_reader_with_attached(&main_db, specs).await;
839 assert!(
840 result.is_ok(),
841 "Should attach database with single quote in path"
842 );
843 }
844}