1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4 Connection, Executor, PgConnection, PgPool,
5 migrate::{MigrationSource, Migrator},
6};
7use std::{
8 fs,
9 path::{Path, PathBuf},
10 thread,
11};
12use tokio::runtime::Runtime;
13use uuid::Uuid;
14
15#[derive(Debug)]
16pub struct TestPg {
17 pub server_url: String,
18 pub dbname: String,
19 #[allow(dead_code)]
20 extensions: Vec<String>,
21}
22
23pub struct TestPgBuilder<S>
25where
26 S: MigrationSource<'static> + Send + Sync + 'static,
27{
28 database_url: String,
29 migrations: S,
30 extensions: Vec<String>,
31 seeds_path: Option<PathBuf>,
32}
33
34impl<S> TestPgBuilder<S>
35where
36 S: MigrationSource<'static> + Send + Sync + 'static,
37{
38 pub fn new(database_url: String, migrations: S) -> Self {
40 Self {
41 database_url,
42 migrations,
43 extensions: vec![],
44 seeds_path: None,
45 }
46 }
47
48 pub fn with_extensions(mut self, extensions: Vec<String>) -> Self {
63 self.extensions = extensions;
64 self
65 }
66
67 pub fn with_seeds<P: AsRef<Path>>(mut self, seeds_path: P) -> Self {
86 self.seeds_path = Some(seeds_path.as_ref().to_path_buf());
87 self
88 }
89
90 pub fn build(self) -> TestPg {
92 TestPg::new_with_config(
93 self.database_url,
94 self.migrations,
95 self.extensions,
96 self.seeds_path,
97 )
98 }
99}
100
101impl TestPg {
102 pub fn new<S>(database_url: String, migrations: S) -> Self
103 where
104 S: MigrationSource<'static> + Send + Sync + 'static,
105 {
106 Self::new_with_config(database_url, migrations, vec![], None)
107 }
108
109 #[allow(dead_code)]
110 fn new_with_extensions<S>(database_url: String, migrations: S, extensions: Vec<String>) -> Self
111 where
112 S: MigrationSource<'static> + Send + Sync + 'static,
113 {
114 Self::new_with_config(database_url, migrations, extensions, None)
115 }
116
117 fn new_with_config<S>(
118 database_url: String,
119 migrations: S,
120 extensions: Vec<String>,
121 seeds_path: Option<PathBuf>,
122 ) -> Self
123 where
124 S: MigrationSource<'static> + Send + Sync + 'static,
125 {
126 let simple = Uuid::new_v4().simple();
127 let (server_url, dbname) = parse_postgres_url(&database_url);
128 let dbname = match dbname {
129 Some(db_name) => format!("{db_name}_test_{simple}"),
130 None => format!("test_{simple}"),
131 };
132 let dbname_cloned = dbname.clone();
133 let extensions_cloned = extensions.clone();
134
135 let tdb = Self {
136 server_url,
137 dbname,
138 extensions,
139 };
140
141 let url = tdb.url();
142
143 thread::spawn(move || {
145 let rt = Runtime::new().unwrap();
146 rt.block_on(async move {
147 let mut conn = PgConnection::connect(&database_url)
149 .await
150 .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
151 conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
152 .await
153 .unwrap();
154
155 let mut conn = PgConnection::connect(&url)
157 .await
158 .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
159
160 for ext in &extensions_cloned {
162 conn.execute(format!(r#"CREATE EXTENSION IF NOT EXISTS "{ext}""#).as_str())
163 .await
164 .unwrap_or_else(|_| panic!("Error while creating extension {ext}"));
165 }
166
167 let m = Migrator::new(migrations).await.unwrap();
168 m.run(&mut conn).await.unwrap();
169
170 if let Some(seeds_dir) = seeds_path {
172 run_seeds(&mut conn, &seeds_dir).await.unwrap();
173 }
174 });
175 })
176 .join()
177 .expect("failed to create database");
178
179 tdb
180 }
181
182 pub fn server_url(&self) -> String {
183 self.server_url.clone()
184 }
185
186 pub fn url(&self) -> String {
187 format!("{}/{}", self.server_url, self.dbname)
188 }
189
190 pub async fn get_pool(&self) -> PgPool {
191 let url = self.url();
192 PgPool::connect(&url)
193 .await
194 .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
195 }
196
197 pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
198 let pool = self.get_pool().await;
199 let path = filename.canonicalize()?;
200 let mut conn = pool.acquire().await?;
201 let sql = format!(
202 "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
203 table,
204 fields.join(","),
205 path.display()
206 );
207 conn.execute(sql.as_str()).await?;
208 Ok(())
211 }
212
213 pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
214 let mut rdr = csv::Reader::from_reader(csv.as_bytes());
215 let headers = rdr.headers()?.iter().join(",");
216 let mut tx = self.get_pool().await.begin().await?;
217 for result in rdr.records() {
218 let record = result?;
219 let sql = format!(
220 "INSERT INTO {} ({}) VALUES ({})",
221 table,
222 headers,
223 record.iter().map(|v| format!("'{v}'")).join(",")
224 );
225 tx.execute(sql.as_str()).await?;
226 }
227 tx.commit().await?;
228 Ok(())
229 }
230}
231
232impl Drop for TestPg {
233 fn drop(&mut self) {
234 let server_url = &self.server_url;
235 let database_url = format!("{server_url}/postgres");
236 let dbname = self.dbname.clone();
237 thread::spawn(move || {
238 let rt = Runtime::new().unwrap();
239 rt.block_on(async move {
240 let mut conn = PgConnection::connect(&database_url).await
241 .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
242 sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
244 .execute( &mut conn)
245 .await
246 .expect("Terminate all other connections");
247 conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
248 .await
249 .expect("Error while querying the drop database");
250 });
251 })
252 .join()
253 .expect("failed to drop database");
254 }
255}
256
257impl Default for TestPg {
258 fn default() -> Self {
259 Self::new(
260 "postgres://postgres:postgres@localhost:5432".to_string(),
261 Path::new("./fixtures/migrations"),
262 )
263 }
264}
265
266async fn run_seeds(conn: &mut PgConnection, seeds_dir: &Path) -> Result<()> {
271 if !seeds_dir.exists() {
272 return Ok(());
273 }
274
275 let mut seed_files = Vec::new();
276
277 for entry in fs::read_dir(seeds_dir)? {
279 let entry = entry?;
280 let path = entry.path();
281
282 if path.is_file()
283 && path.extension().is_some_and(|ext| ext == "sql")
284 && let Some(filename) = path.file_name().and_then(|n| n.to_str())
285 {
286 if let Some(timestamp) = filename.split('_').next() {
288 seed_files.push((timestamp.to_string(), path));
289 }
290 }
291 }
292
293 seed_files.sort_by(|a, b| a.0.cmp(&b.0));
295
296 for (_timestamp, path) in seed_files {
298 let sql = fs::read_to_string(&path)?;
299 conn.execute(sql.as_str()).await?;
300 }
301
302 Ok(())
303}
304
305fn parse_postgres_url(url: &str) -> (String, Option<String>) {
306 let url_without_protocol = url.trim_start_matches("postgres://");
307
308 let parts: Vec<&str> = url_without_protocol.split('/').collect();
309 let server_url = format!("postgres://{}", parts[0]);
310
311 let dbname = if parts.len() > 1 && !parts[1].is_empty() {
312 Some(parts[1].to_string())
313 } else {
314 None
315 };
316
317 (server_url, dbname)
318}
319#[cfg(test)]
320mod tests {
321 use std::env;
322
323 use crate::postgres::TestPg;
324 use anyhow::Result;
325
326 #[tokio::test]
327 async fn test_postgres_should_create_and_drop() {
328 let tdb = TestPg::default();
329 let pool = tdb.get_pool().await;
330 sqlx::query("INSERT INTO todos (title) VALUES ('test')")
332 .execute(&pool)
333 .await
334 .unwrap();
335 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
337 .fetch_one(&pool)
338 .await
339 .unwrap();
340 assert_eq!(id, 1);
341 assert_eq!(title, "test");
342 }
343
344 #[tokio::test]
345 #[ignore = "github action postgres server can't be used for this test"]
346 async fn test_postgres_should_load_csv() -> Result<()> {
347 let filename = env::current_dir()?.join("fixtures/todos.csv");
348 let tdb = TestPg::default();
349 tdb.load_csv("todos", &["title"], &filename).await?;
350 let pool = tdb.get_pool().await;
351 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
353 .fetch_one(&pool)
354 .await
355 .unwrap();
356 assert_eq!(id, 1);
357 assert_eq!(title, "hello world");
358 Ok(())
359 }
360
361 #[tokio::test]
362 async fn test_postgres_should_load_csv_data() -> Result<()> {
363 let csv = include_str!("../fixtures/todos.csv");
364 let tdb = TestPg::default();
365 tdb.load_csv_data("todos", csv).await?;
366 let pool = tdb.get_pool().await;
367 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
369 .fetch_one(&pool)
370 .await
371 .unwrap();
372 assert_eq!(id, 1);
373 assert_eq!(title, "hello world");
374 Ok(())
375 }
376 use super::*;
377
378 #[tokio::test]
379 async fn test_postgres_with_extensions() {
380 use crate::TestPgBuilder;
381
382 let tdb = TestPgBuilder::new(
383 "postgres://postgres:postgres@localhost:5432".to_string(),
384 Path::new("./fixtures/migrations"),
385 )
386 .with_extensions(vec!["uuid-ossp".to_string()])
387 .build();
388
389 let pool = tdb.get_pool().await;
390
391 let result = sqlx::query_scalar::<_, String>("SELECT uuid_generate_v4()::text")
393 .fetch_one(&pool)
394 .await;
395
396 assert!(result.is_ok(), "uuid-ossp extension should be available");
397 }
398
399 #[test]
400 fn test_with_dbname() {
401 let url = "postgres://testuser:1@localhost/pureya";
402 let (server_url, dbname) = parse_postgres_url(url);
403 assert_eq!(server_url, "postgres://testuser:1@localhost");
404 assert_eq!(dbname, Some("pureya".to_string()));
405 }
406
407 #[test]
408 fn test_without_dbname() {
409 let url = "postgres://testuser:1@localhost";
410 let (server_url, dbname) = parse_postgres_url(url);
411 assert_eq!(server_url, "postgres://testuser:1@localhost");
412 assert_eq!(dbname, None);
413 }
414
415 #[tokio::test]
416 async fn test_postgres_with_seeds() {
417 use crate::TestPgBuilder;
418
419 let tdb = TestPgBuilder::new(
420 "postgres://postgres:postgres@localhost:5432".to_string(),
421 Path::new("./fixtures/migrations"),
422 )
423 .with_seeds(Path::new("./fixtures/seeds"))
424 .build();
425
426 let pool = tdb.get_pool().await;
427
428 let todos = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos ORDER BY id")
430 .fetch_all(&pool)
431 .await
432 .unwrap();
433
434 assert_eq!(todos.len(), 3);
435 assert_eq!(todos[0].1, "First seeded todo");
436 assert_eq!(todos[1].1, "Second seeded todo");
437 assert_eq!(todos[2].1, "Third seeded todo");
438 }
439}