sea_orm/driver/
sqlx_mysql.rs

1use futures::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8    pool::PoolConnection,
9    Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17    DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22/// Defines the [sqlx::mysql] connector
23#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26/// Defines a sqlx MySQL pool
27#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29    pub(crate) pool: MySqlPool,
30    metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36    }
37}
38
39impl SqlxMySqlConnector {
40    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
41    pub fn accepts(string: &str) -> bool {
42        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
43    }
44
45    /// Add configuration options for the MySQL database
46    #[instrument(level = "trace")]
47    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
48        let mut opt = options
49            .url
50            .parse::<MySqlConnectOptions>()
51            .map_err(sqlx_error_to_conn_err)?;
52        use sqlx::ConnectOptions;
53        if !options.sqlx_logging {
54            opt = opt.disable_statement_logging();
55        } else {
56            opt = opt.log_statements(options.sqlx_logging_level);
57            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
58                opt = opt.log_slow_statements(
59                    options.sqlx_slow_statements_logging_level,
60                    options.sqlx_slow_statements_logging_threshold,
61                );
62            }
63        }
64        let pool = if options.connect_lazy {
65            options.sqlx_pool_options().connect_lazy_with(opt)
66        } else {
67            options
68                .sqlx_pool_options()
69                .connect_with(opt)
70                .await
71                .map_err(sqlx_error_to_conn_err)?
72        };
73        Ok(DatabaseConnection::SqlxMySqlPoolConnection(
74            SqlxMySqlPoolConnection {
75                pool,
76                metric_callback: None,
77            },
78        ))
79    }
80}
81
82impl SqlxMySqlConnector {
83    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
84    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
85        DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
86            pool,
87            metric_callback: None,
88        })
89    }
90}
91
92impl SqlxMySqlPoolConnection {
93    /// Execute a [Statement] on a MySQL backend
94    #[instrument(level = "trace")]
95    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
96        debug_print!("{}", stmt);
97
98        let query = sqlx_query(&stmt);
99        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
100        crate::metric::metric!(self.metric_callback, &stmt, {
101            match query.execute(&mut *conn).await {
102                Ok(res) => Ok(res.into()),
103                Err(err) => Err(sqlx_error_to_exec_err(err)),
104            }
105        })
106    }
107
108    /// Execute an unprepared SQL statement on a MySQL backend
109    #[instrument(level = "trace")]
110    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
111        debug_print!("{}", sql);
112
113        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
114        match conn.execute(sql).await {
115            Ok(res) => Ok(res.into()),
116            Err(err) => Err(sqlx_error_to_exec_err(err)),
117        }
118    }
119
120    /// Get one result from a SQL query. Returns [Option::None] if no match was found
121    #[instrument(level = "trace")]
122    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
123        debug_print!("{}", stmt);
124
125        let query = sqlx_query(&stmt);
126        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
127        crate::metric::metric!(self.metric_callback, &stmt, {
128            match query.fetch_one(&mut *conn).await {
129                Ok(row) => Ok(Some(row.into())),
130                Err(err) => match err {
131                    sqlx::Error::RowNotFound => Ok(None),
132                    _ => Err(sqlx_error_to_query_err(err)),
133                },
134            }
135        })
136    }
137
138    /// Get the results of a query returning them as a Vec<[QueryResult]>
139    #[instrument(level = "trace")]
140    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
141        debug_print!("{}", stmt);
142
143        let query = sqlx_query(&stmt);
144        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145        crate::metric::metric!(self.metric_callback, &stmt, {
146            match query.fetch_all(&mut *conn).await {
147                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
148                Err(err) => Err(sqlx_error_to_query_err(err)),
149            }
150        })
151    }
152
153    /// Stream the results of executing a SQL query
154    #[instrument(level = "trace")]
155    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
156        debug_print!("{}", stmt);
157
158        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
159        Ok(QueryStream::from((
160            conn,
161            stmt,
162            self.metric_callback.clone(),
163        )))
164    }
165
166    /// Bundle a set of SQL statements that execute together.
167    #[instrument(level = "trace")]
168    pub async fn begin(
169        &self,
170        isolation_level: Option<IsolationLevel>,
171        access_mode: Option<AccessMode>,
172    ) -> Result<DatabaseTransaction, DbErr> {
173        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174        DatabaseTransaction::new_mysql(
175            conn,
176            self.metric_callback.clone(),
177            isolation_level,
178            access_mode,
179        )
180        .await
181    }
182
183    /// Create a MySQL transaction
184    #[instrument(level = "trace", skip(callback))]
185    pub async fn transaction<F, T, E>(
186        &self,
187        callback: F,
188        isolation_level: Option<IsolationLevel>,
189        access_mode: Option<AccessMode>,
190    ) -> Result<T, TransactionError<E>>
191    where
192        F: for<'b> FnOnce(
193                &'b DatabaseTransaction,
194            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
195            + Send,
196        T: Send,
197        E: std::error::Error + Send,
198    {
199        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
200        let transaction = DatabaseTransaction::new_mysql(
201            conn,
202            self.metric_callback.clone(),
203            isolation_level,
204            access_mode,
205        )
206        .await
207        .map_err(|e| TransactionError::Connection(e))?;
208        transaction.run(callback).await
209    }
210
211    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
212    where
213        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
214    {
215        self.metric_callback = Some(Arc::new(callback));
216    }
217
218    /// Checks if a connection to the database is still valid.
219    pub async fn ping(&self) -> Result<(), DbErr> {
220        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
221        match conn.ping().await {
222            Ok(_) => Ok(()),
223            Err(err) => Err(sqlx_error_to_conn_err(err)),
224        }
225    }
226
227    /// Explicitly close the MySQL connection
228    pub async fn close(self) -> Result<(), DbErr> {
229        self.pool.close().await;
230        Ok(())
231    }
232}
233
234impl From<MySqlRow> for QueryResult {
235    fn from(row: MySqlRow) -> QueryResult {
236        QueryResult {
237            row: QueryResultRow::SqlxMySql(row),
238        }
239    }
240}
241
242impl From<MySqlQueryResult> for ExecResult {
243    fn from(result: MySqlQueryResult) -> ExecResult {
244        ExecResult {
245            result: ExecResultHolder::SqlxMySql(result),
246        }
247    }
248}
249
250pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
251    let values = stmt
252        .values
253        .as_ref()
254        .map_or(Values(Vec::new()), |values| values.clone());
255    sqlx::query_with(&stmt.sql, SqlxValues(values))
256}
257
258pub(crate) async fn set_transaction_config(
259    conn: &mut PoolConnection<MySql>,
260    isolation_level: Option<IsolationLevel>,
261    access_mode: Option<AccessMode>,
262) -> Result<(), DbErr> {
263    if let Some(isolation_level) = isolation_level {
264        let stmt = Statement {
265            sql: format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}"),
266            values: None,
267            db_backend: DbBackend::MySql,
268        };
269        let query = sqlx_query(&stmt);
270        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
271    }
272    if let Some(access_mode) = access_mode {
273        let stmt = Statement {
274            sql: format!("SET TRANSACTION {access_mode}"),
275            values: None,
276            db_backend: DbBackend::MySql,
277        };
278        let query = sqlx_query(&stmt);
279        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
280    }
281    Ok(())
282}
283
284impl
285    From<(
286        PoolConnection<sqlx::MySql>,
287        Statement,
288        Option<crate::metric::Callback>,
289    )> for crate::QueryStream
290{
291    fn from(
292        (conn, stmt, metric_callback): (
293            PoolConnection<sqlx::MySql>,
294            Statement,
295            Option<crate::metric::Callback>,
296        ),
297    ) -> Self {
298        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
299    }
300}
301
302impl crate::DatabaseTransaction {
303    pub(crate) async fn new_mysql(
304        inner: PoolConnection<sqlx::MySql>,
305        metric_callback: Option<crate::metric::Callback>,
306        isolation_level: Option<IsolationLevel>,
307        access_mode: Option<AccessMode>,
308    ) -> Result<crate::DatabaseTransaction, DbErr> {
309        Self::begin(
310            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
311            crate::DbBackend::MySql,
312            metric_callback,
313            isolation_level,
314            access_mode,
315        )
316        .await
317    }
318}
319
320#[cfg(feature = "proxy")]
321pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
322    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
323    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
324    use sea_query::Value;
325    use sqlx::{Column, Row, TypeInfo};
326    crate::ProxyRow {
327        values: row
328            .columns()
329            .iter()
330            .map(|c| {
331                (
332                    c.name().to_string(),
333                    match c.type_info().name() {
334                        "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
335                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
336                        )),
337                        "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
338                            row.try_get(c.ordinal())
339                                .expect("Failed to get unsigned tiny integer"),
340                        )),
341                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
342                            row.try_get(c.ordinal())
343                                .expect("Failed to get unsigned small integer"),
344                        )),
345                        "INT UNSIGNED" => Value::Unsigned(Some(
346                            row.try_get(c.ordinal())
347                                .expect("Failed to get unsigned integer"),
348                        )),
349                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
350                            row.try_get(c.ordinal())
351                                .expect("Failed to get unsigned big integer"),
352                        )),
353                        "TINYINT" => Value::TinyInt(Some(
354                            row.try_get(c.ordinal())
355                                .expect("Failed to get tiny integer"),
356                        )),
357                        "SMALLINT" => Value::SmallInt(Some(
358                            row.try_get(c.ordinal())
359                                .expect("Failed to get small integer"),
360                        )),
361                        "INT" => Value::Int(Some(
362                            row.try_get(c.ordinal()).expect("Failed to get integer"),
363                        )),
364                        "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
365                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
366                        )),
367                        "FLOAT" => Value::Float(Some(
368                            row.try_get(c.ordinal()).expect("Failed to get float"),
369                        )),
370                        "DOUBLE" => Value::Double(Some(
371                            row.try_get(c.ordinal()).expect("Failed to get double"),
372                        )),
373
374                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
375                        | "LONGBLOB" => Value::Bytes(Some(Box::new(
376                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
377                        ))),
378
379                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
380                            Value::String(Some(Box::new(
381                                row.try_get(c.ordinal()).expect("Failed to get string"),
382                            )))
383                        }
384
385                        #[cfg(feature = "with-chrono")]
386                        "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
387                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
388                        ))),
389                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
390                        "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
391                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
392                        ))),
393
394                        #[cfg(feature = "with-chrono")]
395                        "DATE" => Value::ChronoDate(Some(Box::new(
396                            row.try_get(c.ordinal()).expect("Failed to get date"),
397                        ))),
398                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
399                        "DATE" => Value::TimeDate(Some(Box::new(
400                            row.try_get(c.ordinal()).expect("Failed to get date"),
401                        ))),
402
403                        #[cfg(feature = "with-chrono")]
404                        "TIME" => Value::ChronoTime(Some(Box::new(
405                            row.try_get(c.ordinal()).expect("Failed to get time"),
406                        ))),
407                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
408                        "TIME" => Value::TimeTime(Some(Box::new(
409                            row.try_get(c.ordinal()).expect("Failed to get time"),
410                        ))),
411
412                        #[cfg(feature = "with-chrono")]
413                        "DATETIME" => Value::ChronoDateTime(Some(Box::new(
414                            row.try_get(c.ordinal()).expect("Failed to get datetime"),
415                        ))),
416                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
417                        "DATETIME" => Value::TimeDateTime(Some(Box::new(
418                            row.try_get(c.ordinal()).expect("Failed to get datetime"),
419                        ))),
420
421                        #[cfg(feature = "with-chrono")]
422                        "YEAR" => Value::ChronoDate(Some(Box::new(
423                            row.try_get(c.ordinal()).expect("Failed to get year"),
424                        ))),
425                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
426                        "YEAR" => Value::TimeDate(Some(Box::new(
427                            row.try_get(c.ordinal()).expect("Failed to get year"),
428                        ))),
429
430                        "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
431                            row.try_get(c.ordinal())
432                                .expect("Failed to get serialized string"),
433                        ))),
434
435                        #[cfg(feature = "with-bigdecimal")]
436                        "DECIMAL" => Value::BigDecimal(Some(Box::new(
437                            row.try_get(c.ordinal()).expect("Failed to get decimal"),
438                        ))),
439                        #[cfg(all(
440                            feature = "with-rust_decimal",
441                            not(feature = "with-bigdecimal")
442                        ))]
443                        "DECIMAL" => Value::Decimal(Some(Box::new(
444                            row.try_get(c.ordinal()).expect("Failed to get decimal"),
445                        ))),
446
447                        #[cfg(feature = "with-json")]
448                        "JSON" => Value::Json(Some(Box::new(
449                            row.try_get(c.ordinal()).expect("Failed to get json"),
450                        ))),
451
452                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
453                    },
454                )
455            })
456            .collect(),
457    }
458}