1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, MySql, MySqlPool,
8 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9 pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18 executor::*,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30 pub(crate) pool: MySqlPool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41 fn from(pool: MySqlPool) -> Self {
42 SqlxMySqlPoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50 fn from(pool: MySqlPool) -> Self {
51 DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxMySqlConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut sqlx_opts = options
65 .url
66 .parse::<MySqlConnectOptions>()
67 .map_err(sqlx_error_to_conn_err)?;
68 use sqlx::ConnectOptions;
69 if !options.sqlx_logging {
70 sqlx_opts = sqlx_opts.disable_statement_logging();
71 } else {
72 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74 sqlx_opts = sqlx_opts.log_slow_statements(
75 options.sqlx_slow_statements_logging_level,
76 options.sqlx_slow_statements_logging_threshold,
77 );
78 }
79 }
80
81 if let Some(f) = &options.mysql_opts_fn {
82 sqlx_opts = f(sqlx_opts);
83 }
84 let after_connect = options.after_connect.clone();
85 let connect_lazy = options.connect_lazy;
86 let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
87 let mut pool_options = options.sqlx_pool_options();
88 if let Some(f) = &mysql_pool_opts_fn {
89 pool_options = f(pool_options);
90 }
91 let pool = if connect_lazy {
92 pool_options.connect_lazy_with(sqlx_opts)
93 } else {
94 pool_options
95 .connect_with(sqlx_opts)
96 .await
97 .map_err(sqlx_error_to_conn_err)?
98 };
99
100 let conn: DatabaseConnection =
101 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
102 pool,
103 metric_callback: None,
104 })
105 .into();
106
107 if let Some(cb) = after_connect {
108 cb(conn.clone()).await?;
109 }
110
111 Ok(conn)
112 }
113}
114
115impl SqlxMySqlConnector {
116 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
118 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
119 pool,
120 metric_callback: None,
121 })
122 .into()
123 }
124}
125
126impl SqlxMySqlPoolConnection {
127 #[instrument(level = "trace")]
129 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
130 debug_print!("{}", stmt);
131
132 let query = sqlx_query(&stmt);
133 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
134 crate::metric::metric!(self.metric_callback, &stmt, {
135 match query.execute(&mut *conn).await {
136 Ok(res) => Ok(res.into()),
137 Err(err) => Err(sqlx_error_to_exec_err(err)),
138 }
139 })
140 }
141
142 #[instrument(level = "trace")]
144 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
145 debug_print!("{}", sql);
146
147 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
148 match conn.execute(sql).await {
149 Ok(res) => Ok(res.into()),
150 Err(err) => Err(sqlx_error_to_exec_err(err)),
151 }
152 }
153
154 #[instrument(level = "trace")]
156 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
157 debug_print!("{}", stmt);
158
159 let query = sqlx_query(&stmt);
160 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
161 crate::metric::metric!(self.metric_callback, &stmt, {
162 match query.fetch_one(&mut *conn).await {
163 Ok(row) => Ok(Some(row.into())),
164 Err(err) => match err {
165 sqlx::Error::RowNotFound => Ok(None),
166 _ => Err(sqlx_error_to_query_err(err)),
167 },
168 }
169 })
170 }
171
172 #[instrument(level = "trace")]
174 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
175 debug_print!("{}", stmt);
176
177 let query = sqlx_query(&stmt);
178 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
179 crate::metric::metric!(self.metric_callback, &stmt, {
180 match query.fetch_all(&mut *conn).await {
181 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
182 Err(err) => Err(sqlx_error_to_query_err(err)),
183 }
184 })
185 }
186
187 #[instrument(level = "trace")]
189 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
190 debug_print!("{}", stmt);
191
192 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
193 Ok(QueryStream::from((
194 conn,
195 stmt,
196 self.metric_callback.clone(),
197 )))
198 }
199
200 #[instrument(level = "trace")]
202 pub async fn begin(
203 &self,
204 isolation_level: Option<IsolationLevel>,
205 access_mode: Option<AccessMode>,
206 ) -> Result<DatabaseTransaction, DbErr> {
207 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
208 DatabaseTransaction::new_mysql(
209 conn,
210 self.metric_callback.clone(),
211 isolation_level,
212 access_mode,
213 )
214 .await
215 }
216
217 #[instrument(level = "trace", skip(callback))]
219 pub async fn transaction<F, T, E>(
220 &self,
221 callback: F,
222 isolation_level: Option<IsolationLevel>,
223 access_mode: Option<AccessMode>,
224 ) -> Result<T, TransactionError<E>>
225 where
226 F: for<'b> FnOnce(
227 &'b DatabaseTransaction,
228 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
229 + Send,
230 T: Send,
231 E: std::fmt::Display + std::fmt::Debug + Send,
232 {
233 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
234 let transaction = DatabaseTransaction::new_mysql(
235 conn,
236 self.metric_callback.clone(),
237 isolation_level,
238 access_mode,
239 )
240 .await
241 .map_err(|e| TransactionError::Connection(e))?;
242 transaction.run(callback).await
243 }
244
245 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
246 where
247 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
248 {
249 self.metric_callback = Some(Arc::new(callback));
250 }
251
252 pub async fn ping(&self) -> Result<(), DbErr> {
254 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
255 match conn.ping().await {
256 Ok(_) => Ok(()),
257 Err(err) => Err(sqlx_error_to_conn_err(err)),
258 }
259 }
260
261 pub async fn close(self) -> Result<(), DbErr> {
264 self.close_by_ref().await
265 }
266
267 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
269 self.pool.close().await;
270 Ok(())
271 }
272}
273
274impl From<MySqlRow> for QueryResult {
275 fn from(row: MySqlRow) -> QueryResult {
276 QueryResult {
277 row: QueryResultRow::SqlxMySql(row),
278 }
279 }
280}
281
282impl From<MySqlQueryResult> for ExecResult {
283 fn from(result: MySqlQueryResult) -> ExecResult {
284 ExecResult {
285 result: ExecResultHolder::SqlxMySql(result),
286 }
287 }
288}
289
290pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
291 let values = stmt
292 .values
293 .as_ref()
294 .map_or(Values(Vec::new()), |values| values.clone());
295 sqlx::query_with(&stmt.sql, SqlxValues(values))
296}
297
298pub(crate) async fn set_transaction_config(
299 conn: &mut PoolConnection<MySql>,
300 isolation_level: Option<IsolationLevel>,
301 access_mode: Option<AccessMode>,
302) -> Result<(), DbErr> {
303 let mut settings = Vec::new();
304
305 if let Some(isolation_level) = isolation_level {
306 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
307 }
308
309 if let Some(access_mode) = access_mode {
310 settings.push(access_mode.to_string());
311 }
312
313 if !settings.is_empty() {
314 let stmt = Statement {
315 sql: format!("SET TRANSACTION {}", settings.join(", ")),
316 values: None,
317 db_backend: DbBackend::MySql,
318 };
319 let query = sqlx_query(&stmt);
320 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
321 }
322 Ok(())
323}
324
325impl
326 From<(
327 PoolConnection<sqlx::MySql>,
328 Statement,
329 Option<crate::metric::Callback>,
330 )> for crate::QueryStream
331{
332 fn from(
333 (conn, stmt, metric_callback): (
334 PoolConnection<sqlx::MySql>,
335 Statement,
336 Option<crate::metric::Callback>,
337 ),
338 ) -> Self {
339 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
340 }
341}
342
343impl crate::DatabaseTransaction {
344 pub(crate) async fn new_mysql(
345 inner: PoolConnection<sqlx::MySql>,
346 metric_callback: Option<crate::metric::Callback>,
347 isolation_level: Option<IsolationLevel>,
348 access_mode: Option<AccessMode>,
349 ) -> Result<crate::DatabaseTransaction, DbErr> {
350 Self::begin(
351 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
352 crate::DbBackend::MySql,
353 metric_callback,
354 isolation_level,
355 access_mode,
356 None,
357 )
358 .await
359 }
360}
361
362#[cfg(feature = "proxy")]
363pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
364 use sea_query::Value;
367 use sqlx::{Column, Row, TypeInfo};
368 crate::ProxyRow {
369 values: row
370 .columns()
371 .iter()
372 .map(|c| {
373 (
374 c.name().to_string(),
375 match c.type_info().name() {
376 "TINYINT(1)" | "BOOLEAN" => {
377 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
378 }
379 "TINYINT UNSIGNED" => Value::TinyUnsigned(
380 row.try_get(c.ordinal())
381 .expect("Failed to get unsigned tiny integer"),
382 ),
383 "SMALLINT UNSIGNED" => Value::SmallUnsigned(
384 row.try_get(c.ordinal())
385 .expect("Failed to get unsigned small integer"),
386 ),
387 "INT UNSIGNED" => Value::Unsigned(
388 row.try_get(c.ordinal())
389 .expect("Failed to get unsigned integer"),
390 ),
391 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
392 row.try_get(c.ordinal())
393 .expect("Failed to get unsigned big integer"),
394 ),
395 "TINYINT" => Value::TinyInt(
396 row.try_get(c.ordinal())
397 .expect("Failed to get tiny integer"),
398 ),
399 "SMALLINT" => Value::SmallInt(
400 row.try_get(c.ordinal())
401 .expect("Failed to get small integer"),
402 ),
403 "INT" => {
404 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
405 }
406 "MEDIUMINT" | "BIGINT" => Value::BigInt(
407 row.try_get(c.ordinal()).expect("Failed to get big integer"),
408 ),
409 "FLOAT" => {
410 Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
411 }
412 "DOUBLE" => {
413 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
414 }
415
416 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
417 | "LONGBLOB" => Value::Bytes(
418 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
419 .expect("Failed to get bytes")
420 .map(Box::new),
421 ),
422
423 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
424 Value::String(
425 row.try_get::<Option<String>, _>(c.ordinal())
426 .expect("Failed to get string")
427 .map(Box::new),
428 )
429 }
430
431 #[cfg(feature = "with-chrono")]
432 "TIMESTAMP" => Value::ChronoDateTimeUtc(
433 row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
434 .expect("Failed to get timestamp")
435 .map(Box::new),
436 ),
437 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
438 "TIMESTAMP" => Value::TimeDateTime(
439 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
440 .expect("Failed to get timestamp")
441 .map(Box::new),
442 ),
443
444 #[cfg(feature = "with-chrono")]
445 "DATE" => Value::ChronoDate(
446 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
447 .expect("Failed to get date")
448 .map(Box::new),
449 ),
450 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
451 "DATE" => Value::TimeDate(
452 row.try_get::<Option<time::Date>, _>(c.ordinal())
453 .expect("Failed to get date")
454 .map(Box::new),
455 ),
456
457 #[cfg(feature = "with-chrono")]
458 "TIME" => Value::ChronoTime(
459 row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
460 .expect("Failed to get time")
461 .map(Box::new),
462 ),
463 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
464 "TIME" => Value::TimeTime(
465 row.try_get::<Option<time::Time>, _>(c.ordinal())
466 .expect("Failed to get time")
467 .map(Box::new),
468 ),
469
470 #[cfg(feature = "with-chrono")]
471 "DATETIME" => Value::ChronoDateTime(
472 row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
473 .expect("Failed to get datetime")
474 .map(Box::new),
475 ),
476 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
477 "DATETIME" => Value::TimeDateTime(
478 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
479 .expect("Failed to get datetime")
480 .map(Box::new),
481 ),
482
483 #[cfg(feature = "with-chrono")]
484 "YEAR" => Value::ChronoDate(
485 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
486 .expect("Failed to get year")
487 .map(Box::new),
488 ),
489 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
490 "YEAR" => Value::TimeDate(
491 row.try_get::<Option<time::Date>, _>(c.ordinal())
492 .expect("Failed to get year")
493 .map(Box::new),
494 ),
495
496 "ENUM" | "SET" | "GEOMETRY" => Value::String(
497 row.try_get::<Option<String>, _>(c.ordinal())
498 .expect("Failed to get serialized string")
499 .map(Box::new),
500 ),
501
502 #[cfg(feature = "with-bigdecimal")]
503 "DECIMAL" => Value::BigDecimal(
504 row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
505 .expect("Failed to get decimal")
506 .map(Box::new),
507 ),
508 #[cfg(all(
509 feature = "with-rust_decimal",
510 not(feature = "with-bigdecimal")
511 ))]
512 "DECIMAL" => Value::Decimal(
513 row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
514 .expect("Failed to get decimal")
515 .map(Box::new),
516 ),
517
518 #[cfg(feature = "with-json")]
519 "JSON" => Value::Json(
520 row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
521 .expect("Failed to get json")
522 .map(Box::new),
523 ),
524
525 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
526 },
527 )
528 })
529 .collect(),
530 }
531}