1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, MySql, MySqlPool,
8 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9 pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 DbBackend, IsolationLevel, Statement, TransactionError, debug_print, error::*, executor::*,
18};
19
20use super::sqlx_common::*;
21
22#[cfg(feature = "stream")]
23use crate::QueryStream;
24
25#[derive(Debug)]
27pub struct SqlxMySqlConnector;
28
29#[derive(Clone)]
31pub struct SqlxMySqlPoolConnection {
32 pub(crate) pool: MySqlPool,
33 metric_callback: Option<crate::metric::Callback>,
34 pub(crate) record_stmt_in_spans: bool,
35}
36
37impl std::fmt::Debug for SqlxMySqlPoolConnection {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
40 }
41}
42
43impl From<MySqlPool> for SqlxMySqlPoolConnection {
44 fn from(pool: MySqlPool) -> Self {
45 SqlxMySqlPoolConnection {
46 pool,
47 metric_callback: None,
48 record_stmt_in_spans: true,
49 }
50 }
51}
52
53impl From<MySqlPool> for DatabaseConnection {
54 fn from(pool: MySqlPool) -> Self {
55 DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
56 }
57}
58
59impl SqlxMySqlConnector {
60 pub fn accepts(string: &str) -> bool {
62 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
63 }
64
65 #[instrument(level = "trace")]
67 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
68 let record_stmt_in_spans = options.get_record_stmt_in_spans();
69 let mut sqlx_opts = options
70 .url
71 .parse::<MySqlConnectOptions>()
72 .map_err(sqlx_error_to_conn_err)?;
73 use sqlx::ConnectOptions;
74 if !options.sqlx_logging {
75 sqlx_opts = sqlx_opts.disable_statement_logging();
76 } else {
77 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
78 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
79 sqlx_opts = sqlx_opts.log_slow_statements(
80 options.sqlx_slow_statements_logging_level,
81 options.sqlx_slow_statements_logging_threshold,
82 );
83 }
84 }
85
86 if let Some(f) = &options.mysql_opts_fn {
87 sqlx_opts = f(sqlx_opts);
88 }
89 let after_connect = options.after_connect.clone();
90 let connect_lazy = options.connect_lazy;
91 let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
92 let mut pool_options = options.sqlx_pool_options();
93 if let Some(f) = &mysql_pool_opts_fn {
94 pool_options = f(pool_options);
95 }
96 let pool = if connect_lazy {
97 pool_options.connect_lazy_with(sqlx_opts)
98 } else {
99 pool_options
100 .connect_with(sqlx_opts)
101 .await
102 .map_err(sqlx_error_to_conn_err)?
103 };
104
105 let conn: DatabaseConnection =
106 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
107 pool,
108 metric_callback: None,
109 record_stmt_in_spans,
110 })
111 .into();
112
113 if let Some(cb) = after_connect {
114 cb(conn.clone()).await?;
115 }
116
117 Ok(conn)
118 }
119}
120
121impl SqlxMySqlConnector {
122 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
124 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
125 pool,
126 metric_callback: None,
127 record_stmt_in_spans: true,
128 })
129 .into()
130 }
131}
132
133impl SqlxMySqlPoolConnection {
134 #[instrument(level = "trace", skip(stmt))]
136 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
137 debug_print!("{}", stmt);
138
139 let query = sqlx_query(&stmt);
140 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
141 crate::metric::metric!(self.metric_callback, &stmt, {
142 match query.execute(&mut *conn).await {
143 Ok(res) => Ok(res.into()),
144 Err(err) => Err(sqlx_error_to_exec_err(err)),
145 }
146 })
147 }
148
149 #[instrument(level = "trace", skip(sql))]
151 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
152 debug_print!("{}", sql);
153
154 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
155 match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
156 Ok(res) => Ok(res.into()),
157 Err(err) => Err(sqlx_error_to_exec_err(err)),
158 }
159 }
160
161 #[instrument(level = "trace", skip(stmt))]
163 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
164 debug_print!("{}", stmt);
165
166 let query = sqlx_query(&stmt);
167 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
168 crate::metric::metric!(self.metric_callback, &stmt, {
169 match query.fetch_one(&mut *conn).await {
170 Ok(row) => Ok(Some(row.into())),
171 Err(err) => match err {
172 sqlx::Error::RowNotFound => Ok(None),
173 _ => Err(sqlx_error_to_query_err(err)),
174 },
175 }
176 })
177 }
178
179 #[instrument(level = "trace", skip(stmt))]
181 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
182 debug_print!("{}", stmt);
183
184 let query = sqlx_query(&stmt);
185 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
186 crate::metric::metric!(self.metric_callback, &stmt, {
187 match query.fetch_all(&mut *conn).await {
188 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
189 Err(err) => Err(sqlx_error_to_query_err(err)),
190 }
191 })
192 }
193
194 #[instrument(level = "trace", skip(stmt))]
196 #[cfg(feature = "stream")]
197 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
198 debug_print!("{}", stmt);
199
200 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
201 Ok(QueryStream::from((
202 conn,
203 stmt,
204 self.metric_callback.clone(),
205 )))
206 }
207
208 #[instrument(level = "trace")]
210 pub async fn begin(
211 &self,
212 isolation_level: Option<IsolationLevel>,
213 access_mode: Option<AccessMode>,
214 ) -> Result<DatabaseTransaction, DbErr> {
215 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
216 DatabaseTransaction::new_mysql(
217 conn,
218 self.metric_callback.clone(),
219 self.record_stmt_in_spans,
220 isolation_level,
221 access_mode,
222 )
223 .await
224 }
225
226 #[instrument(level = "trace", skip(callback))]
228 pub async fn transaction<F, T, E>(
229 &self,
230 callback: F,
231 isolation_level: Option<IsolationLevel>,
232 access_mode: Option<AccessMode>,
233 ) -> Result<T, TransactionError<E>>
234 where
235 F: for<'b> FnOnce(
236 &'b DatabaseTransaction,
237 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
238 + Send,
239 T: Send,
240 E: std::fmt::Display + std::fmt::Debug + Send,
241 {
242 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
243 let transaction = DatabaseTransaction::new_mysql(
244 conn,
245 self.metric_callback.clone(),
246 self.record_stmt_in_spans,
247 isolation_level,
248 access_mode,
249 )
250 .await
251 .map_err(|e| TransactionError::Connection(e))?;
252 transaction.run(callback).await
253 }
254
255 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
256 where
257 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
258 {
259 self.metric_callback = Some(Arc::new(callback));
260 }
261
262 pub async fn ping(&self) -> Result<(), DbErr> {
264 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
265 match conn.ping().await {
266 Ok(_) => Ok(()),
267 Err(err) => Err(sqlx_error_to_conn_err(err)),
268 }
269 }
270
271 pub async fn close(self) -> Result<(), DbErr> {
274 self.close_by_ref().await
275 }
276
277 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
279 self.pool.close().await;
280 Ok(())
281 }
282}
283
284impl From<MySqlRow> for QueryResult {
285 fn from(row: MySqlRow) -> QueryResult {
286 QueryResult {
287 row: QueryResultRow::SqlxMySql(row),
288 }
289 }
290}
291
292impl From<MySqlQueryResult> for ExecResult {
293 fn from(result: MySqlQueryResult) -> ExecResult {
294 ExecResult {
295 result: ExecResultHolder::SqlxMySql(result),
296 }
297 }
298}
299
300pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
301 let values = stmt
302 .values
303 .as_ref()
304 .map_or(Values(Vec::new()), |values| values.clone());
305 sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
306}
307
308pub(crate) async fn set_transaction_config(
309 conn: &mut PoolConnection<MySql>,
310 isolation_level: Option<IsolationLevel>,
311 access_mode: Option<AccessMode>,
312) -> Result<(), DbErr> {
313 let mut settings = Vec::new();
314
315 if let Some(isolation_level) = isolation_level {
316 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
317 }
318
319 if let Some(access_mode) = access_mode {
320 settings.push(access_mode.to_string());
321 }
322
323 if !settings.is_empty() {
324 let stmt = Statement {
325 sql: format!("SET TRANSACTION {}", settings.join(", ")),
326 values: None,
327 db_backend: DbBackend::MySql,
328 };
329 let query = sqlx_query(&stmt);
330 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
331 }
332 Ok(())
333}
334
335#[cfg(feature = "stream")]
336impl
337 From<(
338 PoolConnection<sqlx::MySql>,
339 Statement,
340 Option<crate::metric::Callback>,
341 )> for crate::QueryStream
342{
343 fn from(
344 (conn, stmt, metric_callback): (
345 PoolConnection<sqlx::MySql>,
346 Statement,
347 Option<crate::metric::Callback>,
348 ),
349 ) -> Self {
350 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
351 }
352}
353
354impl crate::DatabaseTransaction {
355 pub(crate) async fn new_mysql(
356 inner: PoolConnection<sqlx::MySql>,
357 metric_callback: Option<crate::metric::Callback>,
358 record_stmt_in_spans: bool,
359 isolation_level: Option<IsolationLevel>,
360 access_mode: Option<AccessMode>,
361 ) -> Result<crate::DatabaseTransaction, DbErr> {
362 Self::begin(
363 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
364 crate::DbBackend::MySql,
365 metric_callback,
366 record_stmt_in_spans,
367 isolation_level,
368 access_mode,
369 None,
370 )
371 .await
372 }
373}
374
375#[cfg(feature = "proxy")]
376pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
377 use sea_query::Value;
380 use sqlx::{Column, Row, TypeInfo};
381 crate::ProxyRow {
382 values: row
383 .columns()
384 .iter()
385 .map(|c| {
386 (
387 c.name().to_string(),
388 match c.type_info().name() {
389 "TINYINT(1)" | "BOOLEAN" => {
390 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
391 }
392 "TINYINT UNSIGNED" => Value::TinyUnsigned(
393 row.try_get(c.ordinal())
394 .expect("Failed to get unsigned tiny integer"),
395 ),
396 "SMALLINT UNSIGNED" => Value::SmallUnsigned(
397 row.try_get(c.ordinal())
398 .expect("Failed to get unsigned small integer"),
399 ),
400 "INT UNSIGNED" => Value::Unsigned(
401 row.try_get(c.ordinal())
402 .expect("Failed to get unsigned integer"),
403 ),
404 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
405 row.try_get(c.ordinal())
406 .expect("Failed to get unsigned big integer"),
407 ),
408 "TINYINT" => Value::TinyInt(
409 row.try_get(c.ordinal())
410 .expect("Failed to get tiny integer"),
411 ),
412 "SMALLINT" => Value::SmallInt(
413 row.try_get(c.ordinal())
414 .expect("Failed to get small integer"),
415 ),
416 "INT" => {
417 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
418 }
419 "MEDIUMINT" | "BIGINT" => Value::BigInt(
420 row.try_get(c.ordinal()).expect("Failed to get big integer"),
421 ),
422 "FLOAT" => {
423 Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
424 }
425 "DOUBLE" => {
426 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
427 }
428
429 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
430 | "LONGBLOB" => Value::Bytes(
431 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
432 .expect("Failed to get bytes")
433 .map(Box::new),
434 ),
435
436 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
437 Value::String(
438 row.try_get::<Option<String>, _>(c.ordinal())
439 .expect("Failed to get string")
440 .map(Box::new),
441 )
442 }
443
444 #[cfg(feature = "with-chrono")]
445 "TIMESTAMP" => Value::ChronoDateTimeUtc(
446 row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
447 .expect("Failed to get timestamp")
448 .map(Box::new),
449 ),
450 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
451 "TIMESTAMP" => Value::TimeDateTime(
452 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
453 .expect("Failed to get timestamp")
454 .map(Box::new),
455 ),
456
457 #[cfg(feature = "with-chrono")]
458 "DATE" => Value::ChronoDate(
459 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
460 .expect("Failed to get date")
461 .map(Box::new),
462 ),
463 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
464 "DATE" => Value::TimeDate(
465 row.try_get::<Option<time::Date>, _>(c.ordinal())
466 .expect("Failed to get date")
467 .map(Box::new),
468 ),
469
470 #[cfg(feature = "with-chrono")]
471 "TIME" => Value::ChronoTime(
472 row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
473 .expect("Failed to get time")
474 .map(Box::new),
475 ),
476 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
477 "TIME" => Value::TimeTime(
478 row.try_get::<Option<time::Time>, _>(c.ordinal())
479 .expect("Failed to get time")
480 .map(Box::new),
481 ),
482
483 #[cfg(feature = "with-chrono")]
484 "DATETIME" => Value::ChronoDateTime(
485 row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
486 .expect("Failed to get datetime")
487 .map(Box::new),
488 ),
489 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
490 "DATETIME" => Value::TimeDateTime(
491 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
492 .expect("Failed to get datetime")
493 .map(Box::new),
494 ),
495
496 #[cfg(feature = "with-chrono")]
497 "YEAR" => Value::ChronoDate(
498 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
499 .expect("Failed to get year")
500 .map(Box::new),
501 ),
502 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
503 "YEAR" => Value::TimeDate(
504 row.try_get::<Option<time::Date>, _>(c.ordinal())
505 .expect("Failed to get year")
506 .map(Box::new),
507 ),
508
509 "ENUM" | "SET" | "GEOMETRY" => Value::String(
510 row.try_get::<Option<String>, _>(c.ordinal())
511 .expect("Failed to get serialized string")
512 .map(Box::new),
513 ),
514
515 #[cfg(feature = "with-bigdecimal")]
516 "DECIMAL" => Value::BigDecimal(
517 row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
518 .expect("Failed to get decimal")
519 .map(Box::new),
520 ),
521 #[cfg(all(
522 feature = "with-rust_decimal",
523 not(feature = "with-bigdecimal")
524 ))]
525 "DECIMAL" => Value::Decimal(
526 row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
527 .expect("Failed to get decimal")
528 .map(Box::new),
529 ),
530
531 #[cfg(feature = "with-json")]
532 "JSON" => Value::Json(
533 row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
534 .expect("Failed to get json")
535 .map(Box::new),
536 ),
537
538 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
539 },
540 )
541 })
542 .collect(),
543 }
544}