scratchstack_config/
database.rs

1use {
2    crate::error::{ConfigError, DatabaseConfigErrorKind},
3    log::{debug, error, info},
4    serde::Deserialize,
5    sqlx::{any::Any as AnyDB, pool::PoolOptions},
6    std::{fmt::Debug, fs::read, time::Duration},
7};
8
9#[derive(Clone, Deserialize, Debug)]
10pub struct DatabaseConfig {
11    pub url: String,
12
13    #[serde(default)]
14    pub password: Option<String>,
15
16    #[serde(default)]
17    pub password_file: Option<String>,
18
19    #[serde(default)]
20    pub max_connections: Option<u32>,
21
22    #[serde(default)]
23    pub min_connections: Option<u32>,
24
25    #[serde(with = "humantime_serde", default)]
26    pub connection_timeout: Option<Duration>,
27
28    #[serde(with = "humantime_serde", default)]
29    pub max_lifetime: Option<Duration>,
30
31    #[serde(with = "humantime_serde", default)]
32    pub idle_timeout: Option<Duration>,
33
34    #[serde(default)]
35    pub test_before_acquire: Option<bool>,
36}
37
38impl DatabaseConfig {
39    pub fn get_database_url(&self) -> Result<String, ConfigError> {
40        let url = self.url.clone();
41
42        if let Some(password) = &self.password {
43            debug!("Database password specified in config file -- replacing occurrences in URL");
44            Ok(url.replace("${password}", password))
45        } else if let Some(password_file) = &self.password_file {
46            debug!("Database password file specified.");
47            match read(password_file) {
48                Ok(password_u8) => match std::str::from_utf8(&password_u8) {
49                    Ok(password) => {
50                        info!("Successfully read database password file {}; replacing URL", password_file);
51                        let password = password.trim();
52                        Ok(url.replace("${password}", password))
53                    }
54                    Err(e) => {
55                        error!("Found non-UTF-8 characters in database password file {}", password_file);
56                        Err(DatabaseConfigErrorKind::InvalidPasswordFileEncoding(password_file.to_string(), e).into())
57                    }
58                },
59                Err(e) => {
60                    error!("Failed to open database password file {}: {}", password_file, e);
61                    Err(ConfigError::IO(e))
62                }
63            }
64        } else if url.contains("${password}") {
65            error!("Found password placeholder '${{password}}' in database URL but no password was supplied: {}", url);
66            Err(DatabaseConfigErrorKind::MissingPassword.into())
67        } else {
68            Ok(url)
69        }
70    }
71
72    pub fn get_pool_options(&self) -> Result<PoolOptions<AnyDB>, ConfigError> {
73        let options = PoolOptions::<AnyDB>::new();
74        let options = options.max_lifetime(self.max_lifetime);
75        let mut options = options.idle_timeout(self.idle_timeout);
76
77        if let Some(size) = self.max_connections {
78            options = options.max_connections(size);
79        }
80
81        if let Some(size) = self.min_connections {
82            options = options.min_connections(size);
83        }
84
85        if let Some(duration) = self.connection_timeout {
86            options = options.acquire_timeout(duration);
87        }
88
89        if let Some(b) = self.test_before_acquire {
90            options = options.test_before_acquire(b);
91        }
92
93        Ok(options)
94    }
95
96    pub fn resolve(&self) -> Result<ResolvedDatabaseConfig, ConfigError> {
97        let url = self.get_database_url()?;
98        let pool_options = self.get_pool_options()?;
99
100        Ok(ResolvedDatabaseConfig {
101            url,
102            pool_options,
103        })
104    }
105}
106
107#[derive(Debug)]
108pub struct ResolvedDatabaseConfig {
109    pub url: String,
110    pub pool_options: PoolOptions<AnyDB>,
111}