1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8 pool::PoolConnection,
9 Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17 DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29 pub(crate) pool: MySqlPool,
30 metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36 }
37}
38
39impl From<MySqlPool> for SqlxMySqlPoolConnection {
40 fn from(pool: MySqlPool) -> Self {
41 SqlxMySqlPoolConnection {
42 pool,
43 metric_callback: None,
44 }
45 }
46}
47
48impl From<MySqlPool> for DatabaseConnection {
49 fn from(pool: MySqlPool) -> Self {
50 DatabaseConnection::SqlxMySqlPoolConnection(pool.into())
51 }
52}
53
54impl SqlxMySqlConnector {
55 pub fn accepts(string: &str) -> bool {
57 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
58 }
59
60 #[instrument(level = "trace")]
62 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
63 let mut opt = options
64 .url
65 .parse::<MySqlConnectOptions>()
66 .map_err(sqlx_error_to_conn_err)?;
67 use sqlx::ConnectOptions;
68 if !options.sqlx_logging {
69 opt = opt.disable_statement_logging();
70 } else {
71 opt = opt.log_statements(options.sqlx_logging_level);
72 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
73 opt = opt.log_slow_statements(
74 options.sqlx_slow_statements_logging_level,
75 options.sqlx_slow_statements_logging_threshold,
76 );
77 }
78 }
79 let pool = if options.connect_lazy {
80 options.sqlx_pool_options().connect_lazy_with(opt)
81 } else {
82 options
83 .sqlx_pool_options()
84 .connect_with(opt)
85 .await
86 .map_err(sqlx_error_to_conn_err)?
87 };
88 Ok(DatabaseConnection::SqlxMySqlPoolConnection(
89 SqlxMySqlPoolConnection {
90 pool,
91 metric_callback: None,
92 },
93 ))
94 }
95}
96
97impl SqlxMySqlConnector {
98 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
100 DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
101 pool,
102 metric_callback: None,
103 })
104 }
105}
106
107impl SqlxMySqlPoolConnection {
108 #[instrument(level = "trace")]
110 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
111 debug_print!("{}", stmt);
112
113 let query = sqlx_query(&stmt);
114 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
115 crate::metric::metric!(self.metric_callback, &stmt, {
116 match query.execute(&mut *conn).await {
117 Ok(res) => Ok(res.into()),
118 Err(err) => Err(sqlx_error_to_exec_err(err)),
119 }
120 })
121 }
122
123 #[instrument(level = "trace")]
125 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
126 debug_print!("{}", sql);
127
128 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
129 match conn.execute(sql).await {
130 Ok(res) => Ok(res.into()),
131 Err(err) => Err(sqlx_error_to_exec_err(err)),
132 }
133 }
134
135 #[instrument(level = "trace")]
137 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
138 debug_print!("{}", stmt);
139
140 let query = sqlx_query(&stmt);
141 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
142 crate::metric::metric!(self.metric_callback, &stmt, {
143 match query.fetch_one(&mut *conn).await {
144 Ok(row) => Ok(Some(row.into())),
145 Err(err) => match err {
146 sqlx::Error::RowNotFound => Ok(None),
147 _ => Err(sqlx_error_to_query_err(err)),
148 },
149 }
150 })
151 }
152
153 #[instrument(level = "trace")]
155 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
156 debug_print!("{}", stmt);
157
158 let query = sqlx_query(&stmt);
159 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160 crate::metric::metric!(self.metric_callback, &stmt, {
161 match query.fetch_all(&mut *conn).await {
162 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
163 Err(err) => Err(sqlx_error_to_query_err(err)),
164 }
165 })
166 }
167
168 #[instrument(level = "trace")]
170 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
171 debug_print!("{}", stmt);
172
173 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174 Ok(QueryStream::from((
175 conn,
176 stmt,
177 self.metric_callback.clone(),
178 )))
179 }
180
181 #[instrument(level = "trace")]
183 pub async fn begin(
184 &self,
185 isolation_level: Option<IsolationLevel>,
186 access_mode: Option<AccessMode>,
187 ) -> Result<DatabaseTransaction, DbErr> {
188 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
189 DatabaseTransaction::new_mysql(
190 conn,
191 self.metric_callback.clone(),
192 isolation_level,
193 access_mode,
194 )
195 .await
196 }
197
198 #[instrument(level = "trace", skip(callback))]
200 pub async fn transaction<F, T, E>(
201 &self,
202 callback: F,
203 isolation_level: Option<IsolationLevel>,
204 access_mode: Option<AccessMode>,
205 ) -> Result<T, TransactionError<E>>
206 where
207 F: for<'b> FnOnce(
208 &'b DatabaseTransaction,
209 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
210 + Send,
211 T: Send,
212 E: std::error::Error + Send,
213 {
214 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
215 let transaction = DatabaseTransaction::new_mysql(
216 conn,
217 self.metric_callback.clone(),
218 isolation_level,
219 access_mode,
220 )
221 .await
222 .map_err(|e| TransactionError::Connection(e))?;
223 transaction.run(callback).await
224 }
225
226 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
227 where
228 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
229 {
230 self.metric_callback = Some(Arc::new(callback));
231 }
232
233 pub async fn ping(&self) -> Result<(), DbErr> {
235 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236 match conn.ping().await {
237 Ok(_) => Ok(()),
238 Err(err) => Err(sqlx_error_to_conn_err(err)),
239 }
240 }
241
242 pub async fn close(self) -> Result<(), DbErr> {
244 self.pool.close().await;
245 Ok(())
246 }
247}
248
249impl From<MySqlRow> for QueryResult {
250 fn from(row: MySqlRow) -> QueryResult {
251 QueryResult {
252 row: QueryResultRow::SqlxMySql(row),
253 }
254 }
255}
256
257impl From<MySqlQueryResult> for ExecResult {
258 fn from(result: MySqlQueryResult) -> ExecResult {
259 ExecResult {
260 result: ExecResultHolder::SqlxMySql(result),
261 }
262 }
263}
264
265pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
266 let values = stmt
267 .values
268 .as_ref()
269 .map_or(Values(Vec::new()), |values| values.clone());
270 sqlx::query_with(&stmt.sql, SqlxValues(values))
271}
272
273pub(crate) async fn set_transaction_config(
274 conn: &mut PoolConnection<MySql>,
275 isolation_level: Option<IsolationLevel>,
276 access_mode: Option<AccessMode>,
277) -> Result<(), DbErr> {
278 let mut settings = Vec::new();
279
280 if let Some(isolation_level) = isolation_level {
281 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
282 }
283
284 if let Some(access_mode) = access_mode {
285 settings.push(access_mode.to_string());
286 }
287
288 if !settings.is_empty() {
289 let stmt = Statement {
290 sql: format!("SET TRANSACTION {}", settings.join(", ")),
291 values: None,
292 db_backend: DbBackend::MySql,
293 };
294 let query = sqlx_query(&stmt);
295 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
296 }
297 Ok(())
298}
299
300impl
301 From<(
302 PoolConnection<sqlx::MySql>,
303 Statement,
304 Option<crate::metric::Callback>,
305 )> for crate::QueryStream
306{
307 fn from(
308 (conn, stmt, metric_callback): (
309 PoolConnection<sqlx::MySql>,
310 Statement,
311 Option<crate::metric::Callback>,
312 ),
313 ) -> Self {
314 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
315 }
316}
317
318impl crate::DatabaseTransaction {
319 pub(crate) async fn new_mysql(
320 inner: PoolConnection<sqlx::MySql>,
321 metric_callback: Option<crate::metric::Callback>,
322 isolation_level: Option<IsolationLevel>,
323 access_mode: Option<AccessMode>,
324 ) -> Result<crate::DatabaseTransaction, DbErr> {
325 Self::begin(
326 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
327 crate::DbBackend::MySql,
328 metric_callback,
329 isolation_level,
330 access_mode,
331 )
332 .await
333 }
334}
335
336#[cfg(feature = "proxy")]
337pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
338 use sea_query::Value;
341 use sqlx::{Column, Row, TypeInfo};
342 crate::ProxyRow {
343 values: row
344 .columns()
345 .iter()
346 .map(|c| {
347 (
348 c.name().to_string(),
349 match c.type_info().name() {
350 "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
351 row.try_get(c.ordinal()).expect("Failed to get boolean"),
352 )),
353 "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
354 row.try_get(c.ordinal())
355 .expect("Failed to get unsigned tiny integer"),
356 )),
357 "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
358 row.try_get(c.ordinal())
359 .expect("Failed to get unsigned small integer"),
360 )),
361 "INT UNSIGNED" => Value::Unsigned(Some(
362 row.try_get(c.ordinal())
363 .expect("Failed to get unsigned integer"),
364 )),
365 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
366 row.try_get(c.ordinal())
367 .expect("Failed to get unsigned big integer"),
368 )),
369 "TINYINT" => Value::TinyInt(Some(
370 row.try_get(c.ordinal())
371 .expect("Failed to get tiny integer"),
372 )),
373 "SMALLINT" => Value::SmallInt(Some(
374 row.try_get(c.ordinal())
375 .expect("Failed to get small integer"),
376 )),
377 "INT" => Value::Int(Some(
378 row.try_get(c.ordinal()).expect("Failed to get integer"),
379 )),
380 "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
381 row.try_get(c.ordinal()).expect("Failed to get big integer"),
382 )),
383 "FLOAT" => Value::Float(Some(
384 row.try_get(c.ordinal()).expect("Failed to get float"),
385 )),
386 "DOUBLE" => Value::Double(Some(
387 row.try_get(c.ordinal()).expect("Failed to get double"),
388 )),
389
390 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
391 | "LONGBLOB" => Value::Bytes(Some(Box::new(
392 row.try_get(c.ordinal()).expect("Failed to get bytes"),
393 ))),
394
395 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
396 Value::String(Some(Box::new(
397 row.try_get(c.ordinal()).expect("Failed to get string"),
398 )))
399 }
400
401 #[cfg(feature = "with-chrono")]
402 "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
403 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
404 ))),
405 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
406 "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
407 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
408 ))),
409
410 #[cfg(feature = "with-chrono")]
411 "DATE" => Value::ChronoDate(Some(Box::new(
412 row.try_get(c.ordinal()).expect("Failed to get date"),
413 ))),
414 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
415 "DATE" => Value::TimeDate(Some(Box::new(
416 row.try_get(c.ordinal()).expect("Failed to get date"),
417 ))),
418
419 #[cfg(feature = "with-chrono")]
420 "TIME" => Value::ChronoTime(Some(Box::new(
421 row.try_get(c.ordinal()).expect("Failed to get time"),
422 ))),
423 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
424 "TIME" => Value::TimeTime(Some(Box::new(
425 row.try_get(c.ordinal()).expect("Failed to get time"),
426 ))),
427
428 #[cfg(feature = "with-chrono")]
429 "DATETIME" => Value::ChronoDateTime(Some(Box::new(
430 row.try_get(c.ordinal()).expect("Failed to get datetime"),
431 ))),
432 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
433 "DATETIME" => Value::TimeDateTime(Some(Box::new(
434 row.try_get(c.ordinal()).expect("Failed to get datetime"),
435 ))),
436
437 #[cfg(feature = "with-chrono")]
438 "YEAR" => Value::ChronoDate(Some(Box::new(
439 row.try_get(c.ordinal()).expect("Failed to get year"),
440 ))),
441 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
442 "YEAR" => Value::TimeDate(Some(Box::new(
443 row.try_get(c.ordinal()).expect("Failed to get year"),
444 ))),
445
446 "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
447 row.try_get(c.ordinal())
448 .expect("Failed to get serialized string"),
449 ))),
450
451 #[cfg(feature = "with-bigdecimal")]
452 "DECIMAL" => Value::BigDecimal(Some(Box::new(
453 row.try_get(c.ordinal()).expect("Failed to get decimal"),
454 ))),
455 #[cfg(all(
456 feature = "with-rust_decimal",
457 not(feature = "with-bigdecimal")
458 ))]
459 "DECIMAL" => Value::Decimal(Some(Box::new(
460 row.try_get(c.ordinal()).expect("Failed to get decimal"),
461 ))),
462
463 #[cfg(feature = "with-json")]
464 "JSON" => Value::Json(Some(Box::new(
465 row.try_get(c.ordinal()).expect("Failed to get json"),
466 ))),
467
468 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
469 },
470 )
471 })
472 .collect(),
473 }
474}