sqlx_db_tester/
postgres.rs

1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4    Connection, Executor, PgConnection, PgPool,
5    migrate::{MigrationSource, Migrator},
6};
7use std::{
8    fs,
9    path::{Path, PathBuf},
10    thread,
11};
12use tokio::runtime::Runtime;
13use uuid::Uuid;
14
15#[derive(Debug)]
16pub struct TestPg {
17    pub server_url: String,
18    pub dbname: String,
19    #[allow(dead_code)]
20    extensions: Vec<String>,
21}
22
23/// Builder for creating a TestPg instance with custom configuration.
24pub struct TestPgBuilder<S>
25where
26    S: MigrationSource<'static> + Send + Sync + 'static,
27{
28    database_url: String,
29    migrations: S,
30    extensions: Vec<String>,
31    seeds_path: Option<PathBuf>,
32}
33
34impl<S> TestPgBuilder<S>
35where
36    S: MigrationSource<'static> + Send + Sync + 'static,
37{
38    /// Create a new TestPgBuilder with the given database URL and migrations.
39    pub fn new(database_url: String, migrations: S) -> Self {
40        Self {
41            database_url,
42            migrations,
43            extensions: vec![],
44            seeds_path: None,
45        }
46    }
47
48    /// Add a list of PostgreSQL extensions to be installed before running migrations.
49    ///
50    /// # Example
51    /// ```no_run
52    /// use sqlx_db_tester::TestPgBuilder;
53    /// use std::path::Path;
54    ///
55    /// let tdb = TestPgBuilder::new(
56    ///     "postgres://postgres:postgres@localhost:5432".to_string(),
57    ///     Path::new("./fixtures/migrations")
58    /// )
59    /// .with_extensions(vec!["uuid-ossp".to_string(), "postgis".to_string()])
60    /// .build();
61    /// ```
62    pub fn with_extensions(mut self, extensions: Vec<String>) -> Self {
63        self.extensions = extensions;
64        self
65    }
66
67    /// Add a path to a directory containing seed SQL files.
68    ///
69    /// Seed files should be named with the pattern `<timestamp>_<description>.sql`
70    /// (e.g., `20240101120000_initial_data.sql`). They will be executed in
71    /// timestamp order after migrations are complete.
72    ///
73    /// # Example
74    /// ```no_run
75    /// use sqlx_db_tester::TestPgBuilder;
76    /// use std::path::Path;
77    ///
78    /// let tdb = TestPgBuilder::new(
79    ///     "postgres://postgres:postgres@localhost:5432".to_string(),
80    ///     Path::new("./fixtures/migrations")
81    /// )
82    /// .with_seeds(Path::new("./fixtures/seeds"))
83    /// .build();
84    /// ```
85    pub fn with_seeds<P: AsRef<Path>>(mut self, seeds_path: P) -> Self {
86        self.seeds_path = Some(seeds_path.as_ref().to_path_buf());
87        self
88    }
89
90    /// Build and initialize the test database with the configured settings.
91    pub fn build(self) -> TestPg {
92        TestPg::new_with_config(
93            self.database_url,
94            self.migrations,
95            self.extensions,
96            self.seeds_path,
97        )
98    }
99}
100
101impl TestPg {
102    pub fn new<S>(database_url: String, migrations: S) -> Self
103    where
104        S: MigrationSource<'static> + Send + Sync + 'static,
105    {
106        Self::new_with_config(database_url, migrations, vec![], None)
107    }
108
109    #[allow(dead_code)]
110    fn new_with_extensions<S>(database_url: String, migrations: S, extensions: Vec<String>) -> Self
111    where
112        S: MigrationSource<'static> + Send + Sync + 'static,
113    {
114        Self::new_with_config(database_url, migrations, extensions, None)
115    }
116
117    fn new_with_config<S>(
118        database_url: String,
119        migrations: S,
120        extensions: Vec<String>,
121        seeds_path: Option<PathBuf>,
122    ) -> Self
123    where
124        S: MigrationSource<'static> + Send + Sync + 'static,
125    {
126        let simple = Uuid::new_v4().simple();
127        let (server_url, dbname) = parse_postgres_url(&database_url);
128        let dbname = match dbname {
129            Some(db_name) => format!("{db_name}_test_{simple}"),
130            None => format!("test_{simple}"),
131        };
132        let dbname_cloned = dbname.clone();
133        let extensions_cloned = extensions.clone();
134
135        let tdb = Self {
136            server_url,
137            dbname,
138            extensions,
139        };
140
141        let url = tdb.url();
142
143        // create database dbname
144        thread::spawn(move || {
145            let rt = Runtime::new().unwrap();
146            rt.block_on(async move {
147                // use server url to create database
148                let mut conn = PgConnection::connect(&database_url)
149                    .await
150                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
151                conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
152                    .await
153                    .unwrap();
154
155                // now connect to test database for migration
156                let mut conn = PgConnection::connect(&url)
157                    .await
158                    .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
159
160                // create extensions before running migrations
161                for ext in &extensions_cloned {
162                    conn.execute(format!(r#"CREATE EXTENSION IF NOT EXISTS "{ext}""#).as_str())
163                        .await
164                        .unwrap_or_else(|_| panic!("Error while creating extension {ext}"));
165                }
166
167                let m = Migrator::new(migrations).await.unwrap();
168                m.run(&mut conn).await.unwrap();
169
170                // run seed files if provided
171                if let Some(seeds_dir) = seeds_path {
172                    run_seeds(&mut conn, &seeds_dir).await.unwrap();
173                }
174            });
175        })
176        .join()
177        .expect("failed to create database");
178
179        tdb
180    }
181
182    pub fn server_url(&self) -> String {
183        self.server_url.clone()
184    }
185
186    pub fn url(&self) -> String {
187        format!("{}/{}", self.server_url, self.dbname)
188    }
189
190    pub async fn get_pool(&self) -> PgPool {
191        let url = self.url();
192        PgPool::connect(&url)
193            .await
194            .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
195    }
196
197    pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
198        let pool = self.get_pool().await;
199        let path = filename.canonicalize()?;
200        let mut conn = pool.acquire().await?;
201        let sql = format!(
202            "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
203            table,
204            fields.join(","),
205            path.display()
206        );
207        conn.execute(sql.as_str()).await?;
208        // copy csv
209
210        Ok(())
211    }
212
213    pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
214        let mut rdr = csv::Reader::from_reader(csv.as_bytes());
215        let headers = rdr.headers()?.iter().join(",");
216        let mut tx = self.get_pool().await.begin().await?;
217        for result in rdr.records() {
218            let record = result?;
219            let sql = format!(
220                "INSERT INTO {} ({}) VALUES ({})",
221                table,
222                headers,
223                record.iter().map(|v| format!("'{v}'")).join(",")
224            );
225            tx.execute(sql.as_str()).await?;
226        }
227        tx.commit().await?;
228        Ok(())
229    }
230}
231
232impl Drop for TestPg {
233    fn drop(&mut self) {
234        let server_url = &self.server_url;
235        let database_url = format!("{server_url}/postgres");
236        let dbname = self.dbname.clone();
237        thread::spawn(move || {
238            let rt = Runtime::new().unwrap();
239            rt.block_on(async move {
240                    let mut conn = PgConnection::connect(&database_url).await
241                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
242                    // terminate existing connections
243                    sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
244                    .execute( &mut conn)
245                    .await
246                    .expect("Terminate all other connections");
247                    conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
248                        .await
249                        .expect("Error while querying the drop database");
250                });
251            })
252            .join()
253            .expect("failed to drop database");
254    }
255}
256
257impl Default for TestPg {
258    fn default() -> Self {
259        Self::new(
260            "postgres://postgres:postgres@localhost:5432".to_string(),
261            Path::new("./fixtures/migrations"),
262        )
263    }
264}
265
266/// Discovers and runs seed SQL files from a directory.
267///
268/// Seed files should follow the naming pattern: `<timestamp>_<description>.sql`
269/// They will be executed in timestamp order.
270async fn run_seeds(conn: &mut PgConnection, seeds_dir: &Path) -> Result<()> {
271    if !seeds_dir.exists() {
272        return Ok(());
273    }
274
275    let mut seed_files = Vec::new();
276
277    // read all .sql files from the seeds directory
278    for entry in fs::read_dir(seeds_dir)? {
279        let entry = entry?;
280        let path = entry.path();
281
282        if path.is_file()
283            && path.extension().is_some_and(|ext| ext == "sql")
284            && let Some(filename) = path.file_name().and_then(|n| n.to_str())
285        {
286            // extract timestamp from filename (before first underscore)
287            if let Some(timestamp) = filename.split('_').next() {
288                seed_files.push((timestamp.to_string(), path));
289            }
290        }
291    }
292
293    // sort by timestamp
294    seed_files.sort_by(|a, b| a.0.cmp(&b.0));
295
296    // execute each seed file
297    for (_timestamp, path) in seed_files {
298        let sql = fs::read_to_string(&path)?;
299        conn.execute(sql.as_str()).await?;
300    }
301
302    Ok(())
303}
304
305fn parse_postgres_url(url: &str) -> (String, Option<String>) {
306    let url_without_protocol = url.trim_start_matches("postgres://");
307
308    let parts: Vec<&str> = url_without_protocol.split('/').collect();
309    let server_url = format!("postgres://{}", parts[0]);
310
311    let dbname = if parts.len() > 1 && !parts[1].is_empty() {
312        Some(parts[1].to_string())
313    } else {
314        None
315    };
316
317    (server_url, dbname)
318}
319#[cfg(test)]
320mod tests {
321    use std::env;
322
323    use crate::postgres::TestPg;
324    use anyhow::Result;
325
326    #[tokio::test]
327    async fn test_postgres_should_create_and_drop() {
328        let tdb = TestPg::default();
329        let pool = tdb.get_pool().await;
330        // insert todo
331        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
332            .execute(&pool)
333            .await
334            .unwrap();
335        // get todo
336        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
337            .fetch_one(&pool)
338            .await
339            .unwrap();
340        assert_eq!(id, 1);
341        assert_eq!(title, "test");
342    }
343
344    #[tokio::test]
345    #[ignore = "github action postgres server can't be used for this test"]
346    async fn test_postgres_should_load_csv() -> Result<()> {
347        let filename = env::current_dir()?.join("fixtures/todos.csv");
348        let tdb = TestPg::default();
349        tdb.load_csv("todos", &["title"], &filename).await?;
350        let pool = tdb.get_pool().await;
351        // get todo
352        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
353            .fetch_one(&pool)
354            .await
355            .unwrap();
356        assert_eq!(id, 1);
357        assert_eq!(title, "hello world");
358        Ok(())
359    }
360
361    #[tokio::test]
362    async fn test_postgres_should_load_csv_data() -> Result<()> {
363        let csv = include_str!("../fixtures/todos.csv");
364        let tdb = TestPg::default();
365        tdb.load_csv_data("todos", csv).await?;
366        let pool = tdb.get_pool().await;
367        // get todo
368        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
369            .fetch_one(&pool)
370            .await
371            .unwrap();
372        assert_eq!(id, 1);
373        assert_eq!(title, "hello world");
374        Ok(())
375    }
376    use super::*;
377
378    #[tokio::test]
379    async fn test_postgres_with_extensions() {
380        use crate::TestPgBuilder;
381
382        let tdb = TestPgBuilder::new(
383            "postgres://postgres:postgres@localhost:5432".to_string(),
384            Path::new("./fixtures/migrations"),
385        )
386        .with_extensions(vec!["uuid-ossp".to_string()])
387        .build();
388
389        let pool = tdb.get_pool().await;
390
391        // Verify the extension is installed by trying to use it
392        let result = sqlx::query_scalar::<_, String>("SELECT uuid_generate_v4()::text")
393            .fetch_one(&pool)
394            .await;
395
396        assert!(result.is_ok(), "uuid-ossp extension should be available");
397    }
398
399    #[test]
400    fn test_with_dbname() {
401        let url = "postgres://testuser:1@localhost/pureya";
402        let (server_url, dbname) = parse_postgres_url(url);
403        assert_eq!(server_url, "postgres://testuser:1@localhost");
404        assert_eq!(dbname, Some("pureya".to_string()));
405    }
406
407    #[test]
408    fn test_without_dbname() {
409        let url = "postgres://testuser:1@localhost";
410        let (server_url, dbname) = parse_postgres_url(url);
411        assert_eq!(server_url, "postgres://testuser:1@localhost");
412        assert_eq!(dbname, None);
413    }
414
415    #[tokio::test]
416    async fn test_postgres_with_seeds() {
417        use crate::TestPgBuilder;
418
419        let tdb = TestPgBuilder::new(
420            "postgres://postgres:postgres@localhost:5432".to_string(),
421            Path::new("./fixtures/migrations"),
422        )
423        .with_seeds(Path::new("./fixtures/seeds"))
424        .build();
425
426        let pool = tdb.get_pool().await;
427
428        // Verify that seed data was loaded in the correct order
429        let todos = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos ORDER BY id")
430            .fetch_all(&pool)
431            .await
432            .unwrap();
433
434        assert_eq!(todos.len(), 3);
435        assert_eq!(todos[0].1, "First seeded todo");
436        assert_eq!(todos[1].1, "Second seeded todo");
437        assert_eq!(todos[2].1, "Third seeded todo");
438    }
439}