sqlx_db_tester/
mysql.rs

1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4    Connection, Executor, MySqlConnection, MySqlPool,
5    migrate::{MigrationSource, Migrator},
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestMySql {
13    pub server_url: String,
14    pub dbname: String,
15}
16
17impl TestMySql {
18    pub fn new<S>(database_url: String, migrations: S) -> Self
19    where
20        S: MigrationSource<'static> + Send + Sync + 'static,
21    {
22        let simple = Uuid::new_v4().simple();
23        let (server_url, dbname) = parse_mysql_url(&database_url);
24        let dbname = match dbname {
25            Some(db_name) => format!("{db_name}_test_{simple}"),
26            None => format!("test_{simple}"),
27        };
28        let dbname_cloned = dbname.clone();
29        let server_url_cloned = server_url.clone();
30
31        let tdb = Self { server_url, dbname };
32
33        let url = tdb.url();
34
35        // create database dbname
36        thread::spawn(move || {
37            let rt = Runtime::new().unwrap();
38            rt.block_on(async move {
39                // use server url to create database
40                // For MySQL, we always connect to the mysql system database to create a new database
41                let create_db_url = format!("{server_url_cloned}/mysql");
42                let mut conn = MySqlConnection::connect(&create_db_url)
43                    .await
44                    .unwrap_or_else(|_| panic!("Error while connecting to {create_db_url}"));
45                conn.execute(format!(r#"CREATE DATABASE `{dbname_cloned}`"#).as_str())
46                    .await
47                    .unwrap();
48
49                // now connect to test database for migration
50                let mut conn = MySqlConnection::connect(&url)
51                    .await
52                    .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
53                let m = Migrator::new(migrations).await.unwrap();
54                m.run(&mut conn).await.unwrap();
55            });
56        })
57        .join()
58        .expect("failed to create database");
59
60        tdb
61    }
62
63    pub fn server_url(&self) -> String {
64        self.server_url.clone()
65    }
66
67    pub fn url(&self) -> String {
68        format!("{}/{}", self.server_url, self.dbname)
69    }
70
71    pub async fn get_pool(&self) -> MySqlPool {
72        let url = self.url();
73        MySqlPool::connect(&url)
74            .await
75            .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
76    }
77
78    pub async fn load_csv(&self, table: &str, _fields: &[&str], filename: &Path) -> Result<()> {
79        // For MySQL, we read the file and use load_csv_data since LOAD DATA LOCAL INFILE
80        // requires complex setup and the file needs to be accessible from the MySQL process
81        let csv_content = std::fs::read_to_string(filename)?;
82        self.load_csv_data(table, &csv_content).await
83    }
84
85    pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
86        let mut rdr = csv::Reader::from_reader(csv.as_bytes());
87        let headers = rdr.headers()?.iter().join(",");
88        let mut tx = self.get_pool().await.begin().await?;
89        for result in rdr.records() {
90            let record = result?;
91            let sql = format!(
92                "INSERT INTO {} ({}) VALUES ({})",
93                table,
94                headers,
95                record.iter().map(|v| format!("'{v}'")).join(",")
96            );
97            tx.execute(sql.as_str()).await?;
98        }
99        tx.commit().await?;
100        Ok(())
101    }
102}
103
104impl Drop for TestMySql {
105    fn drop(&mut self) {
106        let server_url = &self.server_url;
107        let database_url = format!("{server_url}/mysql");
108        let dbname = self.dbname.clone();
109        thread::spawn(move || {
110            let rt = Runtime::new().unwrap();
111            rt.block_on(async move {
112                let mut conn = MySqlConnection::connect(&database_url)
113                    .await
114                    .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
115                conn.execute(format!(r#"DROP DATABASE `{dbname}`"#).as_str())
116                    .await
117                    .expect("Error while querying the drop database");
118            });
119        })
120        .join()
121        .expect("failed to drop database");
122    }
123}
124
125impl Default for TestMySql {
126    fn default() -> Self {
127        Self::new(
128            "mysql://root:password@127.0.0.1:3307".to_string(),
129            Path::new("./fixtures/mysql_migrations"),
130        )
131    }
132}
133
134fn parse_mysql_url(url: &str) -> (String, Option<String>) {
135    let url_without_protocol = url.trim_start_matches("mysql://");
136
137    let parts: Vec<&str> = url_without_protocol.split('/').collect();
138    let server_url = format!("mysql://{}", parts[0]);
139
140    let dbname = if parts.len() > 1 && !parts[1].is_empty() {
141        Some(parts[1].to_string())
142    } else {
143        None
144    };
145
146    (server_url, dbname)
147}
148
149#[cfg(test)]
150mod tests {
151    use std::env;
152
153    use crate::mysql::TestMySql;
154    use anyhow::Result;
155
156    #[tokio::test]
157    #[ignore = "requires MySQL server running on 127.0.0.1:3307"]
158    async fn test_mysql_should_create_and_drop() {
159        let tdb = TestMySql::default();
160        let pool = tdb.get_pool().await;
161        // insert todo
162        sqlx::query("INSERT INTO todos (title) VALUES ('test')")
163            .execute(&pool)
164            .await
165            .unwrap();
166        // get todo
167        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
168            .fetch_one(&pool)
169            .await
170            .unwrap();
171        assert_eq!(id, 1);
172        assert_eq!(title, "test");
173    }
174
175    #[tokio::test]
176    #[ignore = "requires MySQL server running on 127.0.0.1:3307"]
177    async fn test_mysql_should_load_csv() -> Result<()> {
178        let filename = env::current_dir()?.join("fixtures/todos.csv");
179        let tdb = TestMySql::default();
180        tdb.load_csv("todos", &["title"], &filename).await?;
181        let pool = tdb.get_pool().await;
182        // get todo
183        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
184            .fetch_one(&pool)
185            .await
186            .unwrap();
187        assert_eq!(id, 1);
188        assert_eq!(title, "hello world");
189        Ok(())
190    }
191
192    #[tokio::test]
193    #[ignore = "requires MySQL server running on 127.0.0.1:3307"]
194    async fn test_mysql_should_load_csv_data() -> Result<()> {
195        let csv = include_str!("../fixtures/todos.csv");
196        let tdb = TestMySql::default();
197        tdb.load_csv_data("todos", csv).await?;
198        let pool = tdb.get_pool().await;
199        // get todo
200        let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
201            .fetch_one(&pool)
202            .await
203            .unwrap();
204        assert_eq!(id, 1);
205        assert_eq!(title, "hello world");
206        Ok(())
207    }
208
209    use super::*;
210
211    #[test]
212    fn test_with_dbname() {
213        let url = "mysql://testuser:1@localhost/testdb";
214        let (server_url, dbname) = parse_mysql_url(url);
215        assert_eq!(server_url, "mysql://testuser:1@localhost");
216        assert_eq!(dbname, Some("testdb".to_string()));
217    }
218
219    #[test]
220    fn test_without_dbname() {
221        let url = "mysql://testuser:1@localhost";
222        let (server_url, dbname) = parse_mysql_url(url);
223        assert_eq!(server_url, "mysql://testuser:1@localhost");
224        assert_eq!(dbname, None);
225    }
226}