Skip to main content

shelly_data/
adapter.rs

1use crate::error::{DataError, DataResult};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum AdapterKind {
7    None,
8    Postgres,
9    MySql,
10    Sqlite,
11    SingleStore,
12    ClickHouse,
13    BigQuery,
14    OpenSearch,
15}
16
17impl AdapterKind {
18    pub fn as_str(self) -> &'static str {
19        match self {
20            Self::None => "none",
21            Self::Postgres => "postgres",
22            Self::MySql => "mysql",
23            Self::Sqlite => "sqlite",
24            Self::SingleStore => "singlestore",
25            Self::ClickHouse => "clickhouse",
26            Self::BigQuery => "bigquery",
27            Self::OpenSearch => "opensearch",
28        }
29    }
30
31    pub fn parse(raw: &str) -> DataResult<Self> {
32        match raw.trim().to_ascii_lowercase().as_str() {
33            "none" => Ok(Self::None),
34            "postgres" | "postgresql" | "pg" => Ok(Self::Postgres),
35            "mysql" => Ok(Self::MySql),
36            "sqlite" | "sqlite3" => Ok(Self::Sqlite),
37            "singlestore" | "single_store" | "memsql" => Ok(Self::SingleStore),
38            "clickhouse" | "click_house" => Ok(Self::ClickHouse),
39            "bigquery" | "big_query" | "bq" => Ok(Self::BigQuery),
40            "opensearch" | "open_search" => Ok(Self::OpenSearch),
41            value => Err(DataError::Config(format!(
42                "unsupported database adapter `{value}`; expected one of: none, postgres, mysql, sqlite, singlestore, clickhouse, bigquery, opensearch"
43            ))),
44        }
45    }
46
47    pub fn is_sql_backend(self) -> bool {
48        matches!(
49            self,
50            Self::Postgres
51                | Self::MySql
52                | Self::Sqlite
53                | Self::SingleStore
54                | Self::ClickHouse
55                | Self::BigQuery
56        )
57    }
58
59    pub fn supports_migrations(self) -> bool {
60        self.is_sql_backend()
61    }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65pub struct DatabaseConfig {
66    pub adapter: AdapterKind,
67    pub url: Option<String>,
68    pub url_env: Option<String>,
69}
70
71impl Default for DatabaseConfig {
72    fn default() -> Self {
73        Self {
74            adapter: AdapterKind::None,
75            url: None,
76            url_env: Some("DATABASE_URL".to_string()),
77        }
78    }
79}
80
81impl DatabaseConfig {
82    pub fn from_toml_like_str(content: &str) -> DataResult<Self> {
83        let mut config = Self::default();
84        let mut in_database_section = false;
85
86        for raw_line in content.lines() {
87            let line = raw_line.trim();
88            if line.is_empty() || line.starts_with('#') {
89                continue;
90            }
91            if line.starts_with('[') && line.ends_with(']') {
92                in_database_section = line == "[database]";
93                continue;
94            }
95            if !in_database_section {
96                continue;
97            }
98
99            let Some((key, value)) = line.split_once('=') else {
100                continue;
101            };
102
103            let key = key.trim();
104            let value = strip_quotes(value.trim());
105            match key {
106                "adapter" => config.adapter = AdapterKind::parse(value)?,
107                "url" => config.url = Some(value.to_string()),
108                "url_env" => config.url_env = Some(value.to_string()),
109                _ => {}
110            }
111        }
112
113        Ok(config)
114    }
115
116    pub fn resolve_url(&self) -> Option<String> {
117        if let Some(url) = &self.url {
118            return Some(url.clone());
119        }
120        self.url_env
121            .as_deref()
122            .and_then(|env_name| std::env::var(env_name).ok())
123    }
124}
125
126fn strip_quotes(value: &str) -> &str {
127    value
128        .strip_prefix('"')
129        .and_then(|rest| rest.strip_suffix('"'))
130        .unwrap_or(value)
131}
132
133#[cfg(test)]
134mod tests {
135    use super::{AdapterKind, DatabaseConfig};
136    use proptest::prelude::*;
137    use std::sync::atomic::{AtomicU64, Ordering};
138
139    static TEST_ENV_COUNTER: AtomicU64 = AtomicU64::new(0);
140
141    #[test]
142    fn parse_database_config() {
143        let config = DatabaseConfig::from_toml_like_str(
144            r#"
145[database]
146adapter = "postgres"
147url_env = "APP_DB_URL"
148"#,
149        )
150        .unwrap();
151
152        assert_eq!(config.adapter, AdapterKind::Postgres);
153        assert_eq!(config.url_env.as_deref(), Some("APP_DB_URL"));
154    }
155
156    #[test]
157    fn resolve_url_prefers_inline_then_env() {
158        let inline = DatabaseConfig {
159            adapter: AdapterKind::Sqlite,
160            url: Some("sqlite://inline.db".to_string()),
161            url_env: Some("IGNORED_ENV".to_string()),
162        };
163        assert_eq!(inline.resolve_url().as_deref(), Some("sqlite://inline.db"));
164
165        let key = format!(
166            "SHELLY_DATA_TEST_DB_URL_{}",
167            TEST_ENV_COUNTER.fetch_add(1, Ordering::Relaxed)
168        );
169        std::env::set_var(&key, "sqlite://from-env.db");
170        let from_env = DatabaseConfig {
171            adapter: AdapterKind::Sqlite,
172            url: None,
173            url_env: Some(key.clone()),
174        };
175        assert_eq!(from_env.resolve_url().as_deref(), Some("sqlite://from-env.db"));
176        std::env::remove_var(key);
177    }
178
179    proptest! {
180        #[test]
181        fn adapter_parse_accepts_aliases_case_and_whitespace(
182            alias in prop_oneof![
183                Just("none"),
184                Just("postgres"),
185                Just("postgresql"),
186                Just("pg"),
187                Just("mysql"),
188                Just("sqlite"),
189                Just("sqlite3"),
190                Just("singlestore"),
191                Just("single_store"),
192                Just("memsql"),
193                Just("clickhouse"),
194                Just("click_house"),
195                Just("bigquery"),
196                Just("big_query"),
197                Just("bq"),
198                Just("opensearch"),
199                Just("open_search"),
200            ],
201            left_ws in 0usize..3,
202            right_ws in 0usize..3,
203            uppercase in any::<bool>(),
204        ) {
205            let alias = if uppercase {
206                alias.to_ascii_uppercase()
207            } else {
208                alias.to_string()
209            };
210            let input = format!("{}{}{}", " ".repeat(left_ws), alias, " ".repeat(right_ws));
211            let kind = AdapterKind::parse(&input).unwrap();
212            let expected = match alias.to_ascii_lowercase().as_str() {
213                "none" => AdapterKind::None,
214                "postgres" | "postgresql" | "pg" => AdapterKind::Postgres,
215                "mysql" => AdapterKind::MySql,
216                "sqlite" | "sqlite3" => AdapterKind::Sqlite,
217                "singlestore" | "single_store" | "memsql" => AdapterKind::SingleStore,
218                "clickhouse" | "click_house" => AdapterKind::ClickHouse,
219                "bigquery" | "big_query" | "bq" => AdapterKind::BigQuery,
220                "opensearch" | "open_search" => AdapterKind::OpenSearch,
221                _ => unreachable!("input generated from known aliases"),
222            };
223            prop_assert_eq!(kind, expected);
224        }
225
226        #[test]
227        fn adapter_parse_rejects_unknown_values(raw in "[a-zA-Z0-9_\\-]{1,24}") {
228            let normalized = raw.trim().to_ascii_lowercase();
229            prop_assume!(
230                normalized != "none" &&
231                normalized != "postgres" &&
232                normalized != "postgresql" &&
233                normalized != "pg" &&
234                normalized != "mysql" &&
235                normalized != "sqlite" &&
236                normalized != "sqlite3" &&
237                normalized != "singlestore" &&
238                normalized != "single_store" &&
239                normalized != "memsql" &&
240                normalized != "clickhouse" &&
241                normalized != "click_house" &&
242                normalized != "bigquery" &&
243                normalized != "big_query" &&
244                normalized != "bq" &&
245                normalized != "opensearch" &&
246                normalized != "open_search"
247            );
248            prop_assert!(AdapterKind::parse(&raw).is_err());
249        }
250    }
251}