Skip to main content

sanitize_engine/processor/
yaml_proc.rs

1//! YAML structured processor.
2//!
3//! Parses YAML input, walks the value tree, replaces matched field
4//! values, and serializes back. Structure is preserved but minor
5//! formatting differences are possible (serde_yaml normalizes some
6//! whitespace).
7//!
8//! Key paths use the same dot-separated convention as the JSON processor.
9
10use crate::error::{Result, SanitizeError};
11use crate::processor::limits::{DEFAULT_DEPTH, YAML_INPUT_SIZE, YAML_NODE_COUNT};
12use crate::processor::{walk_tree, FileTypeProfile, Processor, TreeNode};
13use crate::store::MappingStore;
14use serde_yaml_ng::Value;
15
16/// Structured processor for YAML files.
17pub struct YamlProcessor;
18
19impl Processor for YamlProcessor {
20    fn name(&self) -> &'static str {
21        "yaml"
22    }
23
24    fn can_handle(&self, content: &[u8], profile: &FileTypeProfile) -> bool {
25        if profile.processor == "yaml" {
26            return true;
27        }
28        // Heuristic: starts with `---` or a YAML-ish key: value.
29        let text = String::from_utf8_lossy(content);
30        let trimmed = text.trim_start();
31        trimmed.starts_with("---")
32            || trimmed.starts_with("- ")
33            || trimmed.starts_with('{')
34            || trimmed.contains(": ")
35    }
36
37    fn process(
38        &self,
39        content: &[u8],
40        profile: &FileTypeProfile,
41        store: &MappingStore,
42    ) -> Result<Vec<u8>> {
43        // Guard against alias bombs: reject inputs above YAML_INPUT_SIZE.
44        if content.len() > YAML_INPUT_SIZE {
45            return Err(SanitizeError::InputTooLarge {
46                size: content.len(),
47                limit: YAML_INPUT_SIZE,
48            });
49        }
50
51        let text = std::str::from_utf8(content).map_err(|e| SanitizeError::ParseError {
52            format: "YAML".into(),
53            message: format!("invalid UTF-8: {}", e),
54        })?;
55
56        let mut value: Value =
57            serde_yaml_ng::from_str(text).map_err(|e| SanitizeError::ParseError {
58                format: "YAML".into(),
59                message: format!("YAML parse error: {}", e),
60            })?;
61
62        // F-06 fix: count total nodes in the deserialized tree to detect
63        // alias bombs. After expansion, aliased subtrees become
64        // independent copies in memory, so the node count reflects the
65        // true memory footprint.
66        let node_count = count_yaml_nodes(&value);
67        if node_count > YAML_NODE_COUNT {
68            return Err(SanitizeError::InputTooLarge {
69                size: node_count,
70                limit: YAML_NODE_COUNT,
71            });
72        }
73
74        walk_yaml(&mut value, "", profile, store, 0)?;
75
76        let output = serde_yaml_ng::to_string(&value)
77            .map_err(|e| SanitizeError::IoError(format!("YAML serialize error: {}", e)))?;
78
79        Ok(output.into_bytes())
80    }
81}
82
83/// Count the total number of nodes in a YAML value tree (F-06 fix).
84/// Used to detect alias bombs that produce a small source document
85/// but expand to millions of nodes after alias resolution.
86fn count_yaml_nodes(value: &Value) -> usize {
87    count_yaml_nodes_inner(value, 0)
88}
89
90/// Inner recursive counter with depth guard to prevent stack overflow
91/// on deeply nested YAML before `walk_yaml`'s depth check is reached.
92fn count_yaml_nodes_inner(value: &Value, depth: usize) -> usize {
93    if depth > DEFAULT_DEPTH {
94        return 1; // Stop counting deeper; walk_yaml will catch depth violations
95    }
96    match value {
97        Value::Mapping(map) => {
98            1 + map
99                .iter()
100                .map(|(k, v)| {
101                    count_yaml_nodes_inner(k, depth + 1) + count_yaml_nodes_inner(v, depth + 1)
102                })
103                .sum::<usize>()
104        }
105        Value::Sequence(seq) => {
106            1 + seq
107                .iter()
108                .map(|v| count_yaml_nodes_inner(v, depth + 1))
109                .sum::<usize>()
110        }
111        Value::Tagged(tagged) => 1 + count_yaml_nodes_inner(&tagged.value, depth + 1),
112        _ => 1, // Null, Bool, Number, String
113    }
114}
115
116impl TreeNode for Value {
117    fn for_each_map_entry<F>(&mut self, mut f: F) -> Result<()>
118    where
119        F: FnMut(&str, &mut Self) -> Result<()>,
120    {
121        if let Self::Mapping(map) = self {
122            let keys: Vec<Self> = map.keys().cloned().collect();
123            for key in keys {
124                let key_str = yaml_key_to_string(&key);
125                if let Some(v) = map.get_mut(&key) {
126                    f(&key_str, v)?;
127                }
128            }
129        }
130        Ok(())
131    }
132
133    fn for_each_seq_item<F>(&mut self, mut f: F) -> Result<()>
134    where
135        F: FnMut(&mut Self) -> Result<()>,
136    {
137        if let Self::Sequence(seq) = self {
138            for item in seq.iter_mut() {
139                f(item)?;
140            }
141        }
142        Ok(())
143    }
144
145    fn as_str_mut(&mut self) -> Option<&mut String> {
146        if let Self::String(s) = self {
147            Some(s)
148        } else {
149            None
150        }
151    }
152
153    fn is_scalar(&self) -> bool {
154        matches!(self, Self::Number(_) | Self::Bool(_))
155    }
156
157    fn scalar_to_string(&self) -> String {
158        yaml_scalar_to_string(self)
159    }
160
161    fn set_string(&mut self, s: String) {
162        *self = Self::String(s);
163    }
164}
165
166/// Recursively walk a YAML value tree, replacing matched field values.
167fn walk_yaml(
168    value: &mut Value,
169    prefix: &str,
170    profile: &FileTypeProfile,
171    store: &MappingStore,
172    depth: usize,
173) -> Result<()> {
174    walk_tree(value, prefix, profile, store, depth, "YAML")
175}
176
177fn yaml_key_to_string(key: &Value) -> String {
178    match key {
179        Value::String(s) => s.clone(),
180        Value::Number(n) => n.to_string(),
181        Value::Bool(b) => b.to_string(),
182        _ => format!("{:?}", key),
183    }
184}
185
186fn yaml_scalar_to_string(v: &Value) -> String {
187    match v {
188        Value::String(s) => s.clone(),
189        Value::Number(n) => n.to_string(),
190        Value::Bool(b) => b.to_string(),
191        _ => String::new(),
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::category::Category;
199    use crate::generator::HmacGenerator;
200    use crate::processor::profile::FieldRule;
201    use std::sync::Arc;
202
203    fn make_store() -> MappingStore {
204        let gen = Arc::new(HmacGenerator::new([42u8; 32]));
205        MappingStore::new(gen, None)
206    }
207
208    #[test]
209    fn basic_yaml_replacement() {
210        let store = make_store();
211        let proc = YamlProcessor;
212
213        let content = b"database:\n  host: db.corp.com\n  password: s3cret\nport: 5432\n";
214        let profile = FileTypeProfile::new(
215            "yaml",
216            vec![
217                FieldRule::new("database.password").with_category(Category::Custom("pw".into())),
218                FieldRule::new("database.host").with_category(Category::Hostname),
219            ],
220        );
221
222        let result = proc.process(content, &profile, &store).unwrap();
223        let out = String::from_utf8(result).unwrap();
224
225        assert!(!out.contains("s3cret"));
226        assert!(!out.contains("db.corp.com"));
227        // port should be preserved
228        assert!(out.contains("5432"));
229    }
230
231    #[test]
232    fn yaml_sequence_traversal() {
233        let store = make_store();
234        let proc = YamlProcessor;
235
236        let content = b"users:\n  - email: a@b.com\n  - email: c@d.com\n";
237        let profile = FileTypeProfile::new(
238            "yaml",
239            vec![FieldRule::new("users.email").with_category(Category::Email)],
240        );
241
242        let result = proc.process(content, &profile, &store).unwrap();
243        let out = String::from_utf8(result).unwrap();
244
245        assert!(!out.contains("a@b.com"));
246        assert!(!out.contains("c@d.com"));
247    }
248}