sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8    pool::PoolConnection,
9    Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17    DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22/// Defines the [sqlx::mysql] connector
23#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26/// Defines a sqlx MySQL pool
27#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29    pub(crate) pool: MySqlPool,
30    metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36    }
37}
38
39impl From<MySqlPool> for SqlxMySqlPoolConnection {
40    fn from(pool: MySqlPool) -> Self {
41        SqlxMySqlPoolConnection {
42            pool,
43            metric_callback: None,
44        }
45    }
46}
47
48impl From<MySqlPool> for DatabaseConnection {
49    fn from(pool: MySqlPool) -> Self {
50        DatabaseConnection::SqlxMySqlPoolConnection(pool.into())
51    }
52}
53
54impl SqlxMySqlConnector {
55    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
56    pub fn accepts(string: &str) -> bool {
57        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
58    }
59
60    /// Add configuration options for the MySQL database
61    #[instrument(level = "trace")]
62    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
63        let mut opt = options
64            .url
65            .parse::<MySqlConnectOptions>()
66            .map_err(sqlx_error_to_conn_err)?;
67        use sqlx::ConnectOptions;
68        if !options.sqlx_logging {
69            opt = opt.disable_statement_logging();
70        } else {
71            opt = opt.log_statements(options.sqlx_logging_level);
72            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
73                opt = opt.log_slow_statements(
74                    options.sqlx_slow_statements_logging_level,
75                    options.sqlx_slow_statements_logging_threshold,
76                );
77            }
78        }
79        let pool = if options.connect_lazy {
80            options.sqlx_pool_options().connect_lazy_with(opt)
81        } else {
82            options
83                .sqlx_pool_options()
84                .connect_with(opt)
85                .await
86                .map_err(sqlx_error_to_conn_err)?
87        };
88        Ok(DatabaseConnection::SqlxMySqlPoolConnection(
89            SqlxMySqlPoolConnection {
90                pool,
91                metric_callback: None,
92            },
93        ))
94    }
95}
96
97impl SqlxMySqlConnector {
98    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
99    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
100        DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
101            pool,
102            metric_callback: None,
103        })
104    }
105}
106
107impl SqlxMySqlPoolConnection {
108    /// Execute a [Statement] on a MySQL backend
109    #[instrument(level = "trace")]
110    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
111        debug_print!("{}", stmt);
112
113        let query = sqlx_query(&stmt);
114        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
115        crate::metric::metric!(self.metric_callback, &stmt, {
116            match query.execute(&mut *conn).await {
117                Ok(res) => Ok(res.into()),
118                Err(err) => Err(sqlx_error_to_exec_err(err)),
119            }
120        })
121    }
122
123    /// Execute an unprepared SQL statement on a MySQL backend
124    #[instrument(level = "trace")]
125    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
126        debug_print!("{}", sql);
127
128        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
129        match conn.execute(sql).await {
130            Ok(res) => Ok(res.into()),
131            Err(err) => Err(sqlx_error_to_exec_err(err)),
132        }
133    }
134
135    /// Get one result from a SQL query. Returns [Option::None] if no match was found
136    #[instrument(level = "trace")]
137    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
138        debug_print!("{}", stmt);
139
140        let query = sqlx_query(&stmt);
141        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
142        crate::metric::metric!(self.metric_callback, &stmt, {
143            match query.fetch_one(&mut *conn).await {
144                Ok(row) => Ok(Some(row.into())),
145                Err(err) => match err {
146                    sqlx::Error::RowNotFound => Ok(None),
147                    _ => Err(sqlx_error_to_query_err(err)),
148                },
149            }
150        })
151    }
152
153    /// Get the results of a query returning them as a Vec<[QueryResult]>
154    #[instrument(level = "trace")]
155    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
156        debug_print!("{}", stmt);
157
158        let query = sqlx_query(&stmt);
159        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160        crate::metric::metric!(self.metric_callback, &stmt, {
161            match query.fetch_all(&mut *conn).await {
162                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
163                Err(err) => Err(sqlx_error_to_query_err(err)),
164            }
165        })
166    }
167
168    /// Stream the results of executing a SQL query
169    #[instrument(level = "trace")]
170    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
171        debug_print!("{}", stmt);
172
173        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174        Ok(QueryStream::from((
175            conn,
176            stmt,
177            self.metric_callback.clone(),
178        )))
179    }
180
181    /// Bundle a set of SQL statements that execute together.
182    #[instrument(level = "trace")]
183    pub async fn begin(
184        &self,
185        isolation_level: Option<IsolationLevel>,
186        access_mode: Option<AccessMode>,
187    ) -> Result<DatabaseTransaction, DbErr> {
188        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
189        DatabaseTransaction::new_mysql(
190            conn,
191            self.metric_callback.clone(),
192            isolation_level,
193            access_mode,
194        )
195        .await
196    }
197
198    /// Create a MySQL transaction
199    #[instrument(level = "trace", skip(callback))]
200    pub async fn transaction<F, T, E>(
201        &self,
202        callback: F,
203        isolation_level: Option<IsolationLevel>,
204        access_mode: Option<AccessMode>,
205    ) -> Result<T, TransactionError<E>>
206    where
207        F: for<'b> FnOnce(
208                &'b DatabaseTransaction,
209            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
210            + Send,
211        T: Send,
212        E: std::error::Error + Send,
213    {
214        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
215        let transaction = DatabaseTransaction::new_mysql(
216            conn,
217            self.metric_callback.clone(),
218            isolation_level,
219            access_mode,
220        )
221        .await
222        .map_err(|e| TransactionError::Connection(e))?;
223        transaction.run(callback).await
224    }
225
226    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
227    where
228        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
229    {
230        self.metric_callback = Some(Arc::new(callback));
231    }
232
233    /// Checks if a connection to the database is still valid.
234    pub async fn ping(&self) -> Result<(), DbErr> {
235        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236        match conn.ping().await {
237            Ok(_) => Ok(()),
238            Err(err) => Err(sqlx_error_to_conn_err(err)),
239        }
240    }
241
242    /// Explicitly close the MySQL connection.
243    /// See [`Self::close_by_ref`] for usage with references.
244    pub async fn close(self) -> Result<(), DbErr> {
245        self.close_by_ref().await
246    }
247
248    /// Explicitly close the MySQL connection
249    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
250        self.pool.close().await;
251        Ok(())
252    }
253}
254
255impl From<MySqlRow> for QueryResult {
256    fn from(row: MySqlRow) -> QueryResult {
257        QueryResult {
258            row: QueryResultRow::SqlxMySql(row),
259        }
260    }
261}
262
263impl From<MySqlQueryResult> for ExecResult {
264    fn from(result: MySqlQueryResult) -> ExecResult {
265        ExecResult {
266            result: ExecResultHolder::SqlxMySql(result),
267        }
268    }
269}
270
271pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
272    let values = stmt
273        .values
274        .as_ref()
275        .map_or(Values(Vec::new()), |values| values.clone());
276    sqlx::query_with(&stmt.sql, SqlxValues(values))
277}
278
279pub(crate) async fn set_transaction_config(
280    conn: &mut PoolConnection<MySql>,
281    isolation_level: Option<IsolationLevel>,
282    access_mode: Option<AccessMode>,
283) -> Result<(), DbErr> {
284    let mut settings = Vec::new();
285
286    if let Some(isolation_level) = isolation_level {
287        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
288    }
289
290    if let Some(access_mode) = access_mode {
291        settings.push(access_mode.to_string());
292    }
293
294    if !settings.is_empty() {
295        let stmt = Statement {
296            sql: format!("SET TRANSACTION {}", settings.join(", ")),
297            values: None,
298            db_backend: DbBackend::MySql,
299        };
300        let query = sqlx_query(&stmt);
301        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
302    }
303    Ok(())
304}
305
306impl
307    From<(
308        PoolConnection<sqlx::MySql>,
309        Statement,
310        Option<crate::metric::Callback>,
311    )> for crate::QueryStream
312{
313    fn from(
314        (conn, stmt, metric_callback): (
315            PoolConnection<sqlx::MySql>,
316            Statement,
317            Option<crate::metric::Callback>,
318        ),
319    ) -> Self {
320        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
321    }
322}
323
324impl crate::DatabaseTransaction {
325    pub(crate) async fn new_mysql(
326        inner: PoolConnection<sqlx::MySql>,
327        metric_callback: Option<crate::metric::Callback>,
328        isolation_level: Option<IsolationLevel>,
329        access_mode: Option<AccessMode>,
330    ) -> Result<crate::DatabaseTransaction, DbErr> {
331        Self::begin(
332            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
333            crate::DbBackend::MySql,
334            metric_callback,
335            isolation_level,
336            access_mode,
337        )
338        .await
339    }
340}
341
342#[cfg(feature = "proxy")]
343pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
344    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
345    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
346    use sea_query::Value;
347    use sqlx::{Column, Row, TypeInfo};
348    crate::ProxyRow {
349        values: row
350            .columns()
351            .iter()
352            .map(|c| {
353                (
354                    c.name().to_string(),
355                    match c.type_info().name() {
356                        "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
357                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
358                        )),
359                        "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
360                            row.try_get(c.ordinal())
361                                .expect("Failed to get unsigned tiny integer"),
362                        )),
363                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
364                            row.try_get(c.ordinal())
365                                .expect("Failed to get unsigned small integer"),
366                        )),
367                        "INT UNSIGNED" => Value::Unsigned(Some(
368                            row.try_get(c.ordinal())
369                                .expect("Failed to get unsigned integer"),
370                        )),
371                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
372                            row.try_get(c.ordinal())
373                                .expect("Failed to get unsigned big integer"),
374                        )),
375                        "TINYINT" => Value::TinyInt(Some(
376                            row.try_get(c.ordinal())
377                                .expect("Failed to get tiny integer"),
378                        )),
379                        "SMALLINT" => Value::SmallInt(Some(
380                            row.try_get(c.ordinal())
381                                .expect("Failed to get small integer"),
382                        )),
383                        "INT" => Value::Int(Some(
384                            row.try_get(c.ordinal()).expect("Failed to get integer"),
385                        )),
386                        "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
387                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
388                        )),
389                        "FLOAT" => Value::Float(Some(
390                            row.try_get(c.ordinal()).expect("Failed to get float"),
391                        )),
392                        "DOUBLE" => Value::Double(Some(
393                            row.try_get(c.ordinal()).expect("Failed to get double"),
394                        )),
395
396                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
397                        | "LONGBLOB" => Value::Bytes(Some(Box::new(
398                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
399                        ))),
400
401                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
402                            Value::String(Some(Box::new(
403                                row.try_get(c.ordinal()).expect("Failed to get string"),
404                            )))
405                        }
406
407                        #[cfg(feature = "with-chrono")]
408                        "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
409                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
410                        ))),
411                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
412                        "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
413                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
414                        ))),
415
416                        #[cfg(feature = "with-chrono")]
417                        "DATE" => Value::ChronoDate(Some(Box::new(
418                            row.try_get(c.ordinal()).expect("Failed to get date"),
419                        ))),
420                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
421                        "DATE" => Value::TimeDate(Some(Box::new(
422                            row.try_get(c.ordinal()).expect("Failed to get date"),
423                        ))),
424
425                        #[cfg(feature = "with-chrono")]
426                        "TIME" => Value::ChronoTime(Some(Box::new(
427                            row.try_get(c.ordinal()).expect("Failed to get time"),
428                        ))),
429                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
430                        "TIME" => Value::TimeTime(Some(Box::new(
431                            row.try_get(c.ordinal()).expect("Failed to get time"),
432                        ))),
433
434                        #[cfg(feature = "with-chrono")]
435                        "DATETIME" => Value::ChronoDateTime(Some(Box::new(
436                            row.try_get(c.ordinal()).expect("Failed to get datetime"),
437                        ))),
438                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
439                        "DATETIME" => Value::TimeDateTime(Some(Box::new(
440                            row.try_get(c.ordinal()).expect("Failed to get datetime"),
441                        ))),
442
443                        #[cfg(feature = "with-chrono")]
444                        "YEAR" => Value::ChronoDate(Some(Box::new(
445                            row.try_get(c.ordinal()).expect("Failed to get year"),
446                        ))),
447                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
448                        "YEAR" => Value::TimeDate(Some(Box::new(
449                            row.try_get(c.ordinal()).expect("Failed to get year"),
450                        ))),
451
452                        "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
453                            row.try_get(c.ordinal())
454                                .expect("Failed to get serialized string"),
455                        ))),
456
457                        #[cfg(feature = "with-bigdecimal")]
458                        "DECIMAL" => Value::BigDecimal(Some(Box::new(
459                            row.try_get(c.ordinal()).expect("Failed to get decimal"),
460                        ))),
461                        #[cfg(all(
462                            feature = "with-rust_decimal",
463                            not(feature = "with-bigdecimal")
464                        ))]
465                        "DECIMAL" => Value::Decimal(Some(Box::new(
466                            row.try_get(c.ordinal()).expect("Failed to get decimal"),
467                        ))),
468
469                        #[cfg(feature = "with-json")]
470                        "JSON" => Value::Json(Some(Box::new(
471                            row.try_get(c.ordinal()).expect("Failed to get json"),
472                        ))),
473
474                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
475                    },
476                )
477            })
478            .collect(),
479    }
480}