universal_config/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod error;
4
5use crate::error::{
6    DeserializationError, Result, UniversalConfigError as Error, UniversalConfigError,
7};
8use dirs::{config_dir, home_dir};
9use serde::de::DeserializeOwned;
10use std::fs;
11use std::path::{Path, PathBuf};
12use tracing::debug;
13
14/// Supported config formats.
15pub enum Format {
16    /// `.json` file
17    #[cfg(feature = "json")]
18    Json,
19    /// `.yaml` or `.yml` files.
20    #[cfg(feature = "yaml")]
21    Yaml,
22    /// `.toml` files.
23    #[cfg(feature = "toml")]
24    Toml,
25    /// `.corn` files.
26    #[cfg(feature = "corn")]
27    Corn,
28    /// `.xml` files.
29    #[cfg(feature = "xml")]
30    Xml,
31    /// `.ron` files.
32    #[cfg(feature = "ron")]
33    Ron,
34    /// `.kdl` files
35    #[cfg(feature = "kdl")]
36    Kdl,
37}
38
39impl Format {
40    #[allow(dead_code)] // ignore warning when all feature flags disabled
41    const fn extension(&self) -> &str {
42        match *self {
43            #[cfg(feature = "json")]
44            Self::Json => "json",
45            #[cfg(feature = "yaml")]
46            Self::Yaml => "yaml",
47            #[cfg(feature = "toml")]
48            Self::Toml => "toml",
49            #[cfg(feature = "corn")]
50            Self::Corn => "corn",
51            #[cfg(feature = "xml")]
52            Self::Xml => "xml",
53            #[cfg(feature = "ron")]
54            Self::Ron => "ron",
55            #[cfg(feature = "kdl")]
56            Self::Kdl => "kdl",
57        }
58    }
59}
60
61/// The main loader struct.
62///
63/// Create a new loader and configure as appropriate
64/// to load your config file.
65pub struct ConfigLoader<'a> {
66    /// The name of your program, used when determining the directory path.
67    app_name: &'a str,
68    /// The name of the file (*excluding* extension) to search for.
69    /// Defaults to `config`.
70    file_name: &'a str,
71    /// Allowed file formats.
72    /// Defaults to all formats.
73    /// Set to disable formats you do not wish to allow.
74    formats: &'a [Format],
75    /// The directory to load the config file from.
76    /// Defaults to your system config dir (`$XDG_CONFIG_DIR` on Linux),
77    /// or your home dir if that does not exist.
78    config_dir: Option<&'a str>,
79}
80
81impl<'a> ConfigLoader<'a> {
82    /// Creates a new config loader for the provided app name.
83    /// Uses a default file name of "config" and all formats.
84    #[must_use]
85    pub const fn new(app_name: &'a str) -> ConfigLoader<'a> {
86        Self {
87            app_name,
88            file_name: "config",
89            formats: &[
90                #[cfg(feature = "json")]
91                Format::Json,
92                #[cfg(feature = "yaml")]
93                Format::Yaml,
94                #[cfg(feature = "toml")]
95                Format::Toml,
96                #[cfg(feature = "corn")]
97                Format::Corn,
98                #[cfg(feature = "xml")]
99                Format::Xml,
100                #[cfg(feature = "ron")]
101                Format::Ron,
102                #[cfg(feature = "kdl")]
103                Format::Kdl,
104            ],
105            config_dir: None,
106        }
107    }
108
109    /// Specifies the file name to look for, excluding the extension.
110    ///
111    /// If not specified, defaults to "config".
112    #[must_use]
113    pub const fn with_file_name(mut self, file_name: &'a str) -> Self {
114        self.file_name = file_name;
115        self
116    }
117
118    /// Specifies which file formats to search for, and in which order.
119    ///
120    /// If not specified, all formats are checked for
121    /// in the order JSON, YAML, TOML, Corn.
122    #[must_use]
123    pub const fn with_formats(mut self, formats: &'a [Format]) -> Self {
124        self.formats = formats;
125        self
126    }
127
128    /// Specifies which directory the config should be loaded from.
129    ///
130    /// If not specified, loads from `$XDG_CONFIG_DIR/<app_name>`
131    /// or `$HOME/.<app_name>` if the config dir does not exist.
132    #[must_use]
133    pub const fn with_config_dir(mut self, dir: &'a str) -> Self {
134        self.config_dir = Some(dir);
135        self
136    }
137
138    /// Attempts to locate a config file on disk and load it.
139    ///
140    /// # Errors
141    ///
142    /// Will return a `UniversalConfigError` if any error occurs
143    /// when looking for, reading, or deserializing a config file.
144    pub fn find_and_load<T: DeserializeOwned>(&self) -> Result<T> {
145        let file = self.try_find_file()?;
146        debug!("Found file at: '{}", file.display());
147        Self::load(&file)
148    }
149
150    /// Attempts to find the directory in which the config file is stored.
151    ///
152    /// # Errors
153    ///
154    /// Will error if the user's home directory cannot be located.
155    pub fn config_dir(&self) -> std::result::Result<PathBuf, UniversalConfigError> {
156        self.config_dir
157            .map(Into::into)
158            .or_else(|| config_dir().map(|dir| dir.join(self.app_name)))
159            .or_else(|| home_dir().map(|dir| dir.join(format!(".{}", self.app_name))))
160            .ok_or(Error::MissingUserDir)
161    }
162
163    /// Attempts to find a config file for the given app name
164    /// in the app's config directory
165    /// that matches any of the allowed formats.
166    fn try_find_file(&self) -> Result<PathBuf> {
167        let config_dir = self.config_dir()?;
168
169        let extensions = self.get_extensions();
170
171        debug!("Using config dir: {}", config_dir.display());
172
173        let file = extensions.into_iter().find_map(|extension| {
174            let full_path = config_dir.join(format!("{}.{extension}", self.file_name));
175
176            if Path::exists(&full_path) {
177                Some(full_path)
178            } else {
179                None
180            }
181        });
182
183        file.ok_or(Error::FileNotFound)
184    }
185
186    /// Loads the file at the given path,
187    /// deserializing it into a new `T`.
188    ///
189    /// The type is automatically determined from the file extension.
190    ///
191    /// # Errors
192    ///
193    /// Will return a `UniversalConfigError` if unable to read or deserialize the file.
194    pub fn load<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T> {
195        let str = fs::read_to_string(&path)?;
196
197        let extension = path
198            .as_ref()
199            .extension()
200            .unwrap_or_default()
201            .to_str()
202            .unwrap_or_default();
203
204        let config = Self::deserialize(&str, extension)?;
205        Ok(config)
206    }
207
208    /// Gets a list of supported and enabled file extensions.
209    fn get_extensions(&self) -> Vec<&'static str> {
210        #[allow(unused_mut)] // ignore warning when all feature flags disabled
211        let mut extensions = vec![];
212
213        for format in self.formats {
214            match *format {
215                #[cfg(feature = "json")]
216                Format::Json => extensions.push("json"),
217                #[cfg(feature = "yaml")]
218                Format::Yaml => {
219                    extensions.push("yaml");
220                    extensions.push("yml");
221                }
222                #[cfg(feature = "toml")]
223                Format::Toml => extensions.push("toml"),
224                #[cfg(feature = "corn")]
225                Format::Corn => extensions.push("corn"),
226                #[cfg(feature = "xml")]
227                Format::Xml => extensions.push("xml"),
228                #[cfg(feature = "ron")]
229                Format::Ron => extensions.push("ron"),
230                #[cfg(feature = "kdl")]
231                Format::Kdl => extensions.push("kdl"),
232            }
233        }
234
235        extensions
236    }
237
238    /// Attempts to deserialize the provided input into `T`,
239    /// based on the provided file extension.
240    fn deserialize<T: DeserializeOwned>(
241        str: &str,
242        extension: &str,
243    ) -> std::result::Result<T, DeserializationError> {
244        let res = match extension {
245            #[cfg(feature = "json")]
246            "json" => serde_json::from_str(str).map_err(DeserializationError::from),
247            #[cfg(feature = "toml")]
248            "toml" => toml::from_str(str).map_err(DeserializationError::from),
249            #[cfg(feature = "yaml")]
250            "yaml" | "yml" => serde_yaml::from_str(str).map_err(DeserializationError::from),
251            #[cfg(feature = "corn")]
252            "corn" => corn::from_str(str).map_err(DeserializationError::from),
253            #[cfg(feature = "xml")]
254            "xml" => serde_xml_rs::from_str(str).map_err(DeserializationError::from),
255            #[cfg(feature = "ron")]
256            "ron" => ron::from_str(str).map_err(DeserializationError::from),
257            #[cfg(feature = "kdl")]
258            "kdl" => kaydle::serde::from_str(str).map_err(DeserializationError::from),
259            _ => {
260                dbg!(str);
261                Err(DeserializationError::UnsupportedExtension(
262                    extension.to_string(),
263                ))
264            }
265        }?;
266
267        Ok(res)
268    }
269
270    /// Saves the provided configuration into a file of the specified format.
271    ///
272    /// The file is stored in the app's configuration directory.
273    /// Directories are automatically created if required.
274    ///
275    /// # Errors
276    ///
277    /// If the provided config cannot be serialised into the format, an error will be returned.
278    /// The `.corn` format is not supported, and the function will error if specified.
279    ///
280    /// If a valid config dir cannot be found, an error will be returned.
281    ///
282    /// If the file cannot be written to the specified path, an error will be returned.
283    #[cfg(feature = "save")]
284    pub fn save<T: serde::Serialize>(&self, config: &T, format: &Format) -> Result<()> {
285        use crate::error::SerializationError;
286
287        let str: std::result::Result<String, SerializationError> = match *format {
288            #[cfg(feature = "json")]
289            Format::Json => serde_json::to_string_pretty(config).map_err(SerializationError::from),
290            #[cfg(feature = "yaml")]
291            Format::Yaml => serde_yaml::to_string(config).map_err(SerializationError::from),
292            #[cfg(feature = "toml")]
293            Format::Toml => toml::to_string_pretty(config).map_err(SerializationError::from),
294            #[cfg(feature = "corn")]
295            Format::Corn => Err(SerializationError::UnsupportedExtension("corn".to_string())),
296            #[cfg(feature = "xml")]
297            Format::Xml => serde_xml_rs::to_string(config).map_err(SerializationError::from),
298            #[cfg(feature = "ron")]
299            Format::Ron => ron::to_string(config).map_err(SerializationError::from),
300            #[cfg(feature = "kdl")]
301            Format::Kdl => Err(SerializationError::UnsupportedExtension("kdl".to_string())),
302        };
303        let str = str?;
304
305        let config_dir = self.config_dir()?;
306        let file_name = format!("{}.{}", self.file_name, format.extension());
307        let full_path = config_dir.join(file_name);
308
309        fs::create_dir_all(config_dir)?;
310        fs::write(full_path, str)?;
311
312        Ok(())
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use serde::Deserialize;
320
321    #[derive(Deserialize)]
322    struct ConfigContents {
323        test: String,
324    }
325
326    #[test]
327    fn test_json() {
328        let res: ConfigContents = ConfigLoader::load("test_configs/config.json").unwrap();
329        assert_eq!(res.test, "hello world")
330    }
331
332    #[test]
333    fn test_yaml() {
334        let res: ConfigContents = ConfigLoader::load("test_configs/config.yaml").unwrap();
335        assert_eq!(res.test, "hello world")
336    }
337
338    #[test]
339    fn test_toml() {
340        let res: ConfigContents = ConfigLoader::load("test_configs/config.toml").unwrap();
341        assert_eq!(res.test, "hello world")
342    }
343
344    #[test]
345    fn test_corn() {
346        let res: ConfigContents = ConfigLoader::load("test_configs/config.corn").unwrap();
347        assert_eq!(res.test, "hello world")
348    }
349
350    #[test]
351    fn test_xml() {
352        let res: ConfigContents = ConfigLoader::load("test_configs/config.xml").unwrap();
353        assert_eq!(res.test, "hello world")
354    }
355
356    #[test]
357    fn test_ron() {
358        let res: ConfigContents = ConfigLoader::load("test_configs/config.ron").unwrap();
359        assert_eq!(res.test, "hello world")
360    }
361
362    #[test]
363    fn test_kdl() {
364        let res: ConfigContents = ConfigLoader::load("test_configs/config.kdl").unwrap();
365        assert_eq!(res.test, "hello world")
366    }
367
368    #[test]
369    fn test_find_load() {
370        let config = ConfigLoader::new("universal-config");
371        let res: ConfigContents = config
372            .with_config_dir("test_configs")
373            .find_and_load()
374            .unwrap();
375        assert_eq!(res.test, "hello world")
376    }
377}