1use sqlx::migrate::Migrator;
2use sqlx::{Connection, Executor, PgConnection, PgPool};
3use std::path::Path;
4use std::thread;
5use tokio::runtime::Runtime;
6use uuid::Uuid;
7
8#[derive(Debug)]
9pub struct TestDb {
10 pub host: String,
11 pub port: u16,
12 pub user: String,
13 pub password: String,
14 pub dbname: String,
15}
16
17impl TestDb {
18 pub fn new(
19 host: impl Into<String>,
20 port: u16,
21 user: impl Into<String>,
22 password: impl Into<String>,
23 migration_path: impl Into<String>,
24 ) -> Self {
25 let user = user.into();
26 let password = password.into();
27 let host = host.into();
28
29 let uuid = Uuid::new_v4();
30 let dbname = format!("test_{}", uuid);
31 let dbname_cloned = dbname.clone();
32
33 let tdb = Self {
34 host,
35 port,
36 user,
37 password,
38 dbname,
39 };
40
41 let server_url = tdb.server_url();
42
43 let url = tdb.url();
44 let migration_path = migration_path.into();
45 let path = migration_path;
46
47 thread::spawn(move || {
49 let rt = Runtime::new().unwrap();
50 rt.block_on(async move {
51 let mut conn = PgConnection::connect(&server_url).await.unwrap();
53 conn.execute(format!(r#"CREATE DATABASE "{}""#, dbname_cloned).as_str())
54 .await
55 .unwrap();
56
57 let mut conn = PgConnection::connect(&url).await.unwrap();
59 let m = Migrator::new(Path::new(&path)).await.unwrap();
60 m.run(&mut conn).await.unwrap();
61 });
62 })
63 .join()
64 .expect("Failed to create database.");
65
66 tdb
67 }
68
69 pub fn server_url(&self) -> String {
70 if self.password.is_empty() {
71 format!(
72 "postgres://{}@{}:{}/{}",
73 self.user, self.host, self.port, "postgres"
74 )
75 } else {
76 format!(
77 "postgres://{}:{}@{}:{}/{}",
78 self.user, self.password, self.host, self.port, "postgres"
79 )
80 }
81 }
82
83 pub fn url(&self) -> String {
84 if self.password.is_empty() {
85 format!(
86 "postgres://{}@{}:{}/{}",
87 self.user, self.host, self.port, self.dbname
88 )
89 } else {
90 format!(
91 "postgres://{}:{}@{}:{}/{}",
92 self.user, self.password, self.host, self.port, self.dbname
93 )
94 }
95 }
96
97 pub async fn get_pool(&self) -> PgPool {
98 sqlx::postgres::PgPoolOptions::new()
99 .max_connections(5)
100 .connect(&self.url())
101 .await
102 .unwrap()
103 }
104}
105
106impl Drop for TestDb {
107 fn drop(&mut self) {
108 let server_url = self.server_url();
109 let dbname = self.dbname.clone();
110 println!("server_url: {}", server_url);
111 println!("Dropping database: {}", dbname);
112
113 thread::spawn(move || {
114 let rt = Runtime::new().unwrap();
115 rt.block_on(async move {
116 let mut conn = sqlx::PgConnection::connect(&server_url).await.unwrap();
117
118 sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{}'"#, dbname))
120 .execute(&mut conn)
121 .await
122 .expect("Terminate all other connections");
123
124
125 conn.execute(format!(r#"DROP DATABASE "{}""#, dbname).as_str())
126 .await
127 .expect("Error while querying the drop database")
128 });
129 })
130 .join()
131 .expect("Failed to drop database.")
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[tokio::test]
140 async fn test_db_should_create_and_drop() {
141 let tdb = TestDb::new("127.0.0.1", 5432, "admin", "P@ssw0rd", "./migrations");
142 let pool = tdb.get_pool().await;
143 sqlx::query("INSERT INTO todo (title) VALUES ('test')")
145 .execute(&pool)
146 .await
147 .unwrap();
148 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todo")
150 .fetch_one(&pool)
151 .await
152 .unwrap();
153 assert_eq!(id, 1);
154 assert_eq!(title, "test");
155 }
156}