Skip to main content

sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, MySql, MySqlPool,
8    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9    pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    DbBackend, IsolationLevel, Statement, TransactionError, debug_print, error::*, executor::*,
18};
19
20use super::sqlx_common::*;
21
22#[cfg(feature = "stream")]
23use crate::QueryStream;
24
25/// Defines the [sqlx::mysql] connector
26#[derive(Debug)]
27pub struct SqlxMySqlConnector;
28
29/// Defines a sqlx MySQL pool
30#[derive(Clone)]
31pub struct SqlxMySqlPoolConnection {
32    pub(crate) pool: MySqlPool,
33    metric_callback: Option<crate::metric::Callback>,
34    pub(crate) record_stmt_in_spans: bool,
35}
36
37impl std::fmt::Debug for SqlxMySqlPoolConnection {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
40    }
41}
42
43impl From<MySqlPool> for SqlxMySqlPoolConnection {
44    fn from(pool: MySqlPool) -> Self {
45        SqlxMySqlPoolConnection {
46            pool,
47            metric_callback: None,
48            record_stmt_in_spans: true,
49        }
50    }
51}
52
53impl From<MySqlPool> for DatabaseConnection {
54    fn from(pool: MySqlPool) -> Self {
55        DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
56    }
57}
58
59impl SqlxMySqlConnector {
60    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
61    pub fn accepts(string: &str) -> bool {
62        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
63    }
64
65    /// Add configuration options for the MySQL database
66    #[instrument(level = "trace")]
67    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
68        let record_stmt_in_spans = options.get_record_stmt_in_spans();
69        let mut sqlx_opts = options
70            .url
71            .parse::<MySqlConnectOptions>()
72            .map_err(sqlx_error_to_conn_err)?;
73        use sqlx::ConnectOptions;
74        if !options.sqlx_logging {
75            sqlx_opts = sqlx_opts.disable_statement_logging();
76        } else {
77            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
78            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
79                sqlx_opts = sqlx_opts.log_slow_statements(
80                    options.sqlx_slow_statements_logging_level,
81                    options.sqlx_slow_statements_logging_threshold,
82                );
83            }
84        }
85
86        if let Some(f) = &options.mysql_opts_fn {
87            sqlx_opts = f(sqlx_opts);
88        }
89        let after_connect = options.after_connect.clone();
90        let connect_lazy = options.connect_lazy;
91        let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
92        let mut pool_options = options.sqlx_pool_options();
93        if let Some(f) = &mysql_pool_opts_fn {
94            pool_options = f(pool_options);
95        }
96        let pool = if connect_lazy {
97            pool_options.connect_lazy_with(sqlx_opts)
98        } else {
99            pool_options
100                .connect_with(sqlx_opts)
101                .await
102                .map_err(sqlx_error_to_conn_err)?
103        };
104
105        let conn: DatabaseConnection =
106            DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
107                pool,
108                metric_callback: None,
109                record_stmt_in_spans,
110            })
111            .into();
112
113        if let Some(cb) = after_connect {
114            cb(conn.clone()).await?;
115        }
116
117        Ok(conn)
118    }
119}
120
121impl SqlxMySqlConnector {
122    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
123    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
124        DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
125            pool,
126            metric_callback: None,
127            record_stmt_in_spans: true,
128        })
129        .into()
130    }
131}
132
133impl SqlxMySqlPoolConnection {
134    /// Execute a [Statement] on a MySQL backend
135    #[instrument(level = "trace", skip(stmt))]
136    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
137        debug_print!("{}", stmt);
138
139        let query = sqlx_query(&stmt);
140        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
141        crate::metric::metric!(self.metric_callback, &stmt, {
142            match query.execute(&mut *conn).await {
143                Ok(res) => Ok(res.into()),
144                Err(err) => Err(sqlx_error_to_exec_err(err)),
145            }
146        })
147    }
148
149    /// Execute an unprepared SQL statement on a MySQL backend
150    #[instrument(level = "trace", skip(sql))]
151    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
152        debug_print!("{}", sql);
153
154        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
155        match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
156            Ok(res) => Ok(res.into()),
157            Err(err) => Err(sqlx_error_to_exec_err(err)),
158        }
159    }
160
161    /// Get one result from a SQL query. Returns [Option::None] if no match was found
162    #[instrument(level = "trace", skip(stmt))]
163    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
164        debug_print!("{}", stmt);
165
166        let query = sqlx_query(&stmt);
167        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
168        crate::metric::metric!(self.metric_callback, &stmt, {
169            match query.fetch_one(&mut *conn).await {
170                Ok(row) => Ok(Some(row.into())),
171                Err(err) => match err {
172                    sqlx::Error::RowNotFound => Ok(None),
173                    _ => Err(sqlx_error_to_query_err(err)),
174                },
175            }
176        })
177    }
178
179    /// Get the results of a query returning them as a Vec<[QueryResult]>
180    #[instrument(level = "trace", skip(stmt))]
181    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
182        debug_print!("{}", stmt);
183
184        let query = sqlx_query(&stmt);
185        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
186        crate::metric::metric!(self.metric_callback, &stmt, {
187            match query.fetch_all(&mut *conn).await {
188                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
189                Err(err) => Err(sqlx_error_to_query_err(err)),
190            }
191        })
192    }
193
194    /// Stream the results of executing a SQL query
195    #[instrument(level = "trace", skip(stmt))]
196    #[cfg(feature = "stream")]
197    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
198        debug_print!("{}", stmt);
199
200        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
201        Ok(QueryStream::from((
202            conn,
203            stmt,
204            self.metric_callback.clone(),
205        )))
206    }
207
208    /// Bundle a set of SQL statements that execute together.
209    #[instrument(level = "trace")]
210    pub async fn begin(
211        &self,
212        isolation_level: Option<IsolationLevel>,
213        access_mode: Option<AccessMode>,
214    ) -> Result<DatabaseTransaction, DbErr> {
215        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
216        DatabaseTransaction::new_mysql(
217            conn,
218            self.metric_callback.clone(),
219            self.record_stmt_in_spans,
220            isolation_level,
221            access_mode,
222        )
223        .await
224    }
225
226    /// Create a MySQL transaction
227    #[instrument(level = "trace", skip(callback))]
228    pub async fn transaction<F, T, E>(
229        &self,
230        callback: F,
231        isolation_level: Option<IsolationLevel>,
232        access_mode: Option<AccessMode>,
233    ) -> Result<T, TransactionError<E>>
234    where
235        F: for<'b> FnOnce(
236                &'b DatabaseTransaction,
237            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
238            + Send,
239        T: Send,
240        E: std::fmt::Display + std::fmt::Debug + Send,
241    {
242        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
243        let transaction = DatabaseTransaction::new_mysql(
244            conn,
245            self.metric_callback.clone(),
246            self.record_stmt_in_spans,
247            isolation_level,
248            access_mode,
249        )
250        .await
251        .map_err(|e| TransactionError::Connection(e))?;
252        transaction.run(callback).await
253    }
254
255    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
256    where
257        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
258    {
259        self.metric_callback = Some(Arc::new(callback));
260    }
261
262    /// Checks if a connection to the database is still valid.
263    pub async fn ping(&self) -> Result<(), DbErr> {
264        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
265        match conn.ping().await {
266            Ok(_) => Ok(()),
267            Err(err) => Err(sqlx_error_to_conn_err(err)),
268        }
269    }
270
271    /// Explicitly close the MySQL connection.
272    /// See [`Self::close_by_ref`] for usage with references.
273    pub async fn close(self) -> Result<(), DbErr> {
274        self.close_by_ref().await
275    }
276
277    /// Explicitly close the MySQL connection
278    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
279        self.pool.close().await;
280        Ok(())
281    }
282}
283
284impl From<MySqlRow> for QueryResult {
285    fn from(row: MySqlRow) -> QueryResult {
286        QueryResult {
287            row: QueryResultRow::SqlxMySql(row),
288        }
289    }
290}
291
292impl From<MySqlQueryResult> for ExecResult {
293    fn from(result: MySqlQueryResult) -> ExecResult {
294        ExecResult {
295            result: ExecResultHolder::SqlxMySql(result),
296        }
297    }
298}
299
300pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
301    let values = stmt
302        .values
303        .as_ref()
304        .map_or(Values(Vec::new()), |values| values.clone());
305    sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
306}
307
308pub(crate) async fn set_transaction_config(
309    conn: &mut PoolConnection<MySql>,
310    isolation_level: Option<IsolationLevel>,
311    access_mode: Option<AccessMode>,
312) -> Result<(), DbErr> {
313    let mut settings = Vec::new();
314
315    if let Some(isolation_level) = isolation_level {
316        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
317    }
318
319    if let Some(access_mode) = access_mode {
320        settings.push(access_mode.to_string());
321    }
322
323    if !settings.is_empty() {
324        let stmt = Statement {
325            sql: format!("SET TRANSACTION {}", settings.join(", ")),
326            values: None,
327            db_backend: DbBackend::MySql,
328        };
329        let query = sqlx_query(&stmt);
330        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
331    }
332    Ok(())
333}
334
335#[cfg(feature = "stream")]
336impl
337    From<(
338        PoolConnection<sqlx::MySql>,
339        Statement,
340        Option<crate::metric::Callback>,
341    )> for crate::QueryStream
342{
343    fn from(
344        (conn, stmt, metric_callback): (
345            PoolConnection<sqlx::MySql>,
346            Statement,
347            Option<crate::metric::Callback>,
348        ),
349    ) -> Self {
350        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
351    }
352}
353
354impl crate::DatabaseTransaction {
355    pub(crate) async fn new_mysql(
356        inner: PoolConnection<sqlx::MySql>,
357        metric_callback: Option<crate::metric::Callback>,
358        record_stmt_in_spans: bool,
359        isolation_level: Option<IsolationLevel>,
360        access_mode: Option<AccessMode>,
361    ) -> Result<crate::DatabaseTransaction, DbErr> {
362        Self::begin(
363            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
364            crate::DbBackend::MySql,
365            metric_callback,
366            record_stmt_in_spans,
367            isolation_level,
368            access_mode,
369            None,
370        )
371        .await
372    }
373}
374
375#[cfg(feature = "proxy")]
376pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
377    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
378    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
379    use sea_query::Value;
380    use sqlx::{Column, Row, TypeInfo};
381    crate::ProxyRow {
382        values: row
383            .columns()
384            .iter()
385            .map(|c| {
386                (
387                    c.name().to_string(),
388                    match c.type_info().name() {
389                        "TINYINT(1)" | "BOOLEAN" => {
390                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
391                        }
392                        "TINYINT UNSIGNED" => Value::TinyUnsigned(
393                            row.try_get(c.ordinal())
394                                .expect("Failed to get unsigned tiny integer"),
395                        ),
396                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(
397                            row.try_get(c.ordinal())
398                                .expect("Failed to get unsigned small integer"),
399                        ),
400                        "INT UNSIGNED" => Value::Unsigned(
401                            row.try_get(c.ordinal())
402                                .expect("Failed to get unsigned integer"),
403                        ),
404                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
405                            row.try_get(c.ordinal())
406                                .expect("Failed to get unsigned big integer"),
407                        ),
408                        "TINYINT" => Value::TinyInt(
409                            row.try_get(c.ordinal())
410                                .expect("Failed to get tiny integer"),
411                        ),
412                        "SMALLINT" => Value::SmallInt(
413                            row.try_get(c.ordinal())
414                                .expect("Failed to get small integer"),
415                        ),
416                        "INT" => {
417                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
418                        }
419                        "MEDIUMINT" | "BIGINT" => Value::BigInt(
420                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
421                        ),
422                        "FLOAT" => {
423                            Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
424                        }
425                        "DOUBLE" => {
426                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
427                        }
428
429                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
430                        | "LONGBLOB" => Value::Bytes(
431                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
432                                .expect("Failed to get bytes")
433                                .map(Box::new),
434                        ),
435
436                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
437                            Value::String(
438                                row.try_get::<Option<String>, _>(c.ordinal())
439                                    .expect("Failed to get string")
440                                    .map(Box::new),
441                            )
442                        }
443
444                        #[cfg(feature = "with-chrono")]
445                        "TIMESTAMP" => Value::ChronoDateTimeUtc(
446                            row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
447                                .expect("Failed to get timestamp")
448                                .map(Box::new),
449                        ),
450                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
451                        "TIMESTAMP" => Value::TimeDateTime(
452                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
453                                .expect("Failed to get timestamp")
454                                .map(Box::new),
455                        ),
456
457                        #[cfg(feature = "with-chrono")]
458                        "DATE" => Value::ChronoDate(
459                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
460                                .expect("Failed to get date")
461                                .map(Box::new),
462                        ),
463                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
464                        "DATE" => Value::TimeDate(
465                            row.try_get::<Option<time::Date>, _>(c.ordinal())
466                                .expect("Failed to get date")
467                                .map(Box::new),
468                        ),
469
470                        #[cfg(feature = "with-chrono")]
471                        "TIME" => Value::ChronoTime(
472                            row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
473                                .expect("Failed to get time")
474                                .map(Box::new),
475                        ),
476                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
477                        "TIME" => Value::TimeTime(
478                            row.try_get::<Option<time::Time>, _>(c.ordinal())
479                                .expect("Failed to get time")
480                                .map(Box::new),
481                        ),
482
483                        #[cfg(feature = "with-chrono")]
484                        "DATETIME" => Value::ChronoDateTime(
485                            row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
486                                .expect("Failed to get datetime")
487                                .map(Box::new),
488                        ),
489                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
490                        "DATETIME" => Value::TimeDateTime(
491                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
492                                .expect("Failed to get datetime")
493                                .map(Box::new),
494                        ),
495
496                        #[cfg(feature = "with-chrono")]
497                        "YEAR" => Value::ChronoDate(
498                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
499                                .expect("Failed to get year")
500                                .map(Box::new),
501                        ),
502                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
503                        "YEAR" => Value::TimeDate(
504                            row.try_get::<Option<time::Date>, _>(c.ordinal())
505                                .expect("Failed to get year")
506                                .map(Box::new),
507                        ),
508
509                        "ENUM" | "SET" | "GEOMETRY" => Value::String(
510                            row.try_get::<Option<String>, _>(c.ordinal())
511                                .expect("Failed to get serialized string")
512                                .map(Box::new),
513                        ),
514
515                        #[cfg(feature = "with-bigdecimal")]
516                        "DECIMAL" => Value::BigDecimal(
517                            row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
518                                .expect("Failed to get decimal")
519                                .map(Box::new),
520                        ),
521                        #[cfg(all(
522                            feature = "with-rust_decimal",
523                            not(feature = "with-bigdecimal")
524                        ))]
525                        "DECIMAL" => Value::Decimal(
526                            row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
527                                .expect("Failed to get decimal")
528                                .map(Box::new),
529                        ),
530
531                        #[cfg(feature = "with-json")]
532                        "JSON" => Value::Json(
533                            row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
534                                .expect("Failed to get json")
535                                .map(Box::new),
536                        ),
537
538                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
539                    },
540                )
541            })
542            .collect(),
543    }
544}