1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, MySql, MySqlPool,
8 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9 pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18 executor::*,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30 pub(crate) pool: MySqlPool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41 fn from(pool: MySqlPool) -> Self {
42 SqlxMySqlPoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50 fn from(pool: MySqlPool) -> Self {
51 DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxMySqlConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut sqlx_opts = options
65 .url
66 .parse::<MySqlConnectOptions>()
67 .map_err(sqlx_error_to_conn_err)?;
68 use sqlx::ConnectOptions;
69 if !options.sqlx_logging {
70 sqlx_opts = sqlx_opts.disable_statement_logging();
71 } else {
72 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74 sqlx_opts = sqlx_opts.log_slow_statements(
75 options.sqlx_slow_statements_logging_level,
76 options.sqlx_slow_statements_logging_threshold,
77 );
78 }
79 }
80
81 if let Some(f) = &options.mysql_opts_fn {
82 sqlx_opts = f(sqlx_opts);
83 }
84
85 let after_connect = options.after_connect.clone();
86
87 let pool = if options.connect_lazy {
88 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
89 } else {
90 options
91 .sqlx_pool_options()
92 .connect_with(sqlx_opts)
93 .await
94 .map_err(sqlx_error_to_conn_err)?
95 };
96
97 let conn: DatabaseConnection =
98 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
99 pool,
100 metric_callback: None,
101 })
102 .into();
103
104 if let Some(cb) = after_connect {
105 cb(conn.clone()).await?;
106 }
107
108 Ok(conn)
109 }
110}
111
112impl SqlxMySqlConnector {
113 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
115 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
116 pool,
117 metric_callback: None,
118 })
119 .into()
120 }
121}
122
123impl SqlxMySqlPoolConnection {
124 #[instrument(level = "trace")]
126 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
127 debug_print!("{}", stmt);
128
129 let query = sqlx_query(&stmt);
130 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
131 crate::metric::metric!(self.metric_callback, &stmt, {
132 match query.execute(&mut *conn).await {
133 Ok(res) => Ok(res.into()),
134 Err(err) => Err(sqlx_error_to_exec_err(err)),
135 }
136 })
137 }
138
139 #[instrument(level = "trace")]
141 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
142 debug_print!("{}", sql);
143
144 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145 match conn.execute(sql).await {
146 Ok(res) => Ok(res.into()),
147 Err(err) => Err(sqlx_error_to_exec_err(err)),
148 }
149 }
150
151 #[instrument(level = "trace")]
153 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
154 debug_print!("{}", stmt);
155
156 let query = sqlx_query(&stmt);
157 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
158 crate::metric::metric!(self.metric_callback, &stmt, {
159 match query.fetch_one(&mut *conn).await {
160 Ok(row) => Ok(Some(row.into())),
161 Err(err) => match err {
162 sqlx::Error::RowNotFound => Ok(None),
163 _ => Err(sqlx_error_to_query_err(err)),
164 },
165 }
166 })
167 }
168
169 #[instrument(level = "trace")]
171 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
172 debug_print!("{}", stmt);
173
174 let query = sqlx_query(&stmt);
175 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
176 crate::metric::metric!(self.metric_callback, &stmt, {
177 match query.fetch_all(&mut *conn).await {
178 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
179 Err(err) => Err(sqlx_error_to_query_err(err)),
180 }
181 })
182 }
183
184 #[instrument(level = "trace")]
186 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
187 debug_print!("{}", stmt);
188
189 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
190 Ok(QueryStream::from((
191 conn,
192 stmt,
193 self.metric_callback.clone(),
194 )))
195 }
196
197 #[instrument(level = "trace")]
199 pub async fn begin(
200 &self,
201 isolation_level: Option<IsolationLevel>,
202 access_mode: Option<AccessMode>,
203 ) -> Result<DatabaseTransaction, DbErr> {
204 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205 DatabaseTransaction::new_mysql(
206 conn,
207 self.metric_callback.clone(),
208 isolation_level,
209 access_mode,
210 )
211 .await
212 }
213
214 #[instrument(level = "trace", skip(callback))]
216 pub async fn transaction<F, T, E>(
217 &self,
218 callback: F,
219 isolation_level: Option<IsolationLevel>,
220 access_mode: Option<AccessMode>,
221 ) -> Result<T, TransactionError<E>>
222 where
223 F: for<'b> FnOnce(
224 &'b DatabaseTransaction,
225 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
226 + Send,
227 T: Send,
228 E: std::fmt::Display + std::fmt::Debug + Send,
229 {
230 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
231 let transaction = DatabaseTransaction::new_mysql(
232 conn,
233 self.metric_callback.clone(),
234 isolation_level,
235 access_mode,
236 )
237 .await
238 .map_err(|e| TransactionError::Connection(e))?;
239 transaction.run(callback).await
240 }
241
242 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
243 where
244 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
245 {
246 self.metric_callback = Some(Arc::new(callback));
247 }
248
249 pub async fn ping(&self) -> Result<(), DbErr> {
251 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
252 match conn.ping().await {
253 Ok(_) => Ok(()),
254 Err(err) => Err(sqlx_error_to_conn_err(err)),
255 }
256 }
257
258 pub async fn close(self) -> Result<(), DbErr> {
261 self.close_by_ref().await
262 }
263
264 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
266 self.pool.close().await;
267 Ok(())
268 }
269}
270
271impl From<MySqlRow> for QueryResult {
272 fn from(row: MySqlRow) -> QueryResult {
273 QueryResult {
274 row: QueryResultRow::SqlxMySql(row),
275 }
276 }
277}
278
279impl From<MySqlQueryResult> for ExecResult {
280 fn from(result: MySqlQueryResult) -> ExecResult {
281 ExecResult {
282 result: ExecResultHolder::SqlxMySql(result),
283 }
284 }
285}
286
287pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
288 let values = stmt
289 .values
290 .as_ref()
291 .map_or(Values(Vec::new()), |values| values.clone());
292 sqlx::query_with(&stmt.sql, SqlxValues(values))
293}
294
295pub(crate) async fn set_transaction_config(
296 conn: &mut PoolConnection<MySql>,
297 isolation_level: Option<IsolationLevel>,
298 access_mode: Option<AccessMode>,
299) -> Result<(), DbErr> {
300 let mut settings = Vec::new();
301
302 if let Some(isolation_level) = isolation_level {
303 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
304 }
305
306 if let Some(access_mode) = access_mode {
307 settings.push(access_mode.to_string());
308 }
309
310 if !settings.is_empty() {
311 let stmt = Statement {
312 sql: format!("SET TRANSACTION {}", settings.join(", ")),
313 values: None,
314 db_backend: DbBackend::MySql,
315 };
316 let query = sqlx_query(&stmt);
317 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
318 }
319 Ok(())
320}
321
322impl
323 From<(
324 PoolConnection<sqlx::MySql>,
325 Statement,
326 Option<crate::metric::Callback>,
327 )> for crate::QueryStream
328{
329 fn from(
330 (conn, stmt, metric_callback): (
331 PoolConnection<sqlx::MySql>,
332 Statement,
333 Option<crate::metric::Callback>,
334 ),
335 ) -> Self {
336 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
337 }
338}
339
340impl crate::DatabaseTransaction {
341 pub(crate) async fn new_mysql(
342 inner: PoolConnection<sqlx::MySql>,
343 metric_callback: Option<crate::metric::Callback>,
344 isolation_level: Option<IsolationLevel>,
345 access_mode: Option<AccessMode>,
346 ) -> Result<crate::DatabaseTransaction, DbErr> {
347 Self::begin(
348 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
349 crate::DbBackend::MySql,
350 metric_callback,
351 isolation_level,
352 access_mode,
353 )
354 .await
355 }
356}
357
358#[cfg(feature = "proxy")]
359pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
360 use sea_query::Value;
363 use sqlx::{Column, Row, TypeInfo};
364 crate::ProxyRow {
365 values: row
366 .columns()
367 .iter()
368 .map(|c| {
369 (
370 c.name().to_string(),
371 match c.type_info().name() {
372 "TINYINT(1)" | "BOOLEAN" => {
373 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
374 }
375 "TINYINT UNSIGNED" => Value::TinyUnsigned(
376 row.try_get(c.ordinal())
377 .expect("Failed to get unsigned tiny integer"),
378 ),
379 "SMALLINT UNSIGNED" => Value::SmallUnsigned(
380 row.try_get(c.ordinal())
381 .expect("Failed to get unsigned small integer"),
382 ),
383 "INT UNSIGNED" => Value::Unsigned(
384 row.try_get(c.ordinal())
385 .expect("Failed to get unsigned integer"),
386 ),
387 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
388 row.try_get(c.ordinal())
389 .expect("Failed to get unsigned big integer"),
390 ),
391 "TINYINT" => Value::TinyInt(
392 row.try_get(c.ordinal())
393 .expect("Failed to get tiny integer"),
394 ),
395 "SMALLINT" => Value::SmallInt(
396 row.try_get(c.ordinal())
397 .expect("Failed to get small integer"),
398 ),
399 "INT" => {
400 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
401 }
402 "MEDIUMINT" | "BIGINT" => Value::BigInt(
403 row.try_get(c.ordinal()).expect("Failed to get big integer"),
404 ),
405 "FLOAT" => {
406 Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
407 }
408 "DOUBLE" => {
409 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
410 }
411
412 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
413 | "LONGBLOB" => Value::Bytes(
414 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
415 .expect("Failed to get bytes")
416 .map(Box::new),
417 ),
418
419 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
420 Value::String(
421 row.try_get::<Option<String>, _>(c.ordinal())
422 .expect("Failed to get string")
423 .map(Box::new),
424 )
425 }
426
427 #[cfg(feature = "with-chrono")]
428 "TIMESTAMP" => Value::ChronoDateTimeUtc(
429 row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
430 .expect("Failed to get timestamp")
431 .map(Box::new),
432 ),
433 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
434 "TIMESTAMP" => Value::TimeDateTime(
435 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
436 .expect("Failed to get timestamp")
437 .map(Box::new),
438 ),
439
440 #[cfg(feature = "with-chrono")]
441 "DATE" => Value::ChronoDate(
442 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
443 .expect("Failed to get date")
444 .map(Box::new),
445 ),
446 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
447 "DATE" => Value::TimeDate(
448 row.try_get::<Option<time::Date>, _>(c.ordinal())
449 .expect("Failed to get date")
450 .map(Box::new),
451 ),
452
453 #[cfg(feature = "with-chrono")]
454 "TIME" => Value::ChronoTime(
455 row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
456 .expect("Failed to get time")
457 .map(Box::new),
458 ),
459 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460 "TIME" => Value::TimeTime(
461 row.try_get::<Option<time::Time>, _>(c.ordinal())
462 .expect("Failed to get time")
463 .map(Box::new),
464 ),
465
466 #[cfg(feature = "with-chrono")]
467 "DATETIME" => Value::ChronoDateTime(
468 row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
469 .expect("Failed to get datetime")
470 .map(Box::new),
471 ),
472 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473 "DATETIME" => Value::TimeDateTime(
474 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
475 .expect("Failed to get datetime")
476 .map(Box::new),
477 ),
478
479 #[cfg(feature = "with-chrono")]
480 "YEAR" => Value::ChronoDate(
481 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
482 .expect("Failed to get year")
483 .map(Box::new),
484 ),
485 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
486 "YEAR" => Value::TimeDate(
487 row.try_get::<Option<time::Date>, _>(c.ordinal())
488 .expect("Failed to get year")
489 .map(Box::new),
490 ),
491
492 "ENUM" | "SET" | "GEOMETRY" => Value::String(
493 row.try_get::<Option<String>, _>(c.ordinal())
494 .expect("Failed to get serialized string")
495 .map(Box::new),
496 ),
497
498 #[cfg(feature = "with-bigdecimal")]
499 "DECIMAL" => Value::BigDecimal(
500 row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
501 .expect("Failed to get decimal")
502 .map(Box::new),
503 ),
504 #[cfg(all(
505 feature = "with-rust_decimal",
506 not(feature = "with-bigdecimal")
507 ))]
508 "DECIMAL" => Value::Decimal(
509 row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
510 .expect("Failed to get decimal")
511 .map(Box::new),
512 ),
513
514 #[cfg(feature = "with-json")]
515 "JSON" => Value::Json(
516 row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
517 .expect("Failed to get json")
518 .map(Box::new),
519 ),
520
521 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
522 },
523 )
524 })
525 .collect(),
526 }
527}