1use crate::PostgresError;
2use crate::{TransactionManager, TransactionTrait};
3use async_trait::async_trait;
4use std::fmt::Debug;
5use std::str::FromStr;
6use std::sync::Arc;
7use testkit_core::{
8 DatabaseBackend, DatabaseConfig, DatabaseName, DatabasePool, TestDatabaseConnection,
9 TestDatabaseInstance,
10};
11use url::Url;
12
13#[derive(Clone)]
15pub struct PostgresConnection {
16 client: Arc<deadpool_postgres::Client>,
17 connection_string: String,
18}
19
20impl PostgresConnection {
21 pub async fn connect(connection_string: impl Into<String>) -> Result<Self, PostgresError> {
23 let connection_string = connection_string.into();
24
25 let pg_config = tokio_postgres::config::Config::from_str(&connection_string)
27 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
28
29 let mgr_config = deadpool_postgres::ManagerConfig {
31 recycling_method: deadpool_postgres::RecyclingMethod::Fast,
32 };
33 let mgr =
34 deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
35
36 let pool = deadpool_postgres::Pool::builder(mgr)
38 .max_size(1)
39 .build()
40 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
41
42 let client = pool
44 .get()
45 .await
46 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
47
48 Ok(Self {
49 client: Arc::new(client),
50 connection_string,
51 })
52 }
53
54 pub async fn with_connection<F, R, E>(
57 connection_string: impl Into<String>,
58 operation: F,
59 ) -> Result<R, PostgresError>
60 where
61 F: FnOnce(&PostgresConnection) -> futures::future::BoxFuture<'_, Result<R, E>>,
62 E: std::error::Error + Send + Sync + 'static,
63 {
64 let conn = Self::connect(connection_string).await?;
66
67 let result = operation(&conn)
69 .await
70 .map_err(|e| PostgresError::QueryError(e.to_string()))?;
71
72 Ok(result)
74 }
75
76 pub fn client(&self) -> &deadpool_postgres::Client {
78 &self.client
79 }
80}
81
82impl TestDatabaseConnection for PostgresConnection {
83 fn connection_string(&self) -> String {
84 self.connection_string.clone()
85 }
86}
87
88#[derive(Clone)]
90pub struct PostgresPool {
91 pool: Arc<deadpool_postgres::Pool>,
92 connection_string: String,
93}
94
95#[async_trait]
96impl DatabasePool for PostgresPool {
97 type Connection = PostgresConnection;
98 type Error = PostgresError;
99
100 async fn acquire(&self) -> Result<Self::Connection, Self::Error> {
101 let client = self
103 .pool
104 .get()
105 .await
106 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
107
108 Ok(PostgresConnection {
110 client: Arc::new(client),
111 connection_string: self.connection_string.clone(),
112 })
113 }
114
115 async fn release(&self, _conn: Self::Connection) -> Result<(), Self::Error> {
116 Ok(())
118 }
119
120 fn connection_string(&self) -> String {
121 self.connection_string.clone()
122 }
123}
124
125#[derive(Clone, Debug)]
127pub struct PostgresBackend {
128 config: DatabaseConfig,
129}
130
131#[async_trait]
132impl DatabaseBackend for PostgresBackend {
133 type Connection = PostgresConnection;
134 type Pool = PostgresPool;
135 type Error = PostgresError;
136
137 async fn new(config: DatabaseConfig) -> Result<Self, Self::Error> {
138 if config.admin_url.is_empty() || config.user_url.is_empty() {
140 return Err(PostgresError::ConfigError(
141 "Admin and user URLs must be provided".into(),
142 ));
143 }
144
145 Ok(Self { config })
146 }
147
148 async fn create_pool(
150 &self,
151 name: &DatabaseName,
152 _config: &DatabaseConfig,
153 ) -> Result<Self::Pool, Self::Error> {
154 let connection_string = self.connection_string(name);
156 let pg_config = tokio_postgres::config::Config::from_str(&connection_string)
157 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
158
159 let mgr_config = deadpool_postgres::ManagerConfig {
161 recycling_method: deadpool_postgres::RecyclingMethod::Fast,
162 };
163 let mgr =
164 deadpool_postgres::Manager::from_config(pg_config, tokio_postgres::NoTls, mgr_config);
165
166 let pool = deadpool_postgres::Pool::builder(mgr)
168 .max_size(20)
169 .build()
170 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
171
172 Ok(PostgresPool {
173 pool: Arc::new(pool),
174 connection_string,
175 })
176 }
177
178 async fn connect(&self, name: &DatabaseName) -> Result<Self::Connection, Self::Error> {
181 let connection_string = self.connection_string(name);
182
183 PostgresConnection::connect(connection_string).await
186 }
187
188 async fn connect_with_string(
190 &self,
191 connection_string: &str,
192 ) -> Result<Self::Connection, Self::Error> {
193 PostgresConnection::connect(connection_string).await
196 }
197
198 async fn create_database(
199 &self,
200 _pool: &Self::Pool,
201 name: &DatabaseName,
202 ) -> Result<(), Self::Error> {
203 let _admin_config = tokio_postgres::config::Config::from_str(&self.config.admin_url)
205 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
206
207 let (client, connection) =
208 tokio_postgres::connect(&self.config.admin_url, tokio_postgres::NoTls)
209 .await
210 .map_err(|e| PostgresError::ConnectionError(e.to_string()))?;
211
212 tokio::spawn(async move { if let Err(_e) = connection.await {} });
214
215 let db_name = name.as_str();
217 let create_query = format!("CREATE DATABASE \"{}\"", db_name);
218
219 client
220 .execute(&create_query, &[])
221 .await
222 .map_err(|e| PostgresError::DatabaseCreationError(e.to_string()))?;
223
224 Ok(())
225 }
226
227 fn drop_database(&self, name: &DatabaseName) -> Result<(), Self::Error> {
228 let url = match Url::parse(&self.config.admin_url) {
230 Ok(url) => url,
231 Err(e) => {
232 tracing::error!("Failed to parse admin URL: {}", e);
233 return Err(PostgresError::ConfigError(e.to_string()));
234 }
235 };
236
237 let database_name = name.as_str();
238 let test_user = url.username();
239
240 let database_host = format!(
242 "{}://{}:{}@{}:{}",
243 url.scheme(),
244 test_user,
245 url.password().unwrap_or(""),
246 url.host_str().unwrap_or("postgres"),
247 url.port().unwrap_or(5432)
248 );
249
250 let output = std::process::Command::new("psql")
252 .arg(&database_host)
253 .arg("-c")
254 .arg(format!("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid();", database_name))
255 .output();
256
257 if let Err(e) = output {
258 tracing::warn!(
259 "Failed to terminate connections to database {}: {}",
260 database_name,
261 e
262 );
263 }
265
266 let output = std::process::Command::new("psql")
268 .arg(&database_host)
269 .arg("-c")
270 .arg(format!("DROP DATABASE IF EXISTS \"{}\";", database_name))
271 .output();
272
273 match output {
274 Ok(output) => {
275 if output.status.success() {
276 tracing::info!("Successfully dropped database {}", name);
277 Ok(())
278 } else {
279 let stderr = String::from_utf8_lossy(&output.stderr);
280 tracing::error!("Failed to drop database {}: {}", name, stderr);
281 Err(PostgresError::DatabaseDropError(stderr.to_string()))
282 }
283 }
284 Err(e) => {
285 tracing::error!("Failed to execute psql command to drop {}: {}", name, e);
286 Err(PostgresError::DatabaseDropError(e.to_string()))
287 }
288 }
289 }
290
291 fn connection_string(&self, name: &DatabaseName) -> String {
292 let base_url = &self.config.user_url;
294
295 if let Some(db_pos) = base_url.rfind('/') {
297 let (prefix, _) = base_url.split_at(db_pos + 1);
298 return format!("{}{}", prefix, name.as_str());
299 }
300
301 format!("postgres://postgres/{}", name.as_str())
303 }
304}
305
306pub struct PostgresTransaction {
308 client: Arc<deadpool_postgres::Client>,
309}
310
311#[async_trait]
312impl TransactionTrait for PostgresTransaction {
313 type Error = PostgresError;
314
315 async fn commit(&mut self) -> Result<(), Self::Error> {
316 self.client
317 .execute("COMMIT", &[])
318 .await
319 .map_err(|e| PostgresError::TransactionError(e.to_string()))?;
320 Ok(())
321 }
322
323 async fn rollback(&mut self) -> Result<(), Self::Error> {
324 self.client
325 .execute("ROLLBACK", &[])
326 .await
327 .map_err(|e| PostgresError::TransactionError(e.to_string()))?;
328 Ok(())
329 }
330}
331
332#[async_trait]
334impl TransactionManager for TestDatabaseInstance<PostgresBackend> {
335 type Error = PostgresError;
336 type Tx = PostgresTransaction;
337 type Connection = PostgresConnection;
338
339 async fn begin_transaction(&mut self) -> Result<Self::Tx, Self::Error> {
340 let pool = &self.pool;
342 let client = pool.acquire().await?;
343
344 client
346 .client
347 .execute("BEGIN", &[])
348 .await
349 .map_err(|e| PostgresError::TransactionError(e.to_string()))?;
350
351 Ok(PostgresTransaction {
352 client: Arc::clone(&client.client),
353 })
354 }
355
356 async fn commit_transaction(tx: &mut Self::Tx) -> Result<(), Self::Error> {
357 tx.commit().await
358 }
359
360 async fn rollback_transaction(tx: &mut Self::Tx) -> Result<(), Self::Error> {
361 tx.rollback().await
362 }
363}
364
365pub async fn postgres_backend() -> Result<PostgresBackend, PostgresError> {
383 let config = DatabaseConfig::default();
384 PostgresBackend::new(config).await
385}
386
387pub async fn postgres_backend_with_config(
406 config: DatabaseConfig,
407) -> Result<PostgresBackend, PostgresError> {
408 PostgresBackend::new(config).await
409}