Skip to main content

rustrails_support/
database.rs

1use std::cell::RefCell;
2
3use sea_orm::{Database, DatabaseConnection};
4
5use crate::runtime;
6
7const DATABASE_NOT_ESTABLISHED: &str = "rustrails_support::database::establish() must be called on this thread before accessing the database connection";
8
9thread_local! {
10    static DB_CONNECTION: RefCell<Option<DatabaseConnection>> = const { RefCell::new(None) };
11}
12
13/// Errors returned while establishing the thread-local database connection.
14#[derive(Debug, thiserror::Error)]
15pub enum DatabaseError {
16    /// The database connection could not be established.
17    #[error("database connection failed: {0}")]
18    ConnectionFailed(#[from] sea_orm::DbErr),
19}
20
21/// Establishes and stores a thread-local SeaORM database connection.
22///
23/// Any previously stored connection on this thread is replaced.
24pub fn establish(url: &str) -> Result<(), DatabaseError> {
25    let connection = runtime::block_on(Database::connect(url))?;
26    DB_CONNECTION.with(|cell| {
27        *cell.borrow_mut() = Some(connection);
28    });
29    Ok(())
30}
31
32/// Returns a clone of the thread-local database connection.
33///
34/// Panics when no connection has been established on the current thread.
35pub fn db() -> DatabaseConnection {
36    DB_CONNECTION.with(|cell| {
37        cell.borrow()
38            .as_ref()
39            .cloned()
40            .unwrap_or_else(|| panic!("{DATABASE_NOT_ESTABLISHED}"))
41    })
42}
43
44/// Borrows the thread-local database connection for the duration of a closure.
45///
46/// Panics when no connection has been established on the current thread.
47pub fn with_db<F, R>(f: F) -> R
48where
49    F: FnOnce(&DatabaseConnection) -> R,
50{
51    DB_CONNECTION.with(|cell| {
52        let borrow = cell.borrow();
53        let connection = borrow
54            .as_ref()
55            .unwrap_or_else(|| panic!("{DATABASE_NOT_ESTABLISHED}"));
56        f(connection)
57    })
58}
59
60/// Returns `true` when the current thread has an established database connection.
61pub fn is_established() -> bool {
62    DB_CONNECTION.with(|cell| cell.borrow().is_some())
63}
64
65#[cfg(test)]
66mod tests {
67    use std::{any::Any, thread};
68
69    use sea_orm::{
70        ConnectionTrait, DatabaseBackend,
71        sea_query::{Alias, ColumnDef, Expr, Query, Table},
72    };
73
74    use super::{db, establish, is_established, with_db};
75    use crate::runtime;
76
77    fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
78    where
79        R: Send + 'static,
80    {
81        match thread::spawn(test).join() {
82            Ok(result) => result,
83            Err(payload) => std::panic::resume_unwind(payload),
84        }
85    }
86
87    fn panic_message(payload: Box<dyn Any + Send>) -> String {
88        if let Some(message) = payload.downcast_ref::<String>() {
89            message.clone()
90        } else if let Some(message) = payload.downcast_ref::<&str>() {
91            (*message).to_owned()
92        } else {
93            "non-string panic payload".to_owned()
94        }
95    }
96
97    #[test]
98    fn establish_connects_to_in_memory_sqlite() {
99        run_isolated(|| {
100            let _runtime = runtime::init_runtime();
101            establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
102            assert!(is_established());
103        });
104    }
105
106    #[test]
107    fn db_panics_before_establish() {
108        let message = run_isolated(|| {
109            let panic = std::panic::catch_unwind(db)
110                .expect_err("db should panic before establish is called");
111            panic_message(panic)
112        });
113
114        assert!(message.contains("database::establish() must be called on this thread"));
115    }
116
117    #[test]
118    fn db_returns_a_usable_connection_after_establish() {
119        run_isolated(|| {
120            let _runtime = runtime::init_runtime();
121            establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
122            runtime::block_on(async {
123                db().ping()
124                    .await
125                    .expect("stored connection should respond to ping");
126            });
127        });
128    }
129
130    #[test]
131    fn with_db_passes_the_connection_into_the_closure() {
132        run_isolated(|| {
133            let _runtime = runtime::init_runtime();
134            establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
135            let backend = with_db(|connection| connection.get_database_backend());
136            assert_eq!(backend, DatabaseBackend::Sqlite);
137        });
138    }
139
140    #[test]
141    fn establish_twice_replaces_the_connection() {
142        run_isolated(|| {
143            let _runtime = runtime::init_runtime();
144            establish("sqlite::memory:").expect("first sqlite in-memory connection should succeed");
145            runtime::block_on(async {
146                db().execute(
147                    &Table::create()
148                        .table(Alias::new("replacement_check"))
149                        .col(
150                            ColumnDef::new(Alias::new("id"))
151                                .integer()
152                                .not_null()
153                                .primary_key(),
154                        )
155                        .to_owned(),
156                )
157                .await
158                .expect("table creation should succeed");
159            });
160
161            establish("sqlite::memory:")
162                .expect("second sqlite in-memory connection should succeed");
163            let query_result = runtime::block_on(async {
164                db().query_one(
165                    &Query::select()
166                        .expr(Expr::col(Alias::new("id")))
167                        .from(Alias::new("replacement_check"))
168                        .limit(1)
169                        .to_owned(),
170                )
171                .await
172            });
173
174            assert!(query_result.is_err());
175        });
176    }
177
178    #[test]
179    fn is_established_reflects_connection_state() {
180        run_isolated(|| {
181            let _runtime = runtime::init_runtime();
182            assert!(!is_established());
183            establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
184            assert!(is_established());
185        });
186    }
187
188    #[test]
189    fn database_error_displays_the_underlying_failure() {
190        run_isolated(|| {
191            let _runtime = runtime::init_runtime();
192            let error = establish("not-a-valid-database-url")
193                .expect_err("invalid database URLs should fail");
194            assert!(error.to_string().starts_with("database connection failed:"));
195        });
196    }
197
198    #[test]
199    fn with_db_can_return_a_computed_value() {
200        run_isolated(|| {
201            let _runtime = runtime::init_runtime();
202            establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
203            let is_sqlite =
204                with_db(|connection| connection.get_database_backend() == DatabaseBackend::Sqlite);
205            assert!(is_sqlite);
206        });
207    }
208}