testkit_core/testdb/
test_database.rs

1use async_trait::async_trait;
2use parking_lot::Mutex;
3use std::fmt::{Debug, Display};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use uuid::Uuid;
8
9/// Configuration for database connections
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct DatabaseConfig {
12    /// Connection string for admin operations (schema changes, etc.)
13    pub admin_url: String,
14    /// Connection string for regular operations
15    pub user_url: String,
16    /// Maximum number of connections to the database
17    pub max_connections: Option<usize>,
18}
19
20impl Default for DatabaseConfig {
21    fn default() -> Self {
22        Self::from_env().unwrap_or_else(|e| {
23            panic!("Failed to create DatabaseConfig: {}", e);
24        })
25    }
26}
27
28impl DatabaseConfig {
29    /// Create a new configuration with explicit connection strings
30    pub fn new(admin_url: impl Into<String>, user_url: impl Into<String>) -> Self {
31        Self {
32            admin_url: admin_url.into(),
33            user_url: user_url.into(),
34            max_connections: None,
35        }
36    }
37
38    /// Get a configuration from environment variables
39    /// Uses ADMIN_DATABASE_URL and DATABASE_URL
40    pub fn from_env() -> std::result::Result<Self, std::env::VarError> {
41        #[cfg(feature = "dotenvy")]
42        let _ = dotenvy::from_filename(".env");
43        let user_url = std::env::var("DATABASE_URL")?;
44        let admin_url = std::env::var("ADMIN_DATABASE_URL").unwrap_or(user_url.clone());
45        Ok(Self::new(admin_url, user_url))
46    }
47}
48
49/// A unique database name
50#[derive(Debug, Clone)]
51pub struct DatabaseName(String);
52
53impl DatabaseName {
54    /// Create a new unique database name with an optional prefix
55    pub fn new(prefix: Option<&str>) -> Self {
56        let uuid = Uuid::new_v4();
57        let safe_uuid = uuid.to_string().replace('-', "_");
58        Self(format!("{}_{}", prefix.unwrap_or("testkit"), safe_uuid))
59    }
60
61    /// Get the database name as a string
62    pub fn as_str(&self) -> &str {
63        &self.0
64    }
65}
66
67impl Display for DatabaseName {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73pub trait TestDatabaseConnection {
74    fn connection_string(&self) -> String;
75}
76
77#[async_trait]
78pub trait DatabasePool: Send + Sync + Clone {
79    type Connection: Send + Sync + TestDatabaseConnection;
80    type Error: Send + Sync + From<String> + Display + Debug;
81
82    async fn acquire(&self) -> Result<Self::Connection, Self::Error>;
83    async fn release(&self, conn: Self::Connection) -> Result<(), Self::Error>;
84    fn connection_string(&self) -> String;
85}
86
87/// Trait defining a test database abstraction
88#[async_trait]
89pub trait DatabaseBackend: Send + Sync + Clone + Debug {
90    type Connection: Send + Sync + Clone;
91    type Pool: Send + Sync + DatabasePool<Connection = Self::Connection, Error = Self::Error>;
92    type Error: Send + Sync + Clone + From<String> + Display + Debug;
93
94    async fn new(config: DatabaseConfig) -> Result<Self, Self::Error>;
95
96    /// Create a new connection pool for the given database
97    async fn create_pool(
98        &self,
99        name: &DatabaseName,
100        config: &DatabaseConfig,
101    ) -> Result<Self::Pool, Self::Error>;
102
103    /// Create a single connection to the given database
104    /// This is useful for cases where a full pool is not needed
105    async fn connect(&self, name: &DatabaseName) -> Result<Self::Connection, Self::Error> {
106        // Default implementation connects using the connection string for the given database name
107        let connection_string = self.connection_string(name);
108        self.connect_with_string(&connection_string).await
109    }
110
111    /// Create a single connection using a connection string directly
112    /// This is useful for connecting to databases that may not have been created by TestKit
113    async fn connect_with_string(
114        &self,
115        connection_string: &str,
116    ) -> Result<Self::Connection, Self::Error>;
117
118    /// Create a new database with the given name
119    async fn create_database(
120        &self,
121        pool: &Self::Pool,
122        name: &DatabaseName,
123    ) -> Result<(), Self::Error>;
124
125    /// Drop a database with the given name
126    fn drop_database(&self, name: &DatabaseName) -> Result<(), Self::Error>;
127
128    /// Get the connection string for the given database
129    fn connection_string(&self, name: &DatabaseName) -> String;
130}
131
132/// A test database that handles setup, connections, and cleanup
133/// TODO: Create a TestManager that can handle connection pooling and cleanup
134#[derive(Clone)]
135pub struct TestDatabaseInstance<B>
136where
137    B: DatabaseBackend + 'static + Clone + Debug + Send + Sync,
138{
139    /// The database backend
140    pub backend: B,
141    /// The connection pool
142    pub pool: B::Pool,
143    /// The database name
144    pub db_name: DatabaseName,
145    /// The connection pool
146    pub connection_pool: Option<Arc<Mutex<Vec<B::Connection>>>>,
147}
148
149impl<B> Debug for TestDatabaseInstance<B>
150where
151    B: DatabaseBackend + 'static + Clone + Debug + Send + Sync,
152{
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        write!(
155            f,
156            "TestDatabaseInstance {{ backend: {:?}, db_name: {:?} }}",
157            self.backend, self.db_name
158        )
159    }
160}
161
162impl<B> TestDatabaseInstance<B>
163where
164    B: DatabaseBackend + 'static + Clone + Debug + Send + Sync,
165{
166    /// Create a new test database with the given backend
167    pub async fn new(backend: B, config: DatabaseConfig) -> Result<Self, B::Error> {
168        // Generate unique name
169        let db_name = DatabaseName::new(None);
170
171        tracing::debug!("Creating connection pool for database: {}", db_name);
172        let pool = backend.create_pool(&db_name, &config).await?;
173
174        tracing::debug!("Creating database: {}", db_name);
175        backend.create_database(&pool, &db_name).await?;
176
177        let inst = Self {
178            backend,
179            pool,
180            db_name,
181            connection_pool: None,
182        };
183
184        Ok(inst)
185    }
186
187    /// Create a new test database with the given backend and specific name
188    pub async fn new_with_name(
189        backend: B,
190        config: DatabaseConfig,
191        db_name: DatabaseName,
192    ) -> Result<Self, B::Error> {
193        tracing::debug!("Creating connection pool for database: {}", db_name);
194        let pool = backend.create_pool(&db_name, &config).await?;
195
196        tracing::debug!("Creating database: {}", db_name);
197        backend.create_database(&pool, &db_name).await?;
198
199        let inst = Self {
200            backend,
201            pool,
202            db_name,
203            connection_pool: None,
204        };
205
206        Ok(inst)
207    }
208
209    /// Returns a reference to the backend
210    pub fn backend(&self) -> &B {
211        &self.backend
212    }
213
214    /// Returns a reference to the database name
215    pub fn name(&self) -> &DatabaseName {
216        &self.db_name
217    }
218
219    /// Create a single connection to the database without using the pool
220    /// This is useful for cases where a single connection is needed for a specific operation
221    pub async fn connect(&self) -> Result<B::Connection, B::Error> {
222        self.backend.connect(&self.db_name).await
223    }
224
225    /// Execute a function with a one-off connection and automatically close it after use
226    /// This is the most efficient way to perform a one-off database operation
227    pub async fn with_connection<F, R, E>(&self, operation: F) -> Result<R, B::Error>
228    where
229        F: FnOnce(&B::Connection) -> Pin<Box<dyn Future<Output = Result<R, E>> + Send>> + Send,
230        E: std::error::Error + Send + Sync + 'static,
231        B::Error: From<E>,
232    {
233        // Create a connection
234        let conn = self.connect().await?;
235
236        // Run the operation
237        let result = operation(&conn).await.map_err(|e| B::Error::from(e))?;
238
239        // Connection will be dropped automatically when it goes out of scope
240        Ok(result)
241    }
242
243    /// Get a connection from the pool or acquire a new one
244    pub async fn acquire_connection(
245        &self,
246    ) -> Result<<B::Pool as DatabasePool>::Connection, B::Error> {
247        let conn = match &self.connection_pool {
248            Some(pool) => {
249                let mut guard = pool.lock();
250                let conn = guard
251                    .pop()
252                    .ok_or(B::Error::from("No connection available".to_string()))?;
253                drop(guard);
254                conn
255            }
256            None => self.pool.acquire().await?,
257        };
258
259        Ok(conn)
260    }
261
262    /// Release a connection back to the pool
263    pub async fn release_connection(
264        &self,
265        conn: <B::Pool as DatabasePool>::Connection,
266    ) -> Result<(), B::Error> {
267        if let Some(pool) = &self.connection_pool {
268            pool.lock().push(conn);
269        }
270
271        Ok(())
272    }
273
274    /// Setup the database with a function
275    /// The connection handling approach needs to match the expected B::Connection type
276    pub async fn setup<F, Fut>(&self, setup_fn: F) -> Result<(), B::Error>
277    where
278        F: FnOnce(&mut <B::Pool as DatabasePool>::Connection) -> Fut + Send,
279        Fut: std::future::Future<Output = Result<(), B::Error>> + Send,
280    {
281        // Get a connection from the pool
282        let mut conn = self.acquire_connection().await?;
283
284        // Call the setup function with a mutable reference to the connection
285        let result = setup_fn(&mut conn).await;
286
287        // Return the connection to the pool if we have one
288        if let Some(pool) = &self.connection_pool {
289            pool.lock().push(conn);
290        }
291
292        result
293    }
294}
295
296impl<B> Drop for TestDatabaseInstance<B>
297where
298    B: DatabaseBackend + Clone + Debug + Send + Sync + 'static,
299{
300    fn drop(&mut self) {
301        let name = self.db_name.clone();
302
303        if let Err(err) = self.backend.drop_database(&name) {
304            tracing::error!("Failed to drop database {}: {}", name, err);
305        } else {
306            tracing::info!("Successfully dropped database {} during Drop", name);
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[tokio::test]
316    async fn test_database_name() {
317        let name = DatabaseName::new(None);
318        assert_ne!(name.as_str(), "");
319    }
320}