Skip to main content

yaml_edit/
visitor.rs

1//! Visitor pattern for traversing YAML AST nodes.
2//!
3//! Allows traversing and processing YAML structures without modifying node types.
4//! Useful for validation, transformation, collection, or analysis.
5//!
6//! # Example
7//!
8//! ```rust
9//! use yaml_edit::visitor::{YamlVisitor, YamlAccept};
10//! use yaml_edit::{Document, Scalar, Mapping, Sequence};
11//! use std::str::FromStr;
12//!
13//! struct ScalarCounter {
14//!     count: usize,
15//! }
16//!
17//! impl YamlVisitor for ScalarCounter {
18//!     fn visit_scalar(&mut self, _scalar: &Scalar) {
19//!         self.count += 1;
20//!     }
21//! }
22//!
23//! let doc = Document::from_str("key: value\nlist:\n  - item1\n  - item2").unwrap();
24//! let mut counter = ScalarCounter { count: 0 };
25//! doc.accept(&mut counter);
26//! assert_eq!(counter.count, 5); // "key", "value", "list", "item1", "item2"
27//! ```
28//!
29//! Default traversal implementations automatically visit child nodes. Override
30//! `visit_mapping` or `visit_sequence` for custom traversal logic.
31
32use crate::yaml::{Document, Mapping, Scalar, Sequence, YamlFile};
33use rowan::ast::AstNode;
34
35/// Trait for implementing the visitor pattern on YAML nodes.
36pub trait YamlVisitor {
37    /// Visit a YAML root node
38    fn visit_yaml(&mut self, yaml: &YamlFile) {
39        // Default implementation visits all documents
40        for doc in yaml.documents() {
41            self.visit_document(&doc);
42        }
43    }
44
45    /// Visit a document node
46    fn visit_document(&mut self, document: &Document) {
47        // Default implementation visits the document's content
48        if let Some(mapping) = document.as_mapping() {
49            self.visit_mapping(&mapping);
50        } else if let Some(sequence) = document.as_sequence() {
51            self.visit_sequence(&sequence);
52        } else if let Some(scalar) = document.as_scalar() {
53            self.visit_scalar(&scalar);
54        }
55    }
56
57    /// Visit a scalar node
58    fn visit_scalar(&mut self, _scalar: &Scalar) {}
59
60    /// Visit a mapping node
61    ///
62    /// The default implementation traverses all key-value pairs in the mapping,
63    /// visiting both keys and values recursively. Override this method if you need
64    /// custom logic when encountering mappings.
65    fn visit_mapping(&mut self, mapping: &Mapping) {
66        self.walk_mapping(mapping);
67    }
68
69    /// Visit a sequence node
70    ///
71    /// The default implementation traverses all items in the sequence recursively.
72    /// Override this method if you need custom logic when encountering sequences.
73    fn visit_sequence(&mut self, sequence: &Sequence) {
74        self.walk_sequence(sequence);
75    }
76
77    /// Traverse all key-value pairs in a mapping (helper for default traversal).
78    ///
79    /// This method is called by the default `visit_mapping` implementation.
80    /// You can call it explicitly if you override `visit_mapping` and want to
81    /// preserve the default traversal behavior.
82    fn walk_mapping(&mut self, mapping: &Mapping) {
83        use crate::yaml::{extract_mapping, extract_scalar, extract_sequence};
84
85        for (key_node, value_node) in mapping.pairs() {
86            // Visit key
87            if let Some(scalar) = extract_scalar(&key_node) {
88                self.visit_scalar(&scalar);
89            } else if let Some(sequence) = extract_sequence(&key_node) {
90                self.visit_sequence(&sequence);
91            } else if let Some(mapping) = extract_mapping(&key_node) {
92                self.visit_mapping(&mapping);
93            }
94
95            // Visit value
96            if let Some(scalar) = extract_scalar(&value_node) {
97                self.visit_scalar(&scalar);
98            } else if let Some(nested_mapping) = extract_mapping(&value_node) {
99                self.visit_mapping(&nested_mapping);
100            } else if let Some(nested_sequence) = extract_sequence(&value_node) {
101                self.visit_sequence(&nested_sequence);
102            }
103        }
104    }
105
106    /// Traverse all items in a sequence (helper for default traversal).
107    ///
108    /// This method is called by the default `visit_sequence` implementation.
109    /// You can call it explicitly if you override `visit_sequence` and want to
110    /// preserve the default traversal behavior.
111    fn walk_sequence(&mut self, sequence: &Sequence) {
112        for item in sequence.items() {
113            if let Some(scalar) = Scalar::cast(item.clone()) {
114                self.visit_scalar(&scalar);
115            } else if let Some(nested_mapping) = Mapping::cast(item.clone()) {
116                self.visit_mapping(&nested_mapping);
117            } else if let Some(nested_sequence) = Sequence::cast(item.clone()) {
118                self.visit_sequence(&nested_sequence);
119            }
120        }
121    }
122}
123
124/// Trait for nodes that can accept a visitor
125pub trait YamlAccept {
126    /// Accept a visitor for traversal
127    fn accept<V: YamlVisitor>(&self, visitor: &mut V);
128}
129
130impl YamlAccept for YamlFile {
131    fn accept<V: YamlVisitor>(&self, visitor: &mut V) {
132        visitor.visit_yaml(self);
133    }
134}
135
136impl YamlAccept for Document {
137    fn accept<V: YamlVisitor>(&self, visitor: &mut V) {
138        visitor.visit_document(self);
139    }
140}
141
142impl YamlAccept for Scalar {
143    fn accept<V: YamlVisitor>(&self, visitor: &mut V) {
144        visitor.visit_scalar(self);
145    }
146}
147
148impl YamlAccept for Mapping {
149    fn accept<V: YamlVisitor>(&self, visitor: &mut V) {
150        visitor.visit_mapping(self);
151    }
152}
153
154impl YamlAccept for Sequence {
155    fn accept<V: YamlVisitor>(&self, visitor: &mut V) {
156        visitor.visit_sequence(self);
157    }
158}
159
160/// A visitor that collects all scalar values from a YAML document
161pub struct ScalarCollector {
162    /// The collected scalar values
163    pub scalars: Vec<String>,
164}
165
166impl ScalarCollector {
167    /// Create a new scalar collector
168    pub fn new() -> Self {
169        Self {
170            scalars: Vec::new(),
171        }
172    }
173}
174
175impl Default for ScalarCollector {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181impl YamlVisitor for ScalarCollector {
182    fn visit_scalar(&mut self, scalar: &Scalar) {
183        self.scalars.push(scalar.to_string());
184    }
185    // Default traversal methods handle mapping and sequence traversal automatically
186}
187
188/// A visitor that counts different types of nodes
189pub struct NodeCounter {
190    /// Number of document nodes encountered
191    pub document_count: usize,
192    /// Number of scalar nodes encountered
193    pub scalar_count: usize,
194    /// Number of mapping nodes encountered
195    pub mapping_count: usize,
196    /// Number of sequence nodes encountered
197    pub sequence_count: usize,
198}
199
200impl NodeCounter {
201    /// Create a new node counter
202    pub fn new() -> Self {
203        Self {
204            document_count: 0,
205            scalar_count: 0,
206            mapping_count: 0,
207            sequence_count: 0,
208        }
209    }
210
211    /// Get total node count
212    pub fn total(&self) -> usize {
213        self.document_count + self.scalar_count + self.mapping_count + self.sequence_count
214    }
215}
216
217impl Default for NodeCounter {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223impl YamlVisitor for NodeCounter {
224    fn visit_document(&mut self, document: &Document) {
225        self.document_count += 1;
226        // Continue visiting children
227        if let Some(mapping) = document.as_mapping() {
228            self.visit_mapping(&mapping);
229        } else if let Some(sequence) = document.as_sequence() {
230            self.visit_sequence(&sequence);
231        } else if let Some(scalar) = document.as_scalar() {
232            self.visit_scalar(&scalar);
233        }
234    }
235
236    fn visit_scalar(&mut self, _scalar: &Scalar) {
237        self.scalar_count += 1;
238    }
239
240    fn visit_mapping(&mut self, mapping: &Mapping) {
241        self.mapping_count += 1;
242        // Call default traversal to visit children
243        self.walk_mapping(mapping);
244    }
245
246    fn visit_sequence(&mut self, sequence: &Sequence) {
247        self.sequence_count += 1;
248        // Call default traversal to visit children
249        self.walk_sequence(sequence);
250    }
251}
252
253/// A visitor that transforms scalar values
254pub struct ScalarTransformer<F>
255where
256    F: FnMut(&str) -> String,
257{
258    transform: F,
259    transformed: Vec<(String, String)>, // (original, transformed) pairs
260}
261
262impl<F> ScalarTransformer<F>
263where
264    F: FnMut(&str) -> String,
265{
266    /// Create a new scalar transformer with the given transformation function
267    pub fn new(transform: F) -> Self {
268        Self {
269            transform,
270            transformed: Vec::new(),
271        }
272    }
273
274    /// Get the transformed pairs
275    pub fn results(&self) -> &[(String, String)] {
276        &self.transformed
277    }
278}
279
280impl<F> YamlVisitor for ScalarTransformer<F>
281where
282    F: FnMut(&str) -> String,
283{
284    fn visit_scalar(&mut self, scalar: &Scalar) {
285        let original = scalar.to_string();
286        let transformed = (self.transform)(&original);
287        self.transformed.push((original, transformed));
288    }
289
290    // Default traversal methods handle mapping and sequence traversal automatically
291}
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::YamlFile;
296
297    #[test]
298    fn test_scalar_collector() {
299        let yaml_text = r#"
300name: John Doe
301age: 30
302address:
303  street: 123 Main St
304  city: New York
305  country: USA
306hobbies:
307  - reading
308  - coding
309  - hiking
310"#;
311
312        let parsed = YamlFile::parse(yaml_text);
313        let yaml = parsed.tree();
314
315        let mut collector = ScalarCollector::new();
316        yaml.accept(&mut collector);
317
318        // Should collect all scalar values from the document
319        assert_eq!(
320            collector.scalars,
321            vec![
322                "name",
323                "John Doe",
324                "age",
325                "30",
326                "address",
327                "street",
328                "123 Main St",
329                "city",
330                "New York",
331                "country",
332                "USA",
333                "hobbies",
334                "reading",
335                "coding",
336                "hiking",
337            ]
338        );
339    }
340
341    #[test]
342    fn test_node_counter() {
343        let yaml_text = r#"
344name: John Doe
345age: 30
346address:
347  street: 123 Main St
348  city: New York
349  country: USA
350hobbies:
351  - reading
352  - coding
353  - hiking
354"#;
355
356        let parsed = YamlFile::parse(yaml_text);
357        let yaml = parsed.tree();
358
359        let mut counter = NodeCounter::new();
360        yaml.accept(&mut counter);
361
362        // Should count all node types
363        assert_eq!(counter.document_count, 1);
364        assert_eq!(counter.scalar_count, 15);
365        assert_eq!(counter.mapping_count, 2); // Root mapping and address mapping
366        assert_eq!(counter.sequence_count, 1); // hobbies sequence
367
368        // Total should be sum of all counts
369        assert_eq!(
370            counter.total(),
371            counter.document_count
372                + counter.scalar_count
373                + counter.mapping_count
374                + counter.sequence_count
375        );
376    }
377
378    #[test]
379    fn test_scalar_transformer() {
380        let yaml_text = r#"
381name: john
382city: new york
383country: usa
384"#;
385
386        let parsed = YamlFile::parse(yaml_text);
387        let yaml = parsed.tree();
388
389        // Transform all scalars to uppercase
390        let mut transformer = ScalarTransformer::new(|s: &str| s.to_uppercase());
391        yaml.accept(&mut transformer);
392
393        let results = transformer.results();
394
395        // Check that transformations were applied
396        assert_eq!(
397            results,
398            &[
399                ("name".to_string(), "NAME".to_string()),
400                ("john".to_string(), "JOHN".to_string()),
401                ("city".to_string(), "CITY".to_string()),
402                ("new york".to_string(), "NEW YORK".to_string()),
403                ("country".to_string(), "COUNTRY".to_string()),
404                ("usa".to_string(), "USA".to_string()),
405            ]
406        );
407    }
408
409    #[test]
410    fn test_visitor_on_sequence() {
411        let yaml_text = r#"
412- item1
413- item2
414- nested:
415    - subitem1
416    - subitem2
417"#;
418
419        let parsed = YamlFile::parse(yaml_text);
420        let yaml = parsed.tree();
421
422        let mut collector = ScalarCollector::new();
423        yaml.accept(&mut collector);
424
425        // Should collect all scalars including nested ones
426        assert_eq!(
427            collector.scalars,
428            vec!["item1", "item2", "nested", "subitem1", "subitem2"]
429        );
430    }
431
432    #[test]
433    fn test_visitor_on_empty_document() {
434        let yaml_text = "";
435
436        let parsed = YamlFile::parse(yaml_text);
437        let yaml = parsed.tree();
438
439        let mut counter = NodeCounter::new();
440        yaml.accept(&mut counter);
441
442        // Empty document should have no nodes (or minimal structure)
443        assert_eq!(counter.total(), 0);
444    }
445
446    #[test]
447    fn test_custom_visitor() {
448        // Define a custom visitor that counts only keys in mappings
449        struct KeyCounter {
450            key_count: usize,
451        }
452
453        impl YamlVisitor for KeyCounter {
454            fn visit_scalar(&mut self, _scalar: &crate::Scalar) {
455                // Don't count scalars that are not keys
456            }
457
458            fn visit_mapping(&mut self, mapping: &crate::Mapping) {
459                // Count keys in this mapping
460                for (_key, value) in mapping.iter() {
461                    self.key_count += 1;
462                    // Recursively visit nested structures
463                    if let Some(nested_mapping) = value.as_mapping() {
464                        nested_mapping.accept(self);
465                    } else if let Some(nested_sequence) = value.as_sequence() {
466                        nested_sequence.accept(self);
467                    }
468                }
469            }
470
471            fn visit_sequence(&mut self, sequence: &crate::Sequence) {
472                // Visit items in sequence to find nested mappings
473                for value in sequence.values() {
474                    if let Some(nested_mapping) = value.as_mapping() {
475                        nested_mapping.accept(self);
476                    } else if let Some(nested_sequence) = value.as_sequence() {
477                        nested_sequence.accept(self);
478                    }
479                }
480            }
481        }
482
483        let yaml_text = r#"
484name: John
485age: 30
486address:
487  street: 123 Main St
488  city: New York
489metadata:
490  created: 2024-01-01
491  updated: 2024-01-02
492"#;
493
494        let parsed = YamlFile::parse(yaml_text);
495        let yaml = parsed.tree();
496
497        let mut key_counter = KeyCounter { key_count: 0 };
498        yaml.accept(&mut key_counter);
499
500        // Should count: name, age, address, street, city, metadata, created, updated = 8 keys
501        assert_eq!(key_counter.key_count, 8);
502    }
503
504    #[test]
505    fn test_visitor_with_multiple_documents() {
506        let yaml_text = r#"
507---
508doc1: value1
509---
510doc2: value2
511---
512doc3: value3
513"#;
514
515        let parsed = YamlFile::parse(yaml_text);
516        let yaml = parsed.tree();
517
518        let mut counter = NodeCounter::new();
519        yaml.accept(&mut counter);
520
521        // Should count 3 documents
522        assert_eq!(counter.document_count, 3);
523        assert!(counter.scalar_count >= 6); // At least 3 keys and 3 values
524    }
525}