1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use std::{path::Path, thread};

use sqlx::{migrate::Migrator, Connection, Executor, PgConnection, PgPool};
use tokio::runtime::Runtime;
use uuid::Uuid;

pub struct TestDB {
    dbname: String,
    host: String,
    port: u16,
    username: String,
    password: String,
}

impl TestDB {
    pub fn new(
        host: impl Into<String>,
        port: u16,
        username: impl Into<String>,
        password: impl Into<String>,
        migration_path: impl Into<String>,
    ) -> Self {
        let host = host.into();
        let username = username.into();
        let password = password.into();
        let migration_path = migration_path.into();

        // create random database
        let uuid = Uuid::new_v4();
        let dbname = format!("test-{}", uuid);
        let dbname_cloned = dbname.clone();
        let config = TestDB {
            dbname,
            host,
            port,
            username,
            password,
        };
        let server_url = config.server_url();
        let url = config.url();

        thread::spawn(move || {
            let rt = Runtime::new().unwrap();
            rt.block_on(async {
                let mut conn = PgConnection::connect(&server_url).await.unwrap();

                // r# # 创建原始字符串字面量 不需要进行转义,如果字符串包含#字符, 可以使用 r#### #### 方式,只要保证首尾#数量相同
                conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
                    .await
                    .unwrap();

                // execute migration on new database
                let mut conn = PgConnection::connect(&url).await.unwrap();

                let migrator = Migrator::new(Path::new(&migration_path)).await.unwrap();
                migrator.run(&mut conn).await.unwrap();
            })
        })
        .join()
        .expect("failed to create database");

        config
    }

    pub fn server_url(&self) -> String {
        if self.password.is_empty() {
            format!("postgres://{}@{}:{}", self.username, self.host, self.port)
        } else {
            format!(
                "postgres://{}:{}@{}:{}",
                self.username, self.password, self.host, self.port
            )
        }
    }

    pub fn url(&self) -> String {
        format!("{}/{}", self.server_url(), self.dbname)
    }

    pub async fn get_pool(&self) -> PgPool {
        sqlx::postgres::PgPoolOptions::new()
            .max_connections(5)
            .connect(&self.url())
            .await
            .unwrap()
    }
}

impl Drop for TestDB {
    fn drop(&mut self) {
        let server_url = self.server_url();
        let db_name = self.dbname.clone();

        thread::spawn(move || {
            let rt = Runtime::new().unwrap();
        rt.block_on(async {
          let mut conn = PgConnection::connect(&server_url).await.unwrap();

          // terminate existing connections
          sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#, db_name))
          .execute(&mut conn)
          .await
          .expect("Terminate all other connections");

          conn.execute(format!(r#"DROP DATABASE "{}""#, db_name).as_str()).await.expect("Error while querying the drop database");
        })}).join().expect("fail to drop database");
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_db_should_create_and_drop() {
        // fixtures 文件路径相对于 Cargo.toml
        let tdb = TestDB::new("localhost", 5432, "postgres", "123456", "./migrations");
        let pool = tdb.get_pool().await;

        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
            .execute(&pool)
            .await
            .unwrap();
        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title from todos")
            .fetch_one(&pool)
            .await
            .unwrap();

        assert_eq!(1, id);
        assert_eq!("test", title);
    }
}