testkit_mysql/
mysql_async.rs

1#![cfg(feature = "with-mysql-async")]
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use tokio::sync::Mutex;
7
8use mysql_async::{Conn, Opts, Pool, prelude::*};
9use testkit_core::{
10    DatabaseBackend, DatabaseConfig, DatabaseName, DatabasePool, TestDatabaseConnection,
11};
12
13use crate::error::MySqlError;
14
15/// A MySQL connection using mysql-async
16#[derive(Clone)]
17pub struct MySqlConnection {
18    /// The connection to the database
19    conn: Arc<Mutex<Conn>>,
20    /// The connection string used to create this connection
21    connection_string: String,
22}
23
24/// A MySQL transaction
25pub struct MySqlTransaction {
26    // We need to own the connection to ensure the transaction stays alive
27    conn: Arc<Mutex<Conn>>,
28    // Track if the transaction is completed
29    completed: bool,
30}
31
32impl MySqlTransaction {
33    // Create a new transaction
34    pub(crate) fn new(conn: Arc<Mutex<Conn>>) -> Self {
35        Self {
36            conn,
37            completed: false,
38        }
39    }
40
41    /// Execute a query within this transaction
42    pub async fn execute<Q: AsRef<str>>(
43        &self,
44        query: Q,
45        params: mysql_async::Params,
46    ) -> Result<(), MySqlError> {
47        let mut conn_guard = self.conn.lock().await;
48        conn_guard
49            .exec_drop(query.as_ref(), params)
50            .await
51            .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))?;
52
53        Ok(())
54    }
55
56    /// Commit the transaction
57    pub async fn commit(mut self) -> Result<(), MySqlError> {
58        // Mark the transaction as completed
59        self.completed = true;
60
61        // Commit the transaction
62        let mut conn_guard = self.conn.lock().await;
63        conn_guard
64            .exec_drop("COMMIT", ())
65            .await
66            .map_err(|e| MySqlError::TransactionError(e.to_string()))?;
67
68        Ok(())
69    }
70
71    /// Rollback the transaction
72    pub async fn rollback(mut self) -> Result<(), MySqlError> {
73        // Mark the transaction as completed
74        self.completed = true;
75
76        // Rollback the transaction
77        let mut conn_guard = self.conn.lock().await;
78        conn_guard
79            .exec_drop("ROLLBACK", ())
80            .await
81            .map_err(|e| MySqlError::TransactionError(e.to_string()))?;
82
83        Ok(())
84    }
85}
86
87// Ensure transaction is rolled back if dropped without explicit commit/rollback
88impl Drop for MySqlTransaction {
89    fn drop(&mut self) {
90        if !self.completed {
91            // We need to rollback the transaction if it's not completed
92            // This is a sync function, so we can't use async/await here
93            tracing::warn!("MySQL transaction was not committed or rolled back explicitly");
94        }
95    }
96}
97
98impl MySqlConnection {
99    /// Create a new connection to the database
100    pub async fn connect(connection_string: String) -> Result<Self, MySqlError> {
101        // Create connection options from the URL
102        let opts = Opts::from_url(&connection_string)
103            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
104
105        // Connect to the database
106        let conn = Conn::new(opts)
107            .await
108            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
109
110        Ok(Self {
111            conn: Arc::new(Mutex::new(conn)),
112            connection_string,
113        })
114    }
115
116    /// Get a reference to the connection
117    pub async fn client(&self) -> Arc<Mutex<Conn>> {
118        self.conn.clone()
119    }
120
121    /// Execute a query directly
122    pub async fn query_drop<Q: AsRef<str>>(&self, query: Q) -> Result<(), MySqlError> {
123        let mut conn_guard = self.conn.lock().await;
124        conn_guard
125            .query_drop(query.as_ref())
126            .await
127            .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))
128    }
129
130    /// Execute a parameterized query directly
131    pub async fn exec_drop<Q: AsRef<str>, P: Into<mysql_async::Params> + Send>(
132        &self,
133        query: Q,
134        params: P,
135    ) -> Result<(), MySqlError> {
136        let mut conn_guard = self.conn.lock().await;
137        conn_guard
138            .exec_drop(query.as_ref(), params)
139            .await
140            .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))
141    }
142
143    /// Start a transaction
144    pub async fn begin_transaction(&self) -> Result<MySqlTransaction, MySqlError> {
145        let mut conn_guard = self.conn.lock().await;
146        conn_guard
147            .exec_drop("BEGIN", ())
148            .await
149            .map_err(|e| MySqlError::TransactionError(e.to_string()))?;
150
151        // Create the transaction with a clone of the connection
152        Ok(MySqlTransaction::new(self.conn.clone()))
153    }
154
155    /// Get the connection string
156    pub fn connection_string(&self) -> &str {
157        &self.connection_string
158    }
159
160    /// Execute a query and map the results
161    pub async fn query_map<T, F, Q>(&self, query: Q, f: F) -> Result<Vec<T>, MySqlError>
162    where
163        Q: AsRef<str>,
164        F: FnMut(mysql_async::Row) -> T + Send + 'static,
165        T: Send + 'static,
166    {
167        let mut conn_guard = self.conn.lock().await;
168        conn_guard
169            .query_map(query.as_ref(), f)
170            .await
171            .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))
172    }
173
174    /// Execute a query and return the first result
175    pub async fn query_first<T: FromRow + Send + 'static, Q: AsRef<str>>(
176        &self,
177        query: Q,
178    ) -> Result<T, MySqlError> {
179        let mut conn_guard = self.conn.lock().await;
180        conn_guard
181            .query_first(query.as_ref())
182            .await
183            .map_err(|e| MySqlError::QueryExecutionError(e.to_string()))?
184            .ok_or_else(|| MySqlError::QueryExecutionError("No rows returned".to_string()))
185    }
186
187    /// Select a specific database
188    pub async fn select_database(&self, database_name: &str) -> Result<(), MySqlError> {
189        let use_stmt = format!("USE `{}`", database_name);
190        self.query_drop(use_stmt).await
191    }
192}
193
194/// A MySQL connection pool using mysql-async
195#[derive(Clone)]
196pub struct MySqlPool {
197    /// The connection pool
198    pub pool: Arc<Pool>,
199    /// The connection string used to create this pool
200    pub connection_string: String,
201}
202
203#[async_trait]
204impl DatabasePool for MySqlPool {
205    type Connection = MySqlConnection;
206    type Error = MySqlError;
207
208    async fn acquire(&self) -> Result<Self::Connection, Self::Error> {
209        // Get a connection from the pool
210        let conn = self
211            .pool
212            .get_conn()
213            .await
214            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
215
216        // Create the MySqlConnection
217        let mysql_conn = MySqlConnection {
218            conn: Arc::new(Mutex::new(conn)),
219            connection_string: self.connection_string.clone(),
220        };
221
222        // Extract the database name from the connection string and select it
223        if let Some(db_name) = self.connection_string.split('/').last() {
224            if !db_name.is_empty() {
225                mysql_conn.select_database(db_name).await?;
226            }
227        }
228
229        Ok(mysql_conn)
230    }
231
232    async fn release(&self, conn: Self::Connection) -> Result<(), Self::Error> {
233        let _conn_guard = conn.conn.lock().await;
234        // Just drop the connection - the pool will handle returning it
235        // Since MySQL Async connections don't have a close method with no args
236        Ok(())
237    }
238
239    fn connection_string(&self) -> String {
240        self.connection_string.clone()
241    }
242}
243
244/// A MySQL backend using mysql-async
245#[derive(Clone, Debug)]
246pub struct MySqlBackend {
247    config: DatabaseConfig,
248}
249
250#[async_trait]
251impl DatabaseBackend for MySqlBackend {
252    type Connection = MySqlConnection;
253    type Pool = MySqlPool;
254    type Error = MySqlError;
255
256    async fn new(config: DatabaseConfig) -> Result<Self, Self::Error> {
257        // Validate the config
258        if config.admin_url.is_empty() || config.user_url.is_empty() {
259            return Err(MySqlError::ConfigError(
260                "Admin and user URLs must be provided".into(),
261            ));
262        }
263
264        Ok(Self { config })
265    }
266
267    /// Create a new connection pool for the given database
268    async fn create_pool(
269        &self,
270        name: &DatabaseName,
271        _config: &DatabaseConfig,
272    ) -> Result<Self::Pool, Self::Error> {
273        // Create connection options from the URL
274        let connection_string = self.connection_string(name);
275        let opts = Opts::from_url(&connection_string)
276            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
277
278        let pool = Pool::new(opts);
279
280        Ok(MySqlPool {
281            pool: Arc::new(pool),
282            connection_string,
283        })
284    }
285
286    /// Create a single connection to the given database
287    async fn connect(&self, name: &DatabaseName) -> Result<Self::Connection, Self::Error> {
288        let connection_string = self.connection_string(name);
289        MySqlConnection::connect(connection_string).await
290    }
291
292    /// Create a single connection using a connection string directly
293    async fn connect_with_string(
294        &self,
295        connection_string: &str,
296    ) -> Result<Self::Connection, Self::Error> {
297        MySqlConnection::connect(connection_string.to_string()).await
298    }
299
300    async fn create_database(
301        &self,
302        pool: &Self::Pool,
303        name: &DatabaseName,
304    ) -> Result<(), Self::Error> {
305        // Create admin connection to create the database
306        let opts = Opts::from_url(&self.config.admin_url)
307            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
308
309        let mut conn = Conn::new(opts)
310            .await
311            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
312
313        // Create the database
314        let db_name = name.as_str();
315        let create_query = format!("CREATE DATABASE `{}`", db_name);
316
317        conn.query_drop(create_query)
318            .await
319            .map_err(|e| MySqlError::DatabaseCreationError(e.to_string()))?;
320
321        // Get a connection from the pool and select the database
322        // This ensures connections from this pool will be connected to the right database
323        let pool_conn = pool
324            .pool
325            .get_conn()
326            .await
327            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
328
329        // Create a MySqlConnection to use our select_database method
330        let mysql_conn = MySqlConnection {
331            conn: Arc::new(Mutex::new(pool_conn)),
332            connection_string: pool.connection_string.clone(),
333        };
334
335        // Select the database for all future connections from this pool
336        mysql_conn.select_database(db_name).await?;
337
338        // Release the connection back to the pool
339        drop(mysql_conn);
340
341        Ok(())
342    }
343
344    /// Drop a database with the given name
345    fn drop_database(&self, name: &DatabaseName) -> Result<(), Self::Error> {
346        // For mysql-async, we can't use async directly in a sync function
347        // but we shouldn't create a new runtime either.
348        // Instead, use a background task to drop the database.
349
350        let admin_url = self.config.admin_url.clone();
351        let db_name = name.as_str().to_string();
352
353        // Spawn a detached task to drop the database
354        // This approach allows us to drop the database without blocking
355        // or creating a new runtime within an existing runtime
356        tokio::spawn(async move {
357            // Create admin connection to drop the database
358            match Opts::from_url(&admin_url) {
359                Ok(opts) => {
360                    if let Ok(mut conn) = Conn::new(opts).await {
361                        let drop_query = format!("DROP DATABASE IF EXISTS `{}`", db_name);
362                        let _ = conn.query_drop(drop_query).await;
363                        tracing::info!("Database {} dropped successfully", db_name);
364                    }
365                }
366                Err(e) => {
367                    // Log the error but don't fail the test
368                    tracing::error!(
369                        "Failed to parse MySQL connection URL for database drop: {}",
370                        e
371                    );
372                }
373            }
374        });
375
376        // Return OK immediately - the database drop happens in the background
377        Ok(())
378    }
379
380    fn connection_string(&self, name: &DatabaseName) -> String {
381        // Parse the user URL
382        let mut url = url::Url::parse(&self.config.user_url).expect("Invalid database URL");
383
384        // Update the path to include the database name, replacing any existing database
385        let db_name = name.as_str();
386
387        // Set the path to just the database name, clearing any existing path
388        let mut path_segments = url.path_segments_mut().expect("Cannot modify URL path");
389        path_segments.clear();
390        path_segments.push(db_name);
391        drop(path_segments);
392
393        // Return the full connection string with the new database name
394        url.to_string()
395    }
396}
397
398impl MySqlBackend {
399    /// Clean the database explicitly - this is a blocking call
400    /// This should be called at the end of tests to ensure databases are cleaned up
401    pub async fn clean_database(&self, name: &DatabaseName) -> Result<(), MySqlError> {
402        // Create admin connection to drop the database
403        let opts = Opts::from_url(&self.config.admin_url)
404            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
405
406        let mut conn = Conn::new(opts)
407            .await
408            .map_err(|e| MySqlError::ConnectionError(e.to_string()))?;
409
410        // Drop the database
411        let db_name = name.as_str();
412        let drop_query = format!("DROP DATABASE IF EXISTS `{}`", db_name);
413
414        conn.query_drop(drop_query)
415            .await
416            .map_err(|e| MySqlError::DatabaseDropError(e.to_string()))?;
417
418        tracing::info!("Database {} cleaned up successfully", db_name);
419
420        Ok(())
421    }
422}
423
424/// Helper function to create a MySQL backend with a configuration
425pub async fn mysql_backend_with_config(config: DatabaseConfig) -> Result<MySqlBackend, MySqlError> {
426    MySqlBackend::new(config).await
427}
428
429#[async_trait]
430impl TestDatabaseConnection for MySqlConnection {
431    fn connection_string(&self) -> String {
432        self.connection_string.clone()
433    }
434}
435
436/// Trait for MySQL transaction operations
437#[async_trait]
438pub trait MysqlTransaction: Send + Sync {
439    /// Commit the transaction
440    async fn commit(self) -> Result<(), MySqlError>;
441
442    /// Rollback the transaction
443    async fn rollback(self) -> Result<(), MySqlError>;
444}
445
446/// Generic transaction trait
447#[async_trait]
448#[allow(unused)]
449pub trait TransactionTrait: Send + Sync {
450    /// Error type
451    type Error: std::error::Error + Send + Sync;
452
453    /// Commit the transaction
454    async fn commit(self) -> Result<(), Self::Error>;
455
456    /// Rollback the transaction
457    async fn rollback(self) -> Result<(), Self::Error>;
458}
459
460// Implement MysqlTransaction for MySqlTransaction
461#[async_trait]
462impl MysqlTransaction for MySqlTransaction {
463    async fn commit(self) -> Result<(), MySqlError> {
464        self.commit().await
465    }
466
467    async fn rollback(self) -> Result<(), MySqlError> {
468        self.rollback().await
469    }
470}
471
472// Implement TransactionTrait for MySqlTransaction
473#[async_trait]
474impl TransactionTrait for MySqlTransaction {
475    type Error = MySqlError;
476
477    async fn commit(self) -> Result<(), Self::Error> {
478        MysqlTransaction::commit(self).await
479    }
480
481    async fn rollback(self) -> Result<(), Self::Error> {
482        MysqlTransaction::rollback(self).await
483    }
484}