1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8 pool::PoolConnection,
9 Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17 DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29 pub(crate) pool: MySqlPool,
30 metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36 }
37}
38
39impl From<MySqlPool> for SqlxMySqlPoolConnection {
40 fn from(pool: MySqlPool) -> Self {
41 SqlxMySqlPoolConnection {
42 pool,
43 metric_callback: None,
44 }
45 }
46}
47
48impl From<MySqlPool> for DatabaseConnection {
49 fn from(pool: MySqlPool) -> Self {
50 DatabaseConnection::SqlxMySqlPoolConnection(pool.into())
51 }
52}
53
54impl SqlxMySqlConnector {
55 pub fn accepts(string: &str) -> bool {
57 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
58 }
59
60 #[instrument(level = "trace")]
62 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
63 let mut opt = options
64 .url
65 .parse::<MySqlConnectOptions>()
66 .map_err(sqlx_error_to_conn_err)?;
67 use sqlx::ConnectOptions;
68 if !options.sqlx_logging {
69 opt = opt.disable_statement_logging();
70 } else {
71 opt = opt.log_statements(options.sqlx_logging_level);
72 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
73 opt = opt.log_slow_statements(
74 options.sqlx_slow_statements_logging_level,
75 options.sqlx_slow_statements_logging_threshold,
76 );
77 }
78 }
79 let pool = if options.connect_lazy {
80 options.sqlx_pool_options().connect_lazy_with(opt)
81 } else {
82 options
83 .sqlx_pool_options()
84 .connect_with(opt)
85 .await
86 .map_err(sqlx_error_to_conn_err)?
87 };
88 Ok(DatabaseConnection::SqlxMySqlPoolConnection(
89 SqlxMySqlPoolConnection {
90 pool,
91 metric_callback: None,
92 },
93 ))
94 }
95}
96
97impl SqlxMySqlConnector {
98 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
100 DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
101 pool,
102 metric_callback: None,
103 })
104 }
105}
106
107impl SqlxMySqlPoolConnection {
108 #[instrument(level = "trace")]
110 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
111 debug_print!("{}", stmt);
112
113 let query = sqlx_query(&stmt);
114 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
115 crate::metric::metric!(self.metric_callback, &stmt, {
116 match query.execute(&mut *conn).await {
117 Ok(res) => Ok(res.into()),
118 Err(err) => Err(sqlx_error_to_exec_err(err)),
119 }
120 })
121 }
122
123 #[instrument(level = "trace")]
125 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
126 debug_print!("{}", sql);
127
128 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
129 match conn.execute(sql).await {
130 Ok(res) => Ok(res.into()),
131 Err(err) => Err(sqlx_error_to_exec_err(err)),
132 }
133 }
134
135 #[instrument(level = "trace")]
137 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
138 debug_print!("{}", stmt);
139
140 let query = sqlx_query(&stmt);
141 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
142 crate::metric::metric!(self.metric_callback, &stmt, {
143 match query.fetch_one(&mut *conn).await {
144 Ok(row) => Ok(Some(row.into())),
145 Err(err) => match err {
146 sqlx::Error::RowNotFound => Ok(None),
147 _ => Err(sqlx_error_to_query_err(err)),
148 },
149 }
150 })
151 }
152
153 #[instrument(level = "trace")]
155 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
156 debug_print!("{}", stmt);
157
158 let query = sqlx_query(&stmt);
159 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160 crate::metric::metric!(self.metric_callback, &stmt, {
161 match query.fetch_all(&mut *conn).await {
162 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
163 Err(err) => Err(sqlx_error_to_query_err(err)),
164 }
165 })
166 }
167
168 #[instrument(level = "trace")]
170 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
171 debug_print!("{}", stmt);
172
173 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174 Ok(QueryStream::from((
175 conn,
176 stmt,
177 self.metric_callback.clone(),
178 )))
179 }
180
181 #[instrument(level = "trace")]
183 pub async fn begin(
184 &self,
185 isolation_level: Option<IsolationLevel>,
186 access_mode: Option<AccessMode>,
187 ) -> Result<DatabaseTransaction, DbErr> {
188 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
189 DatabaseTransaction::new_mysql(
190 conn,
191 self.metric_callback.clone(),
192 isolation_level,
193 access_mode,
194 )
195 .await
196 }
197
198 #[instrument(level = "trace", skip(callback))]
200 pub async fn transaction<F, T, E>(
201 &self,
202 callback: F,
203 isolation_level: Option<IsolationLevel>,
204 access_mode: Option<AccessMode>,
205 ) -> Result<T, TransactionError<E>>
206 where
207 F: for<'b> FnOnce(
208 &'b DatabaseTransaction,
209 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
210 + Send,
211 T: Send,
212 E: std::fmt::Display + std::fmt::Debug + Send,
213 {
214 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
215 let transaction = DatabaseTransaction::new_mysql(
216 conn,
217 self.metric_callback.clone(),
218 isolation_level,
219 access_mode,
220 )
221 .await
222 .map_err(|e| TransactionError::Connection(e))?;
223 transaction.run(callback).await
224 }
225
226 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
227 where
228 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
229 {
230 self.metric_callback = Some(Arc::new(callback));
231 }
232
233 pub async fn ping(&self) -> Result<(), DbErr> {
235 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236 match conn.ping().await {
237 Ok(_) => Ok(()),
238 Err(err) => Err(sqlx_error_to_conn_err(err)),
239 }
240 }
241
242 pub async fn close(self) -> Result<(), DbErr> {
245 self.close_by_ref().await
246 }
247
248 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
250 self.pool.close().await;
251 Ok(())
252 }
253}
254
255impl From<MySqlRow> for QueryResult {
256 fn from(row: MySqlRow) -> QueryResult {
257 QueryResult {
258 row: QueryResultRow::SqlxMySql(row),
259 }
260 }
261}
262
263impl From<MySqlQueryResult> for ExecResult {
264 fn from(result: MySqlQueryResult) -> ExecResult {
265 ExecResult {
266 result: ExecResultHolder::SqlxMySql(result),
267 }
268 }
269}
270
271pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
272 let values = stmt
273 .values
274 .as_ref()
275 .map_or(Values(Vec::new()), |values| values.clone());
276 sqlx::query_with(&stmt.sql, SqlxValues(values))
277}
278
279pub(crate) async fn set_transaction_config(
280 conn: &mut PoolConnection<MySql>,
281 isolation_level: Option<IsolationLevel>,
282 access_mode: Option<AccessMode>,
283) -> Result<(), DbErr> {
284 let mut settings = Vec::new();
285
286 if let Some(isolation_level) = isolation_level {
287 settings.push(format!("ISOLATION LEVEL {isolation_level}"));
288 }
289
290 if let Some(access_mode) = access_mode {
291 settings.push(access_mode.to_string());
292 }
293
294 if !settings.is_empty() {
295 let stmt = Statement {
296 sql: format!("SET TRANSACTION {}", settings.join(", ")),
297 values: None,
298 db_backend: DbBackend::MySql,
299 };
300 let query = sqlx_query(&stmt);
301 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
302 }
303 Ok(())
304}
305
306impl
307 From<(
308 PoolConnection<sqlx::MySql>,
309 Statement,
310 Option<crate::metric::Callback>,
311 )> for crate::QueryStream
312{
313 fn from(
314 (conn, stmt, metric_callback): (
315 PoolConnection<sqlx::MySql>,
316 Statement,
317 Option<crate::metric::Callback>,
318 ),
319 ) -> Self {
320 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
321 }
322}
323
324impl crate::DatabaseTransaction {
325 pub(crate) async fn new_mysql(
326 inner: PoolConnection<sqlx::MySql>,
327 metric_callback: Option<crate::metric::Callback>,
328 isolation_level: Option<IsolationLevel>,
329 access_mode: Option<AccessMode>,
330 ) -> Result<crate::DatabaseTransaction, DbErr> {
331 Self::begin(
332 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
333 crate::DbBackend::MySql,
334 metric_callback,
335 isolation_level,
336 access_mode,
337 )
338 .await
339 }
340}
341
342#[cfg(feature = "proxy")]
343pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
344 use sea_query::Value;
347 use sqlx::{Column, Row, TypeInfo};
348 crate::ProxyRow {
349 values: row
350 .columns()
351 .iter()
352 .map(|c| {
353 (
354 c.name().to_string(),
355 match c.type_info().name() {
356 "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
357 row.try_get(c.ordinal()).expect("Failed to get boolean"),
358 )),
359 "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
360 row.try_get(c.ordinal())
361 .expect("Failed to get unsigned tiny integer"),
362 )),
363 "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
364 row.try_get(c.ordinal())
365 .expect("Failed to get unsigned small integer"),
366 )),
367 "INT UNSIGNED" => Value::Unsigned(Some(
368 row.try_get(c.ordinal())
369 .expect("Failed to get unsigned integer"),
370 )),
371 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
372 row.try_get(c.ordinal())
373 .expect("Failed to get unsigned big integer"),
374 )),
375 "TINYINT" => Value::TinyInt(Some(
376 row.try_get(c.ordinal())
377 .expect("Failed to get tiny integer"),
378 )),
379 "SMALLINT" => Value::SmallInt(Some(
380 row.try_get(c.ordinal())
381 .expect("Failed to get small integer"),
382 )),
383 "INT" => Value::Int(Some(
384 row.try_get(c.ordinal()).expect("Failed to get integer"),
385 )),
386 "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
387 row.try_get(c.ordinal()).expect("Failed to get big integer"),
388 )),
389 "FLOAT" => Value::Float(Some(
390 row.try_get(c.ordinal()).expect("Failed to get float"),
391 )),
392 "DOUBLE" => Value::Double(Some(
393 row.try_get(c.ordinal()).expect("Failed to get double"),
394 )),
395
396 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
397 | "LONGBLOB" => Value::Bytes(Some(Box::new(
398 row.try_get(c.ordinal()).expect("Failed to get bytes"),
399 ))),
400
401 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
402 Value::String(Some(Box::new(
403 row.try_get(c.ordinal()).expect("Failed to get string"),
404 )))
405 }
406
407 #[cfg(feature = "with-chrono")]
408 "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
409 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
410 ))),
411 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
412 "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
413 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
414 ))),
415
416 #[cfg(feature = "with-chrono")]
417 "DATE" => Value::ChronoDate(Some(Box::new(
418 row.try_get(c.ordinal()).expect("Failed to get date"),
419 ))),
420 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
421 "DATE" => Value::TimeDate(Some(Box::new(
422 row.try_get(c.ordinal()).expect("Failed to get date"),
423 ))),
424
425 #[cfg(feature = "with-chrono")]
426 "TIME" => Value::ChronoTime(Some(Box::new(
427 row.try_get(c.ordinal()).expect("Failed to get time"),
428 ))),
429 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
430 "TIME" => Value::TimeTime(Some(Box::new(
431 row.try_get(c.ordinal()).expect("Failed to get time"),
432 ))),
433
434 #[cfg(feature = "with-chrono")]
435 "DATETIME" => Value::ChronoDateTime(Some(Box::new(
436 row.try_get(c.ordinal()).expect("Failed to get datetime"),
437 ))),
438 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
439 "DATETIME" => Value::TimeDateTime(Some(Box::new(
440 row.try_get(c.ordinal()).expect("Failed to get datetime"),
441 ))),
442
443 #[cfg(feature = "with-chrono")]
444 "YEAR" => Value::ChronoDate(Some(Box::new(
445 row.try_get(c.ordinal()).expect("Failed to get year"),
446 ))),
447 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
448 "YEAR" => Value::TimeDate(Some(Box::new(
449 row.try_get(c.ordinal()).expect("Failed to get year"),
450 ))),
451
452 "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
453 row.try_get(c.ordinal())
454 .expect("Failed to get serialized string"),
455 ))),
456
457 #[cfg(feature = "with-bigdecimal")]
458 "DECIMAL" => Value::BigDecimal(Some(Box::new(
459 row.try_get(c.ordinal()).expect("Failed to get decimal"),
460 ))),
461 #[cfg(all(
462 feature = "with-rust_decimal",
463 not(feature = "with-bigdecimal")
464 ))]
465 "DECIMAL" => Value::Decimal(Some(Box::new(
466 row.try_get(c.ordinal()).expect("Failed to get decimal"),
467 ))),
468
469 #[cfg(feature = "with-json")]
470 "JSON" => Value::Json(Some(Box::new(
471 row.try_get(c.ordinal()).expect("Failed to get json"),
472 ))),
473
474 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
475 },
476 )
477 })
478 .collect(),
479 }
480}