sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, MySql, MySqlPool,
8    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9    pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18    executor::*,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::mysql] connector
24#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27/// Defines a sqlx MySQL pool
28#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30    pub(crate) pool: MySqlPool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41    fn from(pool: MySqlPool) -> Self {
42        SqlxMySqlPoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50    fn from(pool: MySqlPool) -> Self {
51        DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxMySqlConnector {
56    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the MySQL database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut sqlx_opts = options
65            .url
66            .parse::<MySqlConnectOptions>()
67            .map_err(sqlx_error_to_conn_err)?;
68        use sqlx::ConnectOptions;
69        if !options.sqlx_logging {
70            sqlx_opts = sqlx_opts.disable_statement_logging();
71        } else {
72            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74                sqlx_opts = sqlx_opts.log_slow_statements(
75                    options.sqlx_slow_statements_logging_level,
76                    options.sqlx_slow_statements_logging_threshold,
77                );
78            }
79        }
80
81        if let Some(f) = &options.mysql_opts_fn {
82            sqlx_opts = f(sqlx_opts);
83        }
84
85        let after_connect = options.after_connect.clone();
86
87        let pool = if options.connect_lazy {
88            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
89        } else {
90            options
91                .sqlx_pool_options()
92                .connect_with(sqlx_opts)
93                .await
94                .map_err(sqlx_error_to_conn_err)?
95        };
96
97        let conn: DatabaseConnection =
98            DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
99                pool,
100                metric_callback: None,
101            })
102            .into();
103
104        if let Some(cb) = after_connect {
105            cb(conn.clone()).await?;
106        }
107
108        Ok(conn)
109    }
110}
111
112impl SqlxMySqlConnector {
113    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
114    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
115        DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
116            pool,
117            metric_callback: None,
118        })
119        .into()
120    }
121}
122
123impl SqlxMySqlPoolConnection {
124    /// Execute a [Statement] on a MySQL backend
125    #[instrument(level = "trace")]
126    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
127        debug_print!("{}", stmt);
128
129        let query = sqlx_query(&stmt);
130        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
131        crate::metric::metric!(self.metric_callback, &stmt, {
132            match query.execute(&mut *conn).await {
133                Ok(res) => Ok(res.into()),
134                Err(err) => Err(sqlx_error_to_exec_err(err)),
135            }
136        })
137    }
138
139    /// Execute an unprepared SQL statement on a MySQL backend
140    #[instrument(level = "trace")]
141    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
142        debug_print!("{}", sql);
143
144        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145        match conn.execute(sql).await {
146            Ok(res) => Ok(res.into()),
147            Err(err) => Err(sqlx_error_to_exec_err(err)),
148        }
149    }
150
151    /// Get one result from a SQL query. Returns [Option::None] if no match was found
152    #[instrument(level = "trace")]
153    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
154        debug_print!("{}", stmt);
155
156        let query = sqlx_query(&stmt);
157        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
158        crate::metric::metric!(self.metric_callback, &stmt, {
159            match query.fetch_one(&mut *conn).await {
160                Ok(row) => Ok(Some(row.into())),
161                Err(err) => match err {
162                    sqlx::Error::RowNotFound => Ok(None),
163                    _ => Err(sqlx_error_to_query_err(err)),
164                },
165            }
166        })
167    }
168
169    /// Get the results of a query returning them as a Vec<[QueryResult]>
170    #[instrument(level = "trace")]
171    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
172        debug_print!("{}", stmt);
173
174        let query = sqlx_query(&stmt);
175        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
176        crate::metric::metric!(self.metric_callback, &stmt, {
177            match query.fetch_all(&mut *conn).await {
178                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
179                Err(err) => Err(sqlx_error_to_query_err(err)),
180            }
181        })
182    }
183
184    /// Stream the results of executing a SQL query
185    #[instrument(level = "trace")]
186    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
187        debug_print!("{}", stmt);
188
189        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
190        Ok(QueryStream::from((
191            conn,
192            stmt,
193            self.metric_callback.clone(),
194        )))
195    }
196
197    /// Bundle a set of SQL statements that execute together.
198    #[instrument(level = "trace")]
199    pub async fn begin(
200        &self,
201        isolation_level: Option<IsolationLevel>,
202        access_mode: Option<AccessMode>,
203    ) -> Result<DatabaseTransaction, DbErr> {
204        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205        DatabaseTransaction::new_mysql(
206            conn,
207            self.metric_callback.clone(),
208            isolation_level,
209            access_mode,
210        )
211        .await
212    }
213
214    /// Create a MySQL transaction
215    #[instrument(level = "trace", skip(callback))]
216    pub async fn transaction<F, T, E>(
217        &self,
218        callback: F,
219        isolation_level: Option<IsolationLevel>,
220        access_mode: Option<AccessMode>,
221    ) -> Result<T, TransactionError<E>>
222    where
223        F: for<'b> FnOnce(
224                &'b DatabaseTransaction,
225            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
226            + Send,
227        T: Send,
228        E: std::fmt::Display + std::fmt::Debug + Send,
229    {
230        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
231        let transaction = DatabaseTransaction::new_mysql(
232            conn,
233            self.metric_callback.clone(),
234            isolation_level,
235            access_mode,
236        )
237        .await
238        .map_err(|e| TransactionError::Connection(e))?;
239        transaction.run(callback).await
240    }
241
242    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
243    where
244        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
245    {
246        self.metric_callback = Some(Arc::new(callback));
247    }
248
249    /// Checks if a connection to the database is still valid.
250    pub async fn ping(&self) -> Result<(), DbErr> {
251        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
252        match conn.ping().await {
253            Ok(_) => Ok(()),
254            Err(err) => Err(sqlx_error_to_conn_err(err)),
255        }
256    }
257
258    /// Explicitly close the MySQL connection.
259    /// See [`Self::close_by_ref`] for usage with references.
260    pub async fn close(self) -> Result<(), DbErr> {
261        self.close_by_ref().await
262    }
263
264    /// Explicitly close the MySQL connection
265    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
266        self.pool.close().await;
267        Ok(())
268    }
269}
270
271impl From<MySqlRow> for QueryResult {
272    fn from(row: MySqlRow) -> QueryResult {
273        QueryResult {
274            row: QueryResultRow::SqlxMySql(row),
275        }
276    }
277}
278
279impl From<MySqlQueryResult> for ExecResult {
280    fn from(result: MySqlQueryResult) -> ExecResult {
281        ExecResult {
282            result: ExecResultHolder::SqlxMySql(result),
283        }
284    }
285}
286
287pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
288    let values = stmt
289        .values
290        .as_ref()
291        .map_or(Values(Vec::new()), |values| values.clone());
292    sqlx::query_with(&stmt.sql, SqlxValues(values))
293}
294
295pub(crate) async fn set_transaction_config(
296    conn: &mut PoolConnection<MySql>,
297    isolation_level: Option<IsolationLevel>,
298    access_mode: Option<AccessMode>,
299) -> Result<(), DbErr> {
300    let mut settings = Vec::new();
301
302    if let Some(isolation_level) = isolation_level {
303        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
304    }
305
306    if let Some(access_mode) = access_mode {
307        settings.push(access_mode.to_string());
308    }
309
310    if !settings.is_empty() {
311        let stmt = Statement {
312            sql: format!("SET TRANSACTION {}", settings.join(", ")),
313            values: None,
314            db_backend: DbBackend::MySql,
315        };
316        let query = sqlx_query(&stmt);
317        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
318    }
319    Ok(())
320}
321
322impl
323    From<(
324        PoolConnection<sqlx::MySql>,
325        Statement,
326        Option<crate::metric::Callback>,
327    )> for crate::QueryStream
328{
329    fn from(
330        (conn, stmt, metric_callback): (
331            PoolConnection<sqlx::MySql>,
332            Statement,
333            Option<crate::metric::Callback>,
334        ),
335    ) -> Self {
336        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
337    }
338}
339
340impl crate::DatabaseTransaction {
341    pub(crate) async fn new_mysql(
342        inner: PoolConnection<sqlx::MySql>,
343        metric_callback: Option<crate::metric::Callback>,
344        isolation_level: Option<IsolationLevel>,
345        access_mode: Option<AccessMode>,
346    ) -> Result<crate::DatabaseTransaction, DbErr> {
347        Self::begin(
348            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
349            crate::DbBackend::MySql,
350            metric_callback,
351            isolation_level,
352            access_mode,
353        )
354        .await
355    }
356}
357
358#[cfg(feature = "proxy")]
359pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
360    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
361    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
362    use sea_query::Value;
363    use sqlx::{Column, Row, TypeInfo};
364    crate::ProxyRow {
365        values: row
366            .columns()
367            .iter()
368            .map(|c| {
369                (
370                    c.name().to_string(),
371                    match c.type_info().name() {
372                        "TINYINT(1)" | "BOOLEAN" => {
373                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
374                        }
375                        "TINYINT UNSIGNED" => Value::TinyUnsigned(
376                            row.try_get(c.ordinal())
377                                .expect("Failed to get unsigned tiny integer"),
378                        ),
379                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(
380                            row.try_get(c.ordinal())
381                                .expect("Failed to get unsigned small integer"),
382                        ),
383                        "INT UNSIGNED" => Value::Unsigned(
384                            row.try_get(c.ordinal())
385                                .expect("Failed to get unsigned integer"),
386                        ),
387                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
388                            row.try_get(c.ordinal())
389                                .expect("Failed to get unsigned big integer"),
390                        ),
391                        "TINYINT" => Value::TinyInt(
392                            row.try_get(c.ordinal())
393                                .expect("Failed to get tiny integer"),
394                        ),
395                        "SMALLINT" => Value::SmallInt(
396                            row.try_get(c.ordinal())
397                                .expect("Failed to get small integer"),
398                        ),
399                        "INT" => {
400                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
401                        }
402                        "MEDIUMINT" | "BIGINT" => Value::BigInt(
403                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
404                        ),
405                        "FLOAT" => {
406                            Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
407                        }
408                        "DOUBLE" => {
409                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
410                        }
411
412                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
413                        | "LONGBLOB" => Value::Bytes(
414                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
415                                .expect("Failed to get bytes")
416                                .map(Box::new),
417                        ),
418
419                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
420                            Value::String(
421                                row.try_get::<Option<String>, _>(c.ordinal())
422                                    .expect("Failed to get string")
423                                    .map(Box::new),
424                            )
425                        }
426
427                        #[cfg(feature = "with-chrono")]
428                        "TIMESTAMP" => Value::ChronoDateTimeUtc(
429                            row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
430                                .expect("Failed to get timestamp")
431                                .map(Box::new),
432                        ),
433                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
434                        "TIMESTAMP" => Value::TimeDateTime(
435                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
436                                .expect("Failed to get timestamp")
437                                .map(Box::new),
438                        ),
439
440                        #[cfg(feature = "with-chrono")]
441                        "DATE" => Value::ChronoDate(
442                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
443                                .expect("Failed to get date")
444                                .map(Box::new),
445                        ),
446                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
447                        "DATE" => Value::TimeDate(
448                            row.try_get::<Option<time::Date>, _>(c.ordinal())
449                                .expect("Failed to get date")
450                                .map(Box::new),
451                        ),
452
453                        #[cfg(feature = "with-chrono")]
454                        "TIME" => Value::ChronoTime(
455                            row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
456                                .expect("Failed to get time")
457                                .map(Box::new),
458                        ),
459                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460                        "TIME" => Value::TimeTime(
461                            row.try_get::<Option<time::Time>, _>(c.ordinal())
462                                .expect("Failed to get time")
463                                .map(Box::new),
464                        ),
465
466                        #[cfg(feature = "with-chrono")]
467                        "DATETIME" => Value::ChronoDateTime(
468                            row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
469                                .expect("Failed to get datetime")
470                                .map(Box::new),
471                        ),
472                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473                        "DATETIME" => Value::TimeDateTime(
474                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
475                                .expect("Failed to get datetime")
476                                .map(Box::new),
477                        ),
478
479                        #[cfg(feature = "with-chrono")]
480                        "YEAR" => Value::ChronoDate(
481                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
482                                .expect("Failed to get year")
483                                .map(Box::new),
484                        ),
485                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
486                        "YEAR" => Value::TimeDate(
487                            row.try_get::<Option<time::Date>, _>(c.ordinal())
488                                .expect("Failed to get year")
489                                .map(Box::new),
490                        ),
491
492                        "ENUM" | "SET" | "GEOMETRY" => Value::String(
493                            row.try_get::<Option<String>, _>(c.ordinal())
494                                .expect("Failed to get serialized string")
495                                .map(Box::new),
496                        ),
497
498                        #[cfg(feature = "with-bigdecimal")]
499                        "DECIMAL" => Value::BigDecimal(
500                            row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
501                                .expect("Failed to get decimal")
502                                .map(Box::new),
503                        ),
504                        #[cfg(all(
505                            feature = "with-rust_decimal",
506                            not(feature = "with-bigdecimal")
507                        ))]
508                        "DECIMAL" => Value::Decimal(
509                            row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
510                                .expect("Failed to get decimal")
511                                .map(Box::new),
512                        ),
513
514                        #[cfg(feature = "with-json")]
515                        "JSON" => Value::Json(
516                            row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
517                                .expect("Failed to get json")
518                                .map(Box::new),
519                        ),
520
521                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
522                    },
523                )
524            })
525            .collect(),
526    }
527}