xcfg/
format.rs

1use std::path::{Path, PathBuf};
2
3use super::error::Error;
4
5#[derive(Debug)]
6pub struct File<T, P = PathBuf>
7where
8    P: AsRef<Path>,
9{
10    pub path: P,
11    pub fmt: Format,
12    pub inner: T,
13}
14
15impl<T, P> Clone for File<T, P>
16where
17    T: Clone,
18    P: AsRef<Path> + Clone,
19{
20    fn clone(&self) -> Self {
21        Self {
22            path: self.path.clone(),
23            fmt: self.fmt,
24            inner: self.inner.clone(),
25        }
26    }
27}
28// ugly but works
29macro_rules! fmt_impl {
30    ($([$name:literal, $mod:ident, $fmt:ident, $ext:pat]),*) => {
31        $(
32            #[cfg(feature = $name)]
33            mod $mod;
34        )*
35
36        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
37        pub enum Format {
38            $(
39                #[cfg(feature = $name)]
40                $fmt,
41            )*
42        }
43        impl Format {
44            pub fn match_ext(ext: &str) -> Option<Self> {
45                match ext {
46                    $(
47                        #[cfg(feature = $name)]
48                        $ext => Some(Self::$fmt),
49                    )*
50                    _ => None,
51                }
52            }
53            pub fn serialize<T>(&self, input: &T) -> Result<String, Error>
54            where
55                T: serde::Serialize,
56            {
57                match self {
58                    $(
59                        #[cfg(feature = $name)]
60                        Self::$fmt => $mod::to_string(input),
61                    )*
62                }
63            }
64            pub fn deserialize<T>(&self, input: &str) -> Result<T, Error>
65            where
66                T: serde::de::DeserializeOwned,
67            {
68                match self {
69                    $(
70                        #[cfg(feature = $name)]
71                        Self::$fmt => $mod::from_str(input),
72                    )*
73                }
74            }
75        }
76    };
77}
78fmt_impl!(
79    ["toml", toml_impl, Toml, "toml"],
80    ["yaml", yaml_impl, Yaml, "yaml" | "yml"],
81    ["json", json_impl, Json, "json"]
82);
83
84mod file_impl {
85    use super::Format;
86    #[derive(Debug, PartialEq, Clone)]
87    pub enum LoadFormat {
88        Unknown,
89        Any,
90        Format(Format),
91    }
92    use std::path::Path;
93
94    use crate::error::Error;
95
96    pub fn load_fmt<P: AsRef<Path>>(path: P) -> LoadFormat {
97        match path.as_ref().extension() {
98            Some(ext) => match ext.to_str() {
99                Some("") => LoadFormat::Any,
100                None => LoadFormat::Unknown,
101                Some(ext) => match Format::match_ext(ext) {
102                    Some(fmt) => LoadFormat::Format(fmt),
103                    _ => LoadFormat::Unknown,
104                },
105            },
106            None => LoadFormat::Any,
107        }
108    }
109
110    pub fn load<T, P: AsRef<Path>>(fmt: Format, path: P) -> Result<T, Error>
111    where
112        T: serde::de::DeserializeOwned,
113    {
114        fmt.deserialize(&std::fs::read_to_string(path)?)
115    }
116
117    #[cfg(test)]
118    mod tests {
119        use super::*;
120        #[test]
121        fn test_load_fmt() {
122            let path = Path::new("test.toml");
123            assert_eq!(load_fmt(path), LoadFormat::Format(Format::Toml));
124            let path = Path::new("test");
125            assert_eq!(load_fmt(path), LoadFormat::Any);
126            let path = Path::new("test.");
127            assert_eq!(load_fmt(path), LoadFormat::Any);
128            let path = Path::new("test.unknown");
129            assert_eq!(load_fmt(path), LoadFormat::Unknown);
130        }
131    }
132}
133impl<T> File<T, PathBuf> {
134    pub fn any_load<AsP>(path: AsP) -> Result<File<T, PathBuf>, Error>
135    where
136        AsP: AsRef<Path>,
137        T: serde::de::DeserializeOwned,
138    {
139        let mut parent = path.as_ref().parent().ok_or(Error::InvalidPath)?;
140        if parent.as_os_str().is_empty() {
141            parent = Path::new(".");
142        }
143        let fname = path
144            .as_ref()
145            .file_name()
146            .and_then(|name| name.to_str())
147            .ok_or(Error::InvalidPath)?;
148        for entry in std::fs::read_dir(parent)? {
149            let entry_path = entry?.path();
150            if !entry_path.is_file() {
151                continue;
152            }
153            let name = match entry_path.file_name().and_then(|name| name.to_str()) {
154                Some(name) => name,
155                None => continue,
156            };
157            if !name.starts_with(fname) {
158                continue;
159            }
160            let load_fmt = file_impl::load_fmt(name);
161            match load_fmt {
162                file_impl::LoadFormat::Unknown | file_impl::LoadFormat::Any => continue,
163                file_impl::LoadFormat::Format(fmt) => {
164                    return File::with_fmt(entry_path, fmt);
165                }
166            }
167        }
168        Err(Error::InvalidPath)
169    }
170}
171impl<T, P> File<T, P>
172where
173    P: AsRef<Path>,
174{
175    pub fn into_inner(self) -> T {
176        self.inner
177    }
178    pub fn new(path: P, inner: T) -> Result<Self, Error> {
179        match file_impl::load_fmt(&path) {
180            file_impl::LoadFormat::Unknown | file_impl::LoadFormat::Any => {
181                Err(Error::UnknownFileFormat)
182            }
183            file_impl::LoadFormat::Format(fmt) => Ok(Self { path, fmt, inner }),
184        }
185    }
186    pub fn with_fmt(path: P, fmt: Format) -> Result<Self, Error>
187    where
188        T: serde::de::DeserializeOwned,
189    {
190        let inner = file_impl::load(fmt, path.as_ref())?;
191        Ok(Self { path, fmt, inner })
192    }
193    pub fn load(mut self) -> Result<(), Error>
194    where
195        T: serde::de::DeserializeOwned,
196    {
197        self.inner = file_impl::load(self.fmt, self.path.as_ref())?;
198        Ok(())
199    }
200    pub fn to_string(&self) -> Result<String, Error>
201    where
202        T: serde::Serialize,
203    {
204        let buf = self.fmt.serialize(&self.inner)?;
205        Ok(buf)
206    }
207    pub fn save(&self) -> Result<(), Error>
208    where
209        T: serde::Serialize,
210    {
211        let buf = self.to_string()?;
212        let parent = self.path.as_ref().parent().ok_or(Error::InvalidPath)?;
213        std::fs::create_dir_all(parent)?;
214        std::fs::write(self.path.as_ref(), buf)?;
215        Ok(())
216    }
217}
218
219pub trait XCfg {
220    fn with_format<P: AsRef<Path>>(path: P, fmt: Format) -> Result<File<Self, P>, Error>
221    where
222        Self: serde::de::DeserializeOwned,
223    {
224        File::with_fmt(path, fmt)
225    }
226    /// # Example
227    ///
228    /// ```rust
229    /// use serde::{Deserialize, Serialize};
230    /// use xcfg::XCfg;
231    /// #[derive(XCfg, Serialize, Deserialize, PartialEq, Debug, Clone)]
232    /// pub struct Test {
233    ///     a: i32,
234    ///     b: Vec<i32>,
235    ///     sub: SubTest,
236    /// }
237    ///
238    /// #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
239    /// pub struct SubTest {
240    ///     c: Vec<String>,
241    /// }
242    ///
243    /// let test = Test {
244    ///     a: 1,
245    ///     b: vec![0, 1, 2],
246    ///     sub: SubTest {
247    ///         c: vec!["ab".to_string(), "cd".to_string()],
248    ///     },
249    /// };
250    /// let path = "./test.toml";
251    /// test.save(path).unwrap();
252    /// assert_eq!(Test::load(path).unwrap().into_inner(), test);
253    /// std::fs::remove_file(path).unwrap();
254    fn load<P: AsRef<Path>>(path: P) -> Result<File<Self, PathBuf>, Error>
255    where
256        Self: serde::de::DeserializeOwned,
257    {
258        use file_impl::LoadFormat;
259        let inner = match file_impl::load_fmt(&path) {
260            LoadFormat::Any => File::any_load(path)?,
261            LoadFormat::Unknown => {
262                return Err(Error::UnknownFileFormat);
263            }
264            LoadFormat::Format(fmt) => {
265                let inner = file_impl::load(fmt, path.as_ref())?;
266                let path = path.as_ref().to_path_buf();
267                File { path, fmt, inner }
268            }
269        };
270        Ok(inner)
271    }
272    /// # Example
273    ///
274    /// ```rust
275    /// use serde::{Deserialize, Serialize};
276    /// use xcfg::XCfg;
277    /// #[derive(XCfg, Serialize, Deserialize, PartialEq, Debug, Clone)]
278    /// pub struct Test {
279    ///     a: i32,
280    ///     b: Vec<i32>,
281    ///     sub: SubTest,
282    /// }
283    /// impl Default for Test {
284    ///     fn default() -> Self {
285    ///         Self {
286    ///             a: 0,
287    ///             b: vec![],
288    ///             sub: SubTest::default(),
289    ///         }
290    ///     }
291    /// }
292    ///
293    /// #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
294    /// pub struct SubTest {
295    ///     c: Vec<String>,
296    /// }
297    /// impl Default for SubTest {
298    ///     fn default() -> Self {
299    ///         Self { c: vec![] }
300    ///     }
301    /// }
302    ///
303    /// let test = Test {
304    ///     a: 1,
305    ///     b: vec![0, 1, 2],
306    ///     sub: SubTest {
307    ///         c: vec!["ab".to_string(), "cd".to_string()],
308    ///     },
309    /// };
310    /// let path = "./test.toml";
311    /// let mut f = Test::load_or_default(path).unwrap();
312    /// assert_eq!(f.inner, Test::default());
313    /// f.inner = test.clone();
314    /// f.save().unwrap();
315    /// assert_eq!(Test::load(path).unwrap().into_inner(), test);
316    /// std::fs::remove_file(path).unwrap();
317    fn load_or_default<P: AsRef<Path>>(path: P) -> Result<File<Self, P>, Error>
318    where
319        Self: Default + serde::de::DeserializeOwned,
320    {
321        use file_impl::LoadFormat;
322        let inner = match file_impl::load_fmt(&path) {
323            LoadFormat::Format(fmt) => {
324                let inner = file_impl::load(fmt, path.as_ref()).unwrap_or_default();
325                File { path, fmt, inner }
326            }
327            _ => {
328                return Err(Error::UnknownFileFormat);
329            }
330        };
331        Ok(inner)
332    }
333    /// # Example
334    ///
335    /// ```rust
336    /// use serde::{Deserialize, Serialize};
337    /// use xcfg::XCfg;
338    /// #[derive(XCfg, Serialize, Deserialize, PartialEq, Debug, Clone)]
339    /// pub struct Test {
340    ///     a: i32,
341    ///     b: Vec<i32>,
342    ///     sub: SubTest,
343    /// }
344    ///
345    /// #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
346    /// pub struct SubTest {
347    ///     c: Vec<String>,
348    /// }
349    ///
350    /// let test = Test {
351    ///     a: 1,
352    ///     b: vec![0, 1, 2],
353    ///     sub: SubTest {
354    ///         c: vec!["ab".to_string(), "cd".to_string()],
355    ///     },
356    /// };
357    /// let path = "./test.toml";
358    /// test.save(path).unwrap();
359    /// std::fs::remove_file(path).unwrap();
360    fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error>
361    where
362        Self: serde::Serialize,
363    {
364        File::new(path, self)?.save()
365    }
366    /// # Example
367    /// ```rust
368    /// use serde::{Deserialize, Serialize};
369    /// use xcfg::{XCfg, Format};
370    /// #[derive(XCfg, Serialize, Deserialize, PartialEq, Debug, Clone)]
371    /// pub struct Test {
372    ///     a: i32,
373    ///     b: Vec<i32>,
374    ///     sub: SubTest,
375    /// }
376    /// #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
377    /// pub struct SubTest {
378    ///     c: Vec<String>,
379    /// }
380    /// let test = Test {
381    ///     a: 1,
382    ///     b: vec![0, 1, 2],
383    ///     sub: SubTest {
384    ///         c: vec!["ab".to_string(), "cd".to_string()],
385    ///    },
386    /// };
387    /// let right = r#"a = 1
388    /// b = [0, 1, 2]
389    ///
390    /// [sub]
391    /// c = ["ab", "cd"]
392    /// "#;
393    /// assert_eq!(test.fmt_to_string(Format::Toml).unwrap(), right);
394    fn fmt_to_string(&self, fmt: Format) -> Result<String, Error>
395    where
396        Self: serde::Serialize,
397    {
398        fmt.serialize(&self)
399    }
400}