1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, MySql, MySqlPool,
8 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9 pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18 executor::*,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30 pub(crate) pool: MySqlPool,
31 metric_callback: Option<crate::metric::Callback>,
32 pub(crate) record_stmt_in_spans: bool,
33}
34
35impl std::fmt::Debug for SqlxMySqlPoolConnection {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
38 }
39}
40
41impl From<MySqlPool> for SqlxMySqlPoolConnection {
42 fn from(pool: MySqlPool) -> Self {
43 SqlxMySqlPoolConnection {
44 pool,
45 metric_callback: None,
46 record_stmt_in_spans: true,
47 }
48 }
49}
50
51impl From<MySqlPool> for DatabaseConnection {
52 fn from(pool: MySqlPool) -> Self {
53 DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
54 }
55}
56
57impl SqlxMySqlConnector {
58 pub fn accepts(string: &str) -> bool {
60 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
61 }
62
63 #[instrument(level = "trace")]
65 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
66 let record_stmt_in_spans = options.get_record_stmt_in_spans();
67 let mut sqlx_opts = options
68 .url
69 .parse::<MySqlConnectOptions>()
70 .map_err(sqlx_error_to_conn_err)?;
71 use sqlx::ConnectOptions;
72 if !options.sqlx_logging {
73 sqlx_opts = sqlx_opts.disable_statement_logging();
74 } else {
75 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
76 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
77 sqlx_opts = sqlx_opts.log_slow_statements(
78 options.sqlx_slow_statements_logging_level,
79 options.sqlx_slow_statements_logging_threshold,
80 );
81 }
82 }
83
84 if let Some(f) = &options.mysql_opts_fn {
85 sqlx_opts = f(sqlx_opts);
86 }
87 let after_connect = options.after_connect.clone();
88 let connect_lazy = options.connect_lazy;
89 let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
90 let mut pool_options = options.sqlx_pool_options();
91 if let Some(f) = &mysql_pool_opts_fn {
92 pool_options = f(pool_options);
93 }
94 let pool = if connect_lazy {
95 pool_options.connect_lazy_with(sqlx_opts)
96 } else {
97 pool_options
98 .connect_with(sqlx_opts)
99 .await
100 .map_err(sqlx_error_to_conn_err)?
101 };
102
103 let conn: DatabaseConnection =
104 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
105 pool,
106 metric_callback: None,
107 record_stmt_in_spans,
108 })
109 .into();
110
111 if let Some(cb) = after_connect {
112 cb(conn.clone()).await?;
113 }
114
115 Ok(conn)
116 }
117}
118
119impl SqlxMySqlConnector {
120 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
122 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
123 pool,
124 metric_callback: None,
125 record_stmt_in_spans: true,
126 })
127 .into()
128 }
129}
130
131impl SqlxMySqlPoolConnection {
132 #[instrument(level = "trace")]
134 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
135 debug_print!("{}", stmt);
136
137 let query = sqlx_query(&stmt);
138 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
139 crate::metric::metric!(self.metric_callback, &stmt, {
140 match query.execute(&mut *conn).await {
141 Ok(res) => Ok(res.into()),
142 Err(err) => Err(sqlx_error_to_exec_err(err)),
143 }
144 })
145 }
146
147 #[instrument(level = "trace")]
149 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
150 debug_print!("{}", sql);
151
152 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
153 match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
154 Ok(res) => Ok(res.into()),
155 Err(err) => Err(sqlx_error_to_exec_err(err)),
156 }
157 }
158
159 #[instrument(level = "trace")]
161 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
162 debug_print!("{}", stmt);
163
164 let query = sqlx_query(&stmt);
165 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
166 crate::metric::metric!(self.metric_callback, &stmt, {
167 match query.fetch_one(&mut *conn).await {
168 Ok(row) => Ok(Some(row.into())),
169 Err(err) => match err {
170 sqlx::Error::RowNotFound => Ok(None),
171 _ => Err(sqlx_error_to_query_err(err)),
172 },
173 }
174 })
175 }
176
177 #[instrument(level = "trace")]
179 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
180 debug_print!("{}", stmt);
181
182 let query = sqlx_query(&stmt);
183 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
184 crate::metric::metric!(self.metric_callback, &stmt, {
185 match query.fetch_all(&mut *conn).await {
186 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
187 Err(err) => Err(sqlx_error_to_query_err(err)),
188 }
189 })
190 }
191
192 #[instrument(level = "trace")]
194 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
195 debug_print!("{}", stmt);
196
197 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
198 Ok(QueryStream::from((
199 conn,
200 stmt,
201 self.metric_callback.clone(),
202 )))
203 }
204
205 #[instrument(level = "trace")]
207 pub async fn begin(
208 &self,
209 isolation_level: Option<IsolationLevel>,
210 access_mode: Option<AccessMode>,
211 ) -> Result<DatabaseTransaction, DbErr> {
212 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
213 DatabaseTransaction::new_mysql(
214 conn,
215 self.metric_callback.clone(),
216 self.record_stmt_in_spans,
217 isolation_level,
218 access_mode,
219 )
220 .await
221 }
222
223 #[instrument(level = "trace", skip(callback))]
225 pub async fn transaction<F, T, E>(
226 &self,
227 callback: F,
228 isolation_level: Option<IsolationLevel>,
229 access_mode: Option<AccessMode>,
230 ) -> Result<T, TransactionError<E>>
231 where
232 F: for<'b> FnOnce(
233 &'b DatabaseTransaction,
234 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
235 + Send,
236 T: Send,
237 E: std::fmt::Display + std::fmt::Debug + Send,
238 {
239 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
240 let transaction = DatabaseTransaction::new_mysql(
241 conn,
242 self.metric_callback.clone(),
243 self.record_stmt_in_spans,
244 isolation_level,
245 access_mode,
246 )
247 .await
248 .map_err(|e| TransactionError::Connection(e))?;
249 transaction.run(callback).await
250 }
251
252 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
253 where
254 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
255 {
256 self.metric_callback = Some(Arc::new(callback));
257 }
258
259 pub async fn ping(&self) -> Result<(), DbErr> {
261 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
262 match conn.ping().await {
263 Ok(_) => Ok(()),
264 Err(err) => Err(sqlx_error_to_conn_err(err)),
265 }
266 }
267
268 pub async fn close(self) -> Result<(), DbErr> {
271 self.close_by_ref().await
272 }
273
274 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
276 self.pool.close().await;
277 Ok(())
278 }
279}
280
281impl From<MySqlRow> for QueryResult {
282 fn from(row: MySqlRow) -> QueryResult {
283 QueryResult {
284 row: QueryResultRow::SqlxMySql(row),
285 }
286 }
287}
288
289impl From<MySqlQueryResult> for ExecResult {
290 fn from(result: MySqlQueryResult) -> ExecResult {
291 ExecResult {
292 result: ExecResultHolder::SqlxMySql(result),
293 }
294 }
295}
296
297pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
298 let values = stmt
299 .values
300 .as_ref()
301 .map_or(Values(Vec::new()), |values| values.clone());
302 sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
303}
304
305pub(crate) async fn set_transaction_config(
306 conn: &mut PoolConnection<MySql>,
307 isolation_level: Option<IsolationLevel>,
308 access_mode: Option<AccessMode>,
309) -> Result<(), DbErr> {
310 let mut settings = Vec::new();
311
312 if let Some(isolation_level) = isolation_level {
313 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
314 }
315
316 if let Some(access_mode) = access_mode {
317 settings.push(access_mode.to_string());
318 }
319
320 if !settings.is_empty() {
321 let stmt = Statement {
322 sql: format!("SET TRANSACTION {}", settings.join(", ")),
323 values: None,
324 db_backend: DbBackend::MySql,
325 };
326 let query = sqlx_query(&stmt);
327 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
328 }
329 Ok(())
330}
331
332impl
333 From<(
334 PoolConnection<sqlx::MySql>,
335 Statement,
336 Option<crate::metric::Callback>,
337 )> for crate::QueryStream
338{
339 fn from(
340 (conn, stmt, metric_callback): (
341 PoolConnection<sqlx::MySql>,
342 Statement,
343 Option<crate::metric::Callback>,
344 ),
345 ) -> Self {
346 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
347 }
348}
349
350impl crate::DatabaseTransaction {
351 pub(crate) async fn new_mysql(
352 inner: PoolConnection<sqlx::MySql>,
353 metric_callback: Option<crate::metric::Callback>,
354 record_stmt_in_spans: bool,
355 isolation_level: Option<IsolationLevel>,
356 access_mode: Option<AccessMode>,
357 ) -> Result<crate::DatabaseTransaction, DbErr> {
358 Self::begin(
359 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
360 crate::DbBackend::MySql,
361 metric_callback,
362 record_stmt_in_spans,
363 isolation_level,
364 access_mode,
365 None,
366 )
367 .await
368 }
369}
370
371#[cfg(feature = "proxy")]
372pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
373 use sea_query::Value;
376 use sqlx::{Column, Row, TypeInfo};
377 crate::ProxyRow {
378 values: row
379 .columns()
380 .iter()
381 .map(|c| {
382 (
383 c.name().to_string(),
384 match c.type_info().name() {
385 "TINYINT(1)" | "BOOLEAN" => {
386 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
387 }
388 "TINYINT UNSIGNED" => Value::TinyUnsigned(
389 row.try_get(c.ordinal())
390 .expect("Failed to get unsigned tiny integer"),
391 ),
392 "SMALLINT UNSIGNED" => Value::SmallUnsigned(
393 row.try_get(c.ordinal())
394 .expect("Failed to get unsigned small integer"),
395 ),
396 "INT UNSIGNED" => Value::Unsigned(
397 row.try_get(c.ordinal())
398 .expect("Failed to get unsigned integer"),
399 ),
400 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
401 row.try_get(c.ordinal())
402 .expect("Failed to get unsigned big integer"),
403 ),
404 "TINYINT" => Value::TinyInt(
405 row.try_get(c.ordinal())
406 .expect("Failed to get tiny integer"),
407 ),
408 "SMALLINT" => Value::SmallInt(
409 row.try_get(c.ordinal())
410 .expect("Failed to get small integer"),
411 ),
412 "INT" => {
413 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
414 }
415 "MEDIUMINT" | "BIGINT" => Value::BigInt(
416 row.try_get(c.ordinal()).expect("Failed to get big integer"),
417 ),
418 "FLOAT" => {
419 Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
420 }
421 "DOUBLE" => {
422 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
423 }
424
425 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
426 | "LONGBLOB" => Value::Bytes(
427 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
428 .expect("Failed to get bytes")
429 .map(Box::new),
430 ),
431
432 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
433 Value::String(
434 row.try_get::<Option<String>, _>(c.ordinal())
435 .expect("Failed to get string")
436 .map(Box::new),
437 )
438 }
439
440 #[cfg(feature = "with-chrono")]
441 "TIMESTAMP" => Value::ChronoDateTimeUtc(
442 row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
443 .expect("Failed to get timestamp")
444 .map(Box::new),
445 ),
446 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
447 "TIMESTAMP" => Value::TimeDateTime(
448 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
449 .expect("Failed to get timestamp")
450 .map(Box::new),
451 ),
452
453 #[cfg(feature = "with-chrono")]
454 "DATE" => Value::ChronoDate(
455 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
456 .expect("Failed to get date")
457 .map(Box::new),
458 ),
459 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460 "DATE" => Value::TimeDate(
461 row.try_get::<Option<time::Date>, _>(c.ordinal())
462 .expect("Failed to get date")
463 .map(Box::new),
464 ),
465
466 #[cfg(feature = "with-chrono")]
467 "TIME" => Value::ChronoTime(
468 row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
469 .expect("Failed to get time")
470 .map(Box::new),
471 ),
472 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473 "TIME" => Value::TimeTime(
474 row.try_get::<Option<time::Time>, _>(c.ordinal())
475 .expect("Failed to get time")
476 .map(Box::new),
477 ),
478
479 #[cfg(feature = "with-chrono")]
480 "DATETIME" => Value::ChronoDateTime(
481 row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
482 .expect("Failed to get datetime")
483 .map(Box::new),
484 ),
485 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
486 "DATETIME" => Value::TimeDateTime(
487 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
488 .expect("Failed to get datetime")
489 .map(Box::new),
490 ),
491
492 #[cfg(feature = "with-chrono")]
493 "YEAR" => Value::ChronoDate(
494 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
495 .expect("Failed to get year")
496 .map(Box::new),
497 ),
498 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
499 "YEAR" => Value::TimeDate(
500 row.try_get::<Option<time::Date>, _>(c.ordinal())
501 .expect("Failed to get year")
502 .map(Box::new),
503 ),
504
505 "ENUM" | "SET" | "GEOMETRY" => Value::String(
506 row.try_get::<Option<String>, _>(c.ordinal())
507 .expect("Failed to get serialized string")
508 .map(Box::new),
509 ),
510
511 #[cfg(feature = "with-bigdecimal")]
512 "DECIMAL" => Value::BigDecimal(
513 row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
514 .expect("Failed to get decimal")
515 .map(Box::new),
516 ),
517 #[cfg(all(
518 feature = "with-rust_decimal",
519 not(feature = "with-bigdecimal")
520 ))]
521 "DECIMAL" => Value::Decimal(
522 row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
523 .expect("Failed to get decimal")
524 .map(Box::new),
525 ),
526
527 #[cfg(feature = "with-json")]
528 "JSON" => Value::Json(
529 row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
530 .expect("Failed to get json")
531 .map(Box::new),
532 ),
533
534 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
535 },
536 )
537 })
538 .collect(),
539 }
540}