yaml_include/
transformer.rs

1use anyhow::{anyhow, Result};
2use serde_yaml_ng::{
3    value::{Tag, TaggedValue},
4    Mapping, Value,
5};
6
7use std::{
8    collections::HashSet,
9    fmt,
10    fs::{canonicalize, read_to_string},
11    path::PathBuf,
12};
13
14use crate::helpers::{load_as_base64, load_yaml};
15
16/// Processing yaml with include documents through `!include <path>` tag.
17///
18/// ## Features
19///
20/// - include and parse recursively `yaml` (and `json`) files
21/// - include `markdown` and `txt` text files
22/// - include other types as `base64` encoded binary data.
23/// - optionaly handle gracefully circular references with `!circular` tag
24///
25/// ## Example
26/// ```
27/// use std::path::PathBuf;
28/// use yaml_include::Transformer;
29///
30/// let path = PathBuf::from("data/sample/main.yml");
31/// if let Ok(transformer) = Transformer::new(path, false) {
32///     println!("{}", transformer);
33/// };
34/// ```
35#[derive(Debug, Clone)]
36pub struct Transformer {
37    error_on_circular: bool,
38    root_path: PathBuf,
39    seen_paths: HashSet<PathBuf>, // for circular reference detection
40}
41
42impl Transformer {
43    /// Instance a transformer from a yaml file path.
44    ///
45    /// # Example:
46    ///
47    /// ```
48    /// use std::path::PathBuf;
49    /// use yaml_include::Transformer;
50    ///
51    /// let path = PathBuf::from("data/sample/main.yml");
52    /// if let Ok(transformer) = Transformer::new(path, false) {
53    ///     dbg!(transformer);
54    /// };
55    /// ```
56    pub fn new(root_path: PathBuf, strict: bool) -> Result<Self> {
57        Self::new_node(root_path, strict, None)
58    }
59
60    /// Parse yaml with recursively processing `!include`
61    ///
62    /// # Example:
63    ///
64    /// ```
65    /// use std::path::PathBuf;
66    /// use yaml_include::Transformer;
67    ///
68    /// let path = PathBuf::from("data/sample/main.yml");
69    /// if let Ok(transformer) = Transformer::new(path, false) {
70    ///     let parsed = transformer.parse();
71    ///     dbg!(parsed);
72    /// };
73    /// ```
74    pub fn parse(&self) -> Value {
75        let file_path = self.root_path.clone();
76        let input = load_yaml(file_path).unwrap();
77
78        self.clone().recursive_process(input)
79    }
80
81    fn new_node(
82        root_path: PathBuf,
83        strict: bool,
84        seen_paths_option: Option<HashSet<PathBuf>>,
85    ) -> Result<Self> {
86        let mut seen_paths = seen_paths_option.unwrap_or_default();
87
88        let normalized_path = canonicalize(&root_path).unwrap();
89
90        // Circular reference guard
91        if seen_paths.contains(&normalized_path) {
92            return Err(anyhow!(
93                "circular reference: {}",
94                &normalized_path.display()
95            ));
96        }
97
98        seen_paths.insert(normalized_path);
99
100        Ok(Transformer {
101            error_on_circular: strict,
102            root_path,
103            seen_paths,
104        })
105    }
106
107    fn recursive_process(self, input: Value) -> Value {
108        match input {
109            Value::Sequence(seq) => seq
110                .iter()
111                .map(|v| self.clone().recursive_process(v.clone()))
112                .collect(),
113            Value::Mapping(map) => Value::Mapping(Mapping::from_iter(
114                map.iter()
115                    .map(|(k, v)| (k.clone(), self.clone().recursive_process(v.clone()))),
116            )),
117            Value::Tagged(tagged_value) => match tagged_value.tag.to_string().as_str() {
118                "!include" => {
119                    let value = tagged_value.value.as_str().unwrap();
120                    let file_path = PathBuf::from(value);
121
122                    self.handle_include_extension(file_path)
123                }
124                _ => Value::Tagged(tagged_value),
125            },
126            // default no transform
127            _ => input,
128        }
129    }
130
131    fn handle_include_extension(&self, file_path: PathBuf) -> Value {
132        let normalized_file_path = self.process_path(&file_path);
133
134        let result = match normalized_file_path.extension() {
135            Some(os_str) => match os_str.to_str() {
136                Some("yaml") | Some("yml") | Some("json") => {
137                    match Transformer::new_node(
138                        normalized_file_path,
139                        self.error_on_circular,
140                        Some(self.seen_paths.clone()),
141                    ) {
142                        Ok(transformer) => transformer.parse(),
143                        Err(e) => {
144                            if self.error_on_circular {
145                                // TODO: probably something better to do than panic ?
146                                panic!("{:?}", e);
147                            }
148
149                            return Value::Tagged(
150                                TaggedValue {
151                                    tag: Tag::new("circular"),
152                                    value: Value::String(file_path.display().to_string()),
153                                }
154                                .into(),
155                            );
156                        }
157                    }
158                }
159                // inlining markdow and text files
160                Some("txt") | Some("markdown") | Some("md") => {
161                    Value::String(read_to_string(normalized_file_path).unwrap())
162                }
163                // inlining other include as binary files
164                None | Some(&_) => Value::Tagged(Box::new(TaggedValue {
165                    tag: Tag::new("binary"),
166                    value: Value::Mapping(Mapping::from_iter([
167                        (
168                            Value::String("filename".into()),
169                            Value::String(
170                                normalized_file_path
171                                    .file_name()
172                                    .unwrap()
173                                    .to_string_lossy()
174                                    .to_string(),
175                            ),
176                        ),
177                        (
178                            Value::String("base64".into()),
179                            Value::String(load_as_base64(&normalized_file_path).unwrap()),
180                        ),
181                    ])),
182                })),
183            },
184            _ => panic!("{:?} path missing file extension", normalized_file_path),
185        };
186
187        result
188    }
189
190    fn process_path(&self, file_path: &PathBuf) -> PathBuf {
191        if file_path.is_absolute() {
192            return file_path.clone();
193        }
194        let joined = self.root_path.parent().unwrap().join(file_path);
195
196        if !joined.is_file() {
197            panic!("{:?} not found", joined);
198        }
199
200        canonicalize(joined).unwrap()
201    }
202}
203
204impl fmt::Display for Transformer {
205    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
206        write!(
207            f,
208            "{}",
209            serde_yaml_ng::to_string(&self.clone().parse()).unwrap()
210        )
211    }
212}
213
214#[test]
215fn test_transformer() -> Result<()> {
216    let expected = read_to_string("data/expected.yml").unwrap();
217    let transformer = Transformer::new(PathBuf::from("data/root.yml"), false);
218    let actual = transformer?.to_string();
219
220    assert_eq!(expected, actual);
221
222    Ok(())
223}