sqlx_mock_db_tester/
postgres.rs

1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::migrate::{MigrationSource, Migrator};
4use sqlx::{Connection, Executor, PgConnection, PgPool};
5use std::path::Path;
6use std::thread;
7use tokio::runtime::Runtime;
8use uuid::Uuid;
9
10#[derive(Debug)]
11pub struct TestPg {
12    pub server_url: String,
13    pub dbname: String,
14}
15
16impl TestPg {
17    pub fn new<S>(server_url: String, migrations: S) -> Self
18    where
19        S: MigrationSource<'static> + Send + Sync + 'static,
20    {
21        let uuid = Uuid::new_v4();
22        let dbname = format!("test_{uuid}");
23        let dbname_cloned = dbname.clone();
24
25        let tdb = Self { server_url, dbname };
26
27        let server_url = tdb.server_url();
28        let url = tdb.url();
29
30        // create database dbname
31        thread::spawn(move || {
32            let rt = Runtime::new().unwrap();
33            rt.block_on(async move {
34                // use server url to create database dbname
35                let mut conn = PgConnection::connect(&server_url).await.unwrap();
36                conn.execute(format!(r#"CREATE DATABASE "{dbname_cloned}""#).as_str())
37                    .await
38                    .unwrap();
39
40                // now connect to test database for migration
41                let mut conn = PgConnection::connect(&url).await.unwrap();
42                let m = Migrator::new(migrations).await.unwrap();
43                m.run(&mut conn).await.unwrap();
44            });
45        })
46        .join()
47        .expect("failed to create database.");
48
49        tdb
50    }
51
52    pub fn server_url(&self) -> String {
53        self.server_url.clone()
54    }
55
56    pub fn url(&self) -> String {
57        format!("{}/{}", self.server_url, self.dbname)
58    }
59
60    pub async fn get_pool(&self) -> PgPool {
61        PgPool::connect(&self.url()).await.unwrap()
62    }
63
64    pub async fn load_csv(&self, table: &str, fields: &[&str], filename: &Path) -> Result<()> {
65        let pool = self.get_pool().await;
66        let path = filename.canonicalize()?;
67        let mut conn = pool.acquire().await?;
68        let sql = format!(
69            "COPY {} ({}) FROM '{}' DELIMITER ',' CSV HEADER;",
70            table,
71            fields.join(","),
72            path.display()
73        );
74        conn.execute(sql.as_str()).await?;
75        // copy csv
76
77        Ok(())
78    }
79
80    pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
81        let mut rdr = csv::Reader::from_reader(csv.as_bytes());
82        let headers = rdr.headers()?.iter().join(",");
83        let mut tx = self.get_pool().await.begin().await?;
84        for result in rdr.records() {
85            let record = result?;
86            let sql = format!(
87                "INSERT INTO {} ({}) VALUES ({})",
88                table,
89                headers,
90                record.iter().map(|v| format!("'{v}'")).join(",")
91            );
92            tx.execute(sql.as_str()).await?;
93        }
94        tx.commit().await?;
95
96        Ok(())
97    }
98}
99
100impl Drop for TestPg {
101    fn drop(&mut self) {
102        let server_url = self.server_url();
103        let dbname = self.dbname.clone();
104        thread::spawn(move || {
105          let rt = Runtime::new().unwrap();
106            rt.block_on(async move {
107                let mut conn = PgConnection::connect(&server_url).await.unwrap();
108                // terminate existing connections
109                sqlx::query(&format!(r#"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid() AND datname = '{dbname}'"#))
110                    .execute( &mut conn)
111                    .await
112                    .expect("Terminate all other connections");
113                conn.execute(format!(r#"DROP DATABASE "{dbname}""#).as_str())
114                    .await
115                    .expect("Error while querying the drop database");
116            });
117        })
118            .join()
119            .expect("failed to drop database");
120    }
121}
122
123impl Default for TestPg {
124    fn default() -> Self {
125        Self::new(
126            "postgres://postgres:postgres@localhost:5432".to_string(),
127            Path::new("./fixtures/migrations"),
128        )
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use crate::TestPg;
135    use anyhow::Result;
136    use std::path::Path;
137
138    #[tokio::test]
139    async fn test_postgres_should_create_and_drop() {
140        // let tdb = TestPg::default();
141        let tdb = TestPg::new(
142            "postgres://postgres:ReservationP@home:15432".to_string(),
143            Path::new("./fixtures/migrations"),
144        );
145        let pool = tdb.get_pool().await;
146        // insert todo
147        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
148            .execute(&pool)
149            .await
150            .unwrap();
151        // get todo
152        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title from todos")
153            .fetch_one(&pool)
154            .await
155            .unwrap();
156
157        assert_eq!(id, 1);
158        assert_eq!(title, "test")
159    }
160
161    #[tokio::test]
162    async fn test_postgres_should_load_csv() -> Result<()> {
163        // ./fixtures/todos.csv must be on the same path under the postgres database server.
164        let filename = Path::new("./fixtures/todos.csv");
165        // let tdb = TestPg::default();
166        let tdb = TestPg::new(
167            "postgres://postgres:ReservationP@home:15432".to_string(),
168            Path::new("./fixtures/migrations"),
169        );
170        if let Err(e) = tdb.load_csv("todos", &["title"], filename).await {
171            if e.to_string().contains("No such file or directory") {
172                return Ok(());
173            } else {
174                panic!("{e}");
175            }
176        }
177        let pool = tdb.get_pool().await;
178        // get todo
179        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
180            .fetch_one(&pool)
181            .await
182            .unwrap();
183        assert_eq!(id, 1);
184        assert_eq!(title, "hello world");
185        Ok(())
186    }
187
188    #[tokio::test]
189    async fn test_postgres_should_load_csv_data() -> Result<()> {
190        let csv = include_str!("../fixtures/todos.csv");
191        let tdb = TestPg::new(
192            "postgres://postgres:ReservationP@home:15432".to_string(),
193            Path::new("./fixtures/migrations"),
194        );
195        tdb.load_csv_data("todos", csv).await?;
196        let pool = tdb.get_pool().await;
197        // get todo
198        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
199            .fetch_one(&pool)
200            .await
201            .unwrap();
202        assert_eq!(id, 1);
203        assert_eq!(title, "hello world");
204        Ok(())
205    }
206}