1#![cfg(feature = "with-mysql-async")]
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use tokio::sync::Mutex;
7
8use mysql_async::{Conn, Opts, Pool, prelude::*};
9use testkit_core::{
10 DatabaseBackend, DatabaseConfig, DatabaseName, DatabasePool, TestDatabaseConnection,
11};
12
13use crate::error::MySqlError;
14
15#[derive(Clone)]
17pub struct MySqlConnection {
18 conn: Arc<Mutex<Conn>>,
20 connection_string: String,
22}
23
24pub struct MySqlTransaction {
26 conn: Arc<Mutex<Conn>>,
28 completed: bool,
30}
31
32impl MySqlTransaction {
33 pub(crate) fn new(conn: Arc<Mutex<Conn>>) -> Self {
35 Self {
36 conn,
37 completed: false,
38 }
39 }
40
41 pub async fn execute<Q: AsRef<str>>(
43 &self,
44 query: Q,
45 params: mysql_async::Params,
46 ) -> Result<(), MySqlError> {
47 let mut conn_guard = self.conn.lock().await;
48 conn_guard
49 .exec_drop(query.as_ref(), params)
50 .await
51 .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))?;
52
53 Ok(())
54 }
55
56 pub async fn commit(mut self) -> Result<(), MySqlError> {
58 self.completed = true;
60
61 let mut conn_guard = self.conn.lock().await;
63 conn_guard
64 .exec_drop("COMMIT", ())
65 .await
66 .map_err(|e| MySqlError::TransactionError(e.to_string()))?;
67
68 Ok(())
69 }
70
71 pub async fn rollback(mut self) -> Result<(), MySqlError> {
73 self.completed = true;
75
76 let mut conn_guard = self.conn.lock().await;
78 conn_guard
79 .exec_drop("ROLLBACK", ())
80 .await
81 .map_err(|e| MySqlError::TransactionError(e.to_string()))?;
82
83 Ok(())
84 }
85}
86
87impl Drop for MySqlTransaction {
89 fn drop(&mut self) {
90 if !self.completed {
91 tracing::warn!("MySQL transaction was not committed or rolled back explicitly");
94 }
95 }
96}
97
98impl MySqlConnection {
99 pub async fn connect(connection_string: String) -> Result<Self, MySqlError> {
101 let opts = Opts::from_url(&connection_string)
103 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
104
105 let conn = Conn::new(opts)
107 .await
108 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
109
110 Ok(Self {
111 conn: Arc::new(Mutex::new(conn)),
112 connection_string,
113 })
114 }
115
116 pub async fn client(&self) -> Arc<Mutex<Conn>> {
118 self.conn.clone()
119 }
120
121 pub async fn query_drop<Q: AsRef<str>>(&self, query: Q) -> Result<(), MySqlError> {
123 let mut conn_guard = self.conn.lock().await;
124 conn_guard
125 .query_drop(query.as_ref())
126 .await
127 .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))
128 }
129
130 pub async fn exec_drop<Q: AsRef<str>, P: Into<mysql_async::Params> + Send>(
132 &self,
133 query: Q,
134 params: P,
135 ) -> Result<(), MySqlError> {
136 let mut conn_guard = self.conn.lock().await;
137 conn_guard
138 .exec_drop(query.as_ref(), params)
139 .await
140 .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))
141 }
142
143 pub async fn begin_transaction(&self) -> Result<MySqlTransaction, MySqlError> {
145 let mut conn_guard = self.conn.lock().await;
146 conn_guard
147 .exec_drop("BEGIN", ())
148 .await
149 .map_err(|e| MySqlError::TransactionError(e.to_string()))?;
150
151 Ok(MySqlTransaction::new(self.conn.clone()))
153 }
154
155 pub fn connection_string(&self) -> &str {
157 &self.connection_string
158 }
159
160 pub async fn query_map<T, F, Q>(&self, query: Q, f: F) -> Result<Vec<T>, MySqlError>
162 where
163 Q: AsRef<str>,
164 F: FnMut(mysql_async::Row) -> T + Send + 'static,
165 T: Send + 'static,
166 {
167 let mut conn_guard = self.conn.lock().await;
168 conn_guard
169 .query_map(query.as_ref(), f)
170 .await
171 .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))
172 }
173
174 pub async fn query_first<T: FromRow + Send + 'static, Q: AsRef<str>>(
176 &self,
177 query: Q,
178 ) -> Result<T, MySqlError> {
179 let mut conn_guard = self.conn.lock().await;
180 conn_guard
181 .query_first(query.as_ref())
182 .await
183 .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))?
184 .ok_or_else(|| MySqlError::QueryExecutionError("No rows returned".to_string()))
185 }
186
187 pub async fn select_database(&self, database_name: &str) -> Result<(), MySqlError> {
189 let use_stmt = format!("USE `{}`", database_name);
190 self.query_drop(use_stmt).await
191 }
192}
193
194#[derive(Clone)]
196pub struct MySqlPool {
197 pub pool: Arc<Pool>,
199 pub connection_string: String,
201}
202
203#[async_trait]
204impl DatabasePool for MySqlPool {
205 type Connection = MySqlConnection;
206 type Error = MySqlError;
207
208 async fn acquire(&self) -> Result<Self::Connection, Self::Error> {
209 let conn = self
211 .pool
212 .get_conn()
213 .await
214 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
215
216 let mysql_conn = MySqlConnection {
218 conn: Arc::new(Mutex::new(conn)),
219 connection_string: self.connection_string.clone(),
220 };
221
222 if let Some(db_name) = self.connection_string.split('/').last() {
224 if !db_name.is_empty() {
225 mysql_conn.select_database(db_name).await?;
226 }
227 }
228
229 Ok(mysql_conn)
230 }
231
232 async fn release(&self, conn: Self::Connection) -> Result<(), Self::Error> {
233 let _conn_guard = conn.conn.lock().await;
234 Ok(())
237 }
238
239 fn connection_string(&self) -> String {
240 self.connection_string.clone()
241 }
242}
243
244#[derive(Clone, Debug)]
246pub struct MySqlBackend {
247 config: DatabaseConfig,
248}
249
250#[async_trait]
251impl DatabaseBackend for MySqlBackend {
252 type Connection = MySqlConnection;
253 type Pool = MySqlPool;
254 type Error = MySqlError;
255
256 async fn new(config: DatabaseConfig) -> Result<Self, Self::Error> {
257 if config.admin_url.is_empty() || config.user_url.is_empty() {
259 return Err(MySqlError::ConfigError(
260 "Admin and user URLs must be provided".into(),
261 ));
262 }
263
264 Ok(Self { config })
265 }
266
267 async fn create_pool(
269 &self,
270 name: &DatabaseName,
271 _config: &DatabaseConfig,
272 ) -> Result<Self::Pool, Self::Error> {
273 let connection_string = self.connection_string(name);
275 let opts = Opts::from_url(&connection_string)
276 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
277
278 let pool = Pool::new(opts);
279
280 Ok(MySqlPool {
281 pool: Arc::new(pool),
282 connection_string,
283 })
284 }
285
286 async fn connect(&self, name: &DatabaseName) -> Result<Self::Connection, Self::Error> {
288 let connection_string = self.connection_string(name);
289 MySqlConnection::connect(connection_string).await
290 }
291
292 async fn connect_with_string(
294 &self,
295 connection_string: &str,
296 ) -> Result<Self::Connection, Self::Error> {
297 MySqlConnection::connect(connection_string.to_string()).await
298 }
299
300 async fn create_database(
301 &self,
302 pool: &Self::Pool,
303 name: &DatabaseName,
304 ) -> Result<(), Self::Error> {
305 let opts = Opts::from_url(&self.config.admin_url)
307 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
308
309 let mut conn = Conn::new(opts)
310 .await
311 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
312
313 let db_name = name.as_str();
315 let create_query = format!("CREATE DATABASE `{}`", db_name);
316
317 conn.query_drop(create_query)
318 .await
319 .map_err(|e| MySqlError::DatabaseCreationError(e.to_string()))?;
320
321 let pool_conn = pool
324 .pool
325 .get_conn()
326 .await
327 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
328
329 let mysql_conn = MySqlConnection {
331 conn: Arc::new(Mutex::new(pool_conn)),
332 connection_string: pool.connection_string.clone(),
333 };
334
335 mysql_conn.select_database(db_name).await?;
337
338 drop(mysql_conn);
340
341 Ok(())
342 }
343
344 fn drop_database(&self, name: &DatabaseName) -> Result<(), Self::Error> {
346 let admin_url = self.config.admin_url.clone();
351 let db_name = name.as_str().to_string();
352
353 tokio::spawn(async move {
357 match Opts::from_url(&admin_url) {
359 Ok(opts) => {
360 if let Ok(mut conn) = Conn::new(opts).await {
361 let drop_query = format!("DROP DATABASE IF EXISTS `{}`", db_name);
362 let _ = conn.query_drop(drop_query).await;
363 tracing::info!("Database {} dropped successfully", db_name);
364 }
365 }
366 Err(e) => {
367 tracing::error!(
369 "Failed to parse MySQL connection URL for database drop: {}",
370 e
371 );
372 }
373 }
374 });
375
376 Ok(())
378 }
379
380 fn connection_string(&self, name: &DatabaseName) -> String {
381 let mut url = url::Url::parse(&self.config.user_url).expect("Invalid database URL");
383
384 let db_name = name.as_str();
386
387 let mut path_segments = url.path_segments_mut().expect("Cannot modify URL path");
389 path_segments.clear();
390 path_segments.push(db_name);
391 drop(path_segments);
392
393 url.to_string()
395 }
396}
397
398impl MySqlBackend {
399 pub async fn clean_database(&self, name: &DatabaseName) -> Result<(), MySqlError> {
402 let opts = Opts::from_url(&self.config.admin_url)
404 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
405
406 let mut conn = Conn::new(opts)
407 .await
408 .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
409
410 let db_name = name.as_str();
412 let drop_query = format!("DROP DATABASE IF EXISTS `{}`", db_name);
413
414 conn.query_drop(drop_query)
415 .await
416 .map_err(|e| MySqlError::DatabaseDropError(e.to_string()))?;
417
418 tracing::info!("Database {} cleaned up successfully", db_name);
419
420 Ok(())
421 }
422}
423
424pub async fn mysql_backend_with_config(config: DatabaseConfig) -> Result<MySqlBackend, MySqlError> {
426 MySqlBackend::new(config).await
427}
428
429#[async_trait]
430impl TestDatabaseConnection for MySqlConnection {
431 fn connection_string(&self) -> String {
432 self.connection_string.clone()
433 }
434}
435
436#[async_trait]
438pub trait MysqlTransaction: Send + Sync {
439 async fn commit(self) -> Result<(), MySqlError>;
441
442 async fn rollback(self) -> Result<(), MySqlError>;
444}
445
446#[async_trait]
448#[allow(unused)]
449pub trait TransactionTrait: Send + Sync {
450 type Error: std::error::Error + Send + Sync;
452
453 async fn commit(self) -> Result<(), Self::Error>;
455
456 async fn rollback(self) -> Result<(), Self::Error>;
458}
459
460#[async_trait]
462impl MysqlTransaction for MySqlTransaction {
463 async fn commit(self) -> Result<(), MySqlError> {
464 self.commit().await
465 }
466
467 async fn rollback(self) -> Result<(), MySqlError> {
468 self.rollback().await
469 }
470}
471
472#[async_trait]
474impl TransactionTrait for MySqlTransaction {
475 type Error = MySqlError;
476
477 async fn commit(self) -> Result<(), Self::Error> {
478 MysqlTransaction::commit(self).await
479 }
480
481 async fn rollback(self) -> Result<(), Self::Error> {
482 MysqlTransaction::rollback(self).await
483 }
484}