sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, MySql, MySqlPool,
8    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9    pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18    executor::*,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::mysql] connector
24#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27/// Defines a sqlx MySQL pool
28#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30    pub(crate) pool: MySqlPool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41    fn from(pool: MySqlPool) -> Self {
42        SqlxMySqlPoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50    fn from(pool: MySqlPool) -> Self {
51        DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxMySqlConnector {
56    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the MySQL database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut sqlx_opts = options
65            .url
66            .parse::<MySqlConnectOptions>()
67            .map_err(sqlx_error_to_conn_err)?;
68        use sqlx::ConnectOptions;
69        if !options.sqlx_logging {
70            sqlx_opts = sqlx_opts.disable_statement_logging();
71        } else {
72            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74                sqlx_opts = sqlx_opts.log_slow_statements(
75                    options.sqlx_slow_statements_logging_level,
76                    options.sqlx_slow_statements_logging_threshold,
77                );
78            }
79        }
80        if let Some(f) = &options.mysql_opts_fn {
81            sqlx_opts = f(sqlx_opts);
82        }
83        let pool = if options.connect_lazy {
84            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
85        } else {
86            options
87                .sqlx_pool_options()
88                .connect_with(sqlx_opts)
89                .await
90                .map_err(sqlx_error_to_conn_err)?
91        };
92        Ok(
93            DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
94                pool,
95                metric_callback: None,
96            })
97            .into(),
98        )
99    }
100}
101
102impl SqlxMySqlConnector {
103    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
104    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
105        DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
106            pool,
107            metric_callback: None,
108        })
109        .into()
110    }
111}
112
113impl SqlxMySqlPoolConnection {
114    /// Execute a [Statement] on a MySQL backend
115    #[instrument(level = "trace")]
116    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
117        debug_print!("{}", stmt);
118
119        let query = sqlx_query(&stmt);
120        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
121        crate::metric::metric!(self.metric_callback, &stmt, {
122            match query.execute(&mut *conn).await {
123                Ok(res) => Ok(res.into()),
124                Err(err) => Err(sqlx_error_to_exec_err(err)),
125            }
126        })
127    }
128
129    /// Execute an unprepared SQL statement on a MySQL backend
130    #[instrument(level = "trace")]
131    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
132        debug_print!("{}", sql);
133
134        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
135        match conn.execute(sql).await {
136            Ok(res) => Ok(res.into()),
137            Err(err) => Err(sqlx_error_to_exec_err(err)),
138        }
139    }
140
141    /// Get one result from a SQL query. Returns [Option::None] if no match was found
142    #[instrument(level = "trace")]
143    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
144        debug_print!("{}", stmt);
145
146        let query = sqlx_query(&stmt);
147        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
148        crate::metric::metric!(self.metric_callback, &stmt, {
149            match query.fetch_one(&mut *conn).await {
150                Ok(row) => Ok(Some(row.into())),
151                Err(err) => match err {
152                    sqlx::Error::RowNotFound => Ok(None),
153                    _ => Err(sqlx_error_to_query_err(err)),
154                },
155            }
156        })
157    }
158
159    /// Get the results of a query returning them as a Vec<[QueryResult]>
160    #[instrument(level = "trace")]
161    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
162        debug_print!("{}", stmt);
163
164        let query = sqlx_query(&stmt);
165        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
166        crate::metric::metric!(self.metric_callback, &stmt, {
167            match query.fetch_all(&mut *conn).await {
168                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
169                Err(err) => Err(sqlx_error_to_query_err(err)),
170            }
171        })
172    }
173
174    /// Stream the results of executing a SQL query
175    #[instrument(level = "trace")]
176    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
177        debug_print!("{}", stmt);
178
179        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
180        Ok(QueryStream::from((
181            conn,
182            stmt,
183            self.metric_callback.clone(),
184        )))
185    }
186
187    /// Bundle a set of SQL statements that execute together.
188    #[instrument(level = "trace")]
189    pub async fn begin(
190        &self,
191        isolation_level: Option<IsolationLevel>,
192        access_mode: Option<AccessMode>,
193    ) -> Result<DatabaseTransaction, DbErr> {
194        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
195        DatabaseTransaction::new_mysql(
196            conn,
197            self.metric_callback.clone(),
198            isolation_level,
199            access_mode,
200        )
201        .await
202    }
203
204    /// Create a MySQL transaction
205    #[instrument(level = "trace", skip(callback))]
206    pub async fn transaction<F, T, E>(
207        &self,
208        callback: F,
209        isolation_level: Option<IsolationLevel>,
210        access_mode: Option<AccessMode>,
211    ) -> Result<T, TransactionError<E>>
212    where
213        F: for<'b> FnOnce(
214                &'b DatabaseTransaction,
215            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
216            + Send,
217        T: Send,
218        E: std::fmt::Display + std::fmt::Debug + Send,
219    {
220        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
221        let transaction = DatabaseTransaction::new_mysql(
222            conn,
223            self.metric_callback.clone(),
224            isolation_level,
225            access_mode,
226        )
227        .await
228        .map_err(|e| TransactionError::Connection(e))?;
229        transaction.run(callback).await
230    }
231
232    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
233    where
234        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
235    {
236        self.metric_callback = Some(Arc::new(callback));
237    }
238
239    /// Checks if a connection to the database is still valid.
240    pub async fn ping(&self) -> Result<(), DbErr> {
241        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
242        match conn.ping().await {
243            Ok(_) => Ok(()),
244            Err(err) => Err(sqlx_error_to_conn_err(err)),
245        }
246    }
247
248    /// Explicitly close the MySQL connection.
249    /// See [`Self::close_by_ref`] for usage with references.
250    pub async fn close(self) -> Result<(), DbErr> {
251        self.close_by_ref().await
252    }
253
254    /// Explicitly close the MySQL connection
255    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
256        self.pool.close().await;
257        Ok(())
258    }
259}
260
261impl From<MySqlRow> for QueryResult {
262    fn from(row: MySqlRow) -> QueryResult {
263        QueryResult {
264            row: QueryResultRow::SqlxMySql(row),
265        }
266    }
267}
268
269impl From<MySqlQueryResult> for ExecResult {
270    fn from(result: MySqlQueryResult) -> ExecResult {
271        ExecResult {
272            result: ExecResultHolder::SqlxMySql(result),
273        }
274    }
275}
276
277pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
278    let values = stmt
279        .values
280        .as_ref()
281        .map_or(Values(Vec::new()), |values| values.clone());
282    sqlx::query_with(&stmt.sql, SqlxValues(values))
283}
284
285pub(crate) async fn set_transaction_config(
286    conn: &mut PoolConnection<MySql>,
287    isolation_level: Option<IsolationLevel>,
288    access_mode: Option<AccessMode>,
289) -> Result<(), DbErr> {
290    let mut settings = Vec::new();
291
292    if let Some(isolation_level) = isolation_level {
293        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
294    }
295
296    if let Some(access_mode) = access_mode {
297        settings.push(access_mode.to_string());
298    }
299
300    if !settings.is_empty() {
301        let stmt = Statement {
302            sql: format!("SET TRANSACTION {}", settings.join(", ")),
303            values: None,
304            db_backend: DbBackend::MySql,
305        };
306        let query = sqlx_query(&stmt);
307        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
308    }
309    Ok(())
310}
311
312impl
313    From<(
314        PoolConnection<sqlx::MySql>,
315        Statement,
316        Option<crate::metric::Callback>,
317    )> for crate::QueryStream
318{
319    fn from(
320        (conn, stmt, metric_callback): (
321            PoolConnection<sqlx::MySql>,
322            Statement,
323            Option<crate::metric::Callback>,
324        ),
325    ) -> Self {
326        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
327    }
328}
329
330impl crate::DatabaseTransaction {
331    pub(crate) async fn new_mysql(
332        inner: PoolConnection<sqlx::MySql>,
333        metric_callback: Option<crate::metric::Callback>,
334        isolation_level: Option<IsolationLevel>,
335        access_mode: Option<AccessMode>,
336    ) -> Result<crate::DatabaseTransaction, DbErr> {
337        Self::begin(
338            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
339            crate::DbBackend::MySql,
340            metric_callback,
341            isolation_level,
342            access_mode,
343        )
344        .await
345    }
346}
347
348#[cfg(feature = "proxy")]
349pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
350    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
351    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
352    use sea_query::Value;
353    use sqlx::{Column, Row, TypeInfo};
354    crate::ProxyRow {
355        values: row
356            .columns()
357            .iter()
358            .map(|c| {
359                (
360                    c.name().to_string(),
361                    match c.type_info().name() {
362                        "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
363                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
364                        )),
365                        "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
366                            row.try_get(c.ordinal())
367                                .expect("Failed to get unsigned tiny integer"),
368                        )),
369                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
370                            row.try_get(c.ordinal())
371                                .expect("Failed to get unsigned small integer"),
372                        )),
373                        "INT UNSIGNED" => Value::Unsigned(Some(
374                            row.try_get(c.ordinal())
375                                .expect("Failed to get unsigned integer"),
376                        )),
377                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
378                            row.try_get(c.ordinal())
379                                .expect("Failed to get unsigned big integer"),
380                        )),
381                        "TINYINT" => Value::TinyInt(Some(
382                            row.try_get(c.ordinal())
383                                .expect("Failed to get tiny integer"),
384                        )),
385                        "SMALLINT" => Value::SmallInt(Some(
386                            row.try_get(c.ordinal())
387                                .expect("Failed to get small integer"),
388                        )),
389                        "INT" => Value::Int(Some(
390                            row.try_get(c.ordinal()).expect("Failed to get integer"),
391                        )),
392                        "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
393                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
394                        )),
395                        "FLOAT" => Value::Float(Some(
396                            row.try_get(c.ordinal()).expect("Failed to get float"),
397                        )),
398                        "DOUBLE" => Value::Double(Some(
399                            row.try_get(c.ordinal()).expect("Failed to get double"),
400                        )),
401
402                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
403                        | "LONGBLOB" => Value::Bytes(Some(Box::new(
404                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
405                        ))),
406
407                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
408                            Value::String(Some(Box::new(
409                                row.try_get(c.ordinal()).expect("Failed to get string"),
410                            )))
411                        }
412
413                        #[cfg(feature = "with-chrono")]
414                        "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
415                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
416                        ))),
417                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
418                        "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
419                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
420                        ))),
421
422                        #[cfg(feature = "with-chrono")]
423                        "DATE" => Value::ChronoDate(Some(Box::new(
424                            row.try_get(c.ordinal()).expect("Failed to get date"),
425                        ))),
426                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
427                        "DATE" => Value::TimeDate(Some(Box::new(
428                            row.try_get(c.ordinal()).expect("Failed to get date"),
429                        ))),
430
431                        #[cfg(feature = "with-chrono")]
432                        "TIME" => Value::ChronoTime(Some(Box::new(
433                            row.try_get(c.ordinal()).expect("Failed to get time"),
434                        ))),
435                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
436                        "TIME" => Value::TimeTime(Some(Box::new(
437                            row.try_get(c.ordinal()).expect("Failed to get time"),
438                        ))),
439
440                        #[cfg(feature = "with-chrono")]
441                        "DATETIME" => Value::ChronoDateTime(Some(Box::new(
442                            row.try_get(c.ordinal()).expect("Failed to get datetime"),
443                        ))),
444                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
445                        "DATETIME" => Value::TimeDateTime(Some(Box::new(
446                            row.try_get(c.ordinal()).expect("Failed to get datetime"),
447                        ))),
448
449                        #[cfg(feature = "with-chrono")]
450                        "YEAR" => Value::ChronoDate(Some(Box::new(
451                            row.try_get(c.ordinal()).expect("Failed to get year"),
452                        ))),
453                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
454                        "YEAR" => Value::TimeDate(Some(Box::new(
455                            row.try_get(c.ordinal()).expect("Failed to get year"),
456                        ))),
457
458                        "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
459                            row.try_get(c.ordinal())
460                                .expect("Failed to get serialized string"),
461                        ))),
462
463                        #[cfg(feature = "with-bigdecimal")]
464                        "DECIMAL" => Value::BigDecimal(Some(Box::new(
465                            row.try_get(c.ordinal()).expect("Failed to get decimal"),
466                        ))),
467                        #[cfg(all(
468                            feature = "with-rust_decimal",
469                            not(feature = "with-bigdecimal")
470                        ))]
471                        "DECIMAL" => Value::Decimal(Some(Box::new(
472                            row.try_get(c.ordinal()).expect("Failed to get decimal"),
473                        ))),
474
475                        #[cfg(feature = "with-json")]
476                        "JSON" => Value::Json(Some(Box::new(
477                            row.try_get(c.ordinal()).expect("Failed to get json"),
478                        ))),
479
480                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
481                    },
482                )
483            })
484            .collect(),
485    }
486}