sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8    pool::PoolConnection,
9    Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17    DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22/// Defines the [sqlx::mysql] connector
23#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26/// Defines a sqlx MySQL pool
27#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29    pub(crate) pool: MySqlPool,
30    metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36    }
37}
38
39impl From<MySqlPool> for SqlxMySqlPoolConnection {
40    fn from(pool: MySqlPool) -> Self {
41        SqlxMySqlPoolConnection {
42            pool,
43            metric_callback: None,
44        }
45    }
46}
47
48impl From<MySqlPool> for DatabaseConnection {
49    fn from(pool: MySqlPool) -> Self {
50        DatabaseConnection::SqlxMySqlPoolConnection(pool.into())
51    }
52}
53
54impl SqlxMySqlConnector {
55    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
56    pub fn accepts(string: &str) -> bool {
57        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
58    }
59
60    /// Add configuration options for the MySQL database
61    #[instrument(level = "trace")]
62    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
63        let mut sqlx_opts = options
64            .url
65            .parse::<MySqlConnectOptions>()
66            .map_err(sqlx_error_to_conn_err)?;
67        use sqlx::ConnectOptions;
68        if !options.sqlx_logging {
69            sqlx_opts = sqlx_opts.disable_statement_logging();
70        } else {
71            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
72            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
73                sqlx_opts = sqlx_opts.log_slow_statements(
74                    options.sqlx_slow_statements_logging_level,
75                    options.sqlx_slow_statements_logging_threshold,
76                );
77            }
78        }
79        if let Some(f) = &options.mysql_opts_fn {
80            sqlx_opts = f(sqlx_opts);
81        }
82        let pool = if options.connect_lazy {
83            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
84        } else {
85            options
86                .sqlx_pool_options()
87                .connect_with(sqlx_opts)
88                .await
89                .map_err(sqlx_error_to_conn_err)?
90        };
91        Ok(DatabaseConnection::SqlxMySqlPoolConnection(
92            SqlxMySqlPoolConnection {
93                pool,
94                metric_callback: None,
95            },
96        ))
97    }
98}
99
100impl SqlxMySqlConnector {
101    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
102    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
103        DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
104            pool,
105            metric_callback: None,
106        })
107    }
108}
109
110impl SqlxMySqlPoolConnection {
111    /// Execute a [Statement] on a MySQL backend
112    #[instrument(level = "trace")]
113    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
114        debug_print!("{}", stmt);
115
116        let query = sqlx_query(&stmt);
117        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
118        crate::metric::metric!(self.metric_callback, &stmt, {
119            match query.execute(&mut *conn).await {
120                Ok(res) => Ok(res.into()),
121                Err(err) => Err(sqlx_error_to_exec_err(err)),
122            }
123        })
124    }
125
126    /// Execute an unprepared SQL statement on a MySQL backend
127    #[instrument(level = "trace")]
128    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
129        debug_print!("{}", sql);
130
131        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
132        match conn.execute(sql).await {
133            Ok(res) => Ok(res.into()),
134            Err(err) => Err(sqlx_error_to_exec_err(err)),
135        }
136    }
137
138    /// Get one result from a SQL query. Returns [Option::None] if no match was found
139    #[instrument(level = "trace")]
140    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
141        debug_print!("{}", stmt);
142
143        let query = sqlx_query(&stmt);
144        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145        crate::metric::metric!(self.metric_callback, &stmt, {
146            match query.fetch_one(&mut *conn).await {
147                Ok(row) => Ok(Some(row.into())),
148                Err(err) => match err {
149                    sqlx::Error::RowNotFound => Ok(None),
150                    _ => Err(sqlx_error_to_query_err(err)),
151                },
152            }
153        })
154    }
155
156    /// Get the results of a query returning them as a Vec<[QueryResult]>
157    #[instrument(level = "trace")]
158    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
159        debug_print!("{}", stmt);
160
161        let query = sqlx_query(&stmt);
162        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
163        crate::metric::metric!(self.metric_callback, &stmt, {
164            match query.fetch_all(&mut *conn).await {
165                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
166                Err(err) => Err(sqlx_error_to_query_err(err)),
167            }
168        })
169    }
170
171    /// Stream the results of executing a SQL query
172    #[instrument(level = "trace")]
173    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
174        debug_print!("{}", stmt);
175
176        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
177        Ok(QueryStream::from((
178            conn,
179            stmt,
180            self.metric_callback.clone(),
181        )))
182    }
183
184    /// Bundle a set of SQL statements that execute together.
185    #[instrument(level = "trace")]
186    pub async fn begin(
187        &self,
188        isolation_level: Option<IsolationLevel>,
189        access_mode: Option<AccessMode>,
190    ) -> Result<DatabaseTransaction, DbErr> {
191        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
192        DatabaseTransaction::new_mysql(
193            conn,
194            self.metric_callback.clone(),
195            isolation_level,
196            access_mode,
197        )
198        .await
199    }
200
201    /// Create a MySQL transaction
202    #[instrument(level = "trace", skip(callback))]
203    pub async fn transaction<F, T, E>(
204        &self,
205        callback: F,
206        isolation_level: Option<IsolationLevel>,
207        access_mode: Option<AccessMode>,
208    ) -> Result<T, TransactionError<E>>
209    where
210        F: for<'b> FnOnce(
211                &'b DatabaseTransaction,
212            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
213            + Send,
214        T: Send,
215        E: std::fmt::Display + std::fmt::Debug + Send,
216    {
217        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
218        let transaction = DatabaseTransaction::new_mysql(
219            conn,
220            self.metric_callback.clone(),
221            isolation_level,
222            access_mode,
223        )
224        .await
225        .map_err(|e| TransactionError::Connection(e))?;
226        transaction.run(callback).await
227    }
228
229    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
230    where
231        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
232    {
233        self.metric_callback = Some(Arc::new(callback));
234    }
235
236    /// Checks if a connection to the database is still valid.
237    pub async fn ping(&self) -> Result<(), DbErr> {
238        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
239        match conn.ping().await {
240            Ok(_) => Ok(()),
241            Err(err) => Err(sqlx_error_to_conn_err(err)),
242        }
243    }
244
245    /// Explicitly close the MySQL connection.
246    /// See [`Self::close_by_ref`] for usage with references.
247    pub async fn close(self) -> Result<(), DbErr> {
248        self.close_by_ref().await
249    }
250
251    /// Explicitly close the MySQL connection
252    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
253        self.pool.close().await;
254        Ok(())
255    }
256}
257
258impl From<MySqlRow> for QueryResult {
259    fn from(row: MySqlRow) -> QueryResult {
260        QueryResult {
261            row: QueryResultRow::SqlxMySql(row),
262        }
263    }
264}
265
266impl From<MySqlQueryResult> for ExecResult {
267    fn from(result: MySqlQueryResult) -> ExecResult {
268        ExecResult {
269            result: ExecResultHolder::SqlxMySql(result),
270        }
271    }
272}
273
274pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
275    let values = stmt
276        .values
277        .as_ref()
278        .map_or(Values(Vec::new()), |values| values.clone());
279    sqlx::query_with(&stmt.sql, SqlxValues(values))
280}
281
282pub(crate) async fn set_transaction_config(
283    conn: &mut PoolConnection<MySql>,
284    isolation_level: Option<IsolationLevel>,
285    access_mode: Option<AccessMode>,
286) -> Result<(), DbErr> {
287    let mut settings = Vec::new();
288
289    if let Some(isolation_level) = isolation_level {
290        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
291    }
292
293    if let Some(access_mode) = access_mode {
294        settings.push(access_mode.to_string());
295    }
296
297    if !settings.is_empty() {
298        let stmt = Statement {
299            sql: format!("SET TRANSACTION {}", settings.join(", ")),
300            values: None,
301            db_backend: DbBackend::MySql,
302        };
303        let query = sqlx_query(&stmt);
304        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
305    }
306    Ok(())
307}
308
309impl
310    From<(
311        PoolConnection<sqlx::MySql>,
312        Statement,
313        Option<crate::metric::Callback>,
314    )> for crate::QueryStream
315{
316    fn from(
317        (conn, stmt, metric_callback): (
318            PoolConnection<sqlx::MySql>,
319            Statement,
320            Option<crate::metric::Callback>,
321        ),
322    ) -> Self {
323        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
324    }
325}
326
327impl crate::DatabaseTransaction {
328    pub(crate) async fn new_mysql(
329        inner: PoolConnection<sqlx::MySql>,
330        metric_callback: Option<crate::metric::Callback>,
331        isolation_level: Option<IsolationLevel>,
332        access_mode: Option<AccessMode>,
333    ) -> Result<crate::DatabaseTransaction, DbErr> {
334        Self::begin(
335            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
336            crate::DbBackend::MySql,
337            metric_callback,
338            isolation_level,
339            access_mode,
340        )
341        .await
342    }
343}
344
345#[cfg(feature = "proxy")]
346pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
347    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
348    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
349    use sea_query::Value;
350    use sqlx::{Column, Row, TypeInfo};
351    crate::ProxyRow {
352        values: row
353            .columns()
354            .iter()
355            .map(|c| {
356                (
357                    c.name().to_string(),
358                    match c.type_info().name() {
359                        "TINYINT(1)" | "BOOLEAN" => {
360                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
361                        }
362                        "TINYINT UNSIGNED" => Value::TinyUnsigned(
363                            row.try_get(c.ordinal())
364                                .expect("Failed to get unsigned tiny integer"),
365                        ),
366                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(
367                            row.try_get(c.ordinal())
368                                .expect("Failed to get unsigned small integer"),
369                        ),
370                        "INT UNSIGNED" => Value::Unsigned(
371                            row.try_get(c.ordinal())
372                                .expect("Failed to get unsigned integer"),
373                        ),
374                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
375                            row.try_get(c.ordinal())
376                                .expect("Failed to get unsigned big integer"),
377                        ),
378                        "TINYINT" => Value::TinyInt(
379                            row.try_get(c.ordinal())
380                                .expect("Failed to get tiny integer"),
381                        ),
382                        "SMALLINT" => Value::SmallInt(
383                            row.try_get(c.ordinal())
384                                .expect("Failed to get small integer"),
385                        ),
386                        "INT" => {
387                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
388                        }
389                        "MEDIUMINT" | "BIGINT" => Value::BigInt(
390                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
391                        ),
392                        "FLOAT" => {
393                            Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
394                        }
395                        "DOUBLE" => {
396                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
397                        }
398
399                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
400                        | "LONGBLOB" => Value::Bytes(
401                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
402                                .expect("Failed to get bytes")
403                                .map(Box::new),
404                        ),
405
406                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
407                            Value::String(
408                                row.try_get::<Option<String>, _>(c.ordinal())
409                                    .expect("Failed to get string")
410                                    .map(Box::new),
411                            )
412                        }
413
414                        #[cfg(feature = "with-chrono")]
415                        "TIMESTAMP" => Value::ChronoDateTimeUtc(
416                            row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
417                                .expect("Failed to get timestamp")
418                                .map(Box::new),
419                        ),
420                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
421                        "TIMESTAMP" => Value::TimeDateTime(
422                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
423                                .expect("Failed to get timestamp")
424                                .map(Box::new),
425                        ),
426
427                        #[cfg(feature = "with-chrono")]
428                        "DATE" => Value::ChronoDate(
429                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
430                                .expect("Failed to get date")
431                                .map(Box::new),
432                        ),
433                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
434                        "DATE" => Value::TimeDate(
435                            row.try_get::<Option<time::Date>, _>(c.ordinal())
436                                .expect("Failed to get date")
437                                .map(Box::new),
438                        ),
439
440                        #[cfg(feature = "with-chrono")]
441                        "TIME" => Value::ChronoTime(
442                            row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
443                                .expect("Failed to get time")
444                                .map(Box::new),
445                        ),
446                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
447                        "TIME" => Value::TimeTime(
448                            row.try_get::<Option<time::Time>, _>(c.ordinal())
449                                .expect("Failed to get time")
450                                .map(Box::new),
451                        ),
452
453                        #[cfg(feature = "with-chrono")]
454                        "DATETIME" => Value::ChronoDateTime(
455                            row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
456                                .expect("Failed to get datetime")
457                                .map(Box::new),
458                        ),
459                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460                        "DATETIME" => Value::TimeDateTime(
461                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
462                                .expect("Failed to get datetime")
463                                .map(Box::new),
464                        ),
465
466                        #[cfg(feature = "with-chrono")]
467                        "YEAR" => Value::ChronoDate(
468                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
469                                .expect("Failed to get year")
470                                .map(Box::new),
471                        ),
472                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473                        "YEAR" => Value::TimeDate(
474                            row.try_get::<Option<time::Date>, _>(c.ordinal())
475                                .expect("Failed to get year")
476                                .map(Box::new),
477                        ),
478
479                        "ENUM" | "SET" | "GEOMETRY" => Value::String(
480                            row.try_get::<Option<String>, _>(c.ordinal())
481                                .expect("Failed to get serialized string")
482                                .map(Box::new),
483                        ),
484
485                        #[cfg(feature = "with-bigdecimal")]
486                        "DECIMAL" => Value::BigDecimal(
487                            row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
488                                .expect("Failed to get decimal")
489                                .map(Box::new),
490                        ),
491                        #[cfg(all(
492                            feature = "with-rust_decimal",
493                            not(feature = "with-bigdecimal")
494                        ))]
495                        "DECIMAL" => Value::Decimal(
496                            row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
497                                .expect("Failed to get decimal")
498                                .map(Box::new),
499                        ),
500
501                        #[cfg(feature = "with-json")]
502                        "JSON" => Value::Json(
503                            row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
504                                .expect("Failed to get json")
505                                .map(Box::new),
506                        ),
507
508                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
509                    },
510                )
511            })
512            .collect(),
513    }
514}