Skip to main content

sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, MySql, MySqlPool,
8    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9    pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18    executor::*,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::mysql] connector
24#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27/// Defines a sqlx MySQL pool
28#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30    pub(crate) pool: MySqlPool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41    fn from(pool: MySqlPool) -> Self {
42        SqlxMySqlPoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50    fn from(pool: MySqlPool) -> Self {
51        DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxMySqlConnector {
56    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the MySQL database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut sqlx_opts = options
65            .url
66            .parse::<MySqlConnectOptions>()
67            .map_err(sqlx_error_to_conn_err)?;
68        use sqlx::ConnectOptions;
69        if !options.sqlx_logging {
70            sqlx_opts = sqlx_opts.disable_statement_logging();
71        } else {
72            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74                sqlx_opts = sqlx_opts.log_slow_statements(
75                    options.sqlx_slow_statements_logging_level,
76                    options.sqlx_slow_statements_logging_threshold,
77                );
78            }
79        }
80
81        if let Some(f) = &options.mysql_opts_fn {
82            sqlx_opts = f(sqlx_opts);
83        }
84
85        let after_connect = options.after_connect.clone();
86
87        let pool = if options.connect_lazy {
88            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
89        } else {
90            options
91                .sqlx_pool_options()
92                .connect_with(sqlx_opts)
93                .await
94                .map_err(sqlx_error_to_conn_err)?
95        };
96
97        let conn: DatabaseConnection =
98            DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
99                pool,
100                metric_callback: None,
101            })
102            .into();
103
104        if let Some(cb) = after_connect {
105            cb(conn.clone()).await?;
106        }
107
108        Ok(conn)
109    }
110}
111
112impl SqlxMySqlConnector {
113    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
114    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
115        DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
116            pool,
117            metric_callback: None,
118        })
119        .into()
120    }
121}
122
123impl SqlxMySqlPoolConnection {
124    /// Execute a [Statement] on a MySQL backend
125    #[instrument(level = "trace")]
126    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
127        debug_print!("{}", stmt);
128
129        let query = sqlx_query(&stmt);
130        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
131        crate::metric::metric!(self.metric_callback, &stmt, {
132            match query.execute(&mut *conn).await {
133                Ok(res) => Ok(res.into()),
134                Err(err) => Err(sqlx_error_to_exec_err(err)),
135            }
136        })
137    }
138
139    /// Execute an unprepared SQL statement on a MySQL backend
140    #[instrument(level = "trace")]
141    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
142        debug_print!("{}", sql);
143
144        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145        match conn.execute(sql).await {
146            Ok(res) => Ok(res.into()),
147            Err(err) => Err(sqlx_error_to_exec_err(err)),
148        }
149    }
150
151    /// Get one result from a SQL query. Returns [Option::None] if no match was found
152    #[instrument(level = "trace")]
153    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
154        debug_print!("{}", stmt);
155
156        let query = sqlx_query(&stmt);
157        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
158        crate::metric::metric!(self.metric_callback, &stmt, {
159            match query.fetch_one(&mut *conn).await {
160                Ok(row) => Ok(Some(row.into())),
161                Err(err) => match err {
162                    sqlx::Error::RowNotFound => Ok(None),
163                    _ => Err(sqlx_error_to_query_err(err)),
164                },
165            }
166        })
167    }
168
169    /// Get the results of a query returning them as a Vec<[QueryResult]>
170    #[instrument(level = "trace")]
171    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
172        debug_print!("{}", stmt);
173
174        let query = sqlx_query(&stmt);
175        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
176        crate::metric::metric!(self.metric_callback, &stmt, {
177            match query.fetch_all(&mut *conn).await {
178                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
179                Err(err) => Err(sqlx_error_to_query_err(err)),
180            }
181        })
182    }
183
184    /// Stream the results of executing a SQL query
185    #[instrument(level = "trace")]
186    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
187        debug_print!("{}", stmt);
188
189        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
190        Ok(QueryStream::from((
191            conn,
192            stmt,
193            self.metric_callback.clone(),
194        )))
195    }
196
197    /// Bundle a set of SQL statements that execute together.
198    #[instrument(level = "trace")]
199    pub async fn begin(
200        &self,
201        isolation_level: Option<IsolationLevel>,
202        access_mode: Option<AccessMode>,
203    ) -> Result<DatabaseTransaction, DbErr> {
204        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205        DatabaseTransaction::new_mysql(
206            conn,
207            self.metric_callback.clone(),
208            isolation_level,
209            access_mode,
210        )
211        .await
212    }
213
214    /// Create a MySQL transaction
215    #[instrument(level = "trace", skip(callback))]
216    pub async fn transaction<F, T, E>(
217        &self,
218        callback: F,
219        isolation_level: Option<IsolationLevel>,
220        access_mode: Option<AccessMode>,
221    ) -> Result<T, TransactionError<E>>
222    where
223        F: for<'b> FnOnce(
224                &'b DatabaseTransaction,
225            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
226            + Send,
227        T: Send,
228        E: std::fmt::Display + std::fmt::Debug + Send,
229    {
230        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
231        let transaction = DatabaseTransaction::new_mysql(
232            conn,
233            self.metric_callback.clone(),
234            isolation_level,
235            access_mode,
236        )
237        .await
238        .map_err(|e| TransactionError::Connection(e))?;
239        transaction.run(callback).await
240    }
241
242    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
243    where
244        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
245    {
246        self.metric_callback = Some(Arc::new(callback));
247    }
248
249    /// Checks if a connection to the database is still valid.
250    pub async fn ping(&self) -> Result<(), DbErr> {
251        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
252        match conn.ping().await {
253            Ok(_) => Ok(()),
254            Err(err) => Err(sqlx_error_to_conn_err(err)),
255        }
256    }
257
258    /// Explicitly close the MySQL connection.
259    /// See [`Self::close_by_ref`] for usage with references.
260    pub async fn close(self) -> Result<(), DbErr> {
261        self.close_by_ref().await
262    }
263
264    /// Explicitly close the MySQL connection
265    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
266        self.pool.close().await;
267        Ok(())
268    }
269}
270
271impl From<MySqlRow> for QueryResult {
272    fn from(row: MySqlRow) -> QueryResult {
273        QueryResult {
274            row: QueryResultRow::SqlxMySql(row),
275        }
276    }
277}
278
279impl From<MySqlQueryResult> for ExecResult {
280    fn from(result: MySqlQueryResult) -> ExecResult {
281        ExecResult {
282            result: ExecResultHolder::SqlxMySql(result),
283        }
284    }
285}
286
287pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
288    let values = stmt
289        .values
290        .as_ref()
291        .map_or(Values(Vec::new()), |values| values.clone());
292    sqlx::query_with(&stmt.sql, SqlxValues(values))
293}
294
295pub(crate) async fn set_transaction_config(
296    conn: &mut PoolConnection<MySql>,
297    isolation_level: Option<IsolationLevel>,
298    access_mode: Option<AccessMode>,
299) -> Result<(), DbErr> {
300    let mut settings = Vec::new();
301
302    if let Some(isolation_level) = isolation_level {
303        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
304    }
305
306    if let Some(access_mode) = access_mode {
307        settings.push(access_mode.to_string());
308    }
309
310    if !settings.is_empty() {
311        let stmt = Statement {
312            sql: format!("SET TRANSACTION {}", settings.join(", ")),
313            values: None,
314            db_backend: DbBackend::MySql,
315        };
316        let query = sqlx_query(&stmt);
317        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
318    }
319    Ok(())
320}
321
322impl
323    From<(
324        PoolConnection<sqlx::MySql>,
325        Statement,
326        Option<crate::metric::Callback>,
327    )> for crate::QueryStream
328{
329    fn from(
330        (conn, stmt, metric_callback): (
331            PoolConnection<sqlx::MySql>,
332            Statement,
333            Option<crate::metric::Callback>,
334        ),
335    ) -> Self {
336        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
337    }
338}
339
340impl crate::DatabaseTransaction {
341    pub(crate) async fn new_mysql(
342        inner: PoolConnection<sqlx::MySql>,
343        metric_callback: Option<crate::metric::Callback>,
344        isolation_level: Option<IsolationLevel>,
345        access_mode: Option<AccessMode>,
346    ) -> Result<crate::DatabaseTransaction, DbErr> {
347        Self::begin(
348            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
349            crate::DbBackend::MySql,
350            metric_callback,
351            isolation_level,
352            access_mode,
353            None,
354        )
355        .await
356    }
357}
358
359#[cfg(feature = "proxy")]
360pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
361    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
362    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
363    use sea_query::Value;
364    use sqlx::{Column, Row, TypeInfo};
365    crate::ProxyRow {
366        values: row
367            .columns()
368            .iter()
369            .map(|c| {
370                (
371                    c.name().to_string(),
372                    match c.type_info().name() {
373                        "TINYINT(1)" | "BOOLEAN" => {
374                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
375                        }
376                        "TINYINT UNSIGNED" => Value::TinyUnsigned(
377                            row.try_get(c.ordinal())
378                                .expect("Failed to get unsigned tiny integer"),
379                        ),
380                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(
381                            row.try_get(c.ordinal())
382                                .expect("Failed to get unsigned small integer"),
383                        ),
384                        "INT UNSIGNED" => Value::Unsigned(
385                            row.try_get(c.ordinal())
386                                .expect("Failed to get unsigned integer"),
387                        ),
388                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
389                            row.try_get(c.ordinal())
390                                .expect("Failed to get unsigned big integer"),
391                        ),
392                        "TINYINT" => Value::TinyInt(
393                            row.try_get(c.ordinal())
394                                .expect("Failed to get tiny integer"),
395                        ),
396                        "SMALLINT" => Value::SmallInt(
397                            row.try_get(c.ordinal())
398                                .expect("Failed to get small integer"),
399                        ),
400                        "INT" => {
401                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
402                        }
403                        "MEDIUMINT" | "BIGINT" => Value::BigInt(
404                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
405                        ),
406                        "FLOAT" => {
407                            Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
408                        }
409                        "DOUBLE" => {
410                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
411                        }
412
413                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
414                        | "LONGBLOB" => Value::Bytes(
415                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
416                                .expect("Failed to get bytes")
417                                .map(Box::new),
418                        ),
419
420                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
421                            Value::String(
422                                row.try_get::<Option<String>, _>(c.ordinal())
423                                    .expect("Failed to get string")
424                                    .map(Box::new),
425                            )
426                        }
427
428                        #[cfg(feature = "with-chrono")]
429                        "TIMESTAMP" => Value::ChronoDateTimeUtc(
430                            row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
431                                .expect("Failed to get timestamp")
432                                .map(Box::new),
433                        ),
434                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
435                        "TIMESTAMP" => Value::TimeDateTime(
436                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
437                                .expect("Failed to get timestamp")
438                                .map(Box::new),
439                        ),
440
441                        #[cfg(feature = "with-chrono")]
442                        "DATE" => Value::ChronoDate(
443                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
444                                .expect("Failed to get date")
445                                .map(Box::new),
446                        ),
447                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
448                        "DATE" => Value::TimeDate(
449                            row.try_get::<Option<time::Date>, _>(c.ordinal())
450                                .expect("Failed to get date")
451                                .map(Box::new),
452                        ),
453
454                        #[cfg(feature = "with-chrono")]
455                        "TIME" => Value::ChronoTime(
456                            row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
457                                .expect("Failed to get time")
458                                .map(Box::new),
459                        ),
460                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
461                        "TIME" => Value::TimeTime(
462                            row.try_get::<Option<time::Time>, _>(c.ordinal())
463                                .expect("Failed to get time")
464                                .map(Box::new),
465                        ),
466
467                        #[cfg(feature = "with-chrono")]
468                        "DATETIME" => Value::ChronoDateTime(
469                            row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
470                                .expect("Failed to get datetime")
471                                .map(Box::new),
472                        ),
473                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
474                        "DATETIME" => Value::TimeDateTime(
475                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
476                                .expect("Failed to get datetime")
477                                .map(Box::new),
478                        ),
479
480                        #[cfg(feature = "with-chrono")]
481                        "YEAR" => Value::ChronoDate(
482                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
483                                .expect("Failed to get year")
484                                .map(Box::new),
485                        ),
486                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
487                        "YEAR" => Value::TimeDate(
488                            row.try_get::<Option<time::Date>, _>(c.ordinal())
489                                .expect("Failed to get year")
490                                .map(Box::new),
491                        ),
492
493                        "ENUM" | "SET" | "GEOMETRY" => Value::String(
494                            row.try_get::<Option<String>, _>(c.ordinal())
495                                .expect("Failed to get serialized string")
496                                .map(Box::new),
497                        ),
498
499                        #[cfg(feature = "with-bigdecimal")]
500                        "DECIMAL" => Value::BigDecimal(
501                            row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
502                                .expect("Failed to get decimal")
503                                .map(Box::new),
504                        ),
505                        #[cfg(all(
506                            feature = "with-rust_decimal",
507                            not(feature = "with-bigdecimal")
508                        ))]
509                        "DECIMAL" => Value::Decimal(
510                            row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
511                                .expect("Failed to get decimal")
512                                .map(Box::new),
513                        ),
514
515                        #[cfg(feature = "with-json")]
516                        "JSON" => Value::Json(
517                            row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
518                                .expect("Failed to get json")
519                                .map(Box::new),
520                        ),
521
522                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
523                    },
524                )
525            })
526            .collect(),
527    }
528}