Skip to main content

rustauth_cli/
config.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use toml_edit::{value, Array, DocumentMut};
7
8#[derive(Debug, Error)]
9pub enum ConfigError {
10    #[error("failed to read {path}: {source}")]
11    Read {
12        path: PathBuf,
13        source: std::io::Error,
14    },
15    #[error("failed to write {path}: {source}")]
16    Write {
17        path: PathBuf,
18        source: std::io::Error,
19    },
20    #[error("failed to parse RustAuth config: {0}")]
21    ParseToml(#[from] toml_edit::de::Error),
22    #[error("failed to parse RustAuth config document: {0}")]
23    ParseDocument(#[from] toml_edit::TomlError),
24    #[error("failed to render RustAuth config: {0}")]
25    SerializeToml(#[from] toml_edit::ser::Error),
26    #[error("plugins.enabled must be an array")]
27    InvalidPlugins,
28    #[error(
29        "database.adapter is required; set it explicitly in rustauth.toml \
30         (e.g. sqlx, diesel, tokio-postgres, deadpool-postgres)"
31    )]
32    MissingDatabaseAdapter,
33}
34
35#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(default)]
37pub struct CliConfig {
38    pub project: ProjectConfig,
39    pub database: DatabaseConfig,
40    pub security: SecurityConfig,
41    pub plugins: PluginsConfig,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(default)]
46pub struct ProjectConfig {
47    pub framework: Option<String>,
48    pub base_url: String,
49    pub base_path: String,
50    pub production: bool,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54#[serde(default)]
55pub struct DatabaseConfig {
56    pub adapter: Option<String>,
57    pub provider: Option<String>,
58    pub url_env: String,
59    pub migrations_dir: String,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(default)]
64pub struct SecurityConfig {
65    pub secret_env: String,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
69#[serde(default)]
70pub struct PluginsConfig {
71    pub enabled: Vec<String>,
72}
73
74impl Default for ProjectConfig {
75    fn default() -> Self {
76        Self {
77            framework: None,
78            base_url: "http://localhost:3000/api/auth".to_owned(),
79            base_path: "/api/auth".to_owned(),
80            production: false,
81        }
82    }
83}
84
85impl Default for DatabaseConfig {
86    fn default() -> Self {
87        Self {
88            adapter: None,
89            provider: None,
90            url_env: "DATABASE_URL".to_owned(),
91            migrations_dir: "migrations/rustauth".to_owned(),
92        }
93    }
94}
95
96impl Default for SecurityConfig {
97    fn default() -> Self {
98        Self {
99            secret_env: "RUSTAUTH_SECRET".to_owned(),
100        }
101    }
102}
103
104impl CliConfig {
105    pub fn parse_str(source: &str) -> Result<Self, ConfigError> {
106        source.parse()
107    }
108
109    pub fn load(path: &Path) -> Result<Self, ConfigError> {
110        let source = fs::read_to_string(path).map_err(|source| ConfigError::Read {
111            path: path.to_path_buf(),
112            source,
113        })?;
114        let config = Self::parse_str(&source)?;
115        config.validate_loaded_fields()?;
116        Ok(config)
117    }
118
119    pub fn database_adapter(&self) -> Option<&str> {
120        self.database
121            .adapter
122            .as_deref()
123            .filter(|adapter| !adapter.trim().is_empty())
124    }
125
126    pub fn validate_loaded_fields(&self) -> Result<(), ConfigError> {
127        if self.database_adapter().is_none() {
128            return Err(ConfigError::MissingDatabaseAdapter);
129        }
130        Ok(())
131    }
132
133    pub fn load_optional(path: &Path) -> Result<Option<Self>, ConfigError> {
134        if !path.exists() {
135            return Ok(None);
136        }
137        Self::load(path).map(Some)
138    }
139
140    pub fn to_toml_string(&self) -> Result<String, ConfigError> {
141        Ok(toml_edit::ser::to_string_pretty(self)?)
142    }
143
144    pub fn write(&self, path: &Path) -> Result<(), ConfigError> {
145        let rendered = self.to_toml_string()?;
146        fs::write(path, rendered).map_err(|source| ConfigError::Write {
147            path: path.to_path_buf(),
148            source,
149        })
150    }
151
152    pub fn add_plugin_to_document(source: &str, plugin: &str) -> Result<String, ConfigError> {
153        let mut document = source.parse::<DocumentMut>()?;
154        ensure_plugin_array(&mut document)?;
155        let enabled = document["plugins"]["enabled"]
156            .as_array_mut()
157            .ok_or(ConfigError::InvalidPlugins)?;
158        if !enabled.iter().any(|item| item.as_str() == Some(plugin)) {
159            enabled.push(plugin);
160        }
161        Ok(document.to_string())
162    }
163
164    pub fn remove_plugin_from_document(source: &str, plugin: &str) -> Result<String, ConfigError> {
165        let mut document = source.parse::<DocumentMut>()?;
166        ensure_plugin_array(&mut document)?;
167        let enabled = document["plugins"]["enabled"]
168            .as_array_mut()
169            .ok_or(ConfigError::InvalidPlugins)?;
170        enabled.retain(|item| item.as_str() != Some(plugin));
171        Ok(document.to_string())
172    }
173}
174
175impl std::str::FromStr for CliConfig {
176    type Err = ConfigError;
177
178    fn from_str(source: &str) -> Result<Self, Self::Err> {
179        Ok(toml_edit::de::from_str(source)?)
180    }
181}
182
183fn ensure_plugin_array(document: &mut DocumentMut) -> Result<(), ConfigError> {
184    if document["plugins"].is_none() {
185        document["plugins"] = toml_edit::table();
186    }
187    if document["plugins"]["enabled"].is_none() {
188        document["plugins"]["enabled"] = value(Array::default());
189    }
190    if !document["plugins"]["enabled"].is_array() {
191        return Err(ConfigError::InvalidPlugins);
192    }
193    Ok(())
194}