sqlx_db_tester/
postgres.rs1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4 migrate::{MigrationSource, Migrator},
5 Connection, Executor, PgConnection, PgPool,
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestPg {
13 pub server_url: String,
14 pub dbname: String,
15}
16
17impl TestPg {
18 pub fn new<S>(database_url: String, migrations: S) -> Self
19 where
20 S: MigrationSource<'static> + Send + Sync + 'static,
21 {
22 let simple = Uuid::new_v4().simple();
23 let (server_url, dbname) = parse_postgres_url(&database_url);
24 let dbname = match dbname {
25 Some(db_name) => format!("{}_test_{}", db_name, simple),
26 None => format!("test_{}", simple),
27 };
28 let dbname_cloned = dbname.clone();
29
30 let tdb = Self { server_url, dbname };
31
32 let url = tdb.url();
33
34 thread::spawn(move || {
36 let rt = Runtime::new().unwrap();
37 rt.block_on(async move {
38 let mut conn = PgConnection::connect(&database_url)
40 .await
41 .unwrap_or_else(|_| panic!("Error while connecting to {}", database_url));
42 conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
43 .await
44 .unwrap();
45
46 let mut conn = PgConnection::connect(&url)
48 .await
49 .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
50 let m = Migrator::new(migrations).await.unwrap();
51 m.run(&mut conn).await.unwrap();
52 });
53 })
54 .join()
55 .expect("failed to create database");
56
57 tdb
58 }
59
60 pub fn server_url(&self) -> String {
61 self.server_url.clone()
62 }
63
64 pub fn url(&self) -> String {
65 format!("{}/{}", self.server_url, self.dbname)
66 }
67
68 pub async fn get_pool(&self) -> PgPool {
69 let url = self.url();
70 PgPool::connect(&url)
71 .await
72 .unwrap_or_else(|_| panic!("Error while connecting to {}", url))
73 }
74
75 pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
76 let pool = self.get_pool().await;
77 let path = filename.canonicalize()?;
78 let mut conn = pool.acquire().await?;
79 let sql = format!(
80 "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
81 table,
82 fields.join(","),
83 path.display()
84 );
85 conn.execute(sql.as_str()).await?;
86 Ok(())
89 }
90
91 pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
92 let mut rdr = csv::Reader::from_reader(csv.as_bytes());
93 let headers = rdr.headers()?.iter().join(",");
94 let mut tx = self.get_pool().await.begin().await?;
95 for result in rdr.records() {
96 let record = result?;
97 let sql = format!(
98 "INSERT INTO {} ({}) VALUES ({})",
99 table,
100 headers,
101 record.iter().map(|v| format!("'{v}'")).join(",")
102 );
103 tx.execute(sql.as_str()).await?;
104 }
105 tx.commit().await?;
106 Ok(())
107 }
108}
109
110impl Drop for TestPg {
111 fn drop(&mut self) {
112 let server_url = &self.server_url;
113 let database_url = format!("{server_url}/postgres");
114 let dbname = self.dbname.clone();
115 thread::spawn(move || {
116 let rt = Runtime::new().unwrap();
117 rt.block_on(async move {
118 let mut conn = PgConnection::connect(&database_url).await
119 .unwrap_or_else(|_| panic!("Error while connecting to {}", database_url));
120 sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
122 .execute( &mut conn)
123 .await
124 .expect("Terminate all other connections");
125 conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
126 .await
127 .expect("Error while querying the drop database");
128 });
129 })
130 .join()
131 .expect("failed to drop database");
132 }
133}
134
135impl Default for TestPg {
136 fn default() -> Self {
137 Self::new(
138 "postgres://postgres:postgres@localhost:5432".to_string(),
139 Path::new("./fixtures/migrations"),
140 )
141 }
142}
143
144fn parse_postgres_url(url: &str) -> (String, Option<String>) {
145 let url_without_protocol = url.trim_start_matches("postgres://");
146
147 let parts: Vec<&str> = url_without_protocol.split('/').collect();
148 let server_url = format!("postgres://{}", parts[0]);
149
150 let dbname = if parts.len() > 1 && !parts[1].is_empty() {
151 Some(parts[1].to_string())
152 } else {
153 None
154 };
155
156 (server_url, dbname)
157}
158#[cfg(test)]
159mod tests {
160 use std::env;
161
162 use crate::postgres::TestPg;
163 use anyhow::Result;
164
165 #[tokio::test]
166 async fn test_postgres_should_create_and_drop() {
167 let tdb = TestPg::default();
168 let pool = tdb.get_pool().await;
169 sqlx::query("INSERT INTO todos (title) VALUES ('test')")
171 .execute(&pool)
172 .await
173 .unwrap();
174 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
176 .fetch_one(&pool)
177 .await
178 .unwrap();
179 assert_eq!(id, 1);
180 assert_eq!(title, "test");
181 }
182
183 #[ignore]
185 #[tokio::test]
186 async fn test_postgres_should_load_csv() -> Result<()> {
187 let filename = env::current_dir()?.join("fixtures/todos.csv");
188 let tdb = TestPg::default();
189 tdb.load_csv("todos", &["title"], &filename).await?;
190 let pool = tdb.get_pool().await;
191 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
193 .fetch_one(&pool)
194 .await
195 .unwrap();
196 assert_eq!(id, 1);
197 assert_eq!(title, "hello world");
198 Ok(())
199 }
200
201 #[tokio::test]
202 async fn test_postgres_should_load_csv_data() -> Result<()> {
203 let csv = include_str!("../fixtures/todos.csv");
204 let tdb = TestPg::default();
205 tdb.load_csv_data("todos", csv).await?;
206 let pool = tdb.get_pool().await;
207 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
209 .fetch_one(&pool)
210 .await
211 .unwrap();
212 assert_eq!(id, 1);
213 assert_eq!(title, "hello world");
214 Ok(())
215 }
216 use super::*;
217 #[test]
218 fn test_with_dbname() {
219 let url = "postgres://testuser:1@localhost/pureya";
220 let (server_url, dbname) = parse_postgres_url(url);
221 assert_eq!(server_url, "postgres://testuser:1@localhost");
222 assert_eq!(dbname, Some("pureya".to_string()));
223 }
224
225 #[test]
226 fn test_without_dbname() {
227 let url = "postgres://testuser:1@localhost";
228 let (server_url, dbname) = parse_postgres_url(url);
229 assert_eq!(server_url, "postgres://testuser:1@localhost");
230 assert_eq!(dbname, None);
231 }
232}