1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, MySql, MySqlPool,
8 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
9 pool::PoolConnection,
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*,
18 executor::*,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxMySqlConnector;
26
27#[derive(Clone)]
29pub struct SqlxMySqlPoolConnection {
30 pub(crate) pool: MySqlPool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxMySqlPoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<MySqlPool> for SqlxMySqlPoolConnection {
41 fn from(pool: MySqlPool) -> Self {
42 SqlxMySqlPoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<MySqlPool> for DatabaseConnection {
50 fn from(pool: MySqlPool) -> Self {
51 DatabaseConnectionType::SqlxMySqlPoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxMySqlConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut sqlx_opts = options
65 .url
66 .parse::<MySqlConnectOptions>()
67 .map_err(sqlx_error_to_conn_err)?;
68 use sqlx::ConnectOptions;
69 if !options.sqlx_logging {
70 sqlx_opts = sqlx_opts.disable_statement_logging();
71 } else {
72 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
73 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
74 sqlx_opts = sqlx_opts.log_slow_statements(
75 options.sqlx_slow_statements_logging_level,
76 options.sqlx_slow_statements_logging_threshold,
77 );
78 }
79 }
80
81 if let Some(f) = &options.mysql_opts_fn {
82 sqlx_opts = f(sqlx_opts);
83 }
84
85 let after_connect = options.after_connect.clone();
86
87 let pool = if options.connect_lazy {
88 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
89 } else {
90 options
91 .sqlx_pool_options()
92 .connect_with(sqlx_opts)
93 .await
94 .map_err(sqlx_error_to_conn_err)?
95 };
96
97 let conn: DatabaseConnection =
98 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
99 pool,
100 metric_callback: None,
101 })
102 .into();
103
104 if let Some(cb) = after_connect {
105 cb(conn.clone()).await?;
106 }
107
108 Ok(conn)
109 }
110}
111
112impl SqlxMySqlConnector {
113 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
115 DatabaseConnectionType::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
116 pool,
117 metric_callback: None,
118 })
119 .into()
120 }
121}
122
123impl SqlxMySqlPoolConnection {
124 #[instrument(level = "trace")]
126 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
127 debug_print!("{}", stmt);
128
129 let query = sqlx_query(&stmt);
130 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
131 crate::metric::metric!(self.metric_callback, &stmt, {
132 match query.execute(&mut *conn).await {
133 Ok(res) => Ok(res.into()),
134 Err(err) => Err(sqlx_error_to_exec_err(err)),
135 }
136 })
137 }
138
139 #[instrument(level = "trace")]
141 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
142 debug_print!("{}", sql);
143
144 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145 match conn.execute(sql).await {
146 Ok(res) => Ok(res.into()),
147 Err(err) => Err(sqlx_error_to_exec_err(err)),
148 }
149 }
150
151 #[instrument(level = "trace")]
153 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
154 debug_print!("{}", stmt);
155
156 let query = sqlx_query(&stmt);
157 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
158 crate::metric::metric!(self.metric_callback, &stmt, {
159 match query.fetch_one(&mut *conn).await {
160 Ok(row) => Ok(Some(row.into())),
161 Err(err) => match err {
162 sqlx::Error::RowNotFound => Ok(None),
163 _ => Err(sqlx_error_to_query_err(err)),
164 },
165 }
166 })
167 }
168
169 #[instrument(level = "trace")]
171 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
172 debug_print!("{}", stmt);
173
174 let query = sqlx_query(&stmt);
175 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
176 crate::metric::metric!(self.metric_callback, &stmt, {
177 match query.fetch_all(&mut *conn).await {
178 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
179 Err(err) => Err(sqlx_error_to_query_err(err)),
180 }
181 })
182 }
183
184 #[instrument(level = "trace")]
186 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
187 debug_print!("{}", stmt);
188
189 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
190 Ok(QueryStream::from((
191 conn,
192 stmt,
193 self.metric_callback.clone(),
194 )))
195 }
196
197 #[instrument(level = "trace")]
199 pub async fn begin(
200 &self,
201 isolation_level: Option<IsolationLevel>,
202 access_mode: Option<AccessMode>,
203 ) -> Result<DatabaseTransaction, DbErr> {
204 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205 DatabaseTransaction::new_mysql(
206 conn,
207 self.metric_callback.clone(),
208 isolation_level,
209 access_mode,
210 )
211 .await
212 }
213
214 #[instrument(level = "trace", skip(callback))]
216 pub async fn transaction<F, T, E>(
217 &self,
218 callback: F,
219 isolation_level: Option<IsolationLevel>,
220 access_mode: Option<AccessMode>,
221 ) -> Result<T, TransactionError<E>>
222 where
223 F: for<'b> FnOnce(
224 &'b DatabaseTransaction,
225 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
226 + Send,
227 T: Send,
228 E: std::fmt::Display + std::fmt::Debug + Send,
229 {
230 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
231 let transaction = DatabaseTransaction::new_mysql(
232 conn,
233 self.metric_callback.clone(),
234 isolation_level,
235 access_mode,
236 )
237 .await
238 .map_err(|e| TransactionError::Connection(e))?;
239 transaction.run(callback).await
240 }
241
242 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
243 where
244 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
245 {
246 self.metric_callback = Some(Arc::new(callback));
247 }
248
249 pub async fn ping(&self) -> Result<(), DbErr> {
251 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
252 match conn.ping().await {
253 Ok(_) => Ok(()),
254 Err(err) => Err(sqlx_error_to_conn_err(err)),
255 }
256 }
257
258 pub async fn close(self) -> Result<(), DbErr> {
261 self.close_by_ref().await
262 }
263
264 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
266 self.pool.close().await;
267 Ok(())
268 }
269}
270
271impl From<MySqlRow> for QueryResult {
272 fn from(row: MySqlRow) -> QueryResult {
273 QueryResult {
274 row: QueryResultRow::SqlxMySql(row),
275 }
276 }
277}
278
279impl From<MySqlQueryResult> for ExecResult {
280 fn from(result: MySqlQueryResult) -> ExecResult {
281 ExecResult {
282 result: ExecResultHolder::SqlxMySql(result),
283 }
284 }
285}
286
287pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
288 let values = stmt
289 .values
290 .as_ref()
291 .map_or(Values(Vec::new()), |values| values.clone());
292 sqlx::query_with(&stmt.sql, SqlxValues(values))
293}
294
295pub(crate) async fn set_transaction_config(
296 conn: &mut PoolConnection<MySql>,
297 isolation_level: Option<IsolationLevel>,
298 access_mode: Option<AccessMode>,
299) -> Result<(), DbErr> {
300 let mut settings = Vec::new();
301
302 if let Some(isolation_level) = isolation_level {
303 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
304 }
305
306 if let Some(access_mode) = access_mode {
307 settings.push(access_mode.to_string());
308 }
309
310 if !settings.is_empty() {
311 let stmt = Statement {
312 sql: format!("SET TRANSACTION {}", settings.join(", ")),
313 values: None,
314 db_backend: DbBackend::MySql,
315 };
316 let query = sqlx_query(&stmt);
317 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
318 }
319 Ok(())
320}
321
322impl
323 From<(
324 PoolConnection<sqlx::MySql>,
325 Statement,
326 Option<crate::metric::Callback>,
327 )> for crate::QueryStream
328{
329 fn from(
330 (conn, stmt, metric_callback): (
331 PoolConnection<sqlx::MySql>,
332 Statement,
333 Option<crate::metric::Callback>,
334 ),
335 ) -> Self {
336 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
337 }
338}
339
340impl crate::DatabaseTransaction {
341 pub(crate) async fn new_mysql(
342 inner: PoolConnection<sqlx::MySql>,
343 metric_callback: Option<crate::metric::Callback>,
344 isolation_level: Option<IsolationLevel>,
345 access_mode: Option<AccessMode>,
346 ) -> Result<crate::DatabaseTransaction, DbErr> {
347 Self::begin(
348 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
349 crate::DbBackend::MySql,
350 metric_callback,
351 isolation_level,
352 access_mode,
353 )
354 .await
355 }
356}
357
358#[cfg(feature = "proxy")]
359pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
360 use sea_query::Value;
363 use sqlx::{Column, Row, TypeInfo};
364 crate::ProxyRow {
365 values: row
366 .columns()
367 .iter()
368 .map(|c| {
369 (
370 c.name().to_string(),
371 match c.type_info().name() {
372 "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
373 row.try_get(c.ordinal()).expect("Failed to get boolean"),
374 )),
375 "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
376 row.try_get(c.ordinal())
377 .expect("Failed to get unsigned tiny integer"),
378 )),
379 "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
380 row.try_get(c.ordinal())
381 .expect("Failed to get unsigned small integer"),
382 )),
383 "INT UNSIGNED" => Value::Unsigned(Some(
384 row.try_get(c.ordinal())
385 .expect("Failed to get unsigned integer"),
386 )),
387 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
388 row.try_get(c.ordinal())
389 .expect("Failed to get unsigned big integer"),
390 )),
391 "TINYINT" => Value::TinyInt(Some(
392 row.try_get(c.ordinal())
393 .expect("Failed to get tiny integer"),
394 )),
395 "SMALLINT" => Value::SmallInt(Some(
396 row.try_get(c.ordinal())
397 .expect("Failed to get small integer"),
398 )),
399 "INT" => Value::Int(Some(
400 row.try_get(c.ordinal()).expect("Failed to get integer"),
401 )),
402 "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
403 row.try_get(c.ordinal()).expect("Failed to get big integer"),
404 )),
405 "FLOAT" => Value::Float(Some(
406 row.try_get(c.ordinal()).expect("Failed to get float"),
407 )),
408 "DOUBLE" => Value::Double(Some(
409 row.try_get(c.ordinal()).expect("Failed to get double"),
410 )),
411
412 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
413 | "LONGBLOB" => Value::Bytes(Some(Box::new(
414 row.try_get(c.ordinal()).expect("Failed to get bytes"),
415 ))),
416
417 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
418 Value::String(Some(Box::new(
419 row.try_get(c.ordinal()).expect("Failed to get string"),
420 )))
421 }
422
423 #[cfg(feature = "with-chrono")]
424 "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
425 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
426 ))),
427 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
428 "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
429 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
430 ))),
431
432 #[cfg(feature = "with-chrono")]
433 "DATE" => Value::ChronoDate(Some(Box::new(
434 row.try_get(c.ordinal()).expect("Failed to get date"),
435 ))),
436 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
437 "DATE" => Value::TimeDate(Some(Box::new(
438 row.try_get(c.ordinal()).expect("Failed to get date"),
439 ))),
440
441 #[cfg(feature = "with-chrono")]
442 "TIME" => Value::ChronoTime(Some(Box::new(
443 row.try_get(c.ordinal()).expect("Failed to get time"),
444 ))),
445 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
446 "TIME" => Value::TimeTime(Some(Box::new(
447 row.try_get(c.ordinal()).expect("Failed to get time"),
448 ))),
449
450 #[cfg(feature = "with-chrono")]
451 "DATETIME" => Value::ChronoDateTime(Some(Box::new(
452 row.try_get(c.ordinal()).expect("Failed to get datetime"),
453 ))),
454 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
455 "DATETIME" => Value::TimeDateTime(Some(Box::new(
456 row.try_get(c.ordinal()).expect("Failed to get datetime"),
457 ))),
458
459 #[cfg(feature = "with-chrono")]
460 "YEAR" => Value::ChronoDate(Some(Box::new(
461 row.try_get(c.ordinal()).expect("Failed to get year"),
462 ))),
463 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
464 "YEAR" => Value::TimeDate(Some(Box::new(
465 row.try_get(c.ordinal()).expect("Failed to get year"),
466 ))),
467
468 "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
469 row.try_get(c.ordinal())
470 .expect("Failed to get serialized string"),
471 ))),
472
473 #[cfg(feature = "with-bigdecimal")]
474 "DECIMAL" => Value::BigDecimal(Some(Box::new(
475 row.try_get(c.ordinal()).expect("Failed to get decimal"),
476 ))),
477 #[cfg(all(
478 feature = "with-rust_decimal",
479 not(feature = "with-bigdecimal")
480 ))]
481 "DECIMAL" => Value::Decimal(Some(Box::new(
482 row.try_get(c.ordinal()).expect("Failed to get decimal"),
483 ))),
484
485 #[cfg(feature = "with-json")]
486 "JSON" => Value::Json(Some(Box::new(
487 row.try_get(c.ordinal()).expect("Failed to get json"),
488 ))),
489
490 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
491 },
492 )
493 })
494 .collect(),
495 }
496}