sqlx_mock/
lib.rs

1use sqlx::{
2    migrate::{MigrationSource, Migrator},
3    Connection, Executor, PgConnection, PgPool,
4};
5use std::{path::Path, thread};
6use tokio::runtime::Runtime;
7use uuid::Uuid;
8
9pub struct TestPostgres {
10    pub server_url: String,
11    pub dbname: String,
12}
13
14impl TestPostgres {
15    pub fn new<S>(server_url: String, source: S) -> Self
16    where
17        S: MigrationSource<'static> + Send + Sync + 'static,
18    {
19        // create a random db name
20        let dbname = format!("test_{}", Uuid::new_v4());
21        let dbname_cloned = dbname.clone();
22        let tpg = Self { server_url, dbname };
23        // config.db.dbname = database_name.clone();
24        let server_url = tpg.server_url();
25        let url = tpg.url();
26
27        thread::spawn(move || {
28            let rt = Runtime::new().unwrap();
29            // create database
30            // server_url is not append a database name
31            rt.block_on(async move {
32                let mut conn = PgConnection::connect(&server_url).await.unwrap();
33                conn.execute(format!(r#"CREATE Database "{}""#, dbname_cloned).as_str())
34                    .await
35                    .unwrap();
36                let mut conn = PgConnection::connect(&url).await.unwrap();
37                // migrate database
38                let m = Migrator::new(source).await.unwrap();
39                m.run(&mut conn).await.unwrap();
40            });
41        })
42        .join()
43        .expect("Thread panicked");
44
45        tpg
46    }
47
48    pub fn server_url(&self) -> String {
49        self.server_url.clone()
50    }
51
52    pub fn url(&self) -> String {
53        format!("{}/{}", self.server_url, self.dbname)
54    }
55
56    pub async fn get_pool(&self) -> PgPool {
57        PgPool::connect(&self.url()).await.unwrap()
58    }
59}
60
61impl Drop for TestPostgres {
62    fn drop(&mut self) {
63        // server_url is not append a database name
64        let url = self.server_url();
65        let dbname = self.dbname.clone();
66        println!("url is: {}", url);
67        println!("database_name is: {}", dbname);
68        thread::spawn(move || {
69            let rt = Runtime::new().unwrap();
70            // drop database
71            rt.block_on(async move {
72                let mut conn = PgConnection::connect(&url).await.unwrap();
73                // terminate all connections
74                sqlx::query(&format!(
75                    r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity
76                 WHERE pid <> pg_backend_pid() AND datname = '{}'"#,
77                    dbname
78                ))
79                .execute(&mut conn)
80                .await
81                .expect("Terminate all other connections");
82
83                // drop database
84                conn.execute(format!(r#"DROP Database "{}""#, dbname).as_str())
85                    .await
86                    .expect("Error while querying the drop database");
87            });
88        })
89        .join()
90        .expect("failed to drop database");
91    }
92}
93
94impl Default for TestPostgres {
95    fn default() -> Self {
96        Self::new(
97            "postgres://postgres:postgres@localhost:5432".to_string(),
98            Path::new("./migrations"),
99        )
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[tokio::test]
108    async fn test_postgres_can_create_and_drop() {
109        let pg = TestPostgres::default();
110        let pool = pg.get_pool().await;
111        sqlx::query("insert into todos (title) values ('test')")
112            .execute(&pool)
113            .await
114            .unwrap();
115
116        let (id, title) = sqlx::query_as::<_, (i32, String)>("select id, title from todos")
117            .fetch_one(&pool)
118            .await
119            .unwrap();
120        assert_eq!(id, 1);
121        assert_eq!(title, "test");
122    }
123}