xsqlx_db_tester/
lib.rs

1use sqlx::{migrate::Migrator, Connection, Executor, PgConnection, PgPool};
2use std::{path::Path, thread};
3use tokio::runtime::Runtime;
4use uuid::Uuid;
5
6pub struct TestDB {
7    pub server_url: String,
8    pub dbname: String,
9}
10
11impl TestDB {
12    pub fn new(server_url: impl Into<String>, miration_path: impl Into<String>) -> TestDB {
13        let uuid = Uuid::new_v4();
14        let dbname = format!("testdb_{}", uuid);
15        let tdb = TestDB {
16            server_url: server_url.into(),
17            dbname: dbname.clone(),
18        };
19
20        let server_url = tdb.server_url();
21        let url = tdb.url();
22        let migration_path = miration_path.into();
23
24        // create database with dbname
25        thread::spawn(move || {
26            let rt = Runtime::new().unwrap();
27            rt.block_on(async move {
28                // use server url to create database
29                let mut conn = PgConnection::connect(&server_url).await.unwrap();
30                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname.clone()).as_ref())
31                    .await
32                    .unwrap_or_else(|_| panic!("Failed to create database {}", dbname));
33
34                // create a new connection for migration
35                let mut conn = PgConnection::connect(&url).await.unwrap();
36                let m = Migrator::new(Path::new(&migration_path)).await.unwrap();
37                m.run(&mut conn)
38                    .await
39                    .unwrap_or_else(|_| panic!("Failed to migrate"));
40            });
41        })
42        .join()
43        .expect("Failed to execute database operation");
44
45        tdb
46    }
47
48    pub fn url(&self) -> String {
49        format!("{}/{}", self.server_url.clone(), self.dbname.clone())
50    }
51
52    pub fn server_url(&self) -> String {
53        self.server_url.clone()
54    }
55
56    pub async fn get_pool(&self) -> PgPool {
57        PgPool::connect(&self.url()).await.unwrap()
58    }
59}
60
61impl Drop for TestDB {
62    fn drop(&mut self) {
63        let url = self.server_url();
64        let dbname = self.dbname.clone();
65        thread::spawn(move || {
66            let rt = Runtime::new().unwrap();
67            rt.block_on(async move {
68                let mut conn = PgConnection::connect(&url).await.unwrap();
69                // terminate all other connections
70                sqlx::query(&format!(
71                    r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity
72                    WHERE pid <> pg_backend_pid() AND datname = '{}'"#,
73                    dbname
74                ))
75                .execute(&mut conn)
76                .await
77                .expect("Terminate all other connections");
78
79                // drop test database
80                conn.execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
81                    .await
82                    .expect("Error while querying the drop database");
83            });
84        })
85        .join()
86        .expect("failed to drop database");
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    use sqlx::Row;
95
96    #[tokio::test]
97    async fn test_db_should_create_and_drop() {
98        let tdb = TestDB::new(
99            "postgres://postgres:postgres@localhost:5432",
100            "./fitures/migrations",
101        );
102        let url = tdb.url();
103        let mut conn = PgConnection::connect(&url).await.unwrap();
104        let row = sqlx::query("SELECT 1")
105            .fetch_one(&mut conn)
106            .await
107            .expect("Failed to query");
108        assert_eq!(row.get::<i32, _>(0), 1);
109    }
110}