Skip to main content

sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, MySql, MySqlPool,
8    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9    pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18    executor::*,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::mysql] connector
24#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27/// Defines a sqlx MySQL pool
28#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30    pub(crate) pool: MySqlPool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41    fn from(pool: MySqlPool) -> Self {
42        SqlxMySqlPoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50    fn from(pool: MySqlPool) -> Self {
51        DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxMySqlConnector {
56    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the MySQL database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut sqlx_opts = options
65            .url
66            .parse::<MySqlConnectOptions>()
67            .map_err(sqlx_error_to_conn_err)?;
68        use sqlx::ConnectOptions;
69        if !options.sqlx_logging {
70            sqlx_opts = sqlx_opts.disable_statement_logging();
71        } else {
72            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74                sqlx_opts = sqlx_opts.log_slow_statements(
75                    options.sqlx_slow_statements_logging_level,
76                    options.sqlx_slow_statements_logging_threshold,
77                );
78            }
79        }
80
81        if let Some(f) = &options.mysql_opts_fn {
82            sqlx_opts = f(sqlx_opts);
83        }
84        let after_connect = options.after_connect.clone();
85        let connect_lazy = options.connect_lazy;
86        let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
87        let mut pool_options = options.sqlx_pool_options();
88        if let Some(f) = &mysql_pool_opts_fn {
89            pool_options = f(pool_options);
90        }
91        let pool = if connect_lazy {
92            pool_options.connect_lazy_with(sqlx_opts)
93        } else {
94            pool_options
95                .connect_with(sqlx_opts)
96                .await
97                .map_err(sqlx_error_to_conn_err)?
98        };
99
100        let conn: DatabaseConnection =
101            DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
102                pool,
103                metric_callback: None,
104            })
105            .into();
106
107        if let Some(cb) = after_connect {
108            cb(conn.clone()).await?;
109        }
110
111        Ok(conn)
112    }
113}
114
115impl SqlxMySqlConnector {
116    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
117    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
118        DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
119            pool,
120            metric_callback: None,
121        })
122        .into()
123    }
124}
125
126impl SqlxMySqlPoolConnection {
127    /// Execute a [Statement] on a MySQL backend
128    #[instrument(level = "trace")]
129    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
130        debug_print!("{}", stmt);
131
132        let query = sqlx_query(&stmt);
133        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
134        crate::metric::metric!(self.metric_callback, &stmt, {
135            match query.execute(&mut *conn).await {
136                Ok(res) => Ok(res.into()),
137                Err(err) => Err(sqlx_error_to_exec_err(err)),
138            }
139        })
140    }
141
142    /// Execute an unprepared SQL statement on a MySQL backend
143    #[instrument(level = "trace")]
144    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
145        debug_print!("{}", sql);
146
147        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
148        match conn.execute(sql).await {
149            Ok(res) => Ok(res.into()),
150            Err(err) => Err(sqlx_error_to_exec_err(err)),
151        }
152    }
153
154    /// Get one result from a SQL query. Returns [Option::None] if no match was found
155    #[instrument(level = "trace")]
156    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
157        debug_print!("{}", stmt);
158
159        let query = sqlx_query(&stmt);
160        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
161        crate::metric::metric!(self.metric_callback, &stmt, {
162            match query.fetch_one(&mut *conn).await {
163                Ok(row) => Ok(Some(row.into())),
164                Err(err) => match err {
165                    sqlx::Error::RowNotFound => Ok(None),
166                    _ => Err(sqlx_error_to_query_err(err)),
167                },
168            }
169        })
170    }
171
172    /// Get the results of a query returning them as a Vec<[QueryResult]>
173    #[instrument(level = "trace")]
174    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
175        debug_print!("{}", stmt);
176
177        let query = sqlx_query(&stmt);
178        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
179        crate::metric::metric!(self.metric_callback, &stmt, {
180            match query.fetch_all(&mut *conn).await {
181                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
182                Err(err) => Err(sqlx_error_to_query_err(err)),
183            }
184        })
185    }
186
187    /// Stream the results of executing a SQL query
188    #[instrument(level = "trace")]
189    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
190        debug_print!("{}", stmt);
191
192        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
193        Ok(QueryStream::from((
194            conn,
195            stmt,
196            self.metric_callback.clone(),
197        )))
198    }
199
200    /// Bundle a set of SQL statements that execute together.
201    #[instrument(level = "trace")]
202    pub async fn begin(
203        &self,
204        isolation_level: Option<IsolationLevel>,
205        access_mode: Option<AccessMode>,
206    ) -> Result<DatabaseTransaction, DbErr> {
207        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
208        DatabaseTransaction::new_mysql(
209            conn,
210            self.metric_callback.clone(),
211            isolation_level,
212            access_mode,
213        )
214        .await
215    }
216
217    /// Create a MySQL transaction
218    #[instrument(level = "trace", skip(callback))]
219    pub async fn transaction<F, T, E>(
220        &self,
221        callback: F,
222        isolation_level: Option<IsolationLevel>,
223        access_mode: Option<AccessMode>,
224    ) -> Result<T, TransactionError<E>>
225    where
226        F: for<'b> FnOnce(
227                &'b DatabaseTransaction,
228            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
229            + Send,
230        T: Send,
231        E: std::fmt::Display + std::fmt::Debug + Send,
232    {
233        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
234        let transaction = DatabaseTransaction::new_mysql(
235            conn,
236            self.metric_callback.clone(),
237            isolation_level,
238            access_mode,
239        )
240        .await
241        .map_err(|e| TransactionError::Connection(e))?;
242        transaction.run(callback).await
243    }
244
245    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
246    where
247        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
248    {
249        self.metric_callback = Some(Arc::new(callback));
250    }
251
252    /// Checks if a connection to the database is still valid.
253    pub async fn ping(&self) -> Result<(), DbErr> {
254        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
255        match conn.ping().await {
256            Ok(_) => Ok(()),
257            Err(err) => Err(sqlx_error_to_conn_err(err)),
258        }
259    }
260
261    /// Explicitly close the MySQL connection.
262    /// See [`Self::close_by_ref`] for usage with references.
263    pub async fn close(self) -> Result<(), DbErr> {
264        self.close_by_ref().await
265    }
266
267    /// Explicitly close the MySQL connection
268    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
269        self.pool.close().await;
270        Ok(())
271    }
272}
273
274impl From<MySqlRow> for QueryResult {
275    fn from(row: MySqlRow) -> QueryResult {
276        QueryResult {
277            row: QueryResultRow::SqlxMySql(row),
278        }
279    }
280}
281
282impl From<MySqlQueryResult> for ExecResult {
283    fn from(result: MySqlQueryResult) -> ExecResult {
284        ExecResult {
285            result: ExecResultHolder::SqlxMySql(result),
286        }
287    }
288}
289
290pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
291    let values = stmt
292        .values
293        .as_ref()
294        .map_or(Values(Vec::new()), |values| values.clone());
295    sqlx::query_with(&stmt.sql, SqlxValues(values))
296}
297
298pub(crate) async fn set_transaction_config(
299    conn: &mut PoolConnection<MySql>,
300    isolation_level: Option<IsolationLevel>,
301    access_mode: Option<AccessMode>,
302) -> Result<(), DbErr> {
303    let mut settings = Vec::new();
304
305    if let Some(isolation_level) = isolation_level {
306        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
307    }
308
309    if let Some(access_mode) = access_mode {
310        settings.push(access_mode.to_string());
311    }
312
313    if !settings.is_empty() {
314        let stmt = Statement {
315            sql: format!("SET TRANSACTION {}", settings.join(", ")),
316            values: None,
317            db_backend: DbBackend::MySql,
318        };
319        let query = sqlx_query(&stmt);
320        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
321    }
322    Ok(())
323}
324
325impl
326    From<(
327        PoolConnection<sqlx::MySql>,
328        Statement,
329        Option<crate::metric::Callback>,
330    )> for crate::QueryStream
331{
332    fn from(
333        (conn, stmt, metric_callback): (
334            PoolConnection<sqlx::MySql>,
335            Statement,
336            Option<crate::metric::Callback>,
337        ),
338    ) -> Self {
339        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
340    }
341}
342
343impl crate::DatabaseTransaction {
344    pub(crate) async fn new_mysql(
345        inner: PoolConnection<sqlx::MySql>,
346        metric_callback: Option<crate::metric::Callback>,
347        isolation_level: Option<IsolationLevel>,
348        access_mode: Option<AccessMode>,
349    ) -> Result<crate::DatabaseTransaction, DbErr> {
350        Self::begin(
351            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
352            crate::DbBackend::MySql,
353            metric_callback,
354            isolation_level,
355            access_mode,
356            None,
357        )
358        .await
359    }
360}
361
362#[cfg(feature = "proxy")]
363pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
364    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
365    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
366    use sea_query::Value;
367    use sqlx::{Column, Row, TypeInfo};
368    crate::ProxyRow {
369        values: row
370            .columns()
371            .iter()
372            .map(|c| {
373                (
374                    c.name().to_string(),
375                    match c.type_info().name() {
376                        "TINYINT(1)" | "BOOLEAN" => {
377                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
378                        }
379                        "TINYINT UNSIGNED" => Value::TinyUnsigned(
380                            row.try_get(c.ordinal())
381                                .expect("Failed to get unsigned tiny integer"),
382                        ),
383                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(
384                            row.try_get(c.ordinal())
385                                .expect("Failed to get unsigned small integer"),
386                        ),
387                        "INT UNSIGNED" => Value::Unsigned(
388                            row.try_get(c.ordinal())
389                                .expect("Failed to get unsigned integer"),
390                        ),
391                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
392                            row.try_get(c.ordinal())
393                                .expect("Failed to get unsigned big integer"),
394                        ),
395                        "TINYINT" => Value::TinyInt(
396                            row.try_get(c.ordinal())
397                                .expect("Failed to get tiny integer"),
398                        ),
399                        "SMALLINT" => Value::SmallInt(
400                            row.try_get(c.ordinal())
401                                .expect("Failed to get small integer"),
402                        ),
403                        "INT" => {
404                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
405                        }
406                        "MEDIUMINT" | "BIGINT" => Value::BigInt(
407                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
408                        ),
409                        "FLOAT" => {
410                            Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
411                        }
412                        "DOUBLE" => {
413                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
414                        }
415
416                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
417                        | "LONGBLOB" => Value::Bytes(
418                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
419                                .expect("Failed to get bytes")
420                                .map(Box::new),
421                        ),
422
423                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
424                            Value::String(
425                                row.try_get::<Option<String>, _>(c.ordinal())
426                                    .expect("Failed to get string")
427                                    .map(Box::new),
428                            )
429                        }
430
431                        #[cfg(feature = "with-chrono")]
432                        "TIMESTAMP" => Value::ChronoDateTimeUtc(
433                            row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
434                                .expect("Failed to get timestamp")
435                                .map(Box::new),
436                        ),
437                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
438                        "TIMESTAMP" => Value::TimeDateTime(
439                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
440                                .expect("Failed to get timestamp")
441                                .map(Box::new),
442                        ),
443
444                        #[cfg(feature = "with-chrono")]
445                        "DATE" => Value::ChronoDate(
446                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
447                                .expect("Failed to get date")
448                                .map(Box::new),
449                        ),
450                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
451                        "DATE" => Value::TimeDate(
452                            row.try_get::<Option<time::Date>, _>(c.ordinal())
453                                .expect("Failed to get date")
454                                .map(Box::new),
455                        ),
456
457                        #[cfg(feature = "with-chrono")]
458                        "TIME" => Value::ChronoTime(
459                            row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
460                                .expect("Failed to get time")
461                                .map(Box::new),
462                        ),
463                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
464                        "TIME" => Value::TimeTime(
465                            row.try_get::<Option<time::Time>, _>(c.ordinal())
466                                .expect("Failed to get time")
467                                .map(Box::new),
468                        ),
469
470                        #[cfg(feature = "with-chrono")]
471                        "DATETIME" => Value::ChronoDateTime(
472                            row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
473                                .expect("Failed to get datetime")
474                                .map(Box::new),
475                        ),
476                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
477                        "DATETIME" => Value::TimeDateTime(
478                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
479                                .expect("Failed to get datetime")
480                                .map(Box::new),
481                        ),
482
483                        #[cfg(feature = "with-chrono")]
484                        "YEAR" => Value::ChronoDate(
485                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
486                                .expect("Failed to get year")
487                                .map(Box::new),
488                        ),
489                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
490                        "YEAR" => Value::TimeDate(
491                            row.try_get::<Option<time::Date>, _>(c.ordinal())
492                                .expect("Failed to get year")
493                                .map(Box::new),
494                        ),
495
496                        "ENUM" | "SET" | "GEOMETRY" => Value::String(
497                            row.try_get::<Option<String>, _>(c.ordinal())
498                                .expect("Failed to get serialized string")
499                                .map(Box::new),
500                        ),
501
502                        #[cfg(feature = "with-bigdecimal")]
503                        "DECIMAL" => Value::BigDecimal(
504                            row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
505                                .expect("Failed to get decimal")
506                                .map(Box::new),
507                        ),
508                        #[cfg(all(
509                            feature = "with-rust_decimal",
510                            not(feature = "with-bigdecimal")
511                        ))]
512                        "DECIMAL" => Value::Decimal(
513                            row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
514                                .expect("Failed to get decimal")
515                                .map(Box::new),
516                        ),
517
518                        #[cfg(feature = "with-json")]
519                        "JSON" => Value::Json(
520                            row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
521                                .expect("Failed to get json")
522                                .map(Box::new),
523                        ),
524
525                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
526                    },
527                )
528            })
529            .collect(),
530    }
531}