1use futures::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow},
8 pool::PoolConnection,
9 Connection, Executor, MySql, MySqlPool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::instrument;
14
15use crate::{
16 debug_print, error::*, executor::*, AccessMode, ConnectOptions, DatabaseConnection,
17 DatabaseTransaction, DbBackend, IsolationLevel, QueryStream, Statement, TransactionError,
18};
19
20use super::sqlx_common::*;
21
22#[derive(Debug)]
24pub struct SqlxMySqlConnector;
25
26#[derive(Clone)]
28pub struct SqlxMySqlPoolConnection {
29 pub(crate) pool: MySqlPool,
30 metric_callback: Option<crate::metric::Callback>,
31}
32
33impl std::fmt::Debug for SqlxMySqlPoolConnection {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "SqlxMySqlPoolConnection {{ pool: {:?} }}", self.pool)
36 }
37}
38
39impl SqlxMySqlConnector {
40 pub fn accepts(string: &str) -> bool {
42 string.starts_with("mysql://") && string.parse::<MySqlConnectOptions>().is_ok()
43 }
44
45 #[instrument(level = "trace")]
47 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
48 let mut opt = options
49 .url
50 .parse::<MySqlConnectOptions>()
51 .map_err(sqlx_error_to_conn_err)?;
52 use sqlx::ConnectOptions;
53 if !options.sqlx_logging {
54 opt = opt.disable_statement_logging();
55 } else {
56 opt = opt.log_statements(options.sqlx_logging_level);
57 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
58 opt = opt.log_slow_statements(
59 options.sqlx_slow_statements_logging_level,
60 options.sqlx_slow_statements_logging_threshold,
61 );
62 }
63 }
64 let pool = if options.connect_lazy {
65 options.sqlx_pool_options().connect_lazy_with(opt)
66 } else {
67 options
68 .sqlx_pool_options()
69 .connect_with(opt)
70 .await
71 .map_err(sqlx_error_to_conn_err)?
72 };
73 Ok(DatabaseConnection::SqlxMySqlPoolConnection(
74 SqlxMySqlPoolConnection {
75 pool,
76 metric_callback: None,
77 },
78 ))
79 }
80}
81
82impl SqlxMySqlConnector {
83 pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
85 DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection {
86 pool,
87 metric_callback: None,
88 })
89 }
90}
91
92impl SqlxMySqlPoolConnection {
93 #[instrument(level = "trace")]
95 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
96 debug_print!("{}", stmt);
97
98 let query = sqlx_query(&stmt);
99 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
100 crate::metric::metric!(self.metric_callback, &stmt, {
101 match query.execute(&mut *conn).await {
102 Ok(res) => Ok(res.into()),
103 Err(err) => Err(sqlx_error_to_exec_err(err)),
104 }
105 })
106 }
107
108 #[instrument(level = "trace")]
110 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
111 debug_print!("{}", sql);
112
113 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
114 match conn.execute(sql).await {
115 Ok(res) => Ok(res.into()),
116 Err(err) => Err(sqlx_error_to_exec_err(err)),
117 }
118 }
119
120 #[instrument(level = "trace")]
122 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
123 debug_print!("{}", stmt);
124
125 let query = sqlx_query(&stmt);
126 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
127 crate::metric::metric!(self.metric_callback, &stmt, {
128 match query.fetch_one(&mut *conn).await {
129 Ok(row) => Ok(Some(row.into())),
130 Err(err) => match err {
131 sqlx::Error::RowNotFound => Ok(None),
132 _ => Err(sqlx_error_to_query_err(err)),
133 },
134 }
135 })
136 }
137
138 #[instrument(level = "trace")]
140 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
141 debug_print!("{}", stmt);
142
143 let query = sqlx_query(&stmt);
144 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
145 crate::metric::metric!(self.metric_callback, &stmt, {
146 match query.fetch_all(&mut *conn).await {
147 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
148 Err(err) => Err(sqlx_error_to_query_err(err)),
149 }
150 })
151 }
152
153 #[instrument(level = "trace")]
155 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
156 debug_print!("{}", stmt);
157
158 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
159 Ok(QueryStream::from((
160 conn,
161 stmt,
162 self.metric_callback.clone(),
163 )))
164 }
165
166 #[instrument(level = "trace")]
168 pub async fn begin(
169 &self,
170 isolation_level: Option<IsolationLevel>,
171 access_mode: Option<AccessMode>,
172 ) -> Result<DatabaseTransaction, DbErr> {
173 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174 DatabaseTransaction::new_mysql(
175 conn,
176 self.metric_callback.clone(),
177 isolation_level,
178 access_mode,
179 )
180 .await
181 }
182
183 #[instrument(level = "trace", skip(callback))]
185 pub async fn transaction<F, T, E>(
186 &self,
187 callback: F,
188 isolation_level: Option<IsolationLevel>,
189 access_mode: Option<AccessMode>,
190 ) -> Result<T, TransactionError<E>>
191 where
192 F: for<'b> FnOnce(
193 &'b DatabaseTransaction,
194 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
195 + Send,
196 T: Send,
197 E: std::error::Error + Send,
198 {
199 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
200 let transaction = DatabaseTransaction::new_mysql(
201 conn,
202 self.metric_callback.clone(),
203 isolation_level,
204 access_mode,
205 )
206 .await
207 .map_err(|e| TransactionError::Connection(e))?;
208 transaction.run(callback).await
209 }
210
211 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
212 where
213 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
214 {
215 self.metric_callback = Some(Arc::new(callback));
216 }
217
218 pub async fn ping(&self) -> Result<(), DbErr> {
220 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
221 match conn.ping().await {
222 Ok(_) => Ok(()),
223 Err(err) => Err(sqlx_error_to_conn_err(err)),
224 }
225 }
226
227 pub async fn close(self) -> Result<(), DbErr> {
229 self.pool.close().await;
230 Ok(())
231 }
232}
233
234impl From<MySqlRow> for QueryResult {
235 fn from(row: MySqlRow) -> QueryResult {
236 QueryResult {
237 row: QueryResultRow::SqlxMySql(row),
238 }
239 }
240}
241
242impl From<MySqlQueryResult> for ExecResult {
243 fn from(result: MySqlQueryResult) -> ExecResult {
244 ExecResult {
245 result: ExecResultHolder::SqlxMySql(result),
246 }
247 }
248}
249
250pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, SqlxValues> {
251 let values = stmt
252 .values
253 .as_ref()
254 .map_or(Values(Vec::new()), |values| values.clone());
255 sqlx::query_with(&stmt.sql, SqlxValues(values))
256}
257
258pub(crate) async fn set_transaction_config(
259 conn: &mut PoolConnection<MySql>,
260 isolation_level: Option<IsolationLevel>,
261 access_mode: Option<AccessMode>,
262) -> Result<(), DbErr> {
263 if let Some(isolation_level) = isolation_level {
264 let stmt = Statement {
265 sql: format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}"),
266 values: None,
267 db_backend: DbBackend::MySql,
268 };
269 let query = sqlx_query(&stmt);
270 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
271 }
272 if let Some(access_mode) = access_mode {
273 let stmt = Statement {
274 sql: format!("SET TRANSACTION {access_mode}"),
275 values: None,
276 db_backend: DbBackend::MySql,
277 };
278 let query = sqlx_query(&stmt);
279 conn.execute(query).await.map_err(sqlx_error_to_exec_err)?;
280 }
281 Ok(())
282}
283
284impl
285 From<(
286 PoolConnection<sqlx::MySql>,
287 Statement,
288 Option<crate::metric::Callback>,
289 )> for crate::QueryStream
290{
291 fn from(
292 (conn, stmt, metric_callback): (
293 PoolConnection<sqlx::MySql>,
294 Statement,
295 Option<crate::metric::Callback>,
296 ),
297 ) -> Self {
298 crate::QueryStream::build(stmt, crate::InnerConnection::MySql(conn), metric_callback)
299 }
300}
301
302impl crate::DatabaseTransaction {
303 pub(crate) async fn new_mysql(
304 inner: PoolConnection<sqlx::MySql>,
305 metric_callback: Option<crate::metric::Callback>,
306 isolation_level: Option<IsolationLevel>,
307 access_mode: Option<AccessMode>,
308 ) -> Result<crate::DatabaseTransaction, DbErr> {
309 Self::begin(
310 Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))),
311 crate::DbBackend::MySql,
312 metric_callback,
313 isolation_level,
314 access_mode,
315 )
316 .await
317 }
318}
319
320#[cfg(feature = "proxy")]
321pub(crate) fn from_sqlx_mysql_row_to_proxy_row(row: &sqlx::mysql::MySqlRow) -> crate::ProxyRow {
322 use sea_query::Value;
325 use sqlx::{Column, Row, TypeInfo};
326 crate::ProxyRow {
327 values: row
328 .columns()
329 .iter()
330 .map(|c| {
331 (
332 c.name().to_string(),
333 match c.type_info().name() {
334 "TINYINT(1)" | "BOOLEAN" => Value::Bool(Some(
335 row.try_get(c.ordinal()).expect("Failed to get boolean"),
336 )),
337 "TINYINT UNSIGNED" => Value::TinyUnsigned(Some(
338 row.try_get(c.ordinal())
339 .expect("Failed to get unsigned tiny integer"),
340 )),
341 "SMALLINT UNSIGNED" => Value::SmallUnsigned(Some(
342 row.try_get(c.ordinal())
343 .expect("Failed to get unsigned small integer"),
344 )),
345 "INT UNSIGNED" => Value::Unsigned(Some(
346 row.try_get(c.ordinal())
347 .expect("Failed to get unsigned integer"),
348 )),
349 "MEDIUMINT UNSIGNED" | "BIGINT UNSIGNED" => Value::BigUnsigned(Some(
350 row.try_get(c.ordinal())
351 .expect("Failed to get unsigned big integer"),
352 )),
353 "TINYINT" => Value::TinyInt(Some(
354 row.try_get(c.ordinal())
355 .expect("Failed to get tiny integer"),
356 )),
357 "SMALLINT" => Value::SmallInt(Some(
358 row.try_get(c.ordinal())
359 .expect("Failed to get small integer"),
360 )),
361 "INT" => Value::Int(Some(
362 row.try_get(c.ordinal()).expect("Failed to get integer"),
363 )),
364 "MEDIUMINT" | "BIGINT" => Value::BigInt(Some(
365 row.try_get(c.ordinal()).expect("Failed to get big integer"),
366 )),
367 "FLOAT" => Value::Float(Some(
368 row.try_get(c.ordinal()).expect("Failed to get float"),
369 )),
370 "DOUBLE" => Value::Double(Some(
371 row.try_get(c.ordinal()).expect("Failed to get double"),
372 )),
373
374 "BIT" | "BINARY" | "VARBINARY" | "TINYBLOB" | "BLOB" | "MEDIUMBLOB"
375 | "LONGBLOB" => Value::Bytes(Some(Box::new(
376 row.try_get(c.ordinal()).expect("Failed to get bytes"),
377 ))),
378
379 "CHAR" | "VARCHAR" | "TINYTEXT" | "TEXT" | "MEDIUMTEXT" | "LONGTEXT" => {
380 Value::String(Some(Box::new(
381 row.try_get(c.ordinal()).expect("Failed to get string"),
382 )))
383 }
384
385 #[cfg(feature = "with-chrono")]
386 "TIMESTAMP" => Value::ChronoDateTimeUtc(Some(Box::new(
387 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
388 ))),
389 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
390 "TIMESTAMP" => Value::TimeDateTime(Some(Box::new(
391 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
392 ))),
393
394 #[cfg(feature = "with-chrono")]
395 "DATE" => Value::ChronoDate(Some(Box::new(
396 row.try_get(c.ordinal()).expect("Failed to get date"),
397 ))),
398 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
399 "DATE" => Value::TimeDate(Some(Box::new(
400 row.try_get(c.ordinal()).expect("Failed to get date"),
401 ))),
402
403 #[cfg(feature = "with-chrono")]
404 "TIME" => Value::ChronoTime(Some(Box::new(
405 row.try_get(c.ordinal()).expect("Failed to get time"),
406 ))),
407 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
408 "TIME" => Value::TimeTime(Some(Box::new(
409 row.try_get(c.ordinal()).expect("Failed to get time"),
410 ))),
411
412 #[cfg(feature = "with-chrono")]
413 "DATETIME" => Value::ChronoDateTime(Some(Box::new(
414 row.try_get(c.ordinal()).expect("Failed to get datetime"),
415 ))),
416 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
417 "DATETIME" => Value::TimeDateTime(Some(Box::new(
418 row.try_get(c.ordinal()).expect("Failed to get datetime"),
419 ))),
420
421 #[cfg(feature = "with-chrono")]
422 "YEAR" => Value::ChronoDate(Some(Box::new(
423 row.try_get(c.ordinal()).expect("Failed to get year"),
424 ))),
425 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
426 "YEAR" => Value::TimeDate(Some(Box::new(
427 row.try_get(c.ordinal()).expect("Failed to get year"),
428 ))),
429
430 "ENUM" | "SET" | "GEOMETRY" => Value::String(Some(Box::new(
431 row.try_get(c.ordinal())
432 .expect("Failed to get serialized string"),
433 ))),
434
435 #[cfg(feature = "with-bigdecimal")]
436 "DECIMAL" => Value::BigDecimal(Some(Box::new(
437 row.try_get(c.ordinal()).expect("Failed to get decimal"),
438 ))),
439 #[cfg(all(
440 feature = "with-rust_decimal",
441 not(feature = "with-bigdecimal")
442 ))]
443 "DECIMAL" => Value::Decimal(Some(Box::new(
444 row.try_get(c.ordinal()).expect("Failed to get decimal"),
445 ))),
446
447 #[cfg(feature = "with-json")]
448 "JSON" => Value::Json(Some(Box::new(
449 row.try_get(c.ordinal()).expect("Failed to get json"),
450 ))),
451
452 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
453 },
454 )
455 })
456 .collect(),
457 }
458}