Skip to main content

sea_orm/driver/
sqlx_mysql.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, MySql, MySqlPool,
8    mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9    pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18    executor::*,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::mysql] connector
24#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27/// Defines a sqlx MySQL pool
28#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30    pub(crate) pool: MySqlPool,
31    metric_callback: Option<crate::metric::Callback>,
32    pub(crate) record_stmt_in_spans: bool,
33}
34
35impl std::fmt::Debug for SqlxMySqlPoolConnection {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
38    }
39}
40
41impl From<MySqlPool> for SqlxMySqlPoolConnection {
42    fn from(pool: MySqlPool) -> Self {
43        SqlxMySqlPoolConnection {
44            pool,
45            metric_callback: None,
46            record_stmt_in_spans: true,
47        }
48    }
49}
50
51impl From<MySqlPool> for DatabaseConnection {
52    fn from(pool: MySqlPool) -> Self {
53        DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
54    }
55}
56
57impl SqlxMySqlConnector {
58    /// Check if the URI provided corresponds to `mysql://` for a MySQL database
59    pub fn accepts(string: &str) -> bool {
60        string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
61    }
62
63    /// Add configuration options for the MySQL database
64    #[instrument(level = "trace")]
65    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
66        let record_stmt_in_spans = options.get_record_stmt_in_spans();
67        let mut sqlx_opts = options
68            .url
69            .parse::<MySqlConnectOptions>()
70            .map_err(sqlx_error_to_conn_err)?;
71        use sqlx::ConnectOptions;
72        if !options.sqlx_logging {
73            sqlx_opts = sqlx_opts.disable_statement_logging();
74        } else {
75            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
76            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
77                sqlx_opts = sqlx_opts.log_slow_statements(
78                    options.sqlx_slow_statements_logging_level,
79                    options.sqlx_slow_statements_logging_threshold,
80                );
81            }
82        }
83
84        if let Some(f) = &options.mysql_opts_fn {
85            sqlx_opts = f(sqlx_opts);
86        }
87        let after_connect = options.after_connect.clone();
88        let connect_lazy = options.connect_lazy;
89        let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
90        let mut pool_options = options.sqlx_pool_options();
91        if let Some(f) = &mysql_pool_opts_fn {
92            pool_options = f(pool_options);
93        }
94        let pool = if connect_lazy {
95            pool_options.connect_lazy_with(sqlx_opts)
96        } else {
97            pool_options
98                .connect_with(sqlx_opts)
99                .await
100                .map_err(sqlx_error_to_conn_err)?
101        };
102
103        let conn: DatabaseConnection =
104            DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
105                pool,
106                metric_callback: None,
107                record_stmt_in_spans,
108            })
109            .into();
110
111        if let Some(cb) = after_connect {
112            cb(conn.clone()).await?;
113        }
114
115        Ok(conn)
116    }
117}
118
119impl SqlxMySqlConnector {
120    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
121    pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
122        DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
123            pool,
124            metric_callback: None,
125            record_stmt_in_spans: true,
126        })
127        .into()
128    }
129}
130
131impl SqlxMySqlPoolConnection {
132    /// Execute a [Statement] on a MySQL backend
133    #[instrument(level = "trace")]
134    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
135        debug_print!("{}", stmt);
136
137        let query = sqlx_query(&stmt);
138        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
139        crate::metric::metric!(self.metric_callback, &stmt, {
140            match query.execute(&mut *conn).await {
141                Ok(res) => Ok(res.into()),
142                Err(err) => Err(sqlx_error_to_exec_err(err)),
143            }
144        })
145    }
146
147    /// Execute an unprepared SQL statement on a MySQL backend
148    #[instrument(level = "trace")]
149    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
150        debug_print!("{}", sql);
151
152        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
153        match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
154            Ok(res) => Ok(res.into()),
155            Err(err) => Err(sqlx_error_to_exec_err(err)),
156        }
157    }
158
159    /// Get one result from a SQL query. Returns [Option::None] if no match was found
160    #[instrument(level = "trace")]
161    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
162        debug_print!("{}", stmt);
163
164        let query = sqlx_query(&stmt);
165        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
166        crate::metric::metric!(self.metric_callback, &stmt, {
167            match query.fetch_one(&mut *conn).await {
168                Ok(row) => Ok(Some(row.into())),
169                Err(err) => match err {
170                    sqlx::Error::RowNotFound => Ok(None),
171                    _ => Err(sqlx_error_to_query_err(err)),
172                },
173            }
174        })
175    }
176
177    /// Get the results of a query returning them as a Vec<[QueryResult]>
178    #[instrument(level = "trace")]
179    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
180        debug_print!("{}", stmt);
181
182        let query = sqlx_query(&stmt);
183        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
184        crate::metric::metric!(self.metric_callback, &stmt, {
185            match query.fetch_all(&mut *conn).await {
186                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
187                Err(err) => Err(sqlx_error_to_query_err(err)),
188            }
189        })
190    }
191
192    /// Stream the results of executing a SQL query
193    #[instrument(level = "trace")]
194    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
195        debug_print!("{}", stmt);
196
197        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
198        Ok(QueryStream::from((
199            conn,
200            stmt,
201            self.metric_callback.clone(),
202        )))
203    }
204
205    /// Bundle a set of SQL statements that execute together.
206    #[instrument(level = "trace")]
207    pub async fn begin(
208        &self,
209        isolation_level: Option<IsolationLevel>,
210        access_mode: Option<AccessMode>,
211    ) -> Result<DatabaseTransaction, DbErr> {
212        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
213        DatabaseTransaction::new_mysql(
214            conn,
215            self.metric_callback.clone(),
216            self.record_stmt_in_spans,
217            isolation_level,
218            access_mode,
219        )
220        .await
221    }
222
223    /// Create a MySQL transaction
224    #[instrument(level = "trace", skip(callback))]
225    pub async fn transaction<F, T, E>(
226        &self,
227        callback: F,
228        isolation_level: Option<IsolationLevel>,
229        access_mode: Option<AccessMode>,
230    ) -> Result<T, TransactionError<E>>
231    where
232        F: for<'b> FnOnce(
233                &'b DatabaseTransaction,
234            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
235            + Send,
236        T: Send,
237        E: std::fmt::Display + std::fmt::Debug + Send,
238    {
239        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
240        let transaction = DatabaseTransaction::new_mysql(
241            conn,
242            self.metric_callback.clone(),
243            self.record_stmt_in_spans,
244            isolation_level,
245            access_mode,
246        )
247        .await
248        .map_err(|e| TransactionError::Connection(e))?;
249        transaction.run(callback).await
250    }
251
252    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
253    where
254        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
255    {
256        self.metric_callback = Some(Arc::new(callback));
257    }
258
259    /// Checks if a connection to the database is still valid.
260    pub async fn ping(&self) -> Result<(), DbErr> {
261        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
262        match conn.ping().await {
263            Ok(_) => Ok(()),
264            Err(err) => Err(sqlx_error_to_conn_err(err)),
265        }
266    }
267
268    /// Explicitly close the MySQL connection.
269    /// See [`Self::close_by_ref`] for usage with references.
270    pub async fn close(self) -> Result<(), DbErr> {
271        self.close_by_ref().await
272    }
273
274    /// Explicitly close the MySQL connection
275    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
276        self.pool.close().await;
277        Ok(())
278    }
279}
280
281impl From<MySqlRow> for QueryResult {
282    fn from(row: MySqlRow) -> QueryResult {
283        QueryResult {
284            row: QueryResultRow::SqlxMySql(row),
285        }
286    }
287}
288
289impl From<MySqlQueryResult> for ExecResult {
290    fn from(result: MySqlQueryResult) -> ExecResult {
291        ExecResult {
292            result: ExecResultHolder::SqlxMySql(result),
293        }
294    }
295}
296
297pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
298    let values = stmt
299        .values
300        .as_ref()
301        .map_or(Values(Vec::new()), |values| values.clone());
302    sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
303}
304
305pub(crate) async fn set_transaction_config(
306    conn: &mut PoolConnection<MySql>,
307    isolation_level: Option<IsolationLevel>,
308    access_mode: Option<AccessMode>,
309) -> Result<(), DbErr> {
310    let mut settings = Vec::new();
311
312    if let Some(isolation_level) = isolation_level {
313        settings.push(format!("ISOLATION LEVEL {isolation_level}"));
314    }
315
316    if let Some(access_mode) = access_mode {
317        settings.push(access_mode.to_string());
318    }
319
320    if !settings.is_empty() {
321        let stmt = Statement {
322            sql: format!("SET TRANSACTION {}", settings.join(", ")),
323            values: None,
324            db_backend: DbBackend::MySql,
325        };
326        let query = sqlx_query(&stmt);
327        conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
328    }
329    Ok(())
330}
331
332impl
333    From<(
334        PoolConnection<sqlx::MySql>,
335        Statement,
336        Option<crate::metric::Callback>,
337    )> for crate::QueryStream
338{
339    fn from(
340        (conn, stmt, metric_callback): (
341            PoolConnection<sqlx::MySql>,
342            Statement,
343            Option<crate::metric::Callback>,
344        ),
345    ) -> Self {
346        crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
347    }
348}
349
350impl crate::DatabaseTransaction {
351    pub(crate) async fn new_mysql(
352        inner: PoolConnection<sqlx::MySql>,
353        metric_callback: Option<crate::metric::Callback>,
354        record_stmt_in_spans: bool,
355        isolation_level: Option<IsolationLevel>,
356        access_mode: Option<AccessMode>,
357    ) -> Result<crate::DatabaseTransaction, DbErr> {
358        Self::begin(
359            Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
360            crate::DbBackend::MySql,
361            metric_callback,
362            record_stmt_in_spans,
363            isolation_level,
364            access_mode,
365            None,
366        )
367        .await
368    }
369}
370
371#[cfg(feature = "proxy")]
372pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
373    // https://docs.rs/sqlx-mysql/0.7.2/src/sqlx_mysql/protocol/text/column.rs.html
374    // https://docs.rs/sqlx-mysql/0.7.2/sqlx_mysql/types/index.html
375    use sea_query::Value;
376    use sqlx::{Column, Row, TypeInfo};
377    crate::ProxyRow {
378        values: row
379            .columns()
380            .iter()
381            .map(|c| {
382                (
383                    c.name().to_string(),
384                    match c.type_info().name() {
385                        "TINYINT(1)" | "BOOLEAN" => {
386                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
387                        }
388                        "TINYINT UNSIGNED" => Value::TinyUnsigned(
389                            row.try_get(c.ordinal())
390                                .expect("Failed to get unsigned tiny integer"),
391                        ),
392                        "SMALLINT UNSIGNED" => Value::SmallUnsigned(
393                            row.try_get(c.ordinal())
394                                .expect("Failed to get unsigned small integer"),
395                        ),
396                        "INT UNSIGNED" => Value::Unsigned(
397                            row.try_get(c.ordinal())
398                                .expect("Failed to get unsigned integer"),
399                        ),
400                        "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
401                            row.try_get(c.ordinal())
402                                .expect("Failed to get unsigned big integer"),
403                        ),
404                        "TINYINT" => Value::TinyInt(
405                            row.try_get(c.ordinal())
406                                .expect("Failed to get tiny integer"),
407                        ),
408                        "SMALLINT" => Value::SmallInt(
409                            row.try_get(c.ordinal())
410                                .expect("Failed to get small integer"),
411                        ),
412                        "INT" => {
413                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
414                        }
415                        "MEDIUMINT" | "BIGINT" => Value::BigInt(
416                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
417                        ),
418                        "FLOAT" => {
419                            Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
420                        }
421                        "DOUBLE" => {
422                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
423                        }
424
425                        "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
426                        | "LONGBLOB" => Value::Bytes(
427                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
428                                .expect("Failed to get bytes")
429                                .map(Box::new),
430                        ),
431
432                        "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
433                            Value::String(
434                                row.try_get::<Option<String>, _>(c.ordinal())
435                                    .expect("Failed to get string")
436                                    .map(Box::new),
437                            )
438                        }
439
440                        #[cfg(feature = "with-chrono")]
441                        "TIMESTAMP" => Value::ChronoDateTimeUtc(
442                            row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
443                                .expect("Failed to get timestamp")
444                                .map(Box::new),
445                        ),
446                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
447                        "TIMESTAMP" => Value::TimeDateTime(
448                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
449                                .expect("Failed to get timestamp")
450                                .map(Box::new),
451                        ),
452
453                        #[cfg(feature = "with-chrono")]
454                        "DATE" => Value::ChronoDate(
455                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
456                                .expect("Failed to get date")
457                                .map(Box::new),
458                        ),
459                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460                        "DATE" => Value::TimeDate(
461                            row.try_get::<Option<time::Date>, _>(c.ordinal())
462                                .expect("Failed to get date")
463                                .map(Box::new),
464                        ),
465
466                        #[cfg(feature = "with-chrono")]
467                        "TIME" => Value::ChronoTime(
468                            row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
469                                .expect("Failed to get time")
470                                .map(Box::new),
471                        ),
472                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473                        "TIME" => Value::TimeTime(
474                            row.try_get::<Option<time::Time>, _>(c.ordinal())
475                                .expect("Failed to get time")
476                                .map(Box::new),
477                        ),
478
479                        #[cfg(feature = "with-chrono")]
480                        "DATETIME" => Value::ChronoDateTime(
481                            row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
482                                .expect("Failed to get datetime")
483                                .map(Box::new),
484                        ),
485                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
486                        "DATETIME" => Value::TimeDateTime(
487                            row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
488                                .expect("Failed to get datetime")
489                                .map(Box::new),
490                        ),
491
492                        #[cfg(feature = "with-chrono")]
493                        "YEAR" => Value::ChronoDate(
494                            row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
495                                .expect("Failed to get year")
496                                .map(Box::new),
497                        ),
498                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
499                        "YEAR" => Value::TimeDate(
500                            row.try_get::<Option<time::Date>, _>(c.ordinal())
501                                .expect("Failed to get year")
502                                .map(Box::new),
503                        ),
504
505                        "ENUM" | "SET" | "GEOMETRY" => Value::String(
506                            row.try_get::<Option<String>, _>(c.ordinal())
507                                .expect("Failed to get serialized string")
508                                .map(Box::new),
509                        ),
510
511                        #[cfg(feature = "with-bigdecimal")]
512                        "DECIMAL" => Value::BigDecimal(
513                            row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
514                                .expect("Failed to get decimal")
515                                .map(Box::new),
516                        ),
517                        #[cfg(all(
518                            feature = "with-rust_decimal",
519                            not(feature = "with-bigdecimal")
520                        ))]
521                        "DECIMAL" => Value::Decimal(
522                            row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
523                                .expect("Failed to get decimal")
524                                .map(Box::new),
525                        ),
526
527                        #[cfg(feature = "with-json")]
528                        "JSON" => Value::Json(
529                            row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
530                                .expect("Failed to get json")
531                                .map(Box::new),
532                        ),
533
534                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
535                    },
536                )
537            })
538            .collect(),
539    }
540}