Skip to main content

styx_cst/
ast.rs

1//! Typed AST wrappers over CST nodes.
2//!
3//! These provide a more ergonomic API for navigating the syntax tree
4//! while still preserving access to the underlying CST for source locations.
5
6use crate::syntax_kind::{SyntaxKind, SyntaxNode, SyntaxToken};
7
8/// Trait for AST nodes that wrap CST nodes.
9pub trait AstNode: Sized {
10    /// Try to cast a syntax node to this AST type.
11    fn cast(node: SyntaxNode) -> Option<Self>;
12
13    /// Get the underlying syntax node.
14    fn syntax(&self) -> &SyntaxNode;
15
16    /// Get the source text of this node.
17    fn text(&self) -> std::borrow::Cow<'_, str> {
18        std::borrow::Cow::Owned(self.syntax().to_string())
19    }
20}
21
22/// Macro for defining simple AST node wrappers.
23macro_rules! ast_node {
24    ($(#[$meta:meta])* $name:ident, $kind:expr) => {
25        $(#[$meta])*
26        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
27        pub struct $name(SyntaxNode);
28
29        impl AstNode for $name {
30            fn cast(node: SyntaxNode) -> Option<Self> {
31                if node.kind() == $kind {
32                    Some(Self(node))
33                } else {
34                    None
35                }
36            }
37
38            fn syntax(&self) -> &SyntaxNode {
39                &self.0
40            }
41        }
42    };
43}
44
45ast_node!(
46    /// The root document node.
47    Document,
48    SyntaxKind::DOCUMENT
49);
50
51ast_node!(
52    /// An entry (key-value pair or sequence element).
53    Entry,
54    SyntaxKind::ENTRY
55);
56
57ast_node!(
58    /// An explicit object `{ ... }`.
59    Object,
60    SyntaxKind::OBJECT
61);
62
63ast_node!(
64    /// A sequence `( ... )`.
65    Sequence,
66    SyntaxKind::SEQUENCE
67);
68
69ast_node!(
70    /// A scalar value.
71    Scalar,
72    SyntaxKind::SCALAR
73);
74
75ast_node!(
76    /// A unit value `@`.
77    Unit,
78    SyntaxKind::UNIT
79);
80
81ast_node!(
82    /// A tag `@name` with optional payload.
83    Tag,
84    SyntaxKind::TAG
85);
86
87ast_node!(
88    /// A heredoc value.
89    Heredoc,
90    SyntaxKind::HEREDOC
91);
92
93ast_node!(
94    /// The key part of an entry.
95    Key,
96    SyntaxKind::KEY
97);
98
99ast_node!(
100    /// The value part of an entry.
101    Value,
102    SyntaxKind::VALUE
103);
104
105// === Document ===
106
107impl Document {
108    /// Iterate over top-level entries.
109    pub fn entries(&self) -> impl Iterator<Item = Entry> {
110        self.0.children().filter_map(Entry::cast)
111    }
112}
113
114// === Entry ===
115
116impl Entry {
117    /// Get the key of this entry (if it has one).
118    pub fn key(&self) -> Option<Key> {
119        self.0.children().find_map(Key::cast)
120    }
121
122    /// Get the value of this entry (if it has one).
123    pub fn value(&self) -> Option<Value> {
124        self.0.children().find_map(Value::cast)
125    }
126
127    /// Get the key text.
128    pub fn key_text(&self) -> Option<String> {
129        self.key().map(|k| k.text_content())
130    }
131
132    /// Get preceding doc comments.
133    pub fn doc_comments(&self) -> impl Iterator<Item = SyntaxToken> {
134        // Look for DOC_COMMENT tokens before this entry in the parent
135        self.0
136            .siblings_with_tokens(rowan::Direction::Prev)
137            .skip(1) // Skip self
138            .take_while(|el| {
139                el.kind() == SyntaxKind::WHITESPACE
140                    || el.kind() == SyntaxKind::NEWLINE
141                    || el.kind() == SyntaxKind::DOC_COMMENT
142            })
143            .filter_map(|el| el.into_token())
144            .filter(|t| t.kind() == SyntaxKind::DOC_COMMENT)
145    }
146}
147
148// === Key ===
149
150impl Key {
151    /// Get the text content of this key, processing escapes if quoted.
152    pub fn text_content(&self) -> String {
153        // Get the first meaningful token
154        for child in self.0.children_with_tokens() {
155            match child {
156                rowan::NodeOrToken::Token(token) => {
157                    return match token.kind() {
158                        SyntaxKind::BARE_SCALAR => token.text().to_string(),
159                        SyntaxKind::QUOTED_SCALAR => unescape_quoted(token.text()),
160                        SyntaxKind::RAW_SCALAR => token.text().to_string(),
161                        _ => continue,
162                    };
163                }
164                rowan::NodeOrToken::Node(node) => {
165                    // Recurse into SCALAR node
166                    if node.kind() == SyntaxKind::SCALAR
167                        && let Some(scalar) = Scalar::cast(node)
168                    {
169                        return scalar.text_content();
170                    }
171                }
172            }
173        }
174        String::new()
175    }
176
177    /// Get the raw text without escape processing.
178    pub fn raw_text(&self) -> String {
179        self.0.to_string()
180    }
181}
182
183// === Value ===
184
185impl Value {
186    /// Get the inner value as an enum.
187    pub fn kind(&self) -> ValueKind {
188        for child in self.0.children() {
189            match child.kind() {
190                SyntaxKind::SCALAR => return ValueKind::Scalar(Scalar::cast(child).unwrap()),
191                SyntaxKind::OBJECT => return ValueKind::Object(Object::cast(child).unwrap()),
192                SyntaxKind::SEQUENCE => return ValueKind::Sequence(Sequence::cast(child).unwrap()),
193                SyntaxKind::UNIT => return ValueKind::Unit(Unit::cast(child).unwrap()),
194                SyntaxKind::TAG => return ValueKind::Tag(Tag::cast(child).unwrap()),
195                SyntaxKind::HEREDOC => return ValueKind::Heredoc(Heredoc::cast(child).unwrap()),
196                _ => continue,
197            }
198        }
199        ValueKind::Missing
200    }
201}
202
203/// The kind of value in an entry.
204#[derive(Debug, Clone)]
205pub enum ValueKind {
206    /// A scalar value.
207    Scalar(Scalar),
208    /// An object.
209    Object(Object),
210    /// A sequence.
211    Sequence(Sequence),
212    /// A unit value.
213    Unit(Unit),
214    /// A tag.
215    Tag(Tag),
216    /// A heredoc.
217    Heredoc(Heredoc),
218    /// Missing value (parse error).
219    Missing,
220}
221
222// === Object ===
223
224/// The separator mode detected in an object.
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum Separator {
227    /// Entries separated by newlines.
228    Newline,
229    /// Entries separated by commas.
230    Comma,
231    /// Mixed separators (error).
232    Mixed,
233}
234
235impl Object {
236    /// Iterate over entries in this object.
237    pub fn entries(&self) -> impl Iterator<Item = Entry> {
238        self.0.children().filter_map(Entry::cast)
239    }
240
241    /// Detect the separator mode used in this object.
242    pub fn separator(&self) -> Separator {
243        let mut has_comma = false;
244        let mut has_newline = false;
245
246        for token in self
247            .0
248            .children_with_tokens()
249            .filter_map(|el| el.into_token())
250        {
251            match token.kind() {
252                SyntaxKind::COMMA => has_comma = true,
253                SyntaxKind::NEWLINE => has_newline = true,
254                _ => {}
255            }
256        }
257
258        if has_comma && has_newline {
259            Separator::Mixed
260        } else if has_newline {
261            Separator::Newline
262        } else {
263            // Comma-separated or no separators (single/empty) = inline format
264            Separator::Comma
265        }
266    }
267
268    /// Get an entry by key name.
269    pub fn get(&self, key: &str) -> Option<Entry> {
270        self.entries()
271            .find(|e| e.key_text().as_deref() == Some(key))
272    }
273}
274
275// === Sequence ===
276
277impl Sequence {
278    /// Iterate over elements in this sequence.
279    ///
280    /// The parser wraps sequence elements in ENTRY/KEY nodes for uniformity.
281    /// This method extracts the actual value from each entry.
282    pub fn elements(&self) -> impl Iterator<Item = SyntaxNode> {
283        self.0.children().filter_map(|n| {
284            if n.kind() == SyntaxKind::ENTRY {
285                // Find the KEY child, then get its first value child
286                n.children()
287                    .find(|c| c.kind() == SyntaxKind::KEY)
288                    .and_then(|key| {
289                        key.children().find(|c| {
290                            matches!(
291                                c.kind(),
292                                SyntaxKind::SCALAR
293                                    | SyntaxKind::OBJECT
294                                    | SyntaxKind::SEQUENCE
295                                    | SyntaxKind::UNIT
296                                    | SyntaxKind::TAG
297                                    | SyntaxKind::HEREDOC
298                            )
299                        })
300                    })
301            } else {
302                // Fallback: direct value children (shouldn't happen with current parser)
303                matches!(
304                    n.kind(),
305                    SyntaxKind::SCALAR
306                        | SyntaxKind::OBJECT
307                        | SyntaxKind::SEQUENCE
308                        | SyntaxKind::UNIT
309                        | SyntaxKind::TAG
310                        | SyntaxKind::HEREDOC
311                )
312                .then_some(n)
313            }
314        })
315    }
316
317    /// Iterate over entries in this sequence (ENTRY nodes).
318    pub fn entries(&self) -> impl Iterator<Item = Entry> {
319        self.0.children().filter_map(Entry::cast)
320    }
321
322    /// Get the number of elements.
323    pub fn len(&self) -> usize {
324        self.elements().count()
325    }
326
327    /// Check if empty.
328    pub fn is_empty(&self) -> bool {
329        self.len() == 0
330    }
331
332    /// Check if the sequence is multiline (contains newlines or comments).
333    pub fn is_multiline(&self) -> bool {
334        self.0
335            .children_with_tokens()
336            .filter_map(|el| el.into_token())
337            .any(|t| {
338                matches!(
339                    t.kind(),
340                    SyntaxKind::NEWLINE | SyntaxKind::LINE_COMMENT | SyntaxKind::DOC_COMMENT
341                )
342            })
343    }
344}
345
346// === Scalar ===
347
348/// The kind of scalar.
349#[derive(Debug, Clone, Copy, PartialEq, Eq)]
350pub enum ScalarKind {
351    /// Bare (unquoted) scalar.
352    Bare,
353    /// Quoted string.
354    Quoted,
355    /// Raw string.
356    Raw,
357}
358
359impl Scalar {
360    /// Get the text content, processing escapes for quoted strings.
361    pub fn text_content(&self) -> String {
362        for token in self
363            .0
364            .children_with_tokens()
365            .filter_map(|el| el.into_token())
366        {
367            return match token.kind() {
368                SyntaxKind::BARE_SCALAR => token.text().to_string(),
369                SyntaxKind::QUOTED_SCALAR => unescape_quoted(token.text()),
370                SyntaxKind::RAW_SCALAR => token.text().to_string(),
371                _ => continue,
372            };
373        }
374        String::new()
375    }
376
377    /// Get the raw text without escape processing.
378    pub fn raw_text(&self) -> String {
379        self.0.to_string()
380    }
381
382    /// Get the kind of scalar.
383    pub fn kind(&self) -> ScalarKind {
384        for token in self
385            .0
386            .children_with_tokens()
387            .filter_map(|el| el.into_token())
388        {
389            return match token.kind() {
390                SyntaxKind::BARE_SCALAR => ScalarKind::Bare,
391                SyntaxKind::QUOTED_SCALAR => ScalarKind::Quoted,
392                SyntaxKind::RAW_SCALAR => ScalarKind::Raw,
393                _ => continue,
394            };
395        }
396        ScalarKind::Bare
397    }
398}
399
400// === Tag ===
401
402impl Tag {
403    /// Get the tag name (without @).
404    pub fn name(&self) -> Option<String> {
405        // The tag token is @name, so we strip the @ prefix
406        self.0
407            .children_with_tokens()
408            .filter_map(|el| el.into_token())
409            .find(|t| t.kind() == SyntaxKind::TAG_TOKEN)
410            .map(|t| t.text()[1..].to_string()) // Skip the '@' prefix
411    }
412
413    /// Get the tag payload if present.
414    pub fn payload(&self) -> Option<SyntaxNode> {
415        self.0
416            .children()
417            .find(|n| n.kind() == SyntaxKind::TAG_PAYLOAD)
418            .and_then(|n| n.children().next())
419    }
420}
421
422// === Heredoc ===
423
424impl Heredoc {
425    /// Get the heredoc content (without delimiters).
426    pub fn content(&self) -> String {
427        for token in self
428            .0
429            .children_with_tokens()
430            .filter_map(|el| el.into_token())
431        {
432            if token.kind() == SyntaxKind::HEREDOC_CONTENT {
433                return token.text().to_string();
434            }
435        }
436        String::new()
437    }
438
439    /// Get the delimiter name.
440    pub fn delimiter(&self) -> Option<String> {
441        for token in self
442            .0
443            .children_with_tokens()
444            .filter_map(|el| el.into_token())
445        {
446            if token.kind() == SyntaxKind::HEREDOC_START {
447                // Extract delimiter from <<DELIM\n
448                let text = token.text();
449                if let Some(rest) = text.strip_prefix("<<") {
450                    return Some(rest.trim_end().to_string());
451                }
452            }
453        }
454        None
455    }
456}
457
458// === Helpers ===
459
460/// Process escape sequences in a quoted string.
461fn unescape_quoted(text: &str) -> String {
462    // Remove surrounding quotes
463    let inner = text
464        .strip_prefix('"')
465        .and_then(|s| s.strip_suffix('"'))
466        .unwrap_or(text);
467
468    let mut result = String::with_capacity(inner.len());
469    let mut chars = inner.chars().peekable();
470
471    while let Some(c) = chars.next() {
472        if c == '\\' {
473            match chars.next() {
474                Some('n') => result.push('\n'),
475                Some('r') => result.push('\r'),
476                Some('t') => result.push('\t'),
477                Some('\\') => result.push('\\'),
478                Some('"') => result.push('"'),
479                Some(c) => {
480                    // Unknown escape, keep as-is
481                    result.push('\\');
482                    result.push(c);
483                }
484                None => result.push('\\'),
485            }
486        } else {
487            result.push(c);
488        }
489    }
490
491    result
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use crate::parser::parse;
498
499    fn doc(source: &str) -> Document {
500        let p = parse(source);
501        assert!(p.is_ok(), "parse errors: {:?}", p.errors());
502        Document::cast(p.syntax()).unwrap()
503    }
504
505    #[test]
506    fn test_document_entries() {
507        let d = doc("a 1\nb 2\nc 3");
508        let entries: Vec<_> = d.entries().collect();
509        assert_eq!(entries.len(), 3);
510    }
511
512    #[test]
513    fn test_entry_key_value() {
514        let d = doc("host localhost");
515        let entry = d.entries().next().unwrap();
516
517        assert_eq!(entry.key_text(), Some("host".to_string()));
518
519        let value = entry.value().unwrap();
520        if let ValueKind::Scalar(s) = value.kind() {
521            assert_eq!(s.text_content(), "localhost");
522        } else {
523            panic!("expected scalar value");
524        }
525    }
526
527    #[test]
528    fn test_object_entries() {
529        let d = doc("config { host localhost, port 8080 }");
530        let entry = d.entries().next().unwrap();
531        let value = entry.value().unwrap();
532
533        if let ValueKind::Object(obj) = value.kind() {
534            assert_eq!(obj.separator(), Separator::Comma);
535
536            let entries: Vec<_> = obj.entries().collect();
537            assert_eq!(entries.len(), 2);
538
539            assert_eq!(entries[0].key_text(), Some("host".to_string()));
540            assert_eq!(entries[1].key_text(), Some("port".to_string()));
541        } else {
542            panic!("expected object value");
543        }
544    }
545
546    #[test]
547    fn test_object_get() {
548        let d = doc("{ name Alice, age 30 }");
549        let entry = d.entries().next().unwrap();
550        let key = entry.key().unwrap();
551        let obj_node = key.syntax().children().next().unwrap();
552        let obj = Object::cast(obj_node).unwrap();
553
554        let name_entry = obj.get("name").unwrap();
555        let val = name_entry.value().unwrap();
556        if let ValueKind::Scalar(s) = val.kind() {
557            assert_eq!(s.text_content(), "Alice");
558        }
559    }
560
561    #[test]
562    fn test_sequence() {
563        let d = doc("items (a b c)");
564        let entry = d.entries().next().unwrap();
565        let value = entry.value().unwrap();
566
567        if let ValueKind::Sequence(seq) = value.kind() {
568            assert_eq!(seq.len(), 3);
569        } else {
570            panic!("expected sequence value");
571        }
572    }
573
574    #[test]
575    fn test_quoted_string_escapes() {
576        let d = doc(r#"msg "hello\nworld""#);
577        let entry = d.entries().next().unwrap();
578        let value = entry.value().unwrap();
579
580        if let ValueKind::Scalar(s) = value.kind() {
581            assert_eq!(s.text_content(), "hello\nworld");
582            assert_eq!(s.kind(), ScalarKind::Quoted);
583        } else {
584            panic!("expected scalar value");
585        }
586    }
587
588    #[test]
589    fn test_tag() {
590        // Tag with attached payload (no space) - payload IS part of tag
591        let d = doc("key @Some(value)");
592        let entry = d.entries().next().unwrap();
593        let value = entry.value().unwrap();
594        let tag_node = value.syntax().children().next().unwrap();
595        let tag = Tag::cast(tag_node).unwrap();
596
597        assert_eq!(tag.name(), Some("Some".to_string()));
598        assert!(tag.payload().is_some(), "attached payload should exist");
599    }
600
601    #[test]
602    fn test_tag_without_payload() {
603        // Tag with space before next value - NO payload (per grammar)
604        let d = doc("@Some value");
605        let entry = d.entries().next().unwrap();
606        let key = entry.key().unwrap();
607        let tag_node = key.syntax().children().next().unwrap();
608        let tag = Tag::cast(tag_node).unwrap();
609
610        assert_eq!(tag.name(), Some("Some".to_string()));
611        assert!(
612            tag.payload().is_none(),
613            "spaced value should not be payload"
614        );
615
616        // The value should be separate
617        let value = entry.value().unwrap();
618        assert!(matches!(value.kind(), ValueKind::Scalar(_)));
619    }
620
621    #[test]
622    fn test_chained_tag_payload_is_nested_tag() {
623        let d = doc("value @must_emit/@discover_start{executor default}");
624        let entry = d.entries().next().unwrap();
625        let value = entry.value().unwrap();
626        let outer_tag = Tag::cast(value.syntax().children().next().unwrap()).unwrap();
627
628        assert_eq!(outer_tag.name(), Some("must_emit".to_string()));
629
630        let inner = outer_tag.payload().unwrap();
631        let inner_tag = Tag::cast(inner).unwrap();
632        assert_eq!(inner_tag.name(), Some("discover_start".to_string()));
633        assert!(
634            inner_tag.payload().is_some(),
635            "inner tag should keep payload"
636        );
637    }
638
639    #[test]
640    fn test_three_segment_chained_tag_payload_is_nested_tags() {
641        let d = doc("value @a/@b/@c");
642        let entry = d.entries().next().unwrap();
643        let value = entry.value().unwrap();
644        let outer = Tag::cast(value.syntax().children().next().unwrap()).unwrap();
645        assert_eq!(outer.name(), Some("a".to_string()));
646
647        let middle = Tag::cast(outer.payload().unwrap()).unwrap();
648        assert_eq!(middle.name(), Some("b".to_string()));
649
650        let inner = Tag::cast(middle.payload().unwrap()).unwrap();
651        assert_eq!(inner.name(), Some("c".to_string()));
652        assert!(inner.payload().is_none(), "leaf tag should be unit");
653    }
654
655    #[test]
656    fn test_chained_tag_scalar_leaf_payload_is_scalar() {
657        let d = doc(r#"value @a/@b"foo""#);
658        let entry = d.entries().next().unwrap();
659        let value = entry.value().unwrap();
660        let outer = Tag::cast(value.syntax().children().next().unwrap()).unwrap();
661        let inner = Tag::cast(outer.payload().unwrap()).unwrap();
662        let payload = Scalar::cast(inner.payload().unwrap()).unwrap();
663        assert_eq!(payload.text_content(), "foo");
664    }
665
666    #[test]
667    fn test_unit() {
668        let d = doc("empty @");
669        let entry = d.entries().next().unwrap();
670        let value = entry.value().unwrap();
671
672        assert!(matches!(value.kind(), ValueKind::Unit(_)));
673    }
674
675    #[test]
676    fn test_unescape_quoted() {
677        assert_eq!(unescape_quoted(r#""hello""#), "hello");
678        assert_eq!(unescape_quoted(r#""hello\nworld""#), "hello\nworld");
679        assert_eq!(unescape_quoted(r#""tab\there""#), "tab\there");
680        assert_eq!(unescape_quoted(r#""quote\"here""#), "quote\"here");
681        assert_eq!(unescape_quoted(r#""back\\slash""#), "back\\slash");
682    }
683}