1use crate::{
4 BasicComposer, CommentPreservingComposer, CommentedValue, Composer, Error, Limits, Position,
5 Result, Value,
6};
7
8pub trait Constructor {
10 fn construct(&mut self) -> Result<Option<Value>>;
12
13 fn check_data(&self) -> bool;
15
16 fn reset(&mut self);
18}
19
20pub trait CommentPreservingConstructor {
22 fn construct_commented(&mut self) -> Result<Option<CommentedValue>>;
24
25 fn check_data(&self) -> bool;
27
28 fn reset(&mut self);
30}
31
32#[derive(Debug)]
34pub struct SafeConstructor {
35 composer: BasicComposer,
36 position: Position,
37 limits: Limits,
38}
39
40impl SafeConstructor {
41 pub fn new(input: String) -> Self {
43 Self::with_limits(input, Limits::default())
44 }
45
46 pub fn with_limits(input: String, limits: Limits) -> Self {
48 let composer = BasicComposer::new_eager_with_limits(input, limits.clone());
50 let position = Position::start();
51
52 Self {
53 composer,
54 position,
55 limits,
56 }
57 }
58
59 pub fn from_composer(composer: BasicComposer) -> Self {
61 let position = Position::start();
62 let limits = Limits::default();
63
64 Self {
65 composer,
66 position,
67 limits,
68 }
69 }
70
71 pub fn from_composer_with_limits(composer: BasicComposer, limits: Limits) -> Self {
73 let position = Position::start();
74
75 Self {
76 composer,
77 position,
78 limits,
79 }
80 }
81
82 fn validate_value(&self, value: Value) -> Result<Value> {
84 match value {
85 Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::String(_) => {
87 Ok(value)
88 }
89
90 Value::Sequence(seq) => {
92 if seq.len() > self.limits.max_collection_size {
94 return Err(Error::limit_exceeded(format!(
95 "Sequence size {} exceeds max_collection_size limit of {}",
96 seq.len(),
97 self.limits.max_collection_size
98 )));
99 }
100 let mut safe_seq = Vec::with_capacity(seq.len());
101 for item in seq {
102 safe_seq.push(self.validate_value(item)?);
103 }
104 Ok(Value::Sequence(safe_seq))
105 }
106
107 Value::Mapping(map) => {
109 if map.len() > self.limits.max_collection_size {
111 return Err(Error::limit_exceeded(format!(
112 "Mapping size {} exceeds max_collection_size limit of {}",
113 map.len(),
114 self.limits.max_collection_size
115 )));
116 }
117 let mut safe_map = indexmap::IndexMap::new();
118 for (key, val) in map {
119 let safe_key = self.validate_value(key)?;
120 let safe_val = self.validate_value(val)?;
121 safe_map.insert(safe_key, safe_val);
122 }
123 Ok(Value::Mapping(safe_map))
124 }
125 }
126 }
127
128 fn apply_safety_rules(&self, value: Value) -> Result<Value> {
130 match value {
131 Value::String(ref s) if s.len() > self.limits.max_string_length => {
133 Err(Error::limit_exceeded(format!(
134 "String too long: {} bytes (max: {})",
135 s.len(),
136 self.limits.max_string_length
137 )))
138 }
139
140 Value::Sequence(ref seq) if seq.len() > self.limits.max_collection_size => {
142 Err(Error::limit_exceeded(format!(
143 "Sequence too long: {} elements (max: {})",
144 seq.len(),
145 self.limits.max_collection_size
146 )))
147 }
148
149 Value::Mapping(ref map) if map.len() > self.limits.max_collection_size => {
151 Err(Error::limit_exceeded(format!(
152 "Mapping too large: {} entries (max: {})",
153 map.len(),
154 self.limits.max_collection_size
155 )))
156 }
157
158 Value::Sequence(seq) => {
160 let mut safe_seq = Vec::with_capacity(seq.len());
161 for item in seq {
162 safe_seq.push(self.apply_safety_rules(item)?);
163 }
164 Ok(Value::Sequence(safe_seq))
165 }
166
167 Value::Mapping(map) => {
168 let mut safe_map = indexmap::IndexMap::new();
169 for (key, val) in map {
170 let safe_key = self.apply_safety_rules(key)?;
171 let safe_val = self.apply_safety_rules(val)?;
172 safe_map.insert(safe_key, safe_val);
173 }
174 Ok(Value::Mapping(safe_map))
175 }
176
177 _ => Ok(value),
179 }
180 }
181}
182
183impl Default for SafeConstructor {
184 fn default() -> Self {
185 Self::new(String::new())
186 }
187}
188
189impl Constructor for SafeConstructor {
190 fn construct(&mut self) -> Result<Option<Value>> {
191 let document = match self.composer.compose_document()? {
193 Some(doc) => doc,
194 None => return Ok(None),
195 };
196
197 let validated = self.validate_value(document)?;
199 let safe_value = self.apply_safety_rules(validated)?;
200
201 Ok(Some(safe_value))
202 }
203
204 fn check_data(&self) -> bool {
205 self.composer.check_document()
206 }
207
208 fn reset(&mut self) {
209 self.composer.reset();
210 self.position = Position::start();
211 }
212}
213
214#[derive(Debug)]
216pub struct RoundTripConstructor {
217 composer: CommentPreservingComposer,
218 position: Position,
219 limits: Limits,
220}
221
222impl RoundTripConstructor {
223 pub fn new(input: String) -> Self {
225 Self::with_limits(input, Limits::default())
226 }
227
228 pub fn with_limits(input: String, limits: Limits) -> Self {
230 let composer = CommentPreservingComposer::with_limits(input, limits.clone());
232 let position = Position::start();
233
234 Self {
235 composer,
236 position,
237 limits,
238 }
239 }
240
241 fn parse_with_comments(&mut self) -> Result<Option<CommentedValue>> {
243 self.composer.compose_document()
245 }
246}
247
248impl CommentPreservingConstructor for RoundTripConstructor {
249 fn construct_commented(&mut self) -> Result<Option<CommentedValue>> {
250 self.parse_with_comments()
251 }
252
253 fn check_data(&self) -> bool {
254 true
257 }
258
259 fn reset(&mut self) {
260 self.position = Position::start();
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_safe_scalar_construction() {
270 let mut constructor = SafeConstructor::new("42".to_string());
271 let result = constructor.construct().unwrap().unwrap();
272 assert_eq!(result, Value::Int(42));
273 }
274
275 #[test]
276 fn test_safe_sequence_construction() {
277 let mut constructor = SafeConstructor::new("[1, 2, 3]".to_string());
278 let result = constructor.construct().unwrap().unwrap();
279
280 let expected = Value::Sequence(vec![Value::Int(1), Value::Int(2), Value::Int(3)]);
281 assert_eq!(result, expected);
282 }
283
284 #[test]
285 fn test_safe_mapping_construction() {
286 let mut constructor = SafeConstructor::new("{'key': 'value'}".to_string());
287 let result = constructor.construct().unwrap().unwrap();
288
289 let mut expected_map = indexmap::IndexMap::new();
290 expected_map.insert(
291 Value::String("key".to_string()),
292 Value::String("value".to_string()),
293 );
294 let expected = Value::Mapping(expected_map);
295
296 assert_eq!(result, expected);
297 }
298
299 #[test]
300 fn test_nested_construction() {
301 let yaml_content = "{'users': [{'name': 'Alice', 'age': 30}]}";
302 let mut constructor = SafeConstructor::new(yaml_content.to_string());
303 let result = constructor.construct().unwrap().unwrap();
304
305 if let Value::Mapping(map) = result {
306 if let Some(Value::Sequence(users)) = map.get(&Value::String("users".to_string())) {
307 assert_eq!(users.len(), 1);
308 if let Value::Mapping(ref user) = users[0] {
309 assert_eq!(
310 user.get(&Value::String("name".to_string())),
311 Some(&Value::String("Alice".to_string()))
312 );
313 assert_eq!(
314 user.get(&Value::String("age".to_string())),
315 Some(&Value::Int(30))
316 );
317 }
318 }
319 } else {
320 panic!("Expected mapping");
321 }
322 }
323
324 #[test]
325 fn test_check_data() {
326 let constructor = SafeConstructor::new("42".to_string());
327 assert!(constructor.check_data());
328 }
329
330 #[test]
331 fn test_multiple_types() {
332 let yaml_content = "{'string': 'hello', 'int': 42, 'bool': true, 'null_key': null}";
333 let mut constructor = SafeConstructor::new(yaml_content.to_string());
334 let result = constructor.construct().unwrap().unwrap();
335
336 if let Value::Mapping(map) = result {
337 assert_eq!(
338 map.get(&Value::String("string".to_string())),
339 Some(&Value::String("hello".to_string()))
340 );
341 assert_eq!(
342 map.get(&Value::String("int".to_string())),
343 Some(&Value::Int(42))
344 );
345 assert_eq!(
346 map.get(&Value::String("bool".to_string())),
347 Some(&Value::Bool(true))
348 );
349 assert_eq!(
351 map.get(&Value::String("null_key".to_string())),
352 Some(&Value::Null)
353 );
354 } else {
355 panic!("Expected mapping");
356 }
357 }
358
359 #[test]
360 fn test_safety_limits() {
361 let large_string = "a".repeat(1000); let yaml_content = format!("value: '{}'", large_string);
364 let mut constructor = SafeConstructor::new(yaml_content);
365
366 let result = constructor.construct();
367 match result {
369 Ok(Some(value)) => {
370 if let Value::Mapping(map) = value {
372 if let Some(Value::String(s)) = map.get(&Value::String("value".to_string())) {
373 assert_eq!(s.len(), 1000);
374 }
375 }
376 }
377 Ok(None) => {
378 }
380 Err(error) => {
381 assert!(!error.to_string().is_empty());
383 }
384 }
385 }
386
387 #[test]
388 fn test_boolean_values() {
389 let test_cases = vec![
390 ("true", true),
391 ("false", false),
392 ("yes", true),
393 ("no", false),
394 ("on", true),
395 ("off", false),
396 ];
397
398 for (input, expected) in test_cases {
399 let mut constructor = SafeConstructor::new(input.to_string());
400 let result = constructor.construct().unwrap().unwrap();
401 assert_eq!(result, Value::Bool(expected), "Failed for input: {}", input);
402 }
403 }
404
405 #[test]
406 fn test_null_values() {
407 let test_cases = vec!["null", "~"];
408
409 for input in test_cases {
410 let mut constructor = SafeConstructor::new(input.to_string());
411 let result = constructor.construct().unwrap().unwrap();
412 assert_eq!(result, Value::Null, "Failed for input: {}", input);
413 }
414 }
415}