Skip to main content

sqlcx_core/
config.rs

1use crate::error::{Result, SqlcxError};
2use schemars::JsonSchema;
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::path::Path;
6
7#[derive(Deserialize, Clone, Debug, JsonSchema)]
8pub struct SqlcxConfig {
9    pub sql: String,
10    pub parser: String,
11    pub targets: Vec<TargetConfig>,
12    #[serde(default)]
13    pub overrides: HashMap<String, String>,
14    #[serde(default)]
15    pub migrate: Option<MigrateConfig>,
16}
17
18#[derive(Deserialize, Clone, Debug, JsonSchema)]
19pub struct TargetConfig {
20    pub language: String,
21    pub out: String,
22    pub schema: String,
23    pub driver: String,
24    #[serde(default)]
25    pub overrides: HashMap<String, String>,
26}
27
28#[derive(Deserialize, Clone, Debug, JsonSchema)]
29pub struct MigrateConfig {
30    #[serde(default = "default_migrate_dir")]
31    pub dir: String,
32    #[serde(default)]
33    pub database_url: Option<String>,
34    #[serde(default = "default_auto_regenerate")]
35    pub auto_regenerate: bool,
36}
37
38fn default_migrate_dir() -> String {
39    "./sql/migrations".to_string()
40}
41
42fn default_auto_regenerate() -> bool {
43    true
44}
45
46/// Load config from a directory by auto-detecting sqlcx.toml or sqlcx.json.
47/// Tries sqlcx.toml first, then sqlcx.json.
48pub fn load_config(dir: &Path) -> Result<SqlcxConfig> {
49    let toml_path = dir.join("sqlcx.toml");
50    if toml_path.exists() {
51        let content = std::fs::read_to_string(&toml_path)?;
52        return toml::from_str(&content).map_err(SqlcxError::from);
53    }
54
55    let json_path = dir.join("sqlcx.json");
56    if json_path.exists() {
57        let content = std::fs::read_to_string(&json_path)?;
58        return serde_json::from_str(&content).map_err(SqlcxError::from);
59    }
60
61    Err(SqlcxError::ConfigNotFound(
62        "no sqlcx.toml or sqlcx.json found".to_string(),
63    ))
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn deserialize_toml_config() {
72        let toml_str = r#"
73sql = "./sql"
74parser = "postgres"
75[[targets]]
76language = "typescript"
77out = "./src/db"
78schema = "typebox"
79driver = "bun-sql"
80[overrides]
81uuid = "string"
82"#;
83        let config: SqlcxConfig = toml::from_str(toml_str).unwrap();
84        assert_eq!(config.sql, "./sql");
85        assert_eq!(config.parser, "postgres");
86        assert_eq!(config.targets.len(), 1);
87        assert_eq!(config.targets[0].language, "typescript");
88        assert_eq!(config.overrides.get("uuid"), Some(&"string".to_string()));
89    }
90
91    #[test]
92    fn deserialize_json_config() {
93        let json_str = r#"{"sql":"./sql","parser":"postgres","targets":[{"language":"typescript","out":"./src/db","schema":"typebox","driver":"bun-sql"}]}"#;
94        let config: SqlcxConfig = serde_json::from_str(json_str).unwrap();
95        assert_eq!(config.sql, "./sql");
96        assert_eq!(config.targets.len(), 1);
97        assert!(config.overrides.is_empty());
98    }
99
100    #[test]
101    fn load_config_auto_detect_toml() {
102        let dir = tempfile::tempdir().unwrap();
103        std::fs::write(
104            dir.path().join("sqlcx.toml"),
105            r#"
106sql = "./sql"
107parser = "postgres"
108[[targets]]
109language = "typescript"
110out = "./src/db"
111schema = "typebox"
112driver = "bun-sql"
113"#,
114        )
115        .unwrap();
116        let config = load_config(dir.path()).unwrap();
117        assert_eq!(config.parser, "postgres");
118    }
119
120    #[test]
121    fn load_config_auto_detect_json() {
122        let dir = tempfile::tempdir().unwrap();
123        std::fs::write(
124            dir.path().join("sqlcx.json"),
125            r#"{"sql":"./sql","parser":"postgres","targets":[{"language":"typescript","out":"./src/db","schema":"typebox","driver":"bun-sql"}]}"#,
126        )
127        .unwrap();
128        let config = load_config(dir.path()).unwrap();
129        assert_eq!(config.parser, "postgres");
130    }
131
132    #[test]
133    fn load_config_not_found() {
134        let dir = tempfile::tempdir().unwrap();
135        assert!(load_config(dir.path()).is_err());
136    }
137}