1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, MySql, MySqlPool,
8 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9 pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18 executor::*,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30 pub(crate) pool: MySqlPool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41 fn from(pool: MySqlPool) -> Self {
42 SqlxMySqlPoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50 fn from(pool: MySqlPool) -> Self {
51 DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxMySqlConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut sqlx_opts = options
65 .url
66 .parse::<MySqlConnectOptions>()
67 .map_err(sqlx_error_to_conn_err)?;
68 use sqlx::ConnectOptions;
69 if !options.sqlx_logging {
70 sqlx_opts = sqlx_opts.disable_statement_logging();
71 } else {
72 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74 sqlx_opts = sqlx_opts.log_slow_statements(
75 options.sqlx_slow_statements_logging_level,
76 options.sqlx_slow_statements_logging_threshold,
77 );
78 }
79 }
80 if let Some(f) = &options.mysql_opts_fn {
81 sqlx_opts = f(sqlx_opts);
82 }
83 let pool = if options.connect_lazy {
84 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
85 } else {
86 options
87 .sqlx_pool_options()
88 .connect_with(sqlx_opts)
89 .await
90 .map_err(sqlx_error_to_conn_err)?
91 };
92 Ok(
93 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
94 pool,
95 metric_callback: None,
96 })
97 .into(),
98 )
99 }
100}
101
102impl SqlxMySqlConnector {
103 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
105 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
106 pool,
107 metric_callback: None,
108 })
109 .into()
110 }
111}
112
113impl SqlxMySqlPoolConnection {
114 #[instrument(level = "trace")]
116 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
117 debug_print!("{}", stmt);
118
119 let query = sqlx_query(&stmt);
120 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
121 crate::metric::metric!(self.metric_callback, &stmt, {
122 match query.execute(&mut *conn).await {
123 Ok(res) => Ok(res.into()),
124 Err(err) => Err(sqlx_error_to_exec_err(err)),
125 }
126 })
127 }
128
129 #[instrument(level = "trace")]
131 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
132 debug_print!("{}", sql);
133
134 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
135 match conn.execute(sql).await {
136 Ok(res) => Ok(res.into()),
137 Err(err) => Err(sqlx_error_to_exec_err(err)),
138 }
139 }
140
141 #[instrument(level = "trace")]
143 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
144 debug_print!("{}", stmt);
145
146 let query = sqlx_query(&stmt);
147 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
148 crate::metric::metric!(self.metric_callback, &stmt, {
149 match query.fetch_one(&mut *conn).await {
150 Ok(row) => Ok(Some(row.into())),
151 Err(err) => match err {
152 sqlx::Error::RowNotFound => Ok(None),
153 _ => Err(sqlx_error_to_query_err(err)),
154 },
155 }
156 })
157 }
158
159 #[instrument(level = "trace")]
161 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
162 debug_print!("{}", stmt);
163
164 let query = sqlx_query(&stmt);
165 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
166 crate::metric::metric!(self.metric_callback, &stmt, {
167 match query.fetch_all(&mut *conn).await {
168 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
169 Err(err) => Err(sqlx_error_to_query_err(err)),
170 }
171 })
172 }
173
174 #[instrument(level = "trace")]
176 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
177 debug_print!("{}", stmt);
178
179 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
180 Ok(QueryStream::from((
181 conn,
182 stmt,
183 self.metric_callback.clone(),
184 )))
185 }
186
187 #[instrument(level = "trace")]
189 pub async fn begin(
190 &self,
191 isolation_level: Option<IsolationLevel>,
192 access_mode: Option<AccessMode>,
193 ) -> Result<DatabaseTransaction, DbErr> {
194 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
195 DatabaseTransaction::new_mysql(
196 conn,
197 self.metric_callback.clone(),
198 isolation_level,
199 access_mode,
200 )
201 .await
202 }
203
204 #[instrument(level = "trace", skip(callback))]
206 pub async fn transaction<F, T, E>(
207 &self,
208 callback: F,
209 isolation_level: Option<IsolationLevel>,
210 access_mode: Option<AccessMode>,
211 ) -> Result<T, TransactionError<E>>
212 where
213 F: for<'b> FnOnce(
214 &'b DatabaseTransaction,
215 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
216 + Send,
217 T: Send,
218 E: std::fmt::Display + std::fmt::Debug + Send,
219 {
220 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
221 let transaction = DatabaseTransaction::new_mysql(
222 conn,
223 self.metric_callback.clone(),
224 isolation_level,
225 access_mode,
226 )
227 .await
228 .map_err(|e| TransactionError::Connection(e))?;
229 transaction.run(callback).await
230 }
231
232 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
233 where
234 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
235 {
236 self.metric_callback = Some(Arc::new(callback));
237 }
238
239 pub async fn ping(&self) -> Result<(), DbErr> {
241 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
242 match conn.ping().await {
243 Ok(_) => Ok(()),
244 Err(err) => Err(sqlx_error_to_conn_err(err)),
245 }
246 }
247
248 pub async fn close(self) -> Result<(), DbErr> {
251 self.close_by_ref().await
252 }
253
254 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
256 self.pool.close().await;
257 Ok(())
258 }
259}
260
261impl From<MySqlRow> for QueryResult {
262 fn from(row: MySqlRow) -> QueryResult {
263 QueryResult {
264 row: QueryResultRow::SqlxMySql(row),
265 }
266 }
267}
268
269impl From<MySqlQueryResult> for ExecResult {
270 fn from(result: MySqlQueryResult) -> ExecResult {
271 ExecResult {
272 result: ExecResultHolder::SqlxMySql(result),
273 }
274 }
275}
276
277pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
278 let values = stmt
279 .values
280 .as_ref()
281 .map_or(Values(Vec::new()), |values| values.clone());
282 sqlx::query_with(&stmt.sql, SqlxValues(values))
283}
284
285pub(crate) async fn set_transaction_config(
286 conn: &mut PoolConnection<MySql>,
287 isolation_level: Option<IsolationLevel>,
288 access_mode: Option<AccessMode>,
289) -> Result<(), DbErr> {
290 let mut settings = Vec::new();
291
292 if let Some(isolation_level) = isolation_level {
293 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
294 }
295
296 if let Some(access_mode) = access_mode {
297 settings.push(access_mode.to_string());
298 }
299
300 if !settings.is_empty() {
301 let stmt = Statement {
302 sql: format!("SET TRANSACTION {}", settings.join(", ")),
303 values: None,
304 db_backend: DbBackend::MySql,
305 };
306 let query = sqlx_query(&stmt);
307 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
308 }
309 Ok(())
310}
311
312impl
313 From<(
314 PoolConnection<sqlx::MySql>,
315 Statement,
316 Option<crate::metric::Callback>,
317 )> for crate::QueryStream
318{
319 fn from(
320 (conn, stmt, metric_callback): (
321 PoolConnection<sqlx::MySql>,
322 Statement,
323 Option<crate::metric::Callback>,
324 ),
325 ) -> Self {
326 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
327 }
328}
329
330impl crate::DatabaseTransaction {
331 pub(crate) async fn new_mysql(
332 inner: PoolConnection<sqlx::MySql>,
333 metric_callback: Option<crate::metric::Callback>,
334 isolation_level: Option<IsolationLevel>,
335 access_mode: Option<AccessMode>,
336 ) -> Result<crate::DatabaseTransaction, DbErr> {
337 Self::begin(
338 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
339 crate::DbBackend::MySql,
340 metric_callback,
341 isolation_level,
342 access_mode,
343 )
344 .await
345 }
346}
347
348#[cfg(feature = "proxy")]
349pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
350 use sea_query::Value;
353 use sqlx::{Column, Row, TypeInfo};
354 crate::ProxyRow {
355 values: row
356 .columns()
357 .iter()
358 .map(|c| {
359 (
360 c.name().to_string(),
361 match c.type_info().name() {
362 "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
363 row.try_get(c.ordinal()).expect("Failed to get boolean"),
364 )),
365 "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
366 row.try_get(c.ordinal())
367 .expect("Failed to get unsigned tiny integer"),
368 )),
369 "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
370 row.try_get(c.ordinal())
371 .expect("Failed to get unsigned small integer"),
372 )),
373 "INT UNSIGNED" => Value::Unsigned(Some(
374 row.try_get(c.ordinal())
375 .expect("Failed to get unsigned integer"),
376 )),
377 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
378 row.try_get(c.ordinal())
379 .expect("Failed to get unsigned big integer"),
380 )),
381 "TINYINT" => Value::TinyInt(Some(
382 row.try_get(c.ordinal())
383 .expect("Failed to get tiny integer"),
384 )),
385 "SMALLINT" => Value::SmallInt(Some(
386 row.try_get(c.ordinal())
387 .expect("Failed to get small integer"),
388 )),
389 "INT" => Value::Int(Some(
390 row.try_get(c.ordinal()).expect("Failed to get integer"),
391 )),
392 "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
393 row.try_get(c.ordinal()).expect("Failed to get big integer"),
394 )),
395 "FLOAT" => Value::Float(Some(
396 row.try_get(c.ordinal()).expect("Failed to get float"),
397 )),
398 "DOUBLE" => Value::Double(Some(
399 row.try_get(c.ordinal()).expect("Failed to get double"),
400 )),
401
402 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
403 | "LONGBLOB" => Value::Bytes(Some(Box::new(
404 row.try_get(c.ordinal()).expect("Failed to get bytes"),
405 ))),
406
407 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
408 Value::String(Some(Box::new(
409 row.try_get(c.ordinal()).expect("Failed to get string"),
410 )))
411 }
412
413 #[cfg(feature = "with-chrono")]
414 "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
415 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
416 ))),
417 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
418 "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
419 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
420 ))),
421
422 #[cfg(feature = "with-chrono")]
423 "DATE" => Value::ChronoDate(Some(Box::new(
424 row.try_get(c.ordinal()).expect("Failed to get date"),
425 ))),
426 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
427 "DATE" => Value::TimeDate(Some(Box::new(
428 row.try_get(c.ordinal()).expect("Failed to get date"),
429 ))),
430
431 #[cfg(feature = "with-chrono")]
432 "TIME" => Value::ChronoTime(Some(Box::new(
433 row.try_get(c.ordinal()).expect("Failed to get time"),
434 ))),
435 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
436 "TIME" => Value::TimeTime(Some(Box::new(
437 row.try_get(c.ordinal()).expect("Failed to get time"),
438 ))),
439
440 #[cfg(feature = "with-chrono")]
441 "DATETIME" => Value::ChronoDateTime(Some(Box::new(
442 row.try_get(c.ordinal()).expect("Failed to get datetime"),
443 ))),
444 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
445 "DATETIME" => Value::TimeDateTime(Some(Box::new(
446 row.try_get(c.ordinal()).expect("Failed to get datetime"),
447 ))),
448
449 #[cfg(feature = "with-chrono")]
450 "YEAR" => Value::ChronoDate(Some(Box::new(
451 row.try_get(c.ordinal()).expect("Failed to get year"),
452 ))),
453 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
454 "YEAR" => Value::TimeDate(Some(Box::new(
455 row.try_get(c.ordinal()).expect("Failed to get year"),
456 ))),
457
458 "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
459 row.try_get(c.ordinal())
460 .expect("Failed to get serialized string"),
461 ))),
462
463 #[cfg(feature = "with-bigdecimal")]
464 "DECIMAL" => Value::BigDecimal(Some(Box::new(
465 row.try_get(c.ordinal()).expect("Failed to get decimal"),
466 ))),
467 #[cfg(all(
468 feature = "with-rust_decimal",
469 not(feature = "with-bigdecimal")
470 ))]
471 "DECIMAL" => Value::Decimal(Some(Box::new(
472 row.try_get(c.ordinal()).expect("Failed to get decimal"),
473 ))),
474
475 #[cfg(feature = "with-json")]
476 "JSON" => Value::Json(Some(Box::new(
477 row.try_get(c.ordinal()).expect("Failed to get json"),
478 ))),
479
480 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
481 },
482 )
483 })
484 .collect(),
485 }
486}