sqlx_db_tester/
postgres.rs

1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4    migrate::{MigrationSource, Migrator},
5    Connection, Executor, PgConnection, PgPool,
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestPg {
13    pub server_url: String,
14    pub dbname: String,
15}
16
17impl TestPg {
18    pub fn new<S>(database_url: String, migrations: S) -> Self
19    where
20        S: MigrationSource<'static> + Send + Sync + 'static,
21    {
22        let simple = Uuid::new_v4().simple();
23        let (server_url, dbname) = parse_postgres_url(&database_url);
24        let dbname = match dbname {
25            Some(db_name) => format!("{}_test_{}", db_name, simple),
26            None => format!("test_{}", simple),
27        };
28        let dbname_cloned = dbname.clone();
29
30        let tdb = Self { server_url, dbname };
31
32        let url = tdb.url();
33
34        // create database dbname
35        thread::spawn(move || {
36            let rt = Runtime::new().unwrap();
37            rt.block_on(async move {
38                // use server url to create database
39                let mut conn = PgConnection::connect(&database_url)
40                    .await
41                    .unwrap_or_else(|_| panic!("Error while connecting to {}", database_url));
42                conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
43                    .await
44                    .unwrap();
45
46                // now connect to test database for migration
47                let mut conn = PgConnection::connect(&url)
48                    .await
49                    .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
50                let m = Migrator::new(migrations).await.unwrap();
51                m.run(&mut conn).await.unwrap();
52            });
53        })
54        .join()
55        .expect("failed to create database");
56
57        tdb
58    }
59
60    pub fn server_url(&self) -> String {
61        self.server_url.clone()
62    }
63
64    pub fn url(&self) -> String {
65        format!("{}/{}", self.server_url, self.dbname)
66    }
67
68    pub async fn get_pool(&self) -> PgPool {
69        let url = self.url();
70        PgPool::connect(&url)
71            .await
72            .unwrap_or_else(|_| panic!("Error while connecting to {}", url))
73    }
74
75    pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
76        let pool = self.get_pool().await;
77        let path = filename.canonicalize()?;
78        let mut conn = pool.acquire().await?;
79        let sql = format!(
80            "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
81            table,
82            fields.join(","),
83            path.display()
84        );
85        conn.execute(sql.as_str()).await?;
86        // copy csv
87
88        Ok(())
89    }
90
91    pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
92        let mut rdr = csv::Reader::from_reader(csv.as_bytes());
93        let headers = rdr.headers()?.iter().join(",");
94        let mut tx = self.get_pool().await.begin().await?;
95        for result in rdr.records() {
96            let record = result?;
97            let sql = format!(
98                "INSERT INTO {} ({}) VALUES ({})",
99                table,
100                headers,
101                record.iter().map(|v| format!("'{v}'")).join(",")
102            );
103            tx.execute(sql.as_str()).await?;
104        }
105        tx.commit().await?;
106        Ok(())
107    }
108}
109
110impl Drop for TestPg {
111    fn drop(&mut self) {
112        let server_url = &self.server_url;
113        let database_url = format!("{server_url}/postgres");
114        let dbname = self.dbname.clone();
115        thread::spawn(move || {
116            let rt = Runtime::new().unwrap();
117            rt.block_on(async move {
118                    let mut conn = PgConnection::connect(&database_url).await
119                    .unwrap_or_else(|_| panic!("Error while connecting to {}", database_url));
120                    // terminate existing connections
121                    sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
122                    .execute( &mut conn)
123                    .await
124                    .expect("Terminate all other connections");
125                    conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
126                        .await
127                        .expect("Error while querying the drop database");
128                });
129            })
130            .join()
131            .expect("failed to drop database");
132    }
133}
134
135impl Default for TestPg {
136    fn default() -> Self {
137        Self::new(
138            "postgres://postgres:postgres@localhost:5432".to_string(),
139            Path::new("./fixtures/migrations"),
140        )
141    }
142}
143
144fn parse_postgres_url(url: &str) -> (String, Option<String>) {
145    let url_without_protocol = url.trim_start_matches("postgres://");
146
147    let parts: Vec<&str> = url_without_protocol.split('/').collect();
148    let server_url = format!("postgres://{}", parts[0]);
149
150    let dbname = if parts.len() > 1 && !parts[1].is_empty() {
151        Some(parts[1].to_string())
152    } else {
153        None
154    };
155
156    (server_url, dbname)
157}
158#[cfg(test)]
159mod tests {
160    use std::env;
161
162    use crate::postgres::TestPg;
163    use anyhow::Result;
164
165    #[tokio::test]
166    async fn test_postgres_should_create_and_drop() {
167        let tdb = TestPg::default();
168        let pool = tdb.get_pool().await;
169        // insert todo
170        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
171            .execute(&pool)
172            .await
173            .unwrap();
174        // get todo
175        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
176            .fetch_one(&pool)
177            .await
178            .unwrap();
179        assert_eq!(id, 1);
180        assert_eq!(title, "test");
181    }
182
183    // for no reasons, this test started to fail on github actions, the file exists it can be listed but it can't be opened during tests
184    #[ignore]
185    #[tokio::test]
186    async fn test_postgres_should_load_csv() -> Result<()> {
187        let filename = env::current_dir()?.join("fixtures/todos.csv");
188        let tdb = TestPg::default();
189        tdb.load_csv("todos", &["title"], &filename).await?;
190        let pool = tdb.get_pool().await;
191        // get todo
192        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
193            .fetch_one(&pool)
194            .await
195            .unwrap();
196        assert_eq!(id, 1);
197        assert_eq!(title, "hello world");
198        Ok(())
199    }
200
201    #[tokio::test]
202    async fn test_postgres_should_load_csv_data() -> Result<()> {
203        let csv = include_str!("../fixtures/todos.csv");
204        let tdb = TestPg::default();
205        tdb.load_csv_data("todos", csv).await?;
206        let pool = tdb.get_pool().await;
207        // get todo
208        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
209            .fetch_one(&pool)
210            .await
211            .unwrap();
212        assert_eq!(id, 1);
213        assert_eq!(title, "hello world");
214        Ok(())
215    }
216    use super::*;
217    #[test]
218    fn test_with_dbname() {
219        let url = "postgres://testuser:1@localhost/pureya";
220        let (server_url, dbname) = parse_postgres_url(url);
221        assert_eq!(server_url, "postgres://testuser:1@localhost");
222        assert_eq!(dbname, Some("pureya".to_string()));
223    }
224
225    #[test]
226    fn test_without_dbname() {
227        let url = "postgres://testuser:1@localhost";
228        let (server_url, dbname) = parse_postgres_url(url);
229        assert_eq!(server_url, "postgres://testuser:1@localhost");
230        assert_eq!(dbname, None);
231    }
232}