Skip to main content

xidl_parser/
parser.rs

1use crate::error::ParserResult;
2use std::borrow::Cow;
3use std::collections::{HashMap, HashSet};
4use tree_sitter::Node;
5
6pub struct ParseContext<'a> {
7    pub source: &'a [u8],
8    pub symbols: HashMap<String, String>,
9    doc_consumed: HashSet<usize>,
10}
11
12impl<'a> ParseContext<'a> {
13    pub fn new(source: &'a [u8]) -> Self {
14        Self {
15            source,
16            symbols: HashMap::new(),
17            doc_consumed: HashSet::new(),
18        }
19    }
20
21    pub fn node_text(&self, node: &Node) -> ParserResult<&str> {
22        Ok(node.utf8_text(self.source)?)
23    }
24
25    pub fn take_doc_comment(&mut self, node: &Node) -> Option<String> {
26        let start = node.start_byte();
27        if self.doc_consumed.contains(&start) {
28            return None;
29        }
30        let doc = self.extract_doc_comment(start);
31        if doc.is_some() {
32            self.doc_consumed.insert(start);
33        }
34        doc
35    }
36
37    fn extract_doc_comment(&self, start: usize) -> Option<String> {
38        if start == 0 {
39            return None;
40        }
41        let src = self.source;
42        let mut line_end = if start > 0 && src[start - 1] == b'\n' {
43            start - 1
44        } else {
45            start
46        };
47        let mut lines = Vec::new();
48        let mut first = true;
49        loop {
50            let mut line_start = 0;
51            if line_end > 0 {
52                let mut i = line_end;
53                while i > 0 && src[i - 1] != b'\n' {
54                    i -= 1;
55                }
56                line_start = i;
57            }
58            let mut line = &src[line_start..line_end];
59            if line.ends_with(b"\r") {
60                line = &line[..line.len() - 1];
61            }
62            if line.iter().all(|b| b.is_ascii_whitespace()) {
63                if first {
64                    if line_start == 0 {
65                        break;
66                    }
67                    line_end = line_start - 1;
68                    first = false;
69                    continue;
70                }
71                break;
72            }
73            first = false;
74            let mut idx = 0;
75            while idx < line.len() && line[idx].is_ascii_whitespace() {
76                idx += 1;
77            }
78            let trimmed = &line[idx..];
79            if trimmed.starts_with(b"///") {
80                let mut content = &trimmed[3..];
81                if content.first() == Some(&b' ') {
82                    content = &content[1..];
83                }
84                lines.push(String::from_utf8_lossy(content).to_string());
85                if line_start == 0 {
86                    break;
87                }
88                line_end = line_start - 1;
89                continue;
90            }
91            break;
92        }
93        if lines.is_empty() {
94            None
95        } else {
96            lines.reverse();
97            Some(lines.join("\n"))
98        }
99    }
100}
101
102pub trait FromTreeSitter<'a>: Sized {
103    fn from_node(node: Node<'a>, context: &mut ParseContext<'a>) -> ParserResult<Self>;
104}
105
106impl<'a> FromTreeSitter<'a> for String {
107    fn from_node(node: Node<'a>, context: &mut ParseContext<'a>) -> ParserResult<Self> {
108        Ok(context.node_text(&node)?.to_string())
109    }
110}
111
112impl<'a, T> FromTreeSitter<'a> for Box<T>
113where
114    T: FromTreeSitter<'a>,
115{
116    fn from_node(node: Node<'a>, context: &mut ParseContext<'a>) -> ParserResult<Self> {
117        Ok(Box::new(T::from_node(node, context)?))
118    }
119}
120
121pub fn parser_text(text: &str) -> ParserResult<crate::typed_ast::Specification> {
122    use crate::typed_ast::Specification;
123
124    let mut parser = tree_sitter::Parser::new();
125    parser.set_language(&tree_sitter_idl::language()).unwrap();
126
127    let normalized = normalize_source_for_tree_sitter(text);
128
129    let tree = parser.parse(normalized.as_ref(), None).ok_or_else(|| {
130        crate::error::ParseError::TreeSitterError("Failed to parse text".to_string())
131    })?;
132
133    let root_node = tree.root_node();
134    if root_node.has_error() {
135        return Err(crate::error::ParseError::TreeSitterError(
136            "Failed to parse text".to_string(),
137        ));
138    }
139    let mut context = ParseContext::new(text.as_bytes());
140
141    Specification::from_node(root_node, &mut context)
142}
143
144pub fn normalize_source_for_tree_sitter(text: &str) -> Cow<'_, str> {
145    let bytes = text.as_bytes();
146    let mut out = String::with_capacity(text.len());
147    let mut changed = false;
148    let mut i = 0usize;
149    let mut quote = None;
150
151    while i < bytes.len() {
152        let ch = bytes[i] as char;
153
154        if let Some(current_quote) = quote {
155            out.push(ch);
156            if ch == '\\' && i + 1 < bytes.len() {
157                i += 1;
158                out.push(bytes[i] as char);
159            } else if ch == current_quote {
160                quote = None;
161            }
162            i += 1;
163            continue;
164        }
165
166        if ch == '"' || ch == '\'' {
167            quote = Some(ch);
168            out.push(ch);
169            i += 1;
170            continue;
171        }
172
173        if ch == '@' {
174            out.push(ch);
175            i += 1;
176            while i < bytes.len() {
177                let next = bytes[i] as char;
178                if next.is_ascii_alphanumeric() || next == '_' || next == '-' || next == ':' {
179                    out.push(next);
180                    i += 1;
181                    continue;
182                }
183                break;
184            }
185            if i < bytes.len() && bytes[i] as char == '(' {
186                let mut j = i + 1;
187                let mut inner_quote = None;
188                let mut depth = 1usize;
189                let mut has_bracket = false;
190                while j < bytes.len() {
191                    let current = bytes[j] as char;
192                    if let Some(q) = inner_quote {
193                        if current == '\\' && j + 1 < bytes.len() {
194                            j += 2;
195                            continue;
196                        }
197                        if current == q {
198                            inner_quote = None;
199                        }
200                        j += 1;
201                        continue;
202                    }
203                    match current {
204                        '"' | '\'' => inner_quote = Some(current),
205                        '[' | ']' => has_bracket = true,
206                        '(' => depth += 1,
207                        ')' => {
208                            depth -= 1;
209                            if depth == 0 {
210                                break;
211                            }
212                        }
213                        _ => {}
214                    }
215                    j += 1;
216                }
217                if j < bytes.len() && has_bracket {
218                    out.push('(');
219                    for _ in i + 1..j {
220                        out.push(' ');
221                    }
222                    out.push(')');
223                    changed = true;
224                    i = j + 1;
225                    continue;
226                }
227            }
228            continue;
229        }
230
231        out.push(ch);
232        i += 1;
233    }
234
235    if changed {
236        Cow::Owned(out)
237    } else {
238        Cow::Borrowed(text)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::parser_text;
245    use crate::typed_ast::{
246        AnnotationAppl, AnnotationName, AnnotationParams, Definition, TemplateTypeSpec,
247        TypeDclInner, TypeDeclaratorInner, TypeSpec,
248    };
249
250    #[test]
251    fn parse_template_type_spec() {
252        let typed = parser_text(
253            r#"
254            module m {
255                typedef Vec<long> MyVec;
256            };
257            "#,
258        )
259        .expect("parse should succeed");
260
261        let module = match &typed.0[0] {
262            Definition::ModuleDcl(module) => module,
263            other => panic!("expected module, got {other:?}"),
264        };
265        let type_dcl = match &module.definition[0] {
266            Definition::TypeDcl(type_dcl) => type_dcl,
267            other => panic!("expected type declaration, got {other:?}"),
268        };
269        let typedef = match &type_dcl.decl {
270            TypeDclInner::TypedefDcl(typedef) => typedef,
271            other => panic!("expected typedef, got {other:?}"),
272        };
273        let template = match &typedef.decl.ty {
274            TypeDeclaratorInner::TemplateTypeSpec(TemplateTypeSpec::TemplateType(template)) => {
275                template
276            }
277            other => panic!("expected template_type, got {other:?}"),
278        };
279        assert_eq!(template.ident.0, "Vec");
280        assert_eq!(template.args.len(), 1);
281        assert!(matches!(
282            template.args[0],
283            TypeSpec::SimpleTypeSpec(crate::typed_ast::SimpleTypeSpec::BaseTypeSpec(
284                crate::typed_ast::BaseTypeSpec::IntegerType(_)
285            ))
286        ));
287    }
288
289    #[test]
290    fn parse_doc_comments_as_doc_annotation() {
291        let typed = parser_text(
292            r#"
293            /// module doc
294            module m {
295                /// struct doc
296                struct S {
297                    /// field doc
298                    long x;
299                };
300            };
301            "#,
302        )
303        .expect("parse should succeed");
304
305        let module = match &typed.0[0] {
306            Definition::ModuleDcl(module) => module,
307            other => panic!("expected module, got {other:?}"),
308        };
309        assert_has_doc(&module.annotations, "\"module doc\"");
310
311        let type_dcl = match &module.definition[0] {
312            Definition::TypeDcl(type_dcl) => type_dcl,
313            other => panic!("expected type declaration, got {other:?}"),
314        };
315        assert_has_doc(&type_dcl.annotations, "\"struct doc\"");
316
317        let struct_def = match &type_dcl.decl {
318            TypeDclInner::ConstrTypeDcl(crate::typed_ast::ConstrTypeDcl::StructDcl(
319                crate::typed_ast::StructDcl::StructDef(def),
320            )) => def,
321            other => panic!("expected struct def, got {other:?}"),
322        };
323        let member = &struct_def.member[0];
324        assert_has_doc(&member.annotations, "\"field doc\"");
325    }
326
327    fn assert_has_doc(annotations: &[AnnotationAppl], expected: &str) {
328        let doc = annotations.iter().find_map(|anno| match &anno.name {
329            AnnotationName::Builtin(name) if name == "doc" => match &anno.params {
330                Some(AnnotationParams::Raw(raw)) => Some(raw.as_str()),
331                _ => None,
332            },
333            _ => None,
334        });
335        assert_eq!(doc, Some(expected));
336    }
337}