sqlx_mock_db_tester/
postgres.rs1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::migrate::{MigrationSource, Migrator};
4use sqlx::{Connection, Executor, PgConnection, PgPool};
5use std::path::Path;
6use std::thread;
7use tokio::runtime::Runtime;
8use uuid::Uuid;
9
10#[derive(Debug)]
11pub struct TestPg {
12 pub server_url: String,
13 pub dbname: String,
14}
15
16impl TestPg {
17 pub fn new<S>(server_url: String, migrations: S) -> Self
18 where
19 S: MigrationSource<'static> + Send + Sync + 'static,
20 {
21 let uuid = Uuid::new_v4();
22 let dbname = format!("test_{uuid}");
23 let dbname_cloned = dbname.clone();
24
25 let tdb = Self { server_url, dbname };
26
27 let server_url = tdb.server_url();
28 let url = tdb.url();
29
30 thread::spawn(move || {
32 let rt = Runtime::new().unwrap();
33 rt.block_on(async move {
34 let mut conn = PgConnection::connect(&server_url).await.unwrap();
36 conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
37 .await
38 .unwrap();
39
40 let mut conn = PgConnection::connect(&url).await.unwrap();
42 let m = Migrator::new(migrations).await.unwrap();
43 m.run(&mut conn).await.unwrap();
44 });
45 })
46 .join()
47 .expect("failed to create database.");
48
49 tdb
50 }
51
52 pub fn server_url(&self) -> String {
53 self.server_url.clone()
54 }
55
56 pub fn url(&self) -> String {
57 format!("{}/{}", self.server_url, self.dbname)
58 }
59
60 pub async fn get_pool(&self) -> PgPool {
61 PgPool::connect(&self.url()).await.unwrap()
62 }
63
64 pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
65 let pool = self.get_pool().await;
66 let path = filename.canonicalize()?;
67 let mut conn = pool.acquire().await?;
68 let sql = format!(
69 "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
70 table,
71 fields.join(","),
72 path.display()
73 );
74 conn.execute(sql.as_str()).await?;
75 Ok(())
78 }
79
80 pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
81 let mut rdr = csv::Reader::from_reader(csv.as_bytes());
82 let headers = rdr.headers()?.iter().join(",");
83 let mut tx = self.get_pool().await.begin().await?;
84 for result in rdr.records() {
85 let record = result?;
86 let sql = format!(
87 "INSERT INTO {} ({}) VALUES ({})",
88 table,
89 headers,
90 record.iter().map(|v| format!("'{v}'")).join(",")
91 );
92 tx.execute(sql.as_str()).await?;
93 }
94 tx.commit().await?;
95
96 Ok(())
97 }
98}
99
100impl Drop for TestPg {
101 fn drop(&mut self) {
102 let server_url = self.server_url();
103 let dbname = self.dbname.clone();
104 thread::spawn(move || {
105 let rt = Runtime::new().unwrap();
106 rt.block_on(async move {
107 let mut conn = PgConnection::connect(&server_url).await.unwrap();
108 sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
110 .execute( &mut conn)
111 .await
112 .expect("Terminate all other connections");
113 conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
114 .await
115 .expect("Error while querying the drop database");
116 });
117 })
118 .join()
119 .expect("failed to drop database");
120 }
121}
122
123impl Default for TestPg {
124 fn default() -> Self {
125 Self::new(
126 "postgres://postgres:postgres@localhost:5432".to_string(),
127 Path::new("./fixtures/migrations"),
128 )
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use crate::TestPg;
135 use anyhow::Result;
136 use std::path::Path;
137
138 #[tokio::test]
139 async fn test_postgres_should_create_and_drop() {
140 let tdb = TestPg::new(
142 "postgres://postgres:ReservationP@home:15432".to_string(),
143 Path::new("./fixtures/migrations"),
144 );
145 let pool = tdb.get_pool().await;
146 sqlx::query("INSERT INTO todos (title) VALUES ('test')")
148 .execute(&pool)
149 .await
150 .unwrap();
151 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title from todos")
153 .fetch_one(&pool)
154 .await
155 .unwrap();
156
157 assert_eq!(id, 1);
158 assert_eq!(title, "test")
159 }
160
161 #[tokio::test]
162 async fn test_postgres_should_load_csv() -> Result<()> {
163 let filename = Path::new("./fixtures/todos.csv");
165 let tdb = TestPg::new(
167 "postgres://postgres:ReservationP@home:15432".to_string(),
168 Path::new("./fixtures/migrations"),
169 );
170 if let Err(e) = tdb.load_csv("todos", &["title"], filename).await {
171 if e.to_string().contains("No such file or directory") {
172 return Ok(());
173 } else {
174 panic!("{e}");
175 }
176 }
177 let pool = tdb.get_pool().await;
178 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
180 .fetch_one(&pool)
181 .await
182 .unwrap();
183 assert_eq!(id, 1);
184 assert_eq!(title, "hello world");
185 Ok(())
186 }
187
188 #[tokio::test]
189 async fn test_postgres_should_load_csv_data() -> Result<()> {
190 let csv = include_str!("../fixtures/todos.csv");
191 let tdb = TestPg::new(
192 "postgres://postgres:ReservationP@home:15432".to_string(),
193 Path::new("./fixtures/migrations"),
194 );
195 tdb.load_csv_data("todos", csv).await?;
196 let pool = tdb.get_pool().await;
197 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
199 .fetch_one(&pool)
200 .await
201 .unwrap();
202 assert_eq!(id, 1);
203 assert_eq!(title, "hello world");
204 Ok(())
205 }
206}