rustrails_support/
database.rs1use std::cell::RefCell;
2
3use sea_orm::{Database, DatabaseConnection};
4
5use crate::runtime;
6
7const DATABASE_NOT_ESTABLISHED: &str = "rustrails_support::database::establish() must be called on this thread before accessing the database connection";
8
9thread_local! {
10 static DB_CONNECTION: RefCell<Option<DatabaseConnection>> = const { RefCell::new(None) };
11}
12
13#[derive(Debug, thiserror::Error)]
15pub enum DatabaseError {
16 #[error("database connection failed: {0}")]
18 ConnectionFailed(#[from] sea_orm::DbErr),
19}
20
21pub fn establish(url: &str) -> Result<(), DatabaseError> {
25 let connection = runtime::block_on(Database::connect(url))?;
26 DB_CONNECTION.with(|cell| {
27 *cell.borrow_mut() = Some(connection);
28 });
29 Ok(())
30}
31
32pub fn db() -> DatabaseConnection {
36 DB_CONNECTION.with(|cell| {
37 cell.borrow()
38 .as_ref()
39 .cloned()
40 .unwrap_or_else(|| panic!("{DATABASE_NOT_ESTABLISHED}"))
41 })
42}
43
44pub fn with_db<F, R>(f: F) -> R
48where
49 F: FnOnce(&DatabaseConnection) -> R,
50{
51 DB_CONNECTION.with(|cell| {
52 let borrow = cell.borrow();
53 let connection = borrow
54 .as_ref()
55 .unwrap_or_else(|| panic!("{DATABASE_NOT_ESTABLISHED}"));
56 f(connection)
57 })
58}
59
60pub fn is_established() -> bool {
62 DB_CONNECTION.with(|cell| cell.borrow().is_some())
63}
64
65#[cfg(test)]
66mod tests {
67 use std::{any::Any, thread};
68
69 use sea_orm::{
70 ConnectionTrait, DatabaseBackend,
71 sea_query::{Alias, ColumnDef, Expr, Query, Table},
72 };
73
74 use super::{db, establish, is_established, with_db};
75 use crate::runtime;
76
77 fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
78 where
79 R: Send + 'static,
80 {
81 match thread::spawn(test).join() {
82 Ok(result) => result,
83 Err(payload) => std::panic::resume_unwind(payload),
84 }
85 }
86
87 fn panic_message(payload: Box<dyn Any + Send>) -> String {
88 if let Some(message) = payload.downcast_ref::<String>() {
89 message.clone()
90 } else if let Some(message) = payload.downcast_ref::<&str>() {
91 (*message).to_owned()
92 } else {
93 "non-string panic payload".to_owned()
94 }
95 }
96
97 #[test]
98 fn establish_connects_to_in_memory_sqlite() {
99 run_isolated(|| {
100 let _runtime = runtime::init_runtime();
101 establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
102 assert!(is_established());
103 });
104 }
105
106 #[test]
107 fn db_panics_before_establish() {
108 let message = run_isolated(|| {
109 let panic = std::panic::catch_unwind(db)
110 .expect_err("db should panic before establish is called");
111 panic_message(panic)
112 });
113
114 assert!(message.contains("database::establish() must be called on this thread"));
115 }
116
117 #[test]
118 fn db_returns_a_usable_connection_after_establish() {
119 run_isolated(|| {
120 let _runtime = runtime::init_runtime();
121 establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
122 runtime::block_on(async {
123 db().ping()
124 .await
125 .expect("stored connection should respond to ping");
126 });
127 });
128 }
129
130 #[test]
131 fn with_db_passes_the_connection_into_the_closure() {
132 run_isolated(|| {
133 let _runtime = runtime::init_runtime();
134 establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
135 let backend = with_db(|connection| connection.get_database_backend());
136 assert_eq!(backend, DatabaseBackend::Sqlite);
137 });
138 }
139
140 #[test]
141 fn establish_twice_replaces_the_connection() {
142 run_isolated(|| {
143 let _runtime = runtime::init_runtime();
144 establish("sqlite::memory:").expect("first sqlite in-memory connection should succeed");
145 runtime::block_on(async {
146 db().execute(
147 &Table::create()
148 .table(Alias::new("replacement_check"))
149 .col(
150 ColumnDef::new(Alias::new("id"))
151 .integer()
152 .not_null()
153 .primary_key(),
154 )
155 .to_owned(),
156 )
157 .await
158 .expect("table creation should succeed");
159 });
160
161 establish("sqlite::memory:")
162 .expect("second sqlite in-memory connection should succeed");
163 let query_result = runtime::block_on(async {
164 db().query_one(
165 &Query::select()
166 .expr(Expr::col(Alias::new("id")))
167 .from(Alias::new("replacement_check"))
168 .limit(1)
169 .to_owned(),
170 )
171 .await
172 });
173
174 assert!(query_result.is_err());
175 });
176 }
177
178 #[test]
179 fn is_established_reflects_connection_state() {
180 run_isolated(|| {
181 let _runtime = runtime::init_runtime();
182 assert!(!is_established());
183 establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
184 assert!(is_established());
185 });
186 }
187
188 #[test]
189 fn database_error_displays_the_underlying_failure() {
190 run_isolated(|| {
191 let _runtime = runtime::init_runtime();
192 let error = establish("not-a-valid-database-url")
193 .expect_err("invalid database URLs should fail");
194 assert!(error.to_string().starts_with("database connection failed:"));
195 });
196 }
197
198 #[test]
199 fn with_db_can_return_a_computed_value() {
200 run_isolated(|| {
201 let _runtime = runtime::init_runtime();
202 establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
203 let is_sqlite =
204 with_db(|connection| connection.get_database_backend() == DatabaseBackend::Sqlite);
205 assert!(is_sqlite);
206 });
207 }
208}