1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8 pool::PoolConnection,
9 Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17 DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29 pub(crate) pool: MySqlPool,
30 metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36 }
37}
38
39impl From<MySqlPool> for SqlxMySqlPoolConnection {
40 fn from(pool: MySqlPool) -> Self {
41 SqlxMySqlPoolConnection {
42 pool,
43 metric_callback: None,
44 }
45 }
46}
47
48impl From<MySqlPool> for DatabaseConnection {
49 fn from(pool: MySqlPool) -> Self {
50 DatabaseConnection::SqlxMySqlPoolConnection(pool.into())
51 }
52}
53
54impl SqlxMySqlConnector {
55 pub fn accepts(string: &str) -> bool {
57 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
58 }
59
60 #[instrument(level = "trace")]
62 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
63 let mut sqlx_opts = options
64 .url
65 .parse::<MySqlConnectOptions>()
66 .map_err(sqlx_error_to_conn_err)?;
67 use sqlx::ConnectOptions;
68 if !options.sqlx_logging {
69 sqlx_opts = sqlx_opts.disable_statement_logging();
70 } else {
71 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
72 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
73 sqlx_opts = sqlx_opts.log_slow_statements(
74 options.sqlx_slow_statements_logging_level,
75 options.sqlx_slow_statements_logging_threshold,
76 );
77 }
78 }
79 if let Some(f) = &options.mysql_opts_fn {
80 sqlx_opts = f(sqlx_opts);
81 }
82 let pool = if options.connect_lazy {
83 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
84 } else {
85 options
86 .sqlx_pool_options()
87 .connect_with(sqlx_opts)
88 .await
89 .map_err(sqlx_error_to_conn_err)?
90 };
91 Ok(DatabaseConnection::SqlxMySqlPoolConnection(
92 SqlxMySqlPoolConnection {
93 pool,
94 metric_callback: None,
95 },
96 ))
97 }
98}
99
100impl SqlxMySqlConnector {
101 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
103 DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
104 pool,
105 metric_callback: None,
106 })
107 }
108}
109
110impl SqlxMySqlPoolConnection {
111 #[instrument(level = "trace")]
113 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
114 debug_print!("{}", stmt);
115
116 let query = sqlx_query(&stmt);
117 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
118 crate::metric::metric!(self.metric_callback, &stmt, {
119 match query.execute(&mut *conn).await {
120 Ok(res) => Ok(res.into()),
121 Err(err) => Err(sqlx_error_to_exec_err(err)),
122 }
123 })
124 }
125
126 #[instrument(level = "trace")]
128 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
129 debug_print!("{}", sql);
130
131 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
132 match conn.execute(sql).await {
133 Ok(res) => Ok(res.into()),
134 Err(err) => Err(sqlx_error_to_exec_err(err)),
135 }
136 }
137
138 #[instrument(level = "trace")]
140 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
141 debug_print!("{}", stmt);
142
143 let query = sqlx_query(&stmt);
144 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145 crate::metric::metric!(self.metric_callback, &stmt, {
146 match query.fetch_one(&mut *conn).await {
147 Ok(row) => Ok(Some(row.into())),
148 Err(err) => match err {
149 sqlx::Error::RowNotFound => Ok(None),
150 _ => Err(sqlx_error_to_query_err(err)),
151 },
152 }
153 })
154 }
155
156 #[instrument(level = "trace")]
158 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
159 debug_print!("{}", stmt);
160
161 let query = sqlx_query(&stmt);
162 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
163 crate::metric::metric!(self.metric_callback, &stmt, {
164 match query.fetch_all(&mut *conn).await {
165 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
166 Err(err) => Err(sqlx_error_to_query_err(err)),
167 }
168 })
169 }
170
171 #[instrument(level = "trace")]
173 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
174 debug_print!("{}", stmt);
175
176 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
177 Ok(QueryStream::from((
178 conn,
179 stmt,
180 self.metric_callback.clone(),
181 )))
182 }
183
184 #[instrument(level = "trace")]
186 pub async fn begin(
187 &self,
188 isolation_level: Option<IsolationLevel>,
189 access_mode: Option<AccessMode>,
190 ) -> Result<DatabaseTransaction, DbErr> {
191 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
192 DatabaseTransaction::new_mysql(
193 conn,
194 self.metric_callback.clone(),
195 isolation_level,
196 access_mode,
197 )
198 .await
199 }
200
201 #[instrument(level = "trace", skip(callback))]
203 pub async fn transaction<F, T, E>(
204 &self,
205 callback: F,
206 isolation_level: Option<IsolationLevel>,
207 access_mode: Option<AccessMode>,
208 ) -> Result<T, TransactionError<E>>
209 where
210 F: for<'b> FnOnce(
211 &'b DatabaseTransaction,
212 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
213 + Send,
214 T: Send,
215 E: std::fmt::Display + std::fmt::Debug + Send,
216 {
217 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
218 let transaction = DatabaseTransaction::new_mysql(
219 conn,
220 self.metric_callback.clone(),
221 isolation_level,
222 access_mode,
223 )
224 .await
225 .map_err(|e| TransactionError::Connection(e))?;
226 transaction.run(callback).await
227 }
228
229 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
230 where
231 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
232 {
233 self.metric_callback = Some(Arc::new(callback));
234 }
235
236 pub async fn ping(&self) -> Result<(), DbErr> {
238 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
239 match conn.ping().await {
240 Ok(_) => Ok(()),
241 Err(err) => Err(sqlx_error_to_conn_err(err)),
242 }
243 }
244
245 pub async fn close(self) -> Result<(), DbErr> {
248 self.close_by_ref().await
249 }
250
251 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
253 self.pool.close().await;
254 Ok(())
255 }
256}
257
258impl From<MySqlRow> for QueryResult {
259 fn from(row: MySqlRow) -> QueryResult {
260 QueryResult {
261 row: QueryResultRow::SqlxMySql(row),
262 }
263 }
264}
265
266impl From<MySqlQueryResult> for ExecResult {
267 fn from(result: MySqlQueryResult) -> ExecResult {
268 ExecResult {
269 result: ExecResultHolder::SqlxMySql(result),
270 }
271 }
272}
273
274pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
275 let values = stmt
276 .values
277 .as_ref()
278 .map_or(Values(Vec::new()), |values| values.clone());
279 sqlx::query_with(&stmt.sql, SqlxValues(values))
280}
281
282pub(crate) async fn set_transaction_config(
283 conn: &mut PoolConnection<MySql>,
284 isolation_level: Option<IsolationLevel>,
285 access_mode: Option<AccessMode>,
286) -> Result<(), DbErr> {
287 let mut settings = Vec::new();
288
289 if let Some(isolation_level) = isolation_level {
290 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
291 }
292
293 if let Some(access_mode) = access_mode {
294 settings.push(access_mode.to_string());
295 }
296
297 if !settings.is_empty() {
298 let stmt = Statement {
299 sql: format!("SET TRANSACTION {}", settings.join(", ")),
300 values: None,
301 db_backend: DbBackend::MySql,
302 };
303 let query = sqlx_query(&stmt);
304 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
305 }
306 Ok(())
307}
308
309impl
310 From<(
311 PoolConnection<sqlx::MySql>,
312 Statement,
313 Option<crate::metric::Callback>,
314 )> for crate::QueryStream
315{
316 fn from(
317 (conn, stmt, metric_callback): (
318 PoolConnection<sqlx::MySql>,
319 Statement,
320 Option<crate::metric::Callback>,
321 ),
322 ) -> Self {
323 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
324 }
325}
326
327impl crate::DatabaseTransaction {
328 pub(crate) async fn new_mysql(
329 inner: PoolConnection<sqlx::MySql>,
330 metric_callback: Option<crate::metric::Callback>,
331 isolation_level: Option<IsolationLevel>,
332 access_mode: Option<AccessMode>,
333 ) -> Result<crate::DatabaseTransaction, DbErr> {
334 Self::begin(
335 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
336 crate::DbBackend::MySql,
337 metric_callback,
338 isolation_level,
339 access_mode,
340 )
341 .await
342 }
343}
344
345#[cfg(feature = "proxy")]
346pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
347 use sea_query::Value;
350 use sqlx::{Column, Row, TypeInfo};
351 crate::ProxyRow {
352 values: row
353 .columns()
354 .iter()
355 .map(|c| {
356 (
357 c.name().to_string(),
358 match c.type_info().name() {
359 "TINYINT(1)" | "BOOLEAN" => {
360 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
361 }
362 "TINYINT UNSIGNED" => Value::TinyUnsigned(
363 row.try_get(c.ordinal())
364 .expect("Failed to get unsigned tiny integer"),
365 ),
366 "SMALLINT UNSIGNED" => Value::SmallUnsigned(
367 row.try_get(c.ordinal())
368 .expect("Failed to get unsigned small integer"),
369 ),
370 "INT UNSIGNED" => Value::Unsigned(
371 row.try_get(c.ordinal())
372 .expect("Failed to get unsigned integer"),
373 ),
374 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(
375 row.try_get(c.ordinal())
376 .expect("Failed to get unsigned big integer"),
377 ),
378 "TINYINT" => Value::TinyInt(
379 row.try_get(c.ordinal())
380 .expect("Failed to get tiny integer"),
381 ),
382 "SMALLINT" => Value::SmallInt(
383 row.try_get(c.ordinal())
384 .expect("Failed to get small integer"),
385 ),
386 "INT" => {
387 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
388 }
389 "MEDIUMINT" | "BIGINT" => Value::BigInt(
390 row.try_get(c.ordinal()).expect("Failed to get big integer"),
391 ),
392 "FLOAT" => {
393 Value::Float(row.try_get(c.ordinal()).expect("Failed to get float"))
394 }
395 "DOUBLE" => {
396 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
397 }
398
399 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
400 | "LONGBLOB" => Value::Bytes(
401 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
402 .expect("Failed to get bytes")
403 .map(Box::new),
404 ),
405
406 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
407 Value::String(
408 row.try_get::<Option<String>, _>(c.ordinal())
409 .expect("Failed to get string")
410 .map(Box::new),
411 )
412 }
413
414 #[cfg(feature = "with-chrono")]
415 "TIMESTAMP" => Value::ChronoDateTimeUtc(
416 row.try_get::<Option<chrono::DateTime<chrono::Utc>>, _>(c.ordinal())
417 .expect("Failed to get timestamp")
418 .map(Box::new),
419 ),
420 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
421 "TIMESTAMP" => Value::TimeDateTime(
422 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
423 .expect("Failed to get timestamp")
424 .map(Box::new),
425 ),
426
427 #[cfg(feature = "with-chrono")]
428 "DATE" => Value::ChronoDate(
429 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
430 .expect("Failed to get date")
431 .map(Box::new),
432 ),
433 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
434 "DATE" => Value::TimeDate(
435 row.try_get::<Option<time::Date>, _>(c.ordinal())
436 .expect("Failed to get date")
437 .map(Box::new),
438 ),
439
440 #[cfg(feature = "with-chrono")]
441 "TIME" => Value::ChronoTime(
442 row.try_get::<Option<chrono::NaiveTime>, _>(c.ordinal())
443 .expect("Failed to get time")
444 .map(Box::new),
445 ),
446 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
447 "TIME" => Value::TimeTime(
448 row.try_get::<Option<time::Time>, _>(c.ordinal())
449 .expect("Failed to get time")
450 .map(Box::new),
451 ),
452
453 #[cfg(feature = "with-chrono")]
454 "DATETIME" => Value::ChronoDateTime(
455 row.try_get::<Option<chrono::NaiveDateTime>, _>(c.ordinal())
456 .expect("Failed to get datetime")
457 .map(Box::new),
458 ),
459 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460 "DATETIME" => Value::TimeDateTime(
461 row.try_get::<Option<time::PrimitiveDateTime>, _>(c.ordinal())
462 .expect("Failed to get datetime")
463 .map(Box::new),
464 ),
465
466 #[cfg(feature = "with-chrono")]
467 "YEAR" => Value::ChronoDate(
468 row.try_get::<Option<chrono::NaiveDate>, _>(c.ordinal())
469 .expect("Failed to get year")
470 .map(Box::new),
471 ),
472 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473 "YEAR" => Value::TimeDate(
474 row.try_get::<Option<time::Date>, _>(c.ordinal())
475 .expect("Failed to get year")
476 .map(Box::new),
477 ),
478
479 "ENUM" | "SET" | "GEOMETRY" => Value::String(
480 row.try_get::<Option<String>, _>(c.ordinal())
481 .expect("Failed to get serialized string")
482 .map(Box::new),
483 ),
484
485 #[cfg(feature = "with-bigdecimal")]
486 "DECIMAL" => Value::BigDecimal(
487 row.try_get::<Option<bigdecimal::BigDecimal>, _>(c.ordinal())
488 .expect("Failed to get decimal")
489 .map(Box::new),
490 ),
491 #[cfg(all(
492 feature = "with-rust_decimal",
493 not(feature = "with-bigdecimal")
494 ))]
495 "DECIMAL" => Value::Decimal(
496 row.try_get::<Option<rust_decimal::Decimal>, _>(c.ordinal())
497 .expect("Failed to get decimal")
498 .map(Box::new),
499 ),
500
501 #[cfg(feature = "with-json")]
502 "JSON" => Value::Json(
503 row.try_get::<Option<serde_json::Value>, _>(c.ordinal())
504 .expect("Failed to get json")
505 .map(Box::new),
506 ),
507
508 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
509 },
510 )
511 })
512 .collect(),
513 }
514}