1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4 Connection, Executor, PgConnection, PgPool,
5 migrate::{MigrationSource, Migrator},
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestPg {
13 pub server_url: String,
14 pub dbname: String,
15 #[allow(dead_code)]
16 extensions: Vec<String>,
17}
18
19pub struct TestPgBuilder<S>
21where
22 S: MigrationSource<'static> + Send + Sync + 'static,
23{
24 database_url: String,
25 migrations: S,
26 extensions: Vec<String>,
27}
28
29impl<S> TestPgBuilder<S>
30where
31 S: MigrationSource<'static> + Send + Sync + 'static,
32{
33 pub fn new(database_url: String, migrations: S) -> Self {
35 Self {
36 database_url,
37 migrations,
38 extensions: vec![],
39 }
40 }
41
42 pub fn with_extensions(mut self, extensions: Vec<String>) -> Self {
57 self.extensions = extensions;
58 self
59 }
60
61 pub fn build(self) -> TestPg {
63 TestPg::new_with_extensions(self.database_url, self.migrations, self.extensions)
64 }
65}
66
67impl TestPg {
68 pub fn new<S>(database_url: String, migrations: S) -> Self
69 where
70 S: MigrationSource<'static> + Send + Sync + 'static,
71 {
72 Self::new_with_extensions(database_url, migrations, vec![])
73 }
74
75 fn new_with_extensions<S>(database_url: String, migrations: S, extensions: Vec<String>) -> Self
76 where
77 S: MigrationSource<'static> + Send + Sync + 'static,
78 {
79 let simple = Uuid::new_v4().simple();
80 let (server_url, dbname) = parse_postgres_url(&database_url);
81 let dbname = match dbname {
82 Some(db_name) => format!("{db_name}_test_{simple}"),
83 None => format!("test_{simple}"),
84 };
85 let dbname_cloned = dbname.clone();
86 let extensions_cloned = extensions.clone();
87
88 let tdb = Self {
89 server_url,
90 dbname,
91 extensions,
92 };
93
94 let url = tdb.url();
95
96 thread::spawn(move || {
98 let rt = Runtime::new().unwrap();
99 rt.block_on(async move {
100 let mut conn = PgConnection::connect(&database_url)
102 .await
103 .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
104 conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
105 .await
106 .unwrap();
107
108 let mut conn = PgConnection::connect(&url)
110 .await
111 .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
112
113 for ext in &extensions_cloned {
115 conn.execute(format!(r#"CREATE EXTENSION IF NOT EXISTS "{ext}""#).as_str())
116 .await
117 .unwrap_or_else(|_| panic!("Error while creating extension {ext}"));
118 }
119
120 let m = Migrator::new(migrations).await.unwrap();
121 m.run(&mut conn).await.unwrap();
122 });
123 })
124 .join()
125 .expect("failed to create database");
126
127 tdb
128 }
129
130 pub fn server_url(&self) -> String {
131 self.server_url.clone()
132 }
133
134 pub fn url(&self) -> String {
135 format!("{}/{}", self.server_url, self.dbname)
136 }
137
138 pub async fn get_pool(&self) -> PgPool {
139 let url = self.url();
140 PgPool::connect(&url)
141 .await
142 .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
143 }
144
145 pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
146 let pool = self.get_pool().await;
147 let path = filename.canonicalize()?;
148 let mut conn = pool.acquire().await?;
149 let sql = format!(
150 "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
151 table,
152 fields.join(","),
153 path.display()
154 );
155 conn.execute(sql.as_str()).await?;
156 Ok(())
159 }
160
161 pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
162 let mut rdr = csv::Reader::from_reader(csv.as_bytes());
163 let headers = rdr.headers()?.iter().join(",");
164 let mut tx = self.get_pool().await.begin().await?;
165 for result in rdr.records() {
166 let record = result?;
167 let sql = format!(
168 "INSERT INTO {} ({}) VALUES ({})",
169 table,
170 headers,
171 record.iter().map(|v| format!("'{v}'")).join(",")
172 );
173 tx.execute(sql.as_str()).await?;
174 }
175 tx.commit().await?;
176 Ok(())
177 }
178}
179
180impl Drop for TestPg {
181 fn drop(&mut self) {
182 let server_url = &self.server_url;
183 let database_url = format!("{server_url}/postgres");
184 let dbname = self.dbname.clone();
185 thread::spawn(move || {
186 let rt = Runtime::new().unwrap();
187 rt.block_on(async move {
188 let mut conn = PgConnection::connect(&database_url).await
189 .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
190 sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
192 .execute( &mut conn)
193 .await
194 .expect("Terminate all other connections");
195 conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
196 .await
197 .expect("Error while querying the drop database");
198 });
199 })
200 .join()
201 .expect("failed to drop database");
202 }
203}
204
205impl Default for TestPg {
206 fn default() -> Self {
207 Self::new(
208 "postgres://postgres:postgres@localhost:5432".to_string(),
209 Path::new("./fixtures/migrations"),
210 )
211 }
212}
213
214fn parse_postgres_url(url: &str) -> (String, Option<String>) {
215 let url_without_protocol = url.trim_start_matches("postgres://");
216
217 let parts: Vec<&str> = url_without_protocol.split('/').collect();
218 let server_url = format!("postgres://{}", parts[0]);
219
220 let dbname = if parts.len() > 1 && !parts[1].is_empty() {
221 Some(parts[1].to_string())
222 } else {
223 None
224 };
225
226 (server_url, dbname)
227}
228#[cfg(test)]
229mod tests {
230 use std::env;
231
232 use crate::postgres::TestPg;
233 use anyhow::Result;
234
235 #[tokio::test]
236 async fn test_postgres_should_create_and_drop() {
237 let tdb = TestPg::default();
238 let pool = tdb.get_pool().await;
239 sqlx::query("INSERT INTO todos (title) VALUES ('test')")
241 .execute(&pool)
242 .await
243 .unwrap();
244 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
246 .fetch_one(&pool)
247 .await
248 .unwrap();
249 assert_eq!(id, 1);
250 assert_eq!(title, "test");
251 }
252
253 #[tokio::test]
254 #[ignore = "github action postgres server can't be used for this test"]
255 async fn test_postgres_should_load_csv() -> Result<()> {
256 let filename = env::current_dir()?.join("fixtures/todos.csv");
257 let tdb = TestPg::default();
258 tdb.load_csv("todos", &["title"], &filename).await?;
259 let pool = tdb.get_pool().await;
260 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
262 .fetch_one(&pool)
263 .await
264 .unwrap();
265 assert_eq!(id, 1);
266 assert_eq!(title, "hello world");
267 Ok(())
268 }
269
270 #[tokio::test]
271 async fn test_postgres_should_load_csv_data() -> Result<()> {
272 let csv = include_str!("../fixtures/todos.csv");
273 let tdb = TestPg::default();
274 tdb.load_csv_data("todos", csv).await?;
275 let pool = tdb.get_pool().await;
276 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
278 .fetch_one(&pool)
279 .await
280 .unwrap();
281 assert_eq!(id, 1);
282 assert_eq!(title, "hello world");
283 Ok(())
284 }
285 use super::*;
286
287 #[tokio::test]
288 async fn test_postgres_with_extensions() {
289 use crate::TestPgBuilder;
290
291 let tdb = TestPgBuilder::new(
292 "postgres://postgres:postgres@localhost:5432".to_string(),
293 Path::new("./fixtures/migrations"),
294 )
295 .with_extensions(vec!["uuid-ossp".to_string()])
296 .build();
297
298 let pool = tdb.get_pool().await;
299
300 let result = sqlx::query_scalar::<_, String>("SELECT uuid_generate_v4()::text")
302 .fetch_one(&pool)
303 .await;
304
305 assert!(result.is_ok(), "uuid-ossp extension should be available");
306 }
307
308 #[test]
309 fn test_with_dbname() {
310 let url = "postgres://testuser:1@localhost/pureya";
311 let (server_url, dbname) = parse_postgres_url(url);
312 assert_eq!(server_url, "postgres://testuser:1@localhost");
313 assert_eq!(dbname, Some("pureya".to_string()));
314 }
315
316 #[test]
317 fn test_without_dbname() {
318 let url = "postgres://testuser:1@localhost";
319 let (server_url, dbname) = parse_postgres_url(url);
320 assert_eq!(server_url, "postgres://testuser:1@localhost");
321 assert_eq!(dbname, None);
322 }
323}