1use async_trait::async_trait;
2use parking_lot::Mutex;
3use std::fmt::{Debug, Display};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct DatabaseConfig {
12 pub admin_url: String,
14 pub user_url: String,
16 pub max_connections: Option<usize>,
18}
19
20impl Default for DatabaseConfig {
21 fn default() -> Self {
22 Self::from_env().unwrap_or_else(|e| {
23 panic!("Failed to create DatabaseConfig: {}", e);
24 })
25 }
26}
27
28impl DatabaseConfig {
29 pub fn new(admin_url: impl Into<String>, user_url: impl Into<String>) -> Self {
31 Self {
32 admin_url: admin_url.into(),
33 user_url: user_url.into(),
34 max_connections: None,
35 }
36 }
37
38 pub fn from_env() -> std::result::Result<Self, std::env::VarError> {
41 #[cfg(feature = "dotenvy")]
42 let _ = dotenvy::from_filename(".env");
43 let user_url = std::env::var("DATABASE_URL")?;
44 let admin_url = std::env::var("ADMIN_DATABASE_URL").unwrap_or(user_url.clone());
45 Ok(Self::new(admin_url, user_url))
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct DatabaseName(String);
52
53impl DatabaseName {
54 pub fn new(prefix: Option<&str>) -> Self {
56 let uuid = Uuid::new_v4();
57 let safe_uuid = uuid.to_string().replace('-', "_");
58 Self(format!("{}_{}", prefix.unwrap_or("testkit"), safe_uuid))
59 }
60
61 pub fn as_str(&self) -> &str {
63 &self.0
64 }
65}
66
67impl Display for DatabaseName {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73pub trait TestDatabaseConnection {
74 fn connection_string(&self) -> String;
75}
76
77#[async_trait]
78pub trait DatabasePool: Send + Sync + Clone {
79 type Connection: Send + Sync + TestDatabaseConnection;
80 type Error: Send + Sync + From<String> + Display + Debug;
81
82 async fn acquire(&self) -> Result<Self::Connection, Self::Error>;
83 async fn release(&self, conn: Self::Connection) -> Result<(), Self::Error>;
84 fn connection_string(&self) -> String;
85}
86
87#[async_trait]
89pub trait DatabaseBackend: Send + Sync + Clone + Debug {
90 type Connection: Send + Sync + Clone;
91 type Pool: Send + Sync + DatabasePool<Connection = Self::Connection, Error = Self::Error>;
92 type Error: Send + Sync + Clone + From<String> + Display + Debug;
93
94 async fn new(config: DatabaseConfig) -> Result<Self, Self::Error>;
95
96 async fn create_pool(
98 &self,
99 name: &DatabaseName,
100 config: &DatabaseConfig,
101 ) -> Result<Self::Pool, Self::Error>;
102
103 async fn connect(&self, name: &DatabaseName) -> Result<Self::Connection, Self::Error> {
106 let connection_string = self.connection_string(name);
108 self.connect_with_string(&connection_string).await
109 }
110
111 async fn connect_with_string(
114 &self,
115 connection_string: &str,
116 ) -> Result<Self::Connection, Self::Error>;
117
118 async fn create_database(
120 &self,
121 pool: &Self::Pool,
122 name: &DatabaseName,
123 ) -> Result<(), Self::Error>;
124
125 fn drop_database(&self, name: &DatabaseName) -> Result<(), Self::Error>;
127
128 fn connection_string(&self, name: &DatabaseName) -> String;
130}
131
132#[derive(Clone)]
135pub struct TestDatabaseInstance<B>
136where
137 B: DatabaseBackend + 'static + Clone + Debug + Send + Sync,
138{
139 pub backend: B,
141 pub pool: B::Pool,
143 pub db_name: DatabaseName,
145 pub connection_pool: Option<Arc<Mutex<Vec<B::Connection>>>>,
147}
148
149impl<B> Debug for TestDatabaseInstance<B>
150where
151 B: DatabaseBackend + 'static + Clone + Debug + Send + Sync,
152{
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 write!(
155 f,
156 "TestDatabaseInstance {{ backend: {:?}, db_name: {:?} }}",
157 self.backend, self.db_name
158 )
159 }
160}
161
162impl<B> TestDatabaseInstance<B>
163where
164 B: DatabaseBackend + 'static + Clone + Debug + Send + Sync,
165{
166 pub async fn new(backend: B, config: DatabaseConfig) -> Result<Self, B::Error> {
168 let db_name = DatabaseName::new(None);
170
171 tracing::debug!("Creating connection pool for database: {}", db_name);
172 let pool = backend.create_pool(&db_name, &config).await?;
173
174 tracing::debug!("Creating database: {}", db_name);
175 backend.create_database(&pool, &db_name).await?;
176
177 let inst = Self {
178 backend,
179 pool,
180 db_name,
181 connection_pool: None,
182 };
183
184 Ok(inst)
185 }
186
187 pub async fn new_with_name(
189 backend: B,
190 config: DatabaseConfig,
191 db_name: DatabaseName,
192 ) -> Result<Self, B::Error> {
193 tracing::debug!("Creating connection pool for database: {}", db_name);
194 let pool = backend.create_pool(&db_name, &config).await?;
195
196 tracing::debug!("Creating database: {}", db_name);
197 backend.create_database(&pool, &db_name).await?;
198
199 let inst = Self {
200 backend,
201 pool,
202 db_name,
203 connection_pool: None,
204 };
205
206 Ok(inst)
207 }
208
209 pub fn backend(&self) -> &B {
211 &self.backend
212 }
213
214 pub fn name(&self) -> &DatabaseName {
216 &self.db_name
217 }
218
219 pub async fn connect(&self) -> Result<B::Connection, B::Error> {
222 self.backend.connect(&self.db_name).await
223 }
224
225 pub async fn with_connection<F, R, E>(&self, operation: F) -> Result<R, B::Error>
228 where
229 F: FnOnce(&B::Connection) -> Pin<Box<dyn Future<Output = Result<R, E>> + Send>> + Send,
230 E: std::error::Error + Send + Sync + 'static,
231 B::Error: From<E>,
232 {
233 let conn = self.connect().await?;
235
236 let result = operation(&conn).await.map_err(|e| B::Error::from(e))?;
238
239 Ok(result)
241 }
242
243 pub async fn acquire_connection(
245 &self,
246 ) -> Result<<B::Pool as DatabasePool>::Connection, B::Error> {
247 let conn = match &self.connection_pool {
248 Some(pool) => {
249 let mut guard = pool.lock();
250 let conn = guard
251 .pop()
252 .ok_or(B::Error::from("No connection available".to_string()))?;
253 drop(guard);
254 conn
255 }
256 None => self.pool.acquire().await?,
257 };
258
259 Ok(conn)
260 }
261
262 pub async fn release_connection(
264 &self,
265 conn: <B::Pool as DatabasePool>::Connection,
266 ) -> Result<(), B::Error> {
267 if let Some(pool) = &self.connection_pool {
268 pool.lock().push(conn);
269 }
270
271 Ok(())
272 }
273
274 pub async fn setup<F, Fut>(&self, setup_fn: F) -> Result<(), B::Error>
277 where
278 F: FnOnce(&mut <B::Pool as DatabasePool>::Connection) -> Fut + Send,
279 Fut: std::future::Future<Output = Result<(), B::Error>> + Send,
280 {
281 let mut conn = self.acquire_connection().await?;
283
284 let result = setup_fn(&mut conn).await;
286
287 if let Some(pool) = &self.connection_pool {
289 pool.lock().push(conn);
290 }
291
292 result
293 }
294}
295
296impl<B> Drop for TestDatabaseInstance<B>
297where
298 B: DatabaseBackend + Clone + Debug + Send + Sync + 'static,
299{
300 fn drop(&mut self) {
301 let name = self.db_name.clone();
302
303 if let Err(err) = self.backend.drop_database(&name) {
304 tracing::error!("Failed to drop database {}: {}", name, err);
305 } else {
306 tracing::info!("Successfully dropped database {} during Drop", name);
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[tokio::test]
316 async fn test_database_name() {
317 let name = DatabaseName::new(None);
318 assert_ne!(name.as_str(), "");
319 }
320}