yaml_include/
transformer.rs

1use anyhow::{anyhow, Result};
2use serde_yaml_ng::{
3    value::{Tag, TaggedValue},
4    Mapping, Value,
5};
6
7use std::{
8    collections::HashSet,
9    fmt,
10    fs::{canonicalize, read_to_string},
11    path::PathBuf,
12    str::FromStr,
13};
14
15use crate::helpers::{load_as_base64, load_yaml};
16
17struct FilePath {
18    path: PathBuf,
19    extension: Extension,
20}
21
22enum Extension {
23    Yaml,
24    Text,
25    Binary,
26}
27
28#[derive(Debug)]
29enum ParseError {
30    MissingPath,
31    MissingExtension,
32}
33
34impl FromStr for Extension {
35    type Err = ();
36
37    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
38        match s {
39            "yaml" | "yml" | "json" => Ok(Self::Yaml),
40            "md" | "markdown" | "txt" => Ok(Self::Text),
41            _ => Ok(Self::Binary),
42        }
43    }
44}
45
46impl TryFrom<Mapping> for FilePath {
47    type Error = ParseError;
48
49    fn try_from(value: Mapping) -> Result<Self, Self::Error> {
50        let path = value
51            .get("path")
52            .and_then(|value| value.as_str())
53            .ok_or(ParseError::MissingPath)?
54            .into();
55
56        let extension = Extension::from_str(
57            value
58                .get("extension")
59                .and_then(|value| value.as_str())
60                .ok_or(ParseError::MissingExtension)?,
61        )
62        .expect("Infaillible conversion");
63
64        Ok(Self { path, extension })
65    }
66}
67
68impl TryFrom<String> for FilePath {
69    type Error = ParseError;
70
71    fn try_from(value: String) -> Result<Self, Self::Error> {
72        let path: PathBuf = value.into();
73
74        let extension = Extension::from_str(
75            path.extension()
76                .and_then(|ext| ext.to_str())
77                .ok_or(ParseError::MissingExtension)?,
78        )
79        .expect("Infaillible conversion");
80
81        Ok(Self { path, extension })
82    }
83}
84
85/// Processing yaml with include documents through `!include <path>` tag.
86///
87/// ## Features
88///
89/// - include and parse recursively `yaml` (and `json`) files
90/// - include `markdown` and `txt` text files
91/// - include other types as `base64` encoded binary data.
92/// - optionaly handle gracefully circular references with `!circular` tag
93///
94/// ## Example
95/// ```
96/// use std::path::PathBuf;
97/// use yaml_include::Transformer;
98///
99/// let path = PathBuf::from("data/sample/main.yml");
100/// if let Ok(transformer) = Transformer::new(path, false) {
101///     println!("{}", transformer);
102/// };
103/// ```
104#[derive(Debug, Clone)]
105pub struct Transformer {
106    error_on_circular: bool,
107    root_path: PathBuf,
108    seen_paths: HashSet<PathBuf>, // for circular reference detection
109}
110
111impl Transformer {
112    /// Instance a transformer from a yaml file path.
113    ///
114    /// # Example:
115    ///
116    /// ```
117    /// use std::path::PathBuf;
118    /// use yaml_include::Transformer;
119    ///
120    /// let path = PathBuf::from("data/sample/main.yml");
121    /// if let Ok(transformer) = Transformer::new(path, false) {
122    ///     dbg!(transformer);
123    /// };
124    /// ```
125    pub fn new(root_path: PathBuf, strict: bool) -> Result<Self> {
126        Self::new_node(root_path, strict, None)
127    }
128
129    /// Parse yaml with recursively processing `!include`
130    ///
131    /// # Example:
132    ///
133    /// ```
134    /// use std::path::PathBuf;
135    /// use yaml_include::Transformer;
136    ///
137    /// let path = PathBuf::from("data/sample/main.yml");
138    /// if let Ok(transformer) = Transformer::new(path, false) {
139    ///     let parsed = transformer.parse();
140    ///     dbg!(parsed);
141    /// };
142    /// ```
143    pub fn parse(&self) -> Value {
144        let file_path = self.root_path.clone();
145        let input = load_yaml(file_path).unwrap();
146
147        self.clone().recursive_process(input)
148    }
149
150    fn new_node(
151        root_path: PathBuf,
152        strict: bool,
153        seen_paths_option: Option<HashSet<PathBuf>>,
154    ) -> Result<Self> {
155        let mut seen_paths = seen_paths_option.unwrap_or_default();
156
157        let normalized_path = canonicalize(&root_path).unwrap();
158
159        // Circular reference guard
160        if seen_paths.contains(&normalized_path) {
161            return Err(anyhow!(
162                "circular reference: {}",
163                &normalized_path.display()
164            ));
165        }
166
167        seen_paths.insert(normalized_path);
168
169        Ok(Transformer {
170            error_on_circular: strict,
171            root_path,
172            seen_paths,
173        })
174    }
175
176    fn recursive_process(self, input: Value) -> Value {
177        match input {
178            Value::Sequence(seq) => seq
179                .iter()
180                .map(|v| self.clone().recursive_process(v.clone()))
181                .collect(),
182            Value::Mapping(map) => Value::Mapping(Mapping::from_iter(
183                map.iter()
184                    .map(|(k, v)| (k.clone(), self.clone().recursive_process(v.clone()))),
185            )),
186            Value::Tagged(tagged_value) => match tagged_value.tag.to_string().as_str() {
187                "!include" => {
188                    let file_path: FilePath = match tagged_value.value {
189                        Value::String(path) => path.try_into().unwrap(),
190                        Value::Mapping(mapping) => mapping.try_into().unwrap(),
191                        _ => panic!("Unsupported Value"),
192                    };
193
194                    self.handle_include_extension(file_path)
195                }
196                _ => Value::Tagged(tagged_value),
197            },
198            // default no transform
199            _ => input,
200        }
201    }
202
203    fn handle_include_extension(&self, file_path: FilePath) -> Value {
204        let normalized_file_path = self.process_path(&file_path.path);
205
206        let result = match file_path.extension {
207            Extension::Yaml => {
208                match Transformer::new_node(
209                    normalized_file_path,
210                    self.error_on_circular,
211                    Some(self.seen_paths.clone()),
212                ) {
213                    Ok(transformer) => transformer.parse(),
214                    Err(e) => {
215                        if self.error_on_circular {
216                            panic!("{:?}", e);
217                        }
218
219                        return Value::Tagged(
220                            TaggedValue {
221                                tag: Tag::new("circular"),
222                                value: Value::String(file_path.path.display().to_string()),
223                            }
224                            .into(),
225                        );
226                    }
227                }
228            }
229            // inlining markdow and text files
230            Extension::Text => Value::String(read_to_string(normalized_file_path).unwrap()),
231            // inlining other include as binary files
232            Extension::Binary => Value::Tagged(Box::new(TaggedValue {
233                tag: Tag::new("binary"),
234                value: Value::Mapping(Mapping::from_iter([
235                    (
236                        Value::String("filename".into()),
237                        Value::String(
238                            normalized_file_path
239                                .file_name()
240                                .unwrap()
241                                .to_string_lossy()
242                                .to_string(),
243                        ),
244                    ),
245                    (
246                        Value::String("base64".into()),
247                        Value::String(load_as_base64(&normalized_file_path).unwrap()),
248                    ),
249                ])),
250            })),
251        };
252
253        result
254    }
255
256    fn process_path(&self, file_path: &PathBuf) -> PathBuf {
257        if file_path.is_absolute() {
258            return file_path.clone();
259        }
260        let joined = self.root_path.parent().unwrap().join(file_path);
261
262        if !joined.is_file() {
263            panic!("{:?} not found", joined);
264        }
265
266        canonicalize(joined).unwrap()
267    }
268}
269
270impl fmt::Display for Transformer {
271    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
272        write!(
273            f,
274            "{}",
275            serde_yaml_ng::to_string(&self.clone().parse()).unwrap()
276        )
277    }
278}
279
280#[test]
281fn test_transformer() -> Result<()> {
282    let expected = read_to_string("data/expected.yml").unwrap();
283    let transformer = Transformer::new(PathBuf::from("data/root.yml"), false);
284    let actual = transformer?.to_string();
285
286    assert_eq!(expected, actual);
287
288    Ok(())
289}