sqlx_db_ssk_tester/
lib.rs

1use sqlx::{migrate::Migrator, Connection, Executor, PgConnection, PgPool};
2use std::{path::Path, thread};
3use tokio::runtime::Runtime;
4use uuid::Uuid;
5
6pub struct TestDb {
7    pub host: String,
8    pub port: u16,
9    pub user: String,
10    pub password: String,
11    pub dbname: String,
12}
13
14impl TestDb {
15    pub fn new(
16        host: impl Into<String>,
17        port: u16,
18        user: impl Into<String>,
19        password: impl Into<String>,
20        migration_path: impl Into<String>,
21    ) -> Self {
22        let host = host.into();
23        let user = user.into();
24        let password = password.into();
25
26        let uuid = Uuid::new_v4();
27        let dbname = format!("test_{}", uuid);
28        let dbname_clone = dbname.clone();
29        let tdb = Self {
30            host,
31            port,
32            user,
33            password,
34            dbname,
35        };
36
37        let server_url = tdb.server_url();
38
39        let url = tdb.url();
40        let migration_path = migration_path.into();
41
42        thread::spawn(move || {
43            let rt = Runtime::new().unwrap();
44            rt.block_on(async move {
45                let mut conn = PgConnection::connect(&server_url).await.unwrap();
46                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_clone).as_str())
47                    .await
48                    .unwrap();
49
50                let mut conn = PgConnection::connect(&url).await.unwrap();
51                let m = Migrator::new(Path::new(&migration_path)).await.unwrap();
52                m.run(&mut conn).await.unwrap();
53            });
54        })
55        .join()
56        .expect("Failed to create test database");
57
58        tdb
59    }
60
61    pub fn server_url(&self) -> String {
62        if self.password.is_empty() {
63            format!("postgres://{}@{}:{}", self.user, self.host, self.port)
64        } else {
65            format!(
66                "postgres://{}:{}@{}:{}",
67                self.user, self.password, self.host, self.port
68            )
69        }
70    }
71
72    pub fn url(&self) -> String {
73        format!("{}/{}", self.server_url(), self.dbname)
74    }
75
76    pub async fn get_pool(&self) -> PgPool {
77        sqlx::postgres::PgPoolOptions::new()
78            .max_connections(5)
79            .connect(&self.url())
80            .await
81            .unwrap()
82    }
83}
84
85impl Drop for TestDb {
86    fn drop(&mut self) {
87        let server_url = self.server_url();
88        let db_name = self.dbname.clone();
89        thread::spawn(move || {
90            let  rt = Runtime::new().unwrap();
91            rt.block_on(async move {
92                let mut conn = PgConnection::connect(&server_url).await.unwrap();
93                // terminate existing connections
94                sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE  pid <> pg_backend_pid() and datname = '{}'"#,db_name))
95                    .execute(&mut conn)
96                    .await
97                    .expect("Terminating connections failed");
98                conn.execute(format!(r#"DROP DATABASE "{}""#, db_name).as_str())
99                    .await
100                    .expect("Error while dropping database");
101            });
102        })
103        .join()
104        .expect("Failed to join thread");
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[tokio::test]
113    async fn test_db_should_create_and_drop() {
114        let tdb = TestDb::new("localhost", 15432, "postgres", "7cOPpA7dnc", "./migrations");
115        let pool = tdb.get_pool().await;
116        println!("Pool: {:?}", pool);
117
118        // insert todos
119        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
120            .execute(&pool)
121            .await
122            .unwrap();
123
124        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title from todos")
125            .fetch_one(&pool)
126            .await
127            .unwrap();
128
129        assert_eq!(id, 1);
130        assert_eq!(title, "test");
131    }
132}