sqlx_database_tester_wsf/
lib.rs

1use std::{path::Path, thread};
2
3use sqlx::{migrate::Migrator, Connection, Executor, PgConnection, PgPool};
4use tokio::runtime::Runtime;
5use uuid::Uuid;
6
7pub struct TestDB {
8    pub dbname: String,
9    host: String,
10    port: u16,
11    username: String,
12    password: String,
13}
14
15impl TestDB {
16    pub fn new(
17        host: impl Into<String>,
18        port: u16,
19        username: impl Into<String>,
20        password: impl Into<String>,
21        migration_path: impl Into<String>,
22    ) -> Self {
23        let host = host.into();
24        let username = username.into();
25        let password = password.into();
26        let migration_path = migration_path.into();
27
28        // create random database
29        let uuid = Uuid::new_v4();
30        let dbname = format!("test-{}", uuid);
31        let dbname_cloned = dbname.clone();
32        let config = TestDB {
33            dbname,
34            host,
35            port,
36            username,
37            password,
38        };
39        let server_url = config.server_url();
40        let url = config.url();
41
42        thread::spawn(move || {
43            let rt = Runtime::new().unwrap();
44            rt.block_on(async {
45                let mut conn = PgConnection::connect(&server_url).await.unwrap();
46
47                // r# # 创建原始字符串字面量 不需要进行转义,如果字符串包含#字符, 可以使用 r#### #### 方式,只要保证首尾#数量相同
48                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
49                    .await
50                    .unwrap();
51
52                // execute migration on new database
53                let mut conn = PgConnection::connect(&url).await.unwrap();
54
55                let migrator = Migrator::new(Path::new(&migration_path)).await.unwrap();
56                migrator.run(&mut conn).await.unwrap();
57            })
58        })
59        .join()
60        .expect("failed to create database");
61
62        config
63    }
64
65    pub fn server_url(&self) -> String {
66        if self.password.is_empty() {
67            format!("postgres://{}@{}:{}", self.username, self.host, self.port)
68        } else {
69            format!(
70                "postgres://{}:{}@{}:{}",
71                self.username, self.password, self.host, self.port
72            )
73        }
74    }
75
76    pub fn url(&self) -> String {
77        format!("{}/{}", self.server_url(), self.dbname)
78    }
79
80    pub async fn get_pool(&self) -> PgPool {
81        sqlx::postgres::PgPoolOptions::new()
82            .max_connections(5)
83            .connect(&self.url())
84            .await
85            .unwrap()
86    }
87}
88
89impl Drop for TestDB {
90    fn drop(&mut self) {
91        let server_url = self.server_url();
92        let db_name = self.dbname.clone();
93
94        thread::spawn(move || {
95            let rt = Runtime::new().unwrap();
96        rt.block_on(async {
97          let mut conn = PgConnection::connect(&server_url).await.unwrap();
98
99          // terminate existing connections
100          sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#, db_name))
101          .execute(&mut conn)
102          .await
103          .expect("Terminate all other connections");
104
105          conn.execute(format!(r#"DROP DATABASE "{}""#, db_name).as_str()).await.expect("Error while querying the drop database");
106        })}).join().expect("fail to drop database");
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[tokio::test]
115    async fn test_db_should_create_and_drop() {
116        // fixtures 文件路径相对于 Cargo.toml
117        let tdb = TestDB::new("localhost", 5432, "postgres", "123456", "./migrations");
118        let pool = tdb.get_pool().await;
119
120        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
121            .execute(&pool)
122            .await
123            .unwrap();
124        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title from todos")
125            .fetch_one(&pool)
126            .await
127            .unwrap();
128
129        assert_eq!(1, id);
130        assert_eq!("test", title);
131    }
132}