switchgear_testing/
db.rs

1use std::path::Path;
2use std::thread;
3
4pub struct TestMysqlDatabase {
5    username: String,
6    db_name: String,
7    connection_url: String,
8    addr: String,
9}
10
11impl TestMysqlDatabase {
12    pub fn new(
13        username: &str,
14        db_name: &str,
15        addr: &str,
16        ssl: bool,
17        ssl_ca: Option<&Path>,
18    ) -> Self {
19        let addr_c = addr.to_string();
20        let username_c = username.to_string();
21        let db_name_c = db_name.to_string();
22        let _ = thread::spawn(move || {
23            let rt = match tokio::runtime::Runtime::new() {
24                Ok(rt) => rt,
25                Err(_) => return,
26            };
27
28            rt.block_on(async {
29                use sqlx::mysql::MySqlPoolOptions;
30
31                let pool = match MySqlPoolOptions::new()
32                    .connect(&format!("mysql://{username_c}:mysql@{addr_c}/mysql"))
33                    .await
34                {
35                    Ok(pool) => pool,
36                    Err(_) => return,
37                };
38
39                let _ = sqlx::query(&format!("CREATE DATABASE {db_name_c}"))
40                    .execute(&pool)
41                    .await;
42            });
43        })
44        .join();
45
46        let ssl = if ssl { "?ssl-mode=VERIFY_IDENTITY" } else { "" };
47
48        let ssl_ca = match (!ssl.is_empty(), ssl_ca) {
49            (true, Some(ssl_ca)) => format!("&ssl-ca={}", ssl_ca.to_string_lossy()),
50            (_, _) => "".to_string(),
51        };
52
53        let connection_url = format!("mysql://{username}:mysql@{addr}/{db_name}{ssl}{ssl_ca}");
54        Self {
55            username: username.to_string(),
56            db_name: db_name.to_string(),
57            connection_url,
58            addr: addr.to_string(),
59        }
60    }
61
62    pub fn connection_url(&self) -> &str {
63        &self.connection_url
64    }
65
66    pub fn database_name(&self) -> &str {
67        &self.db_name
68    }
69
70    pub fn address(&self) -> &str {
71        &self.addr
72    }
73}
74
75impl Drop for TestMysqlDatabase {
76    fn drop(&mut self) {
77        let username = self.username.clone();
78        let db_name = self.db_name.clone();
79        let addr = self.addr.clone();
80        let _ = thread::spawn(move || {
81            let rt = match tokio::runtime::Runtime::new() {
82                Ok(rt) => rt,
83                Err(_) => return,
84            };
85
86            rt.block_on(async {
87                use sqlx::mysql::MySqlPoolOptions;
88
89                let pool = match MySqlPoolOptions::new()
90                    .connect(&format!("mysql://{username}:mysql@{addr}/mysql"))
91                    .await
92                {
93                    Ok(pool) => pool,
94                    Err(_) => return,
95                };
96
97                let _ = sqlx::query(&format!("DROP DATABASE IF EXISTS {db_name}"))
98                    .execute(&pool)
99                    .await;
100            });
101        })
102        .join();
103    }
104}
105
106pub struct TestPostgresDatabase {
107    username: String,
108    db_name: String,
109    connection_url: String,
110    addr: String,
111}
112
113impl TestPostgresDatabase {
114    pub fn new(
115        username: &str,
116        db_name: &str,
117        addr: &str,
118        ssl: bool,
119        ssl_root_cert: Option<&Path>,
120    ) -> Self {
121        let username_c = username.to_string();
122        let db_name_c = db_name.to_string();
123        let addr_c = addr.to_string();
124        let _ = thread::spawn(move || {
125            let rt = match tokio::runtime::Runtime::new() {
126                Ok(rt) => rt,
127                Err(_) => return,
128            };
129
130            rt.block_on(async {
131                use sqlx::postgres::PgPoolOptions;
132
133                let pool = match PgPoolOptions::new()
134                    .connect(&format!(
135                        "postgres://{username_c}:postgres@{addr_c}/postgres"
136                    ))
137                    .await
138                {
139                    Ok(pool) => pool,
140                    Err(_) => return,
141                };
142
143                let _ = sqlx::query(&format!("CREATE DATABASE {db_name_c}"))
144                    .execute(&pool)
145                    .await;
146            });
147        })
148        .join();
149
150        let ssl = if ssl { "?sslmode=verify-full" } else { "" };
151
152        let ssl_root_cert = match (!ssl.is_empty(), ssl_root_cert) {
153            (true, Some(ssl_root_cert)) => {
154                format!("&sslrootcert={}", ssl_root_cert.to_string_lossy())
155            }
156            (_, _) => "".to_string(),
157        };
158
159        let connection_url =
160            format!("postgres://{username}:postgres@{addr}/{db_name}{ssl}{ssl_root_cert}");
161
162        Self {
163            username: username.to_string(),
164            db_name: db_name.to_string(),
165            connection_url,
166            addr: addr.to_string(),
167        }
168    }
169
170    pub fn connection_url(&self) -> &str {
171        &self.connection_url
172    }
173
174    pub fn database_name(&self) -> &str {
175        &self.db_name
176    }
177
178    pub fn address(&self) -> &str {
179        &self.addr
180    }
181}
182
183impl Drop for TestPostgresDatabase {
184    fn drop(&mut self) {
185        let username = self.username.clone();
186        let db_name = self.db_name.clone();
187        let addr = self.addr.clone();
188        let _ = thread::spawn(move || {
189            let rt = match tokio::runtime::Runtime::new() {
190                Ok(rt) => rt,
191                Err(_) => return,
192            };
193
194            rt.block_on(async {
195                use sqlx::postgres::PgPoolOptions;
196
197                let pool = match PgPoolOptions::new()
198                    .connect(&format!("postgres://{username}:postgres@{addr}/postgres"))
199                    .await
200                {
201                    Ok(pool) => pool,
202                    Err(_) => return,
203                };
204
205                let _ = sqlx::query(&format!("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{db_name}' AND pid <>  pg_backend_pid()"))
206                    .execute(&pool).await;
207
208                let _ = sqlx::query(&format!("DROP DATABASE IF EXISTS {db_name}"))
209                    .execute(&pool)
210                    .await;
211            });
212        })
213            .join();
214    }
215}