yaml_include/
transformer.rs1use anyhow::{anyhow, Result};
2use serde_yaml_ng::{
3 value::{Tag, TaggedValue},
4 Mapping, Value,
5};
6
7use std::{
8 collections::HashSet,
9 fmt,
10 fs::{canonicalize, read_to_string},
11 path::PathBuf,
12};
13
14use crate::helpers::{load_as_base64, load_yaml};
15
16#[derive(Debug, Clone)]
36pub struct Transformer {
37 error_on_circular: bool,
38 root_path: PathBuf,
39 seen_paths: HashSet<PathBuf>, }
41
42impl Transformer {
43 pub fn new(root_path: PathBuf, strict: bool) -> Result<Self> {
57 Self::new_node(root_path, strict, None)
58 }
59
60 pub fn parse(&self) -> Value {
75 let file_path = self.root_path.clone();
76 let input = load_yaml(file_path).unwrap();
77
78 self.clone().recursive_process(input)
79 }
80
81 fn new_node(
82 root_path: PathBuf,
83 strict: bool,
84 seen_paths_option: Option<HashSet<PathBuf>>,
85 ) -> Result<Self> {
86 let mut seen_paths = seen_paths_option.unwrap_or_default();
87
88 let normalized_path = canonicalize(&root_path).unwrap();
89
90 if seen_paths.contains(&normalized_path) {
92 return Err(anyhow!(
93 "circular reference: {}",
94 &normalized_path.display()
95 ));
96 }
97
98 seen_paths.insert(normalized_path);
99
100 Ok(Transformer {
101 error_on_circular: strict,
102 root_path,
103 seen_paths,
104 })
105 }
106
107 fn recursive_process(self, input: Value) -> Value {
108 match input {
109 Value::Sequence(seq) => seq
110 .iter()
111 .map(|v| self.clone().recursive_process(v.clone()))
112 .collect(),
113 Value::Mapping(map) => Value::Mapping(Mapping::from_iter(
114 map.iter()
115 .map(|(k, v)| (k.clone(), self.clone().recursive_process(v.clone()))),
116 )),
117 Value::Tagged(tagged_value) => match tagged_value.tag.to_string().as_str() {
118 "!include" => {
119 let value = tagged_value.value.as_str().unwrap();
120 let file_path = PathBuf::from(value);
121
122 self.handle_include_extension(file_path)
123 }
124 _ => Value::Tagged(tagged_value),
125 },
126 _ => input,
128 }
129 }
130
131 fn handle_include_extension(&self, file_path: PathBuf) -> Value {
132 let normalized_file_path = self.process_path(&file_path);
133
134 let result = match normalized_file_path.extension() {
135 Some(os_str) => match os_str.to_str() {
136 Some("yaml") | Some("yml") | Some("json") => {
137 match Transformer::new_node(
138 normalized_file_path,
139 self.error_on_circular,
140 Some(self.seen_paths.clone()),
141 ) {
142 Ok(transformer) => transformer.parse(),
143 Err(e) => {
144 if self.error_on_circular {
145 panic!("{:?}", e);
147 }
148
149 return Value::Tagged(
150 TaggedValue {
151 tag: Tag::new("circular"),
152 value: Value::String(file_path.display().to_string()),
153 }
154 .into(),
155 );
156 }
157 }
158 }
159 Some("txt") | Some("markdown") | Some("md") => {
161 Value::String(read_to_string(normalized_file_path).unwrap())
162 }
163 None | Some(&_) => Value::Tagged(Box::new(TaggedValue {
165 tag: Tag::new("binary"),
166 value: Value::Mapping(Mapping::from_iter([
167 (
168 Value::String("filename".into()),
169 Value::String(
170 normalized_file_path
171 .file_name()
172 .unwrap()
173 .to_string_lossy()
174 .to_string(),
175 ),
176 ),
177 (
178 Value::String("base64".into()),
179 Value::String(load_as_base64(&normalized_file_path).unwrap()),
180 ),
181 ])),
182 })),
183 },
184 _ => panic!("{:?} path missing file extension", normalized_file_path),
185 };
186
187 result
188 }
189
190 fn process_path(&self, file_path: &PathBuf) -> PathBuf {
191 if file_path.is_absolute() {
192 return file_path.clone();
193 }
194 let joined = self.root_path.parent().unwrap().join(file_path);
195
196 if !joined.is_file() {
197 panic!("{:?} not found", joined);
198 }
199
200 canonicalize(joined).unwrap()
201 }
202}
203
204impl fmt::Display for Transformer {
205 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
206 write!(
207 f,
208 "{}",
209 serde_yaml_ng::to_string(&self.clone().parse()).unwrap()
210 )
211 }
212}
213
214#[test]
215fn test_transformer() -> Result<()> {
216 let expected = read_to_string("data/expected.yml").unwrap();
217 let transformer = Transformer::new(PathBuf::from("data/root.yml"), false);
218 let actual = transformer?.to_string();
219
220 assert_eq!(expected, actual);
221
222 Ok(())
223}