sqlx_postgres_tester/
lib.rs

1use std::{path::Path, thread};
2
3use sqlx::{migrate::Migrator, Connection, Executor, PgConnection, PgPool};
4use tokio::runtime::Runtime;
5use uuid::Uuid;
6
7pub struct TestPg {
8    server_url: String,
9    pub dbname: String,
10}
11
12impl Default for TestPg {
13    fn default() -> Self {
14        let server_url = "postgres://postgres:postgres@localhost:5432";
15        let migration_path = "./migrations";
16
17        Self::new(server_url, migration_path)
18    }
19}
20
21impl TestPg {
22    pub fn new(server_url: impl Into<String>, migration_path: impl Into<String>) -> Self {
23        let server_url = server_url.into();
24        let uuid = Uuid::new_v4();
25        let dbname = format!("test_{}", uuid);
26
27        let tdb = Self {
28            server_url,
29            dbname: dbname.clone(),
30        };
31
32        let server_url = tdb.server_url();
33        let url = tdb.url();
34        let migration_path = migration_path.into();
35
36        // create database dbname
37        thread::spawn(move || {
38            let rt = Runtime::new().unwrap();
39            rt.block_on(async move {
40                // use server url to create database
41                let mut conn = PgConnection::connect(&server_url).await.unwrap();
42                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname).as_str())
43                    .await
44                    .expect("Error while create database");
45
46                let mut conn = PgConnection::connect(&url).await.unwrap();
47                let m = Migrator::new(Path::new(&migration_path)).await.unwrap();
48                m.run(&mut conn).await.unwrap();
49            });
50        })
51        .join()
52        .expect("failed to create database");
53
54        tdb
55    }
56
57    pub fn url(&self) -> String {
58        format!("{}/{}", self.server_url(), self.dbname)
59    }
60
61    pub fn server_url(&self) -> String {
62        self.server_url.clone()
63    }
64
65    pub async fn get_pool(&self) -> PgPool {
66        sqlx::postgres::PgPoolOptions::new()
67            .max_connections(5)
68            .connect(&self.url())
69            .await
70            .unwrap()
71    }
72}
73
74impl Drop for TestPg {
75    fn drop(&mut self) {
76        let server_url = self.server_url();
77        let dbname = self.dbname.clone();
78
79        thread::spawn(move || {
80            let rt = Runtime::new().unwrap();
81            rt.block_on(async move {
82                let mut conn = sqlx::PgConnection::connect(&server_url).await.unwrap();
83                // terminate existing connections
84                sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#, dbname))
85                    .execute(&mut conn)
86                    .await
87                    .expect("Terminate all other connections");
88
89                conn
90                    .execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
91                    .await
92                    .expect("Error while drop database");
93            });
94        })
95        .join()
96        .expect("failed to Drop database");
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[tokio::test]
105    async fn test_db_create_and_drop_should_work() {
106        let tdb = TestPg::default();
107        let pool = tdb.get_pool().await;
108
109        sqlx::query("INSERT INTO todos (title) VALUES('todo1')")
110            .execute(&pool)
111            .await
112            .unwrap();
113
114        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
115            .fetch_one(&pool)
116            .await
117            .unwrap();
118
119        assert_eq!(id, 1);
120        assert_eq!(title, "todo1");
121    }
122}