sqlx_db_tester/
postgres.rs

1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4    Connection, Executor, PgConnection, PgPool,
5    migrate::{MigrationSource, Migrator},
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestPg {
13    pub server_url: String,
14    pub dbname: String,
15    #[allow(dead_code)]
16    extensions: Vec<String>,
17}
18
19/// Builder for creating a TestPg instance with custom configuration.
20pub struct TestPgBuilder<S>
21where
22    S: MigrationSource<'static> + Send + Sync + 'static,
23{
24    database_url: String,
25    migrations: S,
26    extensions: Vec<String>,
27}
28
29impl<S> TestPgBuilder<S>
30where
31    S: MigrationSource<'static> + Send + Sync + 'static,
32{
33    /// Create a new TestPgBuilder with the given database URL and migrations.
34    pub fn new(database_url: String, migrations: S) -> Self {
35        Self {
36            database_url,
37            migrations,
38            extensions: vec![],
39        }
40    }
41
42    /// Add a list of PostgreSQL extensions to be installed before running migrations.
43    ///
44    /// # Example
45    /// ```no_run
46    /// use sqlx_db_tester::TestPgBuilder;
47    /// use std::path::Path;
48    ///
49    /// let tdb = TestPgBuilder::new(
50    ///     "postgres://postgres:postgres@localhost:5432".to_string(),
51    ///     Path::new("./fixtures/migrations")
52    /// )
53    /// .with_extensions(vec!["uuid-ossp".to_string(), "postgis".to_string()])
54    /// .build();
55    /// ```
56    pub fn with_extensions(mut self, extensions: Vec<String>) -> Self {
57        self.extensions = extensions;
58        self
59    }
60
61    /// Build and initialize the test database with the configured settings.
62    pub fn build(self) -> TestPg {
63        TestPg::new_with_extensions(self.database_url, self.migrations, self.extensions)
64    }
65}
66
67impl TestPg {
68    pub fn new<S>(database_url: String, migrations: S) -> Self
69    where
70        S: MigrationSource<'static> + Send + Sync + 'static,
71    {
72        Self::new_with_extensions(database_url, migrations, vec![])
73    }
74
75    fn new_with_extensions<S>(database_url: String, migrations: S, extensions: Vec<String>) -> Self
76    where
77        S: MigrationSource<'static> + Send + Sync + 'static,
78    {
79        let simple = Uuid::new_v4().simple();
80        let (server_url, dbname) = parse_postgres_url(&database_url);
81        let dbname = match dbname {
82            Some(db_name) => format!("{db_name}_test_{simple}"),
83            None => format!("test_{simple}"),
84        };
85        let dbname_cloned = dbname.clone();
86        let extensions_cloned = extensions.clone();
87
88        let tdb = Self {
89            server_url,
90            dbname,
91            extensions,
92        };
93
94        let url = tdb.url();
95
96        // create database dbname
97        thread::spawn(move || {
98            let rt = Runtime::new().unwrap();
99            rt.block_on(async move {
100                // use server url to create database
101                let mut conn = PgConnection::connect(&database_url)
102                    .await
103                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
104                conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
105                    .await
106                    .unwrap();
107
108                // now connect to test database for migration
109                let mut conn = PgConnection::connect(&url)
110                    .await
111                    .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
112
113                // create extensions before running migrations
114                for ext in &extensions_cloned {
115                    conn.execute(format!(r#"CREATE EXTENSION IF NOT EXISTS "{ext}""#).as_str())
116                        .await
117                        .unwrap_or_else(|_| panic!("Error while creating extension {ext}"));
118                }
119
120                let m = Migrator::new(migrations).await.unwrap();
121                m.run(&mut conn).await.unwrap();
122            });
123        })
124        .join()
125        .expect("failed to create database");
126
127        tdb
128    }
129
130    pub fn server_url(&self) -> String {
131        self.server_url.clone()
132    }
133
134    pub fn url(&self) -> String {
135        format!("{}/{}", self.server_url, self.dbname)
136    }
137
138    pub async fn get_pool(&self) -> PgPool {
139        let url = self.url();
140        PgPool::connect(&url)
141            .await
142            .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
143    }
144
145    pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
146        let pool = self.get_pool().await;
147        let path = filename.canonicalize()?;
148        let mut conn = pool.acquire().await?;
149        let sql = format!(
150            "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
151            table,
152            fields.join(","),
153            path.display()
154        );
155        conn.execute(sql.as_str()).await?;
156        // copy csv
157
158        Ok(())
159    }
160
161    pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
162        let mut rdr = csv::Reader::from_reader(csv.as_bytes());
163        let headers = rdr.headers()?.iter().join(",");
164        let mut tx = self.get_pool().await.begin().await?;
165        for result in rdr.records() {
166            let record = result?;
167            let sql = format!(
168                "INSERT INTO {} ({}) VALUES ({})",
169                table,
170                headers,
171                record.iter().map(|v| format!("'{v}'")).join(",")
172            );
173            tx.execute(sql.as_str()).await?;
174        }
175        tx.commit().await?;
176        Ok(())
177    }
178}
179
180impl Drop for TestPg {
181    fn drop(&mut self) {
182        let server_url = &self.server_url;
183        let database_url = format!("{server_url}/postgres");
184        let dbname = self.dbname.clone();
185        thread::spawn(move || {
186            let rt = Runtime::new().unwrap();
187            rt.block_on(async move {
188                    let mut conn = PgConnection::connect(&database_url).await
189                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
190                    // terminate existing connections
191                    sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
192                    .execute( &mut conn)
193                    .await
194                    .expect("Terminate all other connections");
195                    conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
196                        .await
197                        .expect("Error while querying the drop database");
198                });
199            })
200            .join()
201            .expect("failed to drop database");
202    }
203}
204
205impl Default for TestPg {
206    fn default() -> Self {
207        Self::new(
208            "postgres://postgres:postgres@localhost:5432".to_string(),
209            Path::new("./fixtures/migrations"),
210        )
211    }
212}
213
214fn parse_postgres_url(url: &str) -> (String, Option<String>) {
215    let url_without_protocol = url.trim_start_matches("postgres://");
216
217    let parts: Vec<&str> = url_without_protocol.split('/').collect();
218    let server_url = format!("postgres://{}", parts[0]);
219
220    let dbname = if parts.len() > 1 && !parts[1].is_empty() {
221        Some(parts[1].to_string())
222    } else {
223        None
224    };
225
226    (server_url, dbname)
227}
228#[cfg(test)]
229mod tests {
230    use std::env;
231
232    use crate::postgres::TestPg;
233    use anyhow::Result;
234
235    #[tokio::test]
236    async fn test_postgres_should_create_and_drop() {
237        let tdb = TestPg::default();
238        let pool = tdb.get_pool().await;
239        // insert todo
240        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
241            .execute(&pool)
242            .await
243            .unwrap();
244        // get todo
245        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
246            .fetch_one(&pool)
247            .await
248            .unwrap();
249        assert_eq!(id, 1);
250        assert_eq!(title, "test");
251    }
252
253    #[tokio::test]
254    #[ignore = "github action postgres server can't be used for this test"]
255    async fn test_postgres_should_load_csv() -> Result<()> {
256        let filename = env::current_dir()?.join("fixtures/todos.csv");
257        let tdb = TestPg::default();
258        tdb.load_csv("todos", &["title"], &filename).await?;
259        let pool = tdb.get_pool().await;
260        // get todo
261        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
262            .fetch_one(&pool)
263            .await
264            .unwrap();
265        assert_eq!(id, 1);
266        assert_eq!(title, "hello world");
267        Ok(())
268    }
269
270    #[tokio::test]
271    async fn test_postgres_should_load_csv_data() -> Result<()> {
272        let csv = include_str!("../fixtures/todos.csv");
273        let tdb = TestPg::default();
274        tdb.load_csv_data("todos", csv).await?;
275        let pool = tdb.get_pool().await;
276        // get todo
277        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
278            .fetch_one(&pool)
279            .await
280            .unwrap();
281        assert_eq!(id, 1);
282        assert_eq!(title, "hello world");
283        Ok(())
284    }
285    use super::*;
286
287    #[tokio::test]
288    async fn test_postgres_with_extensions() {
289        use crate::TestPgBuilder;
290
291        let tdb = TestPgBuilder::new(
292            "postgres://postgres:postgres@localhost:5432".to_string(),
293            Path::new("./fixtures/migrations"),
294        )
295        .with_extensions(vec!["uuid-ossp".to_string()])
296        .build();
297
298        let pool = tdb.get_pool().await;
299
300        // Verify the extension is installed by trying to use it
301        let result = sqlx::query_scalar::<_, String>("SELECT uuid_generate_v4()::text")
302            .fetch_one(&pool)
303            .await;
304
305        assert!(result.is_ok(), "uuid-ossp extension should be available");
306    }
307
308    #[test]
309    fn test_with_dbname() {
310        let url = "postgres://testuser:1@localhost/pureya";
311        let (server_url, dbname) = parse_postgres_url(url);
312        assert_eq!(server_url, "postgres://testuser:1@localhost");
313        assert_eq!(dbname, Some("pureya".to_string()));
314    }
315
316    #[test]
317    fn test_without_dbname() {
318        let url = "postgres://testuser:1@localhost";
319        let (server_url, dbname) = parse_postgres_url(url);
320        assert_eq!(server_url, "postgres://testuser:1@localhost");
321        assert_eq!(dbname, None);
322    }
323}