Skip to main content

tank_mysql/
connection.rs

1use crate::{MySQLDriver, MySQLQueryable, MySQLTransaction};
2use mysql_async::{ClientIdentity, Conn, Opts, OptsBuilder};
3use std::{borrow::Cow, env, path::PathBuf};
4use tank_core::{
5    Connection, Error, ErrorContext, Result, impl_executor_transaction, truncate_long,
6};
7
8/// Connection wrapper used by the MySQL/MariaDB driver.
9///
10/// Holds the underlying `mysql_async` connection and adapts it to the `tank_core::Connection`/`Executor` APIs.
11pub struct MySQLConnection {
12    pub(crate) conn: MySQLQueryable<Conn>,
13}
14
15pub type MariaDBConnection = MySQLConnection;
16
17impl_executor_transaction!(MySQLDriver, MySQLConnection, conn);
18
19impl Connection for MySQLConnection {
20    async fn connect(url: Cow<'static, str>) -> Result<Self> {
21        let context = format!("While trying to connect to `{}`", truncate_long!(&url));
22        let mut url = Self::sanitize_url(url)?;
23        let mut take_url_param = |key: &str, env_var: &str, remove: bool| {
24            let value = url
25                .query_pairs()
26                .find_map(|(k, v)| if k == key { Some(v) } else { None })
27                .map(|v| v.to_string());
28            if remove && let Some(..) = value {
29                let mut result = url.clone();
30                result.set_query(None);
31                result
32                    .query_pairs_mut()
33                    .extend_pairs(url.query_pairs().filter(|(k, _)| k != key));
34                url = result;
35            };
36            value.or_else(|| env::var(env_var).ok().map(Into::into))
37        };
38        let ssl_ca = take_url_param("ssl_ca", "MYSQL_SSL_CA", true);
39        let ssl_cert = take_url_param("ssl_cert", "MYSQL_SSL_CERT", true);
40        let ssl_pass = take_url_param("ssl_pass", "MYSQL_SSL_PASS", true);
41        let opts = Opts::from_url(url.as_str()).with_context(|| context.clone())?;
42        let mut ssl_opts = opts.ssl_opts().cloned();
43        let mut opts = OptsBuilder::from_opts(opts);
44        if let Some(ssl_ca) = ssl_ca {
45            let ca_path = PathBuf::from(ssl_ca);
46            if !ca_path.exists() {
47                let error = Error::msg(format!(
48                    "SSL CA file not found: `{}`",
49                    ca_path.to_string_lossy()
50                ))
51                .context(context);
52                log::error!("{:#}", error);
53                return Err(error);
54            }
55            let certs = vec![ca_path.into()];
56            ssl_opts = Some(ssl_opts.unwrap_or_default().with_root_certs(certs));
57        }
58        if let Some(ssl_cert) = ssl_cert {
59            let ssl_cert = PathBuf::from(ssl_cert);
60            if !ssl_cert.exists() {
61                let error = Error::msg(format!(
62                    "SSL CERT file not found: `{}`",
63                    ssl_cert.to_string_lossy()
64                ))
65                .context(context);
66                log::error!("{:#}", error);
67                return Err(error);
68            }
69            let mut identity = ClientIdentity::new(ssl_cert.into());
70            if let Some(ssl_pass) = ssl_pass {
71                identity = identity.with_password(ssl_pass);
72            };
73            ssl_opts = Some(
74                ssl_opts
75                    .unwrap_or_default()
76                    .with_client_identity(Some(identity)),
77            );
78        }
79        opts = opts.ssl_opts(ssl_opts);
80        let connection = Conn::new(opts).await.context(context)?;
81        Ok(MySQLConnection {
82            conn: MySQLQueryable {
83                executor: connection,
84            },
85        })
86    }
87
88    fn begin(&mut self) -> impl Future<Output = Result<MySQLTransaction<'_>>> {
89        MySQLTransaction::new(self)
90    }
91}