1use anyhow::Result;
2use itertools::Itertools;
3use sqlx::{
4 Connection, Executor, MySqlConnection, MySqlPool,
5 migrate::{MigrationSource, Migrator},
6};
7use std::{path::Path, thread};
8use tokio::runtime::Runtime;
9use uuid::Uuid;
10
11#[derive(Debug)]
12pub struct TestMySql {
13 pub server_url: String,
14 pub dbname: String,
15}
16
17impl TestMySql {
18 pub fn new<S>(database_url: String, migrations: S) -> Self
19 where
20 S: MigrationSource<'static> + Send + Sync + 'static,
21 {
22 let simple = Uuid::new_v4().simple();
23 let (server_url, dbname) = parse_mysql_url(&database_url);
24 let dbname = match dbname {
25 Some(db_name) => format!("{db_name}_test_{simple}"),
26 None => format!("test_{simple}"),
27 };
28 let dbname_cloned = dbname.clone();
29 let server_url_cloned = server_url.clone();
30
31 let tdb = Self { server_url, dbname };
32
33 let url = tdb.url();
34
35 thread::spawn(move || {
37 let rt = Runtime::new().unwrap();
38 rt.block_on(async move {
39 let create_db_url = format!("{server_url_cloned}/mysql");
42 let mut conn = MySqlConnection::connect(&create_db_url)
43 .await
44 .unwrap_or_else(|_| panic!("Error while connecting to {create_db_url}"));
45 conn.execute(format!(r#"CREATE DATABASE `{dbname_cloned}`"#).as_str())
46 .await
47 .unwrap();
48
49 let mut conn = MySqlConnection::connect(&url)
51 .await
52 .unwrap_or_else(|_| panic!("Error while connecting to {}", &url));
53 let m = Migrator::new(migrations).await.unwrap();
54 m.run(&mut conn).await.unwrap();
55 });
56 })
57 .join()
58 .expect("failed to create database");
59
60 tdb
61 }
62
63 pub fn server_url(&self) -> String {
64 self.server_url.clone()
65 }
66
67 pub fn url(&self) -> String {
68 format!("{}/{}", self.server_url, self.dbname)
69 }
70
71 pub async fn get_pool(&self) -> MySqlPool {
72 let url = self.url();
73 MySqlPool::connect(&url)
74 .await
75 .unwrap_or_else(|_| panic!("Error while connecting to {url}"))
76 }
77
78 pub async fn load_csv(&self, table: &str, _fields: &[&str], filename: &Path) -> Result<()> {
79 let csv_content = std::fs::read_to_string(filename)?;
82 self.load_csv_data(table, &csv_content).await
83 }
84
85 pub async fn load_csv_data(&self, table: &str, csv: &str) -> Result<()> {
86 let mut rdr = csv::Reader::from_reader(csv.as_bytes());
87 let headers = rdr.headers()?.iter().join(",");
88 let mut tx = self.get_pool().await.begin().await?;
89 for result in rdr.records() {
90 let record = result?;
91 let sql = format!(
92 "INSERT INTO {} ({}) VALUES ({})",
93 table,
94 headers,
95 record.iter().map(|v| format!("'{v}'")).join(",")
96 );
97 tx.execute(sql.as_str()).await?;
98 }
99 tx.commit().await?;
100 Ok(())
101 }
102}
103
104impl Drop for TestMySql {
105 fn drop(&mut self) {
106 let server_url = &self.server_url;
107 let database_url = format!("{server_url}/mysql");
108 let dbname = self.dbname.clone();
109 thread::spawn(move || {
110 let rt = Runtime::new().unwrap();
111 rt.block_on(async move {
112 let mut conn = MySqlConnection::connect(&database_url)
113 .await
114 .unwrap_or_else(|_| panic!("Error while connecting to {database_url}"));
115 conn.execute(format!(r#"DROP DATABASE `{dbname}`"#).as_str())
116 .await
117 .expect("Error while querying the drop database");
118 });
119 })
120 .join()
121 .expect("failed to drop database");
122 }
123}
124
125impl Default for TestMySql {
126 fn default() -> Self {
127 Self::new(
128 "mysql://root:password@127.0.0.1:3307".to_string(),
129 Path::new("./fixtures/mysql_migrations"),
130 )
131 }
132}
133
134fn parse_mysql_url(url: &str) -> (String, Option<String>) {
135 let url_without_protocol = url.trim_start_matches("mysql://");
136
137 let parts: Vec<&str> = url_without_protocol.split('/').collect();
138 let server_url = format!("mysql://{}", parts[0]);
139
140 let dbname = if parts.len() > 1 && !parts[1].is_empty() {
141 Some(parts[1].to_string())
142 } else {
143 None
144 };
145
146 (server_url, dbname)
147}
148
149#[cfg(test)]
150mod tests {
151 use std::env;
152
153 use crate::mysql::TestMySql;
154 use anyhow::Result;
155
156 #[tokio::test]
157 #[ignore = "requires MySQL server running on 127.0.0.1:3307"]
158 async fn test_mysql_should_create_and_drop() {
159 let tdb = TestMySql::default();
160 let pool = tdb.get_pool().await;
161 sqlx::query("INSERT INTO todos (title) VALUES ('test')")
163 .execute(&pool)
164 .await
165 .unwrap();
166 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
168 .fetch_one(&pool)
169 .await
170 .unwrap();
171 assert_eq!(id, 1);
172 assert_eq!(title, "test");
173 }
174
175 #[tokio::test]
176 #[ignore = "requires MySQL server running on 127.0.0.1:3307"]
177 async fn test_mysql_should_load_csv() -> Result<()> {
178 let filename = env::current_dir()?.join("fixtures/todos.csv");
179 let tdb = TestMySql::default();
180 tdb.load_csv("todos", &["title"], &filename).await?;
181 let pool = tdb.get_pool().await;
182 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
184 .fetch_one(&pool)
185 .await
186 .unwrap();
187 assert_eq!(id, 1);
188 assert_eq!(title, "hello world");
189 Ok(())
190 }
191
192 #[tokio::test]
193 #[ignore = "requires MySQL server running on 127.0.0.1:3307"]
194 async fn test_mysql_should_load_csv_data() -> Result<()> {
195 let csv = include_str!("../fixtures/todos.csv");
196 let tdb = TestMySql::default();
197 tdb.load_csv_data("todos", csv).await?;
198 let pool = tdb.get_pool().await;
199 let (id, title) = sqlx::query_as::<_, (i32, String)>("SELECT id, title FROM todos")
201 .fetch_one(&pool)
202 .await
203 .unwrap();
204 assert_eq!(id, 1);
205 assert_eq!(title, "hello world");
206 Ok(())
207 }
208
209 use super::*;
210
211 #[test]
212 fn test_with_dbname() {
213 let url = "mysql://testuser:1@localhost/testdb";
214 let (server_url, dbname) = parse_mysql_url(url);
215 assert_eq!(server_url, "mysql://testuser:1@localhost");
216 assert_eq!(dbname, Some("testdb".to_string()));
217 }
218
219 #[test]
220 fn test_without_dbname() {
221 let url = "mysql://testuser:1@localhost";
222 let (server_url, dbname) = parse_mysql_url(url);
223 assert_eq!(server_url, "mysql://testuser:1@localhost");
224 assert_eq!(dbname, None);
225 }
226}