sqlx_db_tester/
postgres.rs

1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4    Connection, Executor, PgConnection, PgPool,
5    migrate::{MigrationSource, Migrator},
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestPg {
13    pub server_url: String,
14    pub dbname: String,
15}
16
17impl TestPg {
18    pub fn new<S>(database_url: String, migrations: S) -> Self
19    where
20        S: MigrationSource<'static> + Send + Sync + 'static,
21    {
22        let simple = Uuid::new_v4().simple();
23        let (server_url, dbname) = parse_postgres_url(&database_url);
24        let dbname = match dbname {
25            Some(db_name) => format!("{db_name}_test_{simple}"),
26            None => format!("test_{simple}"),
27        };
28        let dbname_cloned = dbname.clone();
29
30        let tdb = Self { server_url, dbname };
31
32        let url = tdb.url();
33
34        // create database dbname
35        thread::spawn(move || {
36            let rt = Runtime::new().unwrap();
37            rt.block_on(async move {
38                // use server url to create database
39                let mut conn = PgConnection::connect(&database_url)
40                    .await
41                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
42                conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
43                    .await
44                    .unwrap();
45
46                // now connect to test database for migration
47                let mut conn = PgConnection::connect(&url)
48                    .await
49                    .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
50                let m = Migrator::new(migrations).await.unwrap();
51                m.run(&mut conn).await.unwrap();
52            });
53        })
54        .join()
55        .expect("failed to create database");
56
57        tdb
58    }
59
60    pub fn server_url(&self) -> String {
61        self.server_url.clone()
62    }
63
64    pub fn url(&self) -> String {
65        format!("{}/{}", self.server_url, self.dbname)
66    }
67
68    pub async fn get_pool(&self) -> PgPool {
69        let url = self.url();
70        PgPool::connect(&url)
71            .await
72            .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
73    }
74
75    pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
76        let pool = self.get_pool().await;
77        let path = filename.canonicalize()?;
78        let mut conn = pool.acquire().await?;
79        let sql = format!(
80            "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
81            table,
82            fields.join(","),
83            path.display()
84        );
85        conn.execute(sql.as_str()).await?;
86        // copy csv
87
88        Ok(())
89    }
90
91    pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
92        let mut rdr = csv::Reader::from_reader(csv.as_bytes());
93        let headers = rdr.headers()?.iter().join(",");
94        let mut tx = self.get_pool().await.begin().await?;
95        for result in rdr.records() {
96            let record = result?;
97            let sql = format!(
98                "INSERT INTO {} ({}) VALUES ({})",
99                table,
100                headers,
101                record.iter().map(|v| format!("'{v}'")).join(",")
102            );
103            tx.execute(sql.as_str()).await?;
104        }
105        tx.commit().await?;
106        Ok(())
107    }
108}
109
110impl Drop for TestPg {
111    fn drop(&mut self) {
112        let server_url = &self.server_url;
113        let database_url = format!("{server_url}/postgres");
114        let dbname = self.dbname.clone();
115        thread::spawn(move || {
116            let rt = Runtime::new().unwrap();
117            rt.block_on(async move {
118                    let mut conn = PgConnection::connect(&database_url).await
119                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
120                    // terminate existing connections
121                    sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
122                    .execute( &mut conn)
123                    .await
124                    .expect("Terminate all other connections");
125                    conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
126                        .await
127                        .expect("Error while querying the drop database");
128                });
129            })
130            .join()
131            .expect("failed to drop database");
132    }
133}
134
135impl Default for TestPg {
136    fn default() -> Self {
137        Self::new(
138            "postgres://postgres:postgres@localhost:5432".to_string(),
139            Path::new("./fixtures/migrations"),
140        )
141    }
142}
143
144fn parse_postgres_url(url: &str) -> (String, Option<String>) {
145    let url_without_protocol = url.trim_start_matches("postgres://");
146
147    let parts: Vec<&str> = url_without_protocol.split('/').collect();
148    let server_url = format!("postgres://{}", parts[0]);
149
150    let dbname = if parts.len() > 1 && !parts[1].is_empty() {
151        Some(parts[1].to_string())
152    } else {
153        None
154    };
155
156    (server_url, dbname)
157}
158#[cfg(test)]
159mod tests {
160    use std::env;
161
162    use crate::postgres::TestPg;
163    use anyhow::Result;
164
165    #[tokio::test]
166    async fn test_postgres_should_create_and_drop() {
167        let tdb = TestPg::default();
168        let pool = tdb.get_pool().await;
169        // insert todo
170        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
171            .execute(&pool)
172            .await
173            .unwrap();
174        // get todo
175        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
176            .fetch_one(&pool)
177            .await
178            .unwrap();
179        assert_eq!(id, 1);
180        assert_eq!(title, "test");
181    }
182
183    #[tokio::test]
184    async fn test_postgres_should_load_csv() -> Result<()> {
185        let filename = env::current_dir()?.join("fixtures/todos.csv");
186        let tdb = TestPg::default();
187        tdb.load_csv("todos", &["title"], &filename).await?;
188        let pool = tdb.get_pool().await;
189        // get todo
190        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
191            .fetch_one(&pool)
192            .await
193            .unwrap();
194        assert_eq!(id, 1);
195        assert_eq!(title, "hello world");
196        Ok(())
197    }
198
199    #[tokio::test]
200    async fn test_postgres_should_load_csv_data() -> Result<()> {
201        let csv = include_str!("../fixtures/todos.csv");
202        let tdb = TestPg::default();
203        tdb.load_csv_data("todos", csv).await?;
204        let pool = tdb.get_pool().await;
205        // get todo
206        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
207            .fetch_one(&pool)
208            .await
209            .unwrap();
210        assert_eq!(id, 1);
211        assert_eq!(title, "hello world");
212        Ok(())
213    }
214    use super::*;
215    #[test]
216    fn test_with_dbname() {
217        let url = "postgres://testuser:1@localhost/pureya";
218        let (server_url, dbname) = parse_postgres_url(url);
219        assert_eq!(server_url, "postgres://testuser:1@localhost");
220        assert_eq!(dbname, Some("pureya".to_string()));
221    }
222
223    #[test]
224    fn test_without_dbname() {
225        let url = "postgres://testuser:1@localhost";
226        let (server_url, dbname) = parse_postgres_url(url);
227        assert_eq!(server_url, "postgres://testuser:1@localhost");
228        assert_eq!(dbname, None);
229    }
230}