testkit_postgres/
tokio_postgres.rs

1use crate::PostgresError;
2use crate::{TransactionManager, TransactionTrait};
3use async_trait::async_trait;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7use testkit_core::{
8    DatabaseBackend, DatabaseConfig, DatabaseName, DatabasePool, TestDatabaseConnection,
9    TestDatabaseInstance,
10};
11use url::Url;
12
13/// A connection to a PostgreSQL database using tokio-postgres
14#[derive(Clone)]
15pub struct PostgresConnection {
16    client: Arc<deadpool_postgres::Client>,
17    connection_string: String,
18}
19
20impl PostgresConnection {
21    /// Create a new direct connection without using a pool
22    pub async fn connect(connection_string: impl Into<String>) -> Result<Self, PostgresError> {
23        let connection_string = connection_string.into();
24
25        // Parse connection config
26        let pg_config = tokio_postgres::config::Config::from_str(&connection_string)
27            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
28
29        // Create a minimal pool manager
30        let mgr_config = deadpool_postgres::ManagerConfig {
31            recycling_method: deadpool_postgres::RecyclingMethod::Fast,
32        };
33        let mgr =
34            deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
35
36        // Create a minimal pool with a single connection
37        let pool = deadpool_postgres::Pool::builder(mgr)
38            .max_size(1)
39            .build()
40            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
41
42        // Get a client
43        let client = pool
44            .get()
45            .await
46            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
47
48        Ok(Self {
49            client: Arc::new(client),
50            connection_string,
51        })
52    }
53
54    /// Execute a function with a direct connection and automatically close it after use
55    /// This is the most efficient way to perform a one-off database operation
56    pub async fn with_connection<F, R, E>(
57        connection_string: impl Into<String>,
58        operation: F,
59    ) -> Result<R, PostgresError>
60    where
61        F: FnOnce(&PostgresConnection) -> futures::future::BoxFuture<'_, Result<R, E>>,
62        E: std::error::Error + Send + Sync + 'static,
63    {
64        // Create a connection
65        let conn = Self::connect(connection_string).await?;
66
67        // Run the operation
68        let result = operation(&conn)
69            .await
70            .map_err(|e| PostgresError::QueryError(e.to_string()))?;
71
72        // Connection will be dropped automatically when it goes out of scope
73        Ok(result)
74    }
75
76    /// Get a reference to the underlying database client
77    pub fn client(&self) -> &deadpool_postgres::Client {
78        &self.client
79    }
80}
81
82impl TestDatabaseConnection for PostgresConnection {
83    fn connection_string(&self) -> String {
84        self.connection_string.clone()
85    }
86}
87
88/// A connection pool for PostgreSQL using deadpool-postgres
89#[derive(Clone)]
90pub struct PostgresPool {
91    pool: Arc<deadpool_postgres::Pool>,
92    connection_string: String,
93}
94
95#[async_trait]
96impl DatabasePool for PostgresPool {
97    type Connection = PostgresConnection;
98    type Error = PostgresError;
99
100    async fn acquire(&self) -> Result<Self::Connection, Self::Error> {
101        // Get a connection from the pool
102        let client = self
103            .pool
104            .get()
105            .await
106            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
107
108        // Return a new PostgresConnection
109        Ok(PostgresConnection {
110            client: Arc::new(client),
111            connection_string: self.connection_string.clone(),
112        })
113    }
114
115    async fn release(&self, _conn: Self::Connection) -> Result<(), Self::Error> {
116        // The deadpool automatically handles connection release when the client is dropped
117        Ok(())
118    }
119
120    fn connection_string(&self) -> String {
121        self.connection_string.clone()
122    }
123}
124
125/// A PostgreSQL database backend using tokio-postgres
126#[derive(Clone, Debug)]
127pub struct PostgresBackend {
128    config: DatabaseConfig,
129}
130
131#[async_trait]
132impl DatabaseBackend for PostgresBackend {
133    type Connection = PostgresConnection;
134    type Pool = PostgresPool;
135    type Error = PostgresError;
136
137    async fn new(config: DatabaseConfig) -> Result<Self, Self::Error> {
138        // Validate the config
139        if config.admin_url.is_empty() || config.user_url.is_empty() {
140            return Err(PostgresError::ConfigError(
141                "Admin and user URLs must be provided".into(),
142            ));
143        }
144
145        Ok(Self { config })
146    }
147
148    /// Create a new connection pool for the given database
149    async fn create_pool(
150        &self,
151        name: &DatabaseName,
152        _config: &DatabaseConfig,
153    ) -> Result<Self::Pool, Self::Error> {
154        // Create connection config from the URL
155        let connection_string = self.connection_string(name);
156        let pg_config = tokio_postgres::config::Config::from_str(&connection_string)
157            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
158
159        // Create deadpool manager
160        let mgr_config = deadpool_postgres::ManagerConfig {
161            recycling_method: deadpool_postgres::RecyclingMethod::Fast,
162        };
163        let mgr =
164            deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
165
166        // Create the pool
167        let pool = deadpool_postgres::Pool::builder(mgr)
168            .max_size(20)
169            .build()
170            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
171
172        Ok(PostgresPool {
173            pool: Arc::new(pool),
174            connection_string,
175        })
176    }
177
178    /// Create a single connection to the given database
179    /// This is useful for cases where a full pool is not needed
180    async fn connect(&self, name: &DatabaseName) -> Result<Self::Connection, Self::Error> {
181        let connection_string = self.connection_string(name);
182
183        // Use the direct connection method we defined on PostgresConnection
184        // This is more efficient as it avoids pool overhead for one-off connections
185        PostgresConnection::connect(connection_string).await
186    }
187
188    /// Create a single connection using a connection string directly
189    async fn connect_with_string(
190        &self,
191        connection_string: &str,
192    ) -> Result<Self::Connection, Self::Error> {
193        // Use the direct connection method we defined on PostgresConnection
194        // This is more efficient as it avoids pool overhead for one-off connections
195        PostgresConnection::connect(connection_string).await
196    }
197
198    async fn create_database(
199        &self,
200        _pool: &Self::Pool,
201        name: &DatabaseName,
202    ) -> Result<(), Self::Error> {
203        // Create admin connection to create the database
204        let _admin_config = tokio_postgres::config::Config::from_str(&self.config.admin_url)
205            .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
206
207        let (client, connection) =
208            tokio_postgres::connect(&self.config.admin_url, tokio_postgres::NoTls)
209                .await
210                .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
211
212        // Spawn the connection handler
213        tokio::spawn(async move { if let Err(_e) = connection.await {} });
214
215        // Create the database
216        let db_name = name.as_str();
217        let create_query = format!("CREATE DATABASE \"{}\"", db_name);
218
219        client
220            .execute(&create_query, &[])
221            .await
222            .map_err(|e| PostgresError::DatabaseCreationError(e.to_string()))?;
223
224        Ok(())
225    }
226
227    fn drop_database(&self, name: &DatabaseName) -> Result<(), Self::Error> {
228        // Parse the admin URL to extract connection parameters
229        let url = match Url::parse(&self.config.admin_url) {
230            Ok(url) => url,
231            Err(e) => {
232                tracing::error!("Failed to parse admin URL: {}", e);
233                return Err(PostgresError::ConfigError(e.to_string()));
234            }
235        };
236
237        let database_name = name.as_str();
238        let test_user = url.username();
239
240        // Format the connection string for the admin database
241        let database_host = format!(
242            "{}://{}:{}@{}:{}",
243            url.scheme(),
244            test_user,
245            url.password().unwrap_or(""),
246            url.host_str().unwrap_or("postgres"),
247            url.port().unwrap_or(5432)
248        );
249
250        // First, terminate all connections to the database
251        let output = std::process::Command::new("psql")
252            .arg(&database_host)
253            .arg("-c")
254            .arg(format!("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid();", database_name))
255            .output();
256
257        if let Err(e) = output {
258            tracing::warn!(
259                "Failed to terminate connections to database {}: {}",
260                database_name,
261                e
262            );
263            // Continue with drop attempt even if termination fails
264        }
265
266        // Now drop the database
267        let output = std::process::Command::new("psql")
268            .arg(&database_host)
269            .arg("-c")
270            .arg(format!("DROP DATABASE IF EXISTS \"{}\";", database_name))
271            .output();
272
273        match output {
274            Ok(output) => {
275                if output.status.success() {
276                    tracing::info!("Successfully dropped database {}", name);
277                    Ok(())
278                } else {
279                    let stderr = String::from_utf8_lossy(&output.stderr);
280                    tracing::error!("Failed to drop database {}: {}", name, stderr);
281                    Err(PostgresError::DatabaseDropError(stderr.to_string()))
282                }
283            }
284            Err(e) => {
285                tracing::error!("Failed to execute psql command to drop {}: {}", name, e);
286                Err(PostgresError::DatabaseDropError(e.to_string()))
287            }
288        }
289    }
290
291    fn connection_string(&self, name: &DatabaseName) -> String {
292        // Parse the base URL and replace the database name
293        let base_url = &self.config.user_url;
294
295        // Simple string replacement to change the database name
296        if let Some(db_pos) = base_url.rfind('/') {
297            let (prefix, _) = base_url.split_at(db_pos + 1);
298            return format!("{}{}", prefix, name.as_str());
299        }
300
301        // Fallback
302        format!("postgres://postgres/{}", name.as_str())
303    }
304}
305
306/// A PostgreSQL transaction using tokio-postgres
307pub struct PostgresTransaction {
308    client: Arc<deadpool_postgres::Client>,
309}
310
311#[async_trait]
312impl TransactionTrait for PostgresTransaction {
313    type Error = PostgresError;
314
315    async fn commit(&mut self) -> Result<(), Self::Error> {
316        self.client
317            .execute("COMMIT", &[])
318            .await
319            .map_err(|e| PostgresError::TransactionError(e.to_string()))?;
320        Ok(())
321    }
322
323    async fn rollback(&mut self) -> Result<(), Self::Error> {
324        self.client
325            .execute("ROLLBACK", &[])
326            .await
327            .map_err(|e| PostgresError::TransactionError(e.to_string()))?;
328        Ok(())
329    }
330}
331
332/// Implementation of TransactionManager for PostgreSQL
333#[async_trait]
334impl TransactionManager for TestDatabaseInstance<PostgresBackend> {
335    type Error = PostgresError;
336    type Tx = PostgresTransaction;
337    type Connection = PostgresConnection;
338
339    async fn begin_transaction(&mut self) -> Result<Self::Tx, Self::Error> {
340        // Get a connection from the pool
341        let pool = &self.pool;
342        let client = pool.acquire().await?;
343
344        // Begin transaction
345        client
346            .client
347            .execute("BEGIN", &[])
348            .await
349            .map_err(|e| PostgresError::TransactionError(e.to_string()))?;
350
351        Ok(PostgresTransaction {
352            client: Arc::clone(&client.client),
353        })
354    }
355
356    async fn commit_transaction(tx: &mut Self::Tx) -> Result<(), Self::Error> {
357        tx.commit().await
358    }
359
360    async fn rollback_transaction(tx: &mut Self::Tx) -> Result<(), Self::Error> {
361        tx.rollback().await
362    }
363}
364
365/// Create a new PostgreSQL backend from environment variables
366///
367/// This function can be used to create a backend that can be passed into `with_database()`
368///
369/// # Example
370/// ```no_run
371/// use testkit_postgres::postgres_backend;
372/// use testkit_core::with_database;
373///
374/// async fn test() {
375///     let backend = postgres_backend().await.unwrap();
376///     let context = with_database(backend)
377///         .execute()
378///         .await
379///         .unwrap();
380/// }
381/// ```
382pub async fn postgres_backend() -> Result<PostgresBackend, PostgresError> {
383    let config = DatabaseConfig::default();
384    PostgresBackend::new(config).await
385}
386
387/// Create a new PostgreSQL backend with a custom config
388///
389/// This function can be used to create a backend that can be passed into `with_database()`
390///
391/// # Example
392/// ```no_run
393/// use testkit_postgres::{postgres_backend_with_config, DatabaseConfig};
394/// use testkit_core::with_database;
395///
396/// async fn test() {
397///     let config = DatabaseConfig::new("postgres://admin@postgres/postgres", "postgres://user@postgres/postgres");
398///     let backend = postgres_backend_with_config(config).await.unwrap();
399///     let context = with_database(backend)
400///         .execute()
401///         .await
402///         .unwrap();
403/// }
404/// ```
405pub async fn postgres_backend_with_config(
406    config: DatabaseConfig,
407) -> Result<PostgresBackend, PostgresError> {
408    PostgresBackend::new(config).await
409}