sqlx_db_tester_fb/
lib.rs

1use sqlx::migrate::Migrator;
2use sqlx::{Connection, Executor, PgConnection, PgPool};
3use std::path::Path;
4use std::thread;
5use tokio::runtime::Runtime;
6use uuid::Uuid;
7
8#[derive(Debug)]
9pub struct TestDb {
10    pub host: String,
11    pub port: u16,
12    pub user: String,
13    pub password: String,
14    pub dbname: String,
15}
16
17impl TestDb {
18    pub fn new(
19        host: impl Into<String>,
20        port: u16,
21        user: impl Into<String>,
22        password: impl Into<String>,
23        migration_path: impl Into<String>,
24    ) -> Self {
25        let user = user.into();
26        let password = password.into();
27        let host = host.into();
28
29        let uuid = Uuid::new_v4();
30        let dbname = format!("test_{}", uuid);
31        let dbname_cloned = dbname.clone();
32
33        let tdb = Self {
34            host,
35            port,
36            user,
37            password,
38            dbname,
39        };
40
41        let server_url = tdb.server_url();
42
43        let url = tdb.url();
44        let migration_path = migration_path.into();
45        let path = migration_path;
46
47        // create database dbname
48        thread::spawn(move || {
49            let rt = Runtime::new().unwrap();
50            rt.block_on(async move {
51                // use another connection to create database
52                let mut conn = PgConnection::connect(&server_url).await.unwrap();
53                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
54                    .await
55                    .unwrap();
56
57                // now connect to test database for migration
58                let mut conn = PgConnection::connect(&url).await.unwrap();
59                let m = Migrator::new(Path::new(&path)).await.unwrap();
60                m.run(&mut conn).await.unwrap();
61            });
62        })
63        .join()
64        .expect("Failed to create database.");
65
66        tdb
67    }
68
69    pub fn server_url(&self) -> String {
70        if self.password.is_empty() {
71            format!(
72                "postgres://{}@{}:{}/{}",
73                self.user, self.host, self.port, "postgres"
74            )
75        } else {
76            format!(
77                "postgres://{}:{}@{}:{}/{}",
78                self.user, self.password, self.host, self.port, "postgres"
79            )
80        }
81    }
82
83    pub fn url(&self) -> String {
84        if self.password.is_empty() {
85            format!(
86                "postgres://{}@{}:{}/{}",
87                self.user, self.host, self.port, self.dbname
88            )
89        } else {
90            format!(
91                "postgres://{}:{}@{}:{}/{}",
92                self.user, self.password, self.host, self.port, self.dbname
93            )
94        }
95    }
96
97    pub async fn get_pool(&self) -> PgPool {
98        sqlx::postgres::PgPoolOptions::new()
99            .max_connections(5)
100            .connect(&self.url())
101            .await
102            .unwrap()
103    }
104}
105
106impl Drop for TestDb {
107    fn drop(&mut self) {
108        let server_url = self.server_url();
109        let dbname = self.dbname.clone();
110        println!("server_url: {}", server_url);
111        println!("Dropping database: {}", dbname);
112
113        thread::spawn(move || {
114            let rt = Runtime::new().unwrap();
115            rt.block_on(async move {
116                let mut conn = sqlx::PgConnection::connect(&server_url).await.unwrap();
117
118                // terminate existing connections
119                sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#, dbname))
120                    .execute(&mut conn)
121                    .await
122                    .expect("Terminate all other connections");
123
124
125                conn.execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
126                    .await
127                    .expect("Error while querying the drop database")
128            });
129        })
130        .join()
131        .expect("Failed to drop database.")
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[tokio::test]
140    async fn test_db_should_create_and_drop() {
141        let tdb = TestDb::new("127.0.0.1", 5432, "admin", "P@ssw0rd", "./migrations");
142        let pool = tdb.get_pool().await;
143        // insert todo
144        sqlx::query("INSERT INTO todo (title) VALUES ('test')")
145            .execute(&pool)
146            .await
147            .unwrap();
148        // get todo
149        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todo")
150            .fetch_one(&pool)
151            .await
152            .unwrap();
153        assert_eq!(id, 1);
154        assert_eq!(title, "test");
155    }
156}