sqlint/connector/
mysql.rs

1mod conversion;
2mod error;
3
4use crate::{
5    ast::{Query, Value},
6    connector::{metrics, queryable::*, ResultSet},
7    error::{Error, ErrorKind},
8    visitor::{self, Visitor},
9};
10use async_trait::async_trait;
11use lru_cache::LruCache;
12use mysql_async::{
13    self as my,
14    prelude::{Query as _, Queryable as _},
15};
16use percent_encoding::percent_decode;
17use std::{
18    borrow::Cow,
19    future::Future,
20    path::{Path, PathBuf},
21    sync::atomic::{AtomicBool, Ordering},
22    time::Duration,
23};
24use tokio::sync::Mutex;
25use url::Url;
26
27/// The underlying MySQL driver. Only available with the `expose-drivers`
28/// Cargo feature.
29#[cfg(feature = "expose-drivers")]
30pub use mysql_async;
31
32use super::IsolationLevel;
33
34/// A connector interface for the MySQL database.
35#[derive(Debug)]
36#[cfg_attr(feature = "docs", doc(cfg(feature = "mysql")))]
37pub struct Mysql {
38    pub(crate) conn: Mutex<my::Conn>,
39    pub(crate) url: MysqlUrl,
40    socket_timeout: Option<Duration>,
41    is_healthy: AtomicBool,
42    statement_cache: Mutex<LruCache<String, my::Statement>>,
43}
44
45/// Wraps a connection url and exposes the parsing logic used by sqlint, including default values.
46#[derive(Debug, Clone)]
47#[cfg_attr(feature = "docs", doc(cfg(feature = "mysql")))]
48pub struct MysqlUrl {
49    url: Url,
50    query_params: MysqlUrlQueryParams,
51}
52
53impl MysqlUrl {
54    /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection
55    /// parameters.
56    pub fn new(url: Url) -> Result<Self, Error> {
57        let query_params = Self::parse_query_params(&url)?;
58
59        Ok(Self { url, query_params })
60    }
61
62    /// The bare `Url` to the database.
63    pub fn url(&self) -> &Url {
64        &self.url
65    }
66
67    /// The percent-decoded database username.
68    pub fn username(&self) -> Cow<str> {
69        match percent_decode(self.url.username().as_bytes()).decode_utf8() {
70            Ok(username) => username,
71            Err(_) => {
72                tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version.");
73
74                self.url.username().into()
75            }
76        }
77    }
78
79    /// The percent-decoded database password.
80    pub fn password(&self) -> Option<Cow<str>> {
81        match self.url.password().and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) {
82            Some(password) => Some(password),
83            None => self.url.password().map(|s| s.into()),
84        }
85    }
86
87    /// Name of the database connected. Defaults to `mysql`.
88    pub fn dbname(&self) -> &str {
89        match self.url.path_segments() {
90            Some(mut segments) => segments.next().unwrap_or("mysql"),
91            None => "mysql",
92        }
93    }
94
95    /// The database host. If `socket` and `host` are not set, defaults to `localhost`.
96    pub fn host(&self) -> &str {
97        self.url.host_str().unwrap_or("localhost")
98    }
99
100    /// If set, connected to the database through a Unix socket.
101    pub fn socket(&self) -> &Option<String> {
102        &self.query_params.socket
103    }
104
105    /// The database port, defaults to `3306`.
106    pub fn port(&self) -> u16 {
107        self.url.port().unwrap_or(3306)
108    }
109
110    /// The connection timeout.
111    pub fn connect_timeout(&self) -> Option<Duration> {
112        self.query_params.connect_timeout
113    }
114
115    /// The pool check_out timeout
116    pub fn pool_timeout(&self) -> Option<Duration> {
117        self.query_params.pool_timeout
118    }
119
120    /// The socket timeout
121    pub fn socket_timeout(&self) -> Option<Duration> {
122        self.query_params.socket_timeout
123    }
124
125    /// Prefer socket connection
126    pub fn prefer_socket(&self) -> Option<bool> {
127        self.query_params.prefer_socket
128    }
129
130    /// The maximum connection lifetime
131    pub fn max_connection_lifetime(&self) -> Option<Duration> {
132        self.query_params.max_connection_lifetime
133    }
134
135    /// The maximum idle connection lifetime
136    pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
137        self.query_params.max_idle_connection_lifetime
138    }
139
140    fn statement_cache_size(&self) -> usize {
141        self.query_params.statement_cache_size
142    }
143
144    pub(crate) fn cache(&self) -> LruCache<String, my::Statement> {
145        LruCache::new(self.query_params.statement_cache_size)
146    }
147
148    fn parse_query_params(url: &Url) -> Result<MysqlUrlQueryParams, Error> {
149        let mut ssl_opts = my::SslOpts::default();
150        ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true);
151
152        let mut connection_limit = None;
153        let mut use_ssl = false;
154        let mut socket = None;
155        let mut socket_timeout = None;
156        let mut connect_timeout = Some(Duration::from_secs(5));
157        let mut pool_timeout = Some(Duration::from_secs(10));
158        let mut max_connection_lifetime = None;
159        let mut max_idle_connection_lifetime = Some(Duration::from_secs(300));
160        let mut prefer_socket = None;
161        let mut statement_cache_size = 100;
162        let mut identity: Option<(Option<PathBuf>, Option<String>)> = None;
163
164        for (k, v) in url.query_pairs() {
165            match k.as_ref() {
166                "connection_limit" => {
167                    let as_int: usize =
168                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
169
170                    connection_limit = Some(as_int);
171                }
172                "statement_cache_size" => {
173                    statement_cache_size =
174                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
175                }
176                "sslcert" => {
177                    use_ssl = true;
178                    ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf()));
179                }
180                "sslidentity" => {
181                    use_ssl = true;
182
183                    identity = match identity {
184                        Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)),
185                        None => Some((Some(Path::new(&*v).to_path_buf()), None)),
186                    };
187                }
188                "sslpassword" => {
189                    use_ssl = true;
190
191                    identity = match identity {
192                        Some((path, _)) => Some((path, Some(v.to_string()))),
193                        None => Some((None, Some(v.to_string()))),
194                    };
195                }
196                "socket" => {
197                    socket = Some(v.replace(['(', ')'], ""));
198                }
199                "socket_timeout" => {
200                    let as_int =
201                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
202                    socket_timeout = Some(Duration::from_secs(as_int));
203                }
204                "prefer_socket" => {
205                    let as_bool =
206                        v.parse::<bool>().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
207                    prefer_socket = Some(as_bool)
208                }
209                "connect_timeout" => {
210                    let as_int =
211                        v.parse::<u64>().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
212
213                    connect_timeout = match as_int {
214                        0 => None,
215                        _ => Some(Duration::from_secs(as_int)),
216                    };
217                }
218                "pool_timeout" => {
219                    let as_int =
220                        v.parse::<u64>().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
221
222                    pool_timeout = match as_int {
223                        0 => None,
224                        _ => Some(Duration::from_secs(as_int)),
225                    };
226                }
227                "sslaccept" => {
228                    use_ssl = true;
229                    match v.as_ref() {
230                        "strict" => {
231                            ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false);
232                        }
233                        "accept_invalid_certs" => {}
234                        _ => {
235                            tracing::debug!(
236                                message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`",
237                                mode = &*v
238                            );
239                        }
240                    };
241                }
242                "max_connection_lifetime" => {
243                    let as_int =
244                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
245
246                    if as_int == 0 {
247                        max_connection_lifetime = None;
248                    } else {
249                        max_connection_lifetime = Some(Duration::from_secs(as_int));
250                    }
251                }
252                "max_idle_connection_lifetime" => {
253                    let as_int =
254                        v.parse().map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
255
256                    if as_int == 0 {
257                        max_idle_connection_lifetime = None;
258                    } else {
259                        max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
260                    }
261                }
262                _ => {
263                    tracing::trace!(message = "Discarding connection string param", param = &*k);
264                }
265            };
266        }
267
268        ssl_opts = match identity {
269            Some((Some(path), Some(pw))) => {
270                let identity = mysql_async::ClientIdentity::new(path).with_password(pw);
271                ssl_opts.with_client_identity(Some(identity))
272            }
273            Some((Some(path), None)) => {
274                let identity = mysql_async::ClientIdentity::new(path);
275                ssl_opts.with_client_identity(Some(identity))
276            }
277            _ => ssl_opts,
278        };
279
280        Ok(MysqlUrlQueryParams {
281            ssl_opts,
282            connection_limit,
283            use_ssl,
284            socket,
285            socket_timeout,
286            connect_timeout,
287            pool_timeout,
288            max_connection_lifetime,
289            max_idle_connection_lifetime,
290            prefer_socket,
291            statement_cache_size,
292        })
293    }
294
295    #[cfg(feature = "pooled")]
296    pub(crate) fn connection_limit(&self) -> Option<usize> {
297        self.query_params.connection_limit
298    }
299
300    pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder {
301        let mut config = my::OptsBuilder::default()
302            .stmt_cache_size(Some(0))
303            .user(Some(self.username()))
304            .pass(self.password())
305            .db_name(Some(self.dbname()));
306
307        match self.socket() {
308            Some(ref socket) => {
309                config = config.socket(Some(socket));
310            }
311            None => {
312                config = config.ip_or_hostname(self.host()).tcp_port(self.port());
313            }
314        }
315
316        config = config.conn_ttl(Some(Duration::from_secs(5)));
317
318        if self.query_params.use_ssl {
319            config = config.ssl_opts(Some(self.query_params.ssl_opts.clone()));
320        }
321
322        if self.query_params.prefer_socket.is_some() {
323            config = config.prefer_socket(self.query_params.prefer_socket);
324        }
325
326        config
327    }
328}
329
330#[derive(Debug, Clone)]
331pub(crate) struct MysqlUrlQueryParams {
332    ssl_opts: my::SslOpts,
333    connection_limit: Option<usize>,
334    use_ssl: bool,
335    socket: Option<String>,
336    socket_timeout: Option<Duration>,
337    connect_timeout: Option<Duration>,
338    pool_timeout: Option<Duration>,
339    max_connection_lifetime: Option<Duration>,
340    max_idle_connection_lifetime: Option<Duration>,
341    prefer_socket: Option<bool>,
342    statement_cache_size: usize,
343}
344
345impl Mysql {
346    /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate.
347    pub async fn new(url: MysqlUrl) -> crate::Result<Self> {
348        let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?;
349
350        Ok(Self {
351            socket_timeout: url.query_params.socket_timeout,
352            conn: Mutex::new(conn),
353            statement_cache: Mutex::new(url.cache()),
354            url,
355            is_healthy: AtomicBool::new(true),
356        })
357    }
358
359    /// The underlying mysql_async::Conn. Only available with the
360    /// `expose-drivers` Cargo feature. This is a lower level API when you need
361    /// to get into database specific features.
362    #[cfg(feature = "expose-drivers")]
363    pub fn conn(&self) -> &Mutex<mysql_async::Conn> {
364        &self.conn
365    }
366
367    async fn perform_io<F, U, T>(&self, op: U) -> crate::Result<T>
368    where
369        F: Future<Output = crate::Result<T>>,
370        U: FnOnce() -> F,
371    {
372        match super::timeout::socket(self.socket_timeout, op()).await {
373            Err(e) if e.is_closed() => {
374                self.is_healthy.store(false, Ordering::SeqCst);
375                Err(e)
376            }
377            res => Ok(res?),
378        }
379    }
380
381    async fn prepared<F, U, T>(&self, sql: &str, op: U) -> crate::Result<T>
382    where
383        F: Future<Output = crate::Result<T>>,
384        U: Fn(my::Statement) -> F,
385    {
386        if self.url.statement_cache_size() == 0 {
387            self.perform_io(|| async move {
388                let stmt = {
389                    let mut conn = self.conn.lock().await;
390                    conn.prep(sql).await?
391                };
392
393                let res = op(stmt.clone()).await;
394
395                {
396                    let mut conn = self.conn.lock().await;
397                    conn.close(stmt).await?;
398                }
399
400                res
401            })
402            .await
403        } else {
404            self.perform_io(|| async move {
405                let stmt = self.fetch_cached(sql).await?;
406                op(stmt).await
407            })
408            .await
409        }
410    }
411
412    async fn fetch_cached(&self, sql: &str) -> crate::Result<my::Statement> {
413        let mut cache = self.statement_cache.lock().await;
414        let capacity = cache.capacity();
415        let stored = cache.len();
416
417        match cache.get_mut(sql) {
418            Some(stmt) => {
419                tracing::trace!(message = "CACHE HIT!", query = sql, capacity = capacity, stored = stored,);
420
421                Ok(stmt.clone()) // arc'd
422            }
423            None => {
424                tracing::trace!(message = "CACHE MISS!", query = sql, capacity = capacity, stored = stored,);
425
426                let mut conn = self.conn.lock().await;
427                if cache.capacity() == cache.len() {
428                    if let Some((_, stmt)) = cache.remove_lru() {
429                        conn.close(stmt).await?;
430                    }
431                }
432
433                let stmt = conn.prep(sql).await?;
434                cache.insert(sql.to_string(), stmt.clone());
435
436                Ok(stmt)
437            }
438        }
439    }
440}
441
442impl TransactionCapable for Mysql {}
443
444#[async_trait]
445impl Queryable for Mysql {
446    async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
447        let (sql, params) = visitor::Mysql::build(q)?;
448        self.query_raw(&sql, &params).await
449    }
450
451    async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
452        metrics::query("mysql.query_raw", sql, params, move || async move {
453            self.prepared(sql, |stmt| async move {
454                let mut conn = self.conn.lock().await;
455                let rows: Vec<my::Row> = conn.exec(&stmt, conversion::conv_params(params)?).await?;
456                let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect();
457
458                let last_id = conn.last_insert_id();
459                let mut result_set = ResultSet::new(columns, Vec::new());
460
461                for mut row in rows {
462                    result_set.rows.push(row.take_result_row()?);
463                }
464
465                if let Some(id) = last_id {
466                    result_set.set_last_insert_id(id);
467                };
468
469                Ok(result_set)
470            })
471            .await
472        })
473        .await
474    }
475
476    async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
477        self.query_raw(sql, params).await
478    }
479
480    async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
481        let (sql, params) = visitor::Mysql::build(q)?;
482        self.execute_raw(&sql, &params).await
483    }
484
485    async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
486        metrics::query("mysql.execute_raw", sql, params, move || async move {
487            self.prepared(sql, |stmt| async move {
488                let mut conn = self.conn.lock().await;
489                conn.exec_drop(stmt, conversion::conv_params(params)?).await?;
490
491                Ok(conn.affected_rows())
492            })
493            .await
494        })
495        .await
496    }
497
498    async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
499        self.execute_raw(sql, params).await
500    }
501
502    async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
503        metrics::query("mysql.raw_cmd", cmd, &[], move || async move {
504            self.perform_io(|| async move {
505                let mut conn = self.conn.lock().await;
506                let mut result = cmd.run(&mut *conn).await?;
507
508                loop {
509                    result.map(drop).await?;
510
511                    if result.is_empty() {
512                        result.map(drop).await?;
513                        break;
514                    }
515                }
516
517                Ok(())
518            })
519            .await
520        })
521        .await
522    }
523
524    async fn version(&self) -> crate::Result<Option<String>> {
525        let query = r#"SELECT @@GLOBAL.version version"#;
526        let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?;
527
528        let version_string = rows.get(0).and_then(|row| row.get("version").and_then(|version| version.to_string()));
529
530        Ok(version_string)
531    }
532
533    fn is_healthy(&self) -> bool {
534        self.is_healthy.load(Ordering::SeqCst)
535    }
536
537    async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
538        if matches!(isolation_level, IsolationLevel::Snapshot) {
539            return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
540        }
541
542        self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")).await?;
543
544        Ok(())
545    }
546
547    fn requires_isolation_first(&self) -> bool {
548        true
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::MysqlUrl;
555    use crate::tests::test_api::mysql::CONN_STR;
556    use crate::{error::*, single::Sqlint};
557    use url::Url;
558
559    #[test]
560    fn should_parse_socket_url() {
561        let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap();
562        assert_eq!("dbname", url.dbname());
563        assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket());
564    }
565
566    #[test]
567    fn should_parse_prefer_socket() {
568        let url =
569            MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap();
570        assert_eq!(false, url.prefer_socket().unwrap());
571    }
572
573    #[test]
574    fn should_parse_sslaccept() {
575        let url =
576            MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap();
577        assert_eq!(true, url.query_params.use_ssl);
578        assert_eq!(false, url.query_params.ssl_opts.skip_domain_validation());
579        assert_eq!(false, url.query_params.ssl_opts.accept_invalid_certs());
580    }
581
582    #[test]
583    fn should_allow_changing_of_cache_size() {
584        let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap())
585            .unwrap();
586        assert_eq!(420, url.cache().capacity());
587    }
588
589    #[test]
590    fn should_have_default_cache_size() {
591        let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap();
592        assert_eq!(100, url.cache().capacity());
593    }
594
595    #[tokio::test]
596    async fn should_map_nonexisting_database_error() {
597        let mut url = Url::parse(&CONN_STR).unwrap();
598        url.set_username("root").unwrap();
599        url.set_path("/this_does_not_exist");
600
601        let url = url.as_str().to_string();
602        let res = Sqlint::new(&url).await;
603
604        let err = res.unwrap_err();
605
606        match err.kind() {
607            ErrorKind::DatabaseDoesNotExist { db_name } => {
608                assert_eq!(Some("1049"), err.original_code());
609                assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message());
610                assert_eq!(&Name::available("this_does_not_exist"), db_name)
611            }
612            e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e),
613        }
614    }
615
616    #[tokio::test]
617    async fn should_map_wrong_credentials_error() {
618        let mut url = Url::parse(&CONN_STR).unwrap();
619        url.set_username("WRONG").unwrap();
620
621        let res = Sqlint::new(url.as_str()).await;
622        assert!(res.is_err());
623
624        let err = res.unwrap_err();
625        assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
626    }
627}