1use crate::error::{DataError, DataResult};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum AdapterKind {
7 None,
8 Postgres,
9 MySql,
10 Sqlite,
11}
12
13impl AdapterKind {
14 pub fn as_str(self) -> &'static str {
15 match self {
16 Self::None => "none",
17 Self::Postgres => "postgres",
18 Self::MySql => "mysql",
19 Self::Sqlite => "sqlite",
20 }
21 }
22
23 pub fn parse(raw: &str) -> DataResult<Self> {
24 match raw.trim().to_ascii_lowercase().as_str() {
25 "none" => Ok(Self::None),
26 "postgres" | "postgresql" | "pg" => Ok(Self::Postgres),
27 "mysql" => Ok(Self::MySql),
28 "sqlite" | "sqlite3" => Ok(Self::Sqlite),
29 value => Err(DataError::Config(format!(
30 "unsupported database adapter `{value}`; expected one of: none, postgres, mysql, sqlite"
31 ))),
32 }
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub struct DatabaseConfig {
38 pub adapter: AdapterKind,
39 pub url: Option<String>,
40 pub url_env: Option<String>,
41}
42
43impl Default for DatabaseConfig {
44 fn default() -> Self {
45 Self {
46 adapter: AdapterKind::None,
47 url: None,
48 url_env: Some("DATABASE_URL".to_string()),
49 }
50 }
51}
52
53impl DatabaseConfig {
54 pub fn from_toml_like_str(content: &str) -> DataResult<Self> {
55 let mut config = Self::default();
56 let mut in_database_section = false;
57
58 for raw_line in content.lines() {
59 let line = raw_line.trim();
60 if line.is_empty() || line.starts_with('#') {
61 continue;
62 }
63 if line.starts_with('[') && line.ends_with(']') {
64 in_database_section = line == "[database]";
65 continue;
66 }
67 if !in_database_section {
68 continue;
69 }
70
71 let Some((key, value)) = line.split_once('=') else {
72 continue;
73 };
74
75 let key = key.trim();
76 let value = strip_quotes(value.trim());
77 match key {
78 "adapter" => config.adapter = AdapterKind::parse(value)?,
79 "url" => config.url = Some(value.to_string()),
80 "url_env" => config.url_env = Some(value.to_string()),
81 _ => {}
82 }
83 }
84
85 Ok(config)
86 }
87
88 pub fn resolve_url(&self) -> Option<String> {
89 if let Some(url) = &self.url {
90 return Some(url.clone());
91 }
92 self.url_env
93 .as_deref()
94 .and_then(|env_name| std::env::var(env_name).ok())
95 }
96}
97
98fn strip_quotes(value: &str) -> &str {
99 value
100 .strip_prefix('"')
101 .and_then(|rest| rest.strip_suffix('"'))
102 .unwrap_or(value)
103}
104
105#[cfg(test)]
106mod tests {
107 use super::{AdapterKind, DatabaseConfig};
108 use proptest::prelude::*;
109
110 #[test]
111 fn parse_database_config() {
112 let config = DatabaseConfig::from_toml_like_str(
113 r#"
114[database]
115adapter = "postgres"
116url_env = "APP_DB_URL"
117"#,
118 )
119 .unwrap();
120
121 assert_eq!(config.adapter, AdapterKind::Postgres);
122 assert_eq!(config.url_env.as_deref(), Some("APP_DB_URL"));
123 }
124
125 proptest! {
126 #[test]
127 fn adapter_parse_accepts_aliases_case_and_whitespace(
128 alias in prop_oneof![
129 Just("none"),
130 Just("postgres"),
131 Just("postgresql"),
132 Just("pg"),
133 Just("mysql"),
134 Just("sqlite"),
135 Just("sqlite3"),
136 ],
137 left_ws in 0usize..3,
138 right_ws in 0usize..3,
139 uppercase in any::<bool>(),
140 ) {
141 let alias = if uppercase {
142 alias.to_ascii_uppercase()
143 } else {
144 alias.to_string()
145 };
146 let input = format!("{}{}{}", " ".repeat(left_ws), alias, " ".repeat(right_ws));
147 let kind = AdapterKind::parse(&input).unwrap();
148 let expected = match alias.to_ascii_lowercase().as_str() {
149 "none" => AdapterKind::None,
150 "postgres" | "postgresql" | "pg" => AdapterKind::Postgres,
151 "mysql" => AdapterKind::MySql,
152 "sqlite" | "sqlite3" => AdapterKind::Sqlite,
153 _ => unreachable!("input generated from known aliases"),
154 };
155 prop_assert_eq!(kind, expected);
156 }
157
158 #[test]
159 fn adapter_parse_rejects_unknown_values(raw in "[a-zA-Z0-9_\\-]{1,24}") {
160 let normalized = raw.trim().to_ascii_lowercase();
161 prop_assume!(
162 normalized != "none" &&
163 normalized != "postgres" &&
164 normalized != "postgresql" &&
165 normalized != "pg" &&
166 normalized != "mysql" &&
167 normalized != "sqlite" &&
168 normalized != "sqlite3"
169 );
170 prop_assert!(AdapterKind::parse(&raw).is_err());
171 }
172 }
173}