1mod conversion;
2mod error;
3
4use crate::{
5 ast::{Query, Value},
6 connector::{metrics, queryable::*, ResultSet},
7 error::{Error, ErrorKind},
8 visitor::{self, Visitor},
9};
10use async_trait::async_trait;
11use lru_cache::LruCache;
12use mysql_async::{
13 self as my,
14 prelude::{Query as _, Queryable as _},
15};
16use percent_encoding::percent_decode;
17use std::{
18 borrow::Cow,
19 future::Future,
20 path::{Path, PathBuf},
21 sync::atomic::{AtomicBool, Ordering},
22 time::Duration,
23};
24use tokio::sync::Mutex;
25use url::Url;
26
27#[cfg(feature = "expose-drivers")]
30pub use mysql_async;
31
32use super::IsolationLevel;
33
34#[derive(Debug)]
36#[cfg_attr(feature = "docs", doc(cfg(feature = "mysql")))]
37pub struct Mysql {
38 pub(crate) conn: Mutex<my::Conn>,
39 pub(crate) url: MysqlUrl,
40 socket_timeout: Option<Duration>,
41 is_healthy: AtomicBool,
42 statement_cache: Mutex<LruCache<String, my::Statement>>,
43}
44
45#[derive(Debug, Clone)]
47#[cfg_attr(feature = "docs", doc(cfg(feature = "mysql")))]
48pub struct MysqlUrl {
49 url: Url,
50 query_params: MysqlUrlQueryParams,
51}
52
53impl MysqlUrl {
54 pub fn new(url: Url) -> Result<Self, Error> {
57 let query_params = Self::parse_query_params(&url)?;
58
59 Ok(Self { url, query_params })
60 }
61
62 pub fn url(&self) -> &Url {
64 &self.url
65 }
66
67 pub fn username(&self) -> Cow<str> {
69 match percent_decode(self.url.username().as_bytes()).decode_utf8() {
70 Ok(username) => username,
71 Err(_) => {
72 tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version.");
73
74 self.url.username().into()
75 }
76 }
77 }
78
79 pub fn password(&self) -> Option<Cow<str>> {
81 match self.url.password().and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) {
82 Some(password) => Some(password),
83 None => self.url.password().map(|s| s.into()),
84 }
85 }
86
87 pub fn dbname(&self) -> &str {
89 match self.url.path_segments() {
90 Some(mut segments) => segments.next().unwrap_or("mysql"),
91 None => "mysql",
92 }
93 }
94
95 pub fn host(&self) -> &str {
97 self.url.host_str().unwrap_or("localhost")
98 }
99
100 pub fn socket(&self) -> &Option<String> {
102 &self.query_params.socket
103 }
104
105 pub fn port(&self) -> u16 {
107 self.url.port().unwrap_or(3306)
108 }
109
110 pub fn connect_timeout(&self) -> Option<Duration> {
112 self.query_params.connect_timeout
113 }
114
115 pub fn pool_timeout(&self) -> Option<Duration> {
117 self.query_params.pool_timeout
118 }
119
120 pub fn socket_timeout(&self) -> Option<Duration> {
122 self.query_params.socket_timeout
123 }
124
125 pub fn prefer_socket(&self) -> Option<bool> {
127 self.query_params.prefer_socket
128 }
129
130 pub fn max_connection_lifetime(&self) -> Option<Duration> {
132 self.query_params.max_connection_lifetime
133 }
134
135 pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
137 self.query_params.max_idle_connection_lifetime
138 }
139
140 fn statement_cache_size(&self) -> usize {
141 self.query_params.statement_cache_size
142 }
143
144 pub(crate) fn cache(&self) -> LruCache<String, my::Statement> {
145 LruCache::new(self.query_params.statement_cache_size)
146 }
147
148 fn parse_query_params(url: &Url) -> Result<MysqlUrlQueryParams, Error> {
149 let mut ssl_opts = my::SslOpts::default();
150 ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true);
151
152 let mut connection_limit = None;
153 let mut use_ssl = false;
154 let mut socket = None;
155 let mut socket_timeout = None;
156 let mut connect_timeout = Some(Duration::from_secs(5));
157 let mut pool_timeout = Some(Duration::from_secs(10));
158 let mut max_connection_lifetime = None;
159 let mut max_idle_connection_lifetime = Some(Duration::from_secs(300));
160 let mut prefer_socket = None;
161 let mut statement_cache_size = 100;
162 let mut identity: Option<(Option<PathBuf>, Option<String>)> = None;
163
164 for (k, v) in url.query_pairs() {
165 match k.as_ref() {
166 "connection_limit" => {
167 let as_int: usize =
168 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
169
170 connection_limit = Some(as_int);
171 }
172 "statement_cache_size" => {
173 statement_cache_size =
174 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
175 }
176 "sslcert" => {
177 use_ssl = true;
178 ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf()));
179 }
180 "sslidentity" => {
181 use_ssl = true;
182
183 identity = match identity {
184 Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)),
185 None => Some((Some(Path::new(&*v).to_path_buf()), None)),
186 };
187 }
188 "sslpassword" => {
189 use_ssl = true;
190
191 identity = match identity {
192 Some((path, _)) => Some((path, Some(v.to_string()))),
193 None => Some((None, Some(v.to_string()))),
194 };
195 }
196 "socket" => {
197 socket = Some(v.replace(['(', ')'], ""));
198 }
199 "socket_timeout" => {
200 let as_int =
201 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
202 socket_timeout = Some(Duration::from_secs(as_int));
203 }
204 "prefer_socket" => {
205 let as_bool =
206 v.parse::<bool>().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
207 prefer_socket = Some(as_bool)
208 }
209 "connect_timeout" => {
210 let as_int =
211 v.parse::<u64>().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
212
213 connect_timeout = match as_int {
214 0 => None,
215 _ => Some(Duration::from_secs(as_int)),
216 };
217 }
218 "pool_timeout" => {
219 let as_int =
220 v.parse::<u64>().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
221
222 pool_timeout = match as_int {
223 0 => None,
224 _ => Some(Duration::from_secs(as_int)),
225 };
226 }
227 "sslaccept" => {
228 use_ssl = true;
229 match v.as_ref() {
230 "strict" => {
231 ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false);
232 }
233 "accept_invalid_certs" => {}
234 _ => {
235 tracing::debug!(
236 message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`",
237 mode = &*v
238 );
239 }
240 };
241 }
242 "max_connection_lifetime" => {
243 let as_int =
244 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
245
246 if as_int == 0 {
247 max_connection_lifetime = None;
248 } else {
249 max_connection_lifetime = Some(Duration::from_secs(as_int));
250 }
251 }
252 "max_idle_connection_lifetime" => {
253 let as_int =
254 v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
255
256 if as_int == 0 {
257 max_idle_connection_lifetime = None;
258 } else {
259 max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
260 }
261 }
262 _ => {
263 tracing::trace!(message = "Discarding connection string param", param = &*k);
264 }
265 };
266 }
267
268 ssl_opts = match identity {
269 Some((Some(path), Some(pw))) => {
270 let identity = mysql_async::ClientIdentity::new(path).with_password(pw);
271 ssl_opts.with_client_identity(Some(identity))
272 }
273 Some((Some(path), None)) => {
274 let identity = mysql_async::ClientIdentity::new(path);
275 ssl_opts.with_client_identity(Some(identity))
276 }
277 _ => ssl_opts,
278 };
279
280 Ok(MysqlUrlQueryParams {
281 ssl_opts,
282 connection_limit,
283 use_ssl,
284 socket,
285 socket_timeout,
286 connect_timeout,
287 pool_timeout,
288 max_connection_lifetime,
289 max_idle_connection_lifetime,
290 prefer_socket,
291 statement_cache_size,
292 })
293 }
294
295 #[cfg(feature = "pooled")]
296 pub(crate) fn connection_limit(&self) -> Option<usize> {
297 self.query_params.connection_limit
298 }
299
300 pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder {
301 let mut config = my::OptsBuilder::default()
302 .stmt_cache_size(Some(0))
303 .user(Some(self.username()))
304 .pass(self.password())
305 .db_name(Some(self.dbname()));
306
307 match self.socket() {
308 Some(ref socket) => {
309 config = config.socket(Some(socket));
310 }
311 None => {
312 config = config.ip_or_hostname(self.host()).tcp_port(self.port());
313 }
314 }
315
316 config = config.conn_ttl(Some(Duration::from_secs(5)));
317
318 if self.query_params.use_ssl {
319 config = config.ssl_opts(Some(self.query_params.ssl_opts.clone()));
320 }
321
322 if self.query_params.prefer_socket.is_some() {
323 config = config.prefer_socket(self.query_params.prefer_socket);
324 }
325
326 config
327 }
328}
329
330#[derive(Debug, Clone)]
331pub(crate) struct MysqlUrlQueryParams {
332 ssl_opts: my::SslOpts,
333 connection_limit: Option<usize>,
334 use_ssl: bool,
335 socket: Option<String>,
336 socket_timeout: Option<Duration>,
337 connect_timeout: Option<Duration>,
338 pool_timeout: Option<Duration>,
339 max_connection_lifetime: Option<Duration>,
340 max_idle_connection_lifetime: Option<Duration>,
341 prefer_socket: Option<bool>,
342 statement_cache_size: usize,
343}
344
345impl Mysql {
346 pub async fn new(url: MysqlUrl) -> crate::Result<Self> {
348 let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?;
349
350 Ok(Self {
351 socket_timeout: url.query_params.socket_timeout,
352 conn: Mutex::new(conn),
353 statement_cache: Mutex::new(url.cache()),
354 url,
355 is_healthy: AtomicBool::new(true),
356 })
357 }
358
359 #[cfg(feature = "expose-drivers")]
363 pub fn conn(&self) -> &Mutex<mysql_async::Conn> {
364 &self.conn
365 }
366
367 async fn perform_io<F, U, T>(&self, op: U) -> crate::Result<T>
368 where
369 F: Future<Output = crate::Result<T>>,
370 U: FnOnce() -> F,
371 {
372 match super::timeout::socket(self.socket_timeout, op()).await {
373 Err(e) if e.is_closed() => {
374 self.is_healthy.store(false, Ordering::SeqCst);
375 Err(e)
376 }
377 res => Ok(res?),
378 }
379 }
380
381 async fn prepared<F, U, T>(&self, sql: &str, op: U) -> crate::Result<T>
382 where
383 F: Future<Output = crate::Result<T>>,
384 U: Fn(my::Statement) -> F,
385 {
386 if self.url.statement_cache_size() == 0 {
387 self.perform_io(|| async move {
388 let stmt = {
389 let mut conn = self.conn.lock().await;
390 conn.prep(sql).await?
391 };
392
393 let res = op(stmt.clone()).await;
394
395 {
396 let mut conn = self.conn.lock().await;
397 conn.close(stmt).await?;
398 }
399
400 res
401 })
402 .await
403 } else {
404 self.perform_io(|| async move {
405 let stmt = self.fetch_cached(sql).await?;
406 op(stmt).await
407 })
408 .await
409 }
410 }
411
412 async fn fetch_cached(&self, sql: &str) -> crate::Result<my::Statement> {
413 let mut cache = self.statement_cache.lock().await;
414 let capacity = cache.capacity();
415 let stored = cache.len();
416
417 match cache.get_mut(sql) {
418 Some(stmt) => {
419 tracing::trace!(message = "CACHE HIT!", query = sql, capacity = capacity, stored = stored,);
420
421 Ok(stmt.clone()) }
423 None => {
424 tracing::trace!(message = "CACHE MISS!", query = sql, capacity = capacity, stored = stored,);
425
426 let mut conn = self.conn.lock().await;
427 if cache.capacity() == cache.len() {
428 if let Some((_, stmt)) = cache.remove_lru() {
429 conn.close(stmt).await?;
430 }
431 }
432
433 let stmt = conn.prep(sql).await?;
434 cache.insert(sql.to_string(), stmt.clone());
435
436 Ok(stmt)
437 }
438 }
439 }
440}
441
442impl TransactionCapable for Mysql {}
443
444#[async_trait]
445impl Queryable for Mysql {
446 async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
447 let (sql, params) = visitor::Mysql::build(q)?;
448 self.query_raw(&sql, ¶ms).await
449 }
450
451 async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
452 metrics::query("mysql.query_raw", sql, params, move || async move {
453 self.prepared(sql, |stmt| async move {
454 let mut conn = self.conn.lock().await;
455 let rows: Vec<my::Row> = conn.exec(&stmt, conversion::conv_params(params)?).await?;
456 let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect();
457
458 let last_id = conn.last_insert_id();
459 let mut result_set = ResultSet::new(columns, Vec::new());
460
461 for mut row in rows {
462 result_set.rows.push(row.take_result_row()?);
463 }
464
465 if let Some(id) = last_id {
466 result_set.set_last_insert_id(id);
467 };
468
469 Ok(result_set)
470 })
471 .await
472 })
473 .await
474 }
475
476 async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
477 self.query_raw(sql, params).await
478 }
479
480 async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
481 let (sql, params) = visitor::Mysql::build(q)?;
482 self.execute_raw(&sql, ¶ms).await
483 }
484
485 async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
486 metrics::query("mysql.execute_raw", sql, params, move || async move {
487 self.prepared(sql, |stmt| async move {
488 let mut conn = self.conn.lock().await;
489 conn.exec_drop(stmt, conversion::conv_params(params)?).await?;
490
491 Ok(conn.affected_rows())
492 })
493 .await
494 })
495 .await
496 }
497
498 async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
499 self.execute_raw(sql, params).await
500 }
501
502 async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
503 metrics::query("mysql.raw_cmd", cmd, &[], move || async move {
504 self.perform_io(|| async move {
505 let mut conn = self.conn.lock().await;
506 let mut result = cmd.run(&mut *conn).await?;
507
508 loop {
509 result.map(drop).await?;
510
511 if result.is_empty() {
512 result.map(drop).await?;
513 break;
514 }
515 }
516
517 Ok(())
518 })
519 .await
520 })
521 .await
522 }
523
524 async fn version(&self) -> crate::Result<Option<String>> {
525 let query = r#"SELECT @@GLOBAL.version version"#;
526 let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?;
527
528 let version_string = rows.get(0).and_then(|row| row.get("version").and_then(|version| version.to_string()));
529
530 Ok(version_string)
531 }
532
533 fn is_healthy(&self) -> bool {
534 self.is_healthy.load(Ordering::SeqCst)
535 }
536
537 async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
538 if matches!(isolation_level, IsolationLevel::Snapshot) {
539 return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
540 }
541
542 self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")).await?;
543
544 Ok(())
545 }
546
547 fn requires_isolation_first(&self) -> bool {
548 true
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::MysqlUrl;
555 use crate::tests::test_api::mysql::CONN_STR;
556 use crate::{error::*, single::Sqlint};
557 use url::Url;
558
559 #[test]
560 fn should_parse_socket_url() {
561 let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap();
562 assert_eq!("dbname", url.dbname());
563 assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket());
564 }
565
566 #[test]
567 fn should_parse_prefer_socket() {
568 let url =
569 MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap();
570 assert_eq!(false, url.prefer_socket().unwrap());
571 }
572
573 #[test]
574 fn should_parse_sslaccept() {
575 let url =
576 MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap();
577 assert_eq!(true, url.query_params.use_ssl);
578 assert_eq!(false, url.query_params.ssl_opts.skip_domain_validation());
579 assert_eq!(false, url.query_params.ssl_opts.accept_invalid_certs());
580 }
581
582 #[test]
583 fn should_allow_changing_of_cache_size() {
584 let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap())
585 .unwrap();
586 assert_eq!(420, url.cache().capacity());
587 }
588
589 #[test]
590 fn should_have_default_cache_size() {
591 let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap();
592 assert_eq!(100, url.cache().capacity());
593 }
594
595 #[tokio::test]
596 async fn should_map_nonexisting_database_error() {
597 let mut url = Url::parse(&CONN_STR).unwrap();
598 url.set_username("root").unwrap();
599 url.set_path("/this_does_not_exist");
600
601 let url = url.as_str().to_string();
602 let res = Sqlint::new(&url).await;
603
604 let err = res.unwrap_err();
605
606 match err.kind() {
607 ErrorKind::DatabaseDoesNotExist { db_name } => {
608 assert_eq!(Some("1049"), err.original_code());
609 assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message());
610 assert_eq!(&Name::available("this_does_not_exist"), db_name)
611 }
612 e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e),
613 }
614 }
615
616 #[tokio::test]
617 async fn should_map_wrong_credentials_error() {
618 let mut url = Url::parse(&CONN_STR).unwrap();
619 url.set_username("WRONG").unwrap();
620
621 let res = Sqlint::new(url.as_str()).await;
622 assert!(res.is_err());
623
624 let err = res.unwrap_err();
625 assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
626 }
627}