squawk_syntax/
lib.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/lib.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27pub mod ast;
28pub mod identifier;
29mod parsing;
30pub mod syntax_error;
31mod syntax_node;
32mod token_text;
33mod validation;
34
35#[cfg(test)]
36mod test;
37
38use std::{marker::PhantomData, sync::Arc};
39
40pub use squawk_parser::SyntaxKind;
41
42use ast::AstNode;
43use rowan::GreenNode;
44use syntax_error::SyntaxError;
45pub use syntax_node::{SyntaxNode, SyntaxToken};
46pub use token_text::TokenText;
47
48/// `Parse` is the result of the parsing: a syntax tree and a collection of
49/// errors.
50///
51/// Note that we always produce a syntax tree, even for completely invalid
52/// files.
53#[derive(Debug, PartialEq, Eq)]
54pub struct Parse<T> {
55    green: GreenNode,
56    errors: Option<Arc<[SyntaxError]>>,
57    _ty: PhantomData<fn() -> T>,
58}
59
60impl<T> Clone for Parse<T> {
61    fn clone(&self) -> Parse<T> {
62        Parse {
63            green: self.green.clone(),
64            errors: self.errors.clone(),
65            _ty: PhantomData,
66        }
67    }
68}
69
70impl<T> Parse<T> {
71    fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
72        Parse {
73            green,
74            errors: if errors.is_empty() {
75                None
76            } else {
77                Some(errors.into())
78            },
79            _ty: PhantomData,
80        }
81    }
82
83    pub fn syntax_node(&self) -> SyntaxNode {
84        SyntaxNode::new_root(self.green.clone())
85    }
86
87    pub fn errors(&self) -> Vec<SyntaxError> {
88        let mut errors = if let Some(e) = self.errors.as_deref() {
89            e.to_vec()
90        } else {
91            vec![]
92        };
93        validation::validate(&self.syntax_node(), &mut errors);
94        errors
95    }
96}
97
98impl<T: AstNode> Parse<T> {
99    /// Converts this parse result into a parse result for an untyped syntax tree.
100    pub fn to_syntax(self) -> Parse<SyntaxNode> {
101        Parse {
102            green: self.green,
103            errors: self.errors,
104            _ty: PhantomData,
105        }
106    }
107
108    /// Gets the parsed syntax tree as a typed ast node.
109    ///
110    /// # Panics
111    ///
112    /// Panics if the root node cannot be casted into the typed ast node
113    /// (e.g. if it's an `ERROR` node).
114    pub fn tree(&self) -> T {
115        T::cast(self.syntax_node()).unwrap()
116    }
117
118    /// Converts from `Parse<T>` to [`Result<T, Vec<SyntaxError>>`].
119    pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
120        match self.errors() {
121            errors if !errors.is_empty() => Err(errors),
122            _ => Ok(self.tree()),
123        }
124    }
125}
126
127impl Parse<SyntaxNode> {
128    pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
129        if N::cast(self.syntax_node()).is_some() {
130            Some(Parse {
131                green: self.green,
132                errors: self.errors,
133                _ty: PhantomData,
134            })
135        } else {
136            None
137        }
138    }
139}
140
141/// `SourceFile` represents a parse tree for a single SQL file.
142pub use crate::ast::SourceFile;
143
144impl SourceFile {
145    pub fn parse(text: &str) -> Parse<SourceFile> {
146        let (green, errors) = parsing::parse_text(text);
147        let root = SyntaxNode::new_root(green.clone());
148
149        assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
150        Parse::new(green, errors)
151    }
152}
153
154/// Matches a `SyntaxNode` against an `ast` type.
155///
156/// # Example:
157///
158/// ```ignore
159/// match_ast! {
160///     match node {
161///         ast::CallExpr(it) => { ... },
162///         ast::MethodCallExpr(it) => { ... },
163///         ast::MacroCall(it) => { ... },
164///         _ => None,
165///     }
166/// }
167/// ```
168#[macro_export]
169macro_rules! match_ast {
170    (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
171
172    (match ($node:expr) {
173        $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
174        _ => $catch_all:expr $(,)?
175    }) => {{
176        $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
177        { $catch_all }
178    }};
179}
180
181/// This test does not assert anything and instead just shows off the crate's
182/// API.
183#[test]
184fn api_walkthrough() {
185    use ast::SourceFile;
186    use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
187    use std::fmt::Write;
188
189    let source_code = "
190        create function foo(p int8)
191        returns int
192        as 'select 1 + 1'
193        language sql;
194    ";
195    // `SourceFile` is the main entry point.
196    //
197    // The `parse` method returns a `Parse` -- a pair of syntax tree and a list
198    // of errors. That is, syntax tree is constructed even in presence of errors.
199    let parse = SourceFile::parse(source_code);
200    assert!(parse.errors().is_empty());
201
202    // The `tree` method returns an owned syntax node of type `SourceFile`.
203    // Owned nodes are cheap: inside, they are `Rc` handles to the underling data.
204    let file: SourceFile = parse.tree();
205
206    // `SourceFile` is the root of the syntax tree. We can iterate file's items.
207    // Let's fetch the `foo` function.
208    let mut func = None;
209    for stmt in file.stmts() {
210        match stmt {
211            ast::Stmt::CreateFunction(f) => func = Some(f),
212            _ => unreachable!(),
213        }
214    }
215    let func: ast::CreateFunction = func.unwrap();
216
217    // Each AST node has a bunch of getters for children. All getters return
218    // `Option`s though, to account for incomplete code. Some getters are common
219    // for several kinds of node. In this case, a trait like `ast::NameOwner`
220    // usually exists. By convention, all ast types should be used with `ast::`
221    // qualifier.
222    let path: Option<ast::Path> = func.path();
223    let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
224    assert_eq!(name.text(), "foo");
225
226    // return
227    let ret_type: Option<ast::RetType> = func.ret_type();
228    let r_ty = &ret_type.unwrap().ty().unwrap();
229    let type_: &ast::PathType = match &r_ty {
230        ast::Type::PathType(r) => r,
231        _ => unreachable!(),
232    };
233    let type_path: ast::Path = type_.path().unwrap();
234    assert_eq!(type_path.syntax().to_string(), "int");
235
236    // params
237    let param_list: ast::ParamList = func.param_list().unwrap();
238    let param: ast::Param = param_list.params().next().unwrap();
239
240    let param_name: ast::Name = param.name().unwrap();
241    assert_eq!(param_name.syntax().to_string(), "p");
242
243    let param_ty: ast::Type = param.ty().unwrap();
244    assert_eq!(param_ty.syntax().to_string(), "int8");
245
246    let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
247
248    // Enums are used to group related ast nodes together, and can be used for
249    // matching. However, because there are no public fields, it's possible to
250    // match only the top level enum: that is the price we pay for increased API
251    // flexibility
252    let func_option = func_option_list.options().next().unwrap();
253    let option: &ast::AsFuncOption = match &func_option {
254        ast::FuncOption::AsFuncOption(o) => o,
255        _ => unreachable!(),
256    };
257    let definition: ast::Literal = option.definition().unwrap();
258    assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
259
260    // Besides the "typed" AST API, there's an untyped CST one as well.
261    // To switch from AST to CST, call `.syntax()` method:
262    let func_option_syntax: &SyntaxNode = func_option.syntax();
263
264    // Note how `expr` and `bin_expr` are in fact the same node underneath:
265    assert!(func_option_syntax == option.syntax());
266
267    // To go from CST to AST, `AstNode::cast` function is used:
268    let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
269        Some(e) => e,
270        None => unreachable!(),
271    };
272
273    // The two properties each syntax node has is a `SyntaxKind`:
274    assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
275
276    // And text range:
277    assert_eq!(
278        func_option_syntax.text_range(),
279        TextRange::new(65.into(), 82.into())
280    );
281
282    // You can get node's text as a `SyntaxText` object, which will traverse the
283    // tree collecting token's text:
284    let text: SyntaxText = func_option_syntax.text();
285    assert_eq!(text.to_string(), "as 'select 1 + 1'");
286
287    // There's a bunch of traversal methods on `SyntaxNode`:
288    assert_eq!(
289        func_option_syntax.parent().as_ref(),
290        Some(func_option_list.syntax())
291    );
292    assert_eq!(
293        param_list
294            .syntax()
295            .first_child_or_token()
296            .map(|it| it.kind()),
297        Some(SyntaxKind::L_PAREN)
298    );
299    assert_eq!(
300        func_option_syntax
301            .next_sibling_or_token()
302            .map(|it| it.kind()),
303        Some(SyntaxKind::WHITESPACE)
304    );
305
306    // As well as some iterator helpers:
307    let f = func_option_syntax
308        .ancestors()
309        .find_map(ast::CreateFunction::cast);
310    assert_eq!(f, Some(func));
311    assert!(
312        param
313            .syntax()
314            .siblings_with_tokens(Direction::Next)
315            .any(|it| it.kind() == SyntaxKind::R_PAREN)
316    );
317    assert_eq!(
318        func_option_syntax.descendants_with_tokens().count(),
319        5, // 5 tokens `1`, ` `, `+`, ` `, `1`
320           // 2 child literal expressions: `1`, `1`
321           // 1 the node itself: `1 + 1`
322    );
323
324    // There's also a `preorder` method with a more fine-grained iteration control:
325    let mut buf = String::new();
326    let mut indent = 0;
327    for event in func_option_syntax.preorder_with_tokens() {
328        match event {
329            WalkEvent::Enter(node) => {
330                let text = match &node {
331                    NodeOrToken::Node(it) => it.text().to_string(),
332                    NodeOrToken::Token(it) => it.text().to_owned(),
333                };
334                buf.write_fmt(format_args!(
335                    "{:indent$}{:?} {:?}\n",
336                    " ",
337                    text,
338                    node.kind(),
339                    indent = indent
340                ))
341                .unwrap();
342                indent += 2;
343            }
344            WalkEvent::Leave(_) => indent -= 2,
345        }
346    }
347    assert_eq!(indent, 0);
348    assert_eq!(
349        buf.trim(),
350        r#"
351"as 'select 1 + 1'" AS_FUNC_OPTION
352  "as" AS_KW
353  " " WHITESPACE
354  "'select 1 + 1'" LITERAL
355    "'select 1 + 1'" STRING
356    "#
357        .trim()
358    );
359
360    // To recursively process the tree, there are three approaches:
361    // 1. explicitly call getter methods on AST nodes.
362    // 2. use descendants and `AstNode::cast`.
363    // 3. use descendants and `match_ast!`.
364    //
365    // Here's how the first one looks like:
366    let exprs_cast: Vec<String> = file
367        .syntax()
368        .descendants()
369        .filter_map(ast::FuncOption::cast)
370        .map(|expr| expr.syntax().text().to_string())
371        .collect();
372
373    // An alternative is to use a macro.
374    let mut exprs_visit = Vec::new();
375    for node in file.syntax().descendants() {
376        match_ast! {
377            match node {
378                ast::FuncOption(it) => {
379                    let res = it.syntax().text().to_string();
380                    exprs_visit.push(res);
381                },
382                _ => (),
383            }
384        }
385    }
386    assert_eq!(exprs_cast, exprs_visit);
387}
388
389#[test]
390fn create_table() {
391    use insta::assert_debug_snapshot;
392
393    let source_code = "
394        create table users (
395            id int8 primary key,
396            name varchar(255) not null,
397            email text,
398            created_at timestamp default now()
399        );
400        
401        create table posts (
402            id serial primary key,
403            title varchar(500),
404            content text,
405            user_id int8 references users(id)
406        );
407    ";
408
409    let parse = SourceFile::parse(source_code);
410    assert!(parse.errors().is_empty());
411    let file: SourceFile = parse.tree();
412
413    let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
414
415    for stmt in file.stmts() {
416        if let ast::Stmt::CreateTable(create_table) = stmt {
417            let table_name = create_table.path().unwrap().syntax().to_string();
418            let mut columns = vec![];
419            for arg in create_table.table_arg_list().unwrap().args() {
420                match arg {
421                    ast::TableArg::Column(column) => {
422                        let column_name = column.name().unwrap();
423                        let column_type = column.ty().unwrap();
424                        columns.push((
425                            column_name.syntax().to_string(),
426                            column_type.syntax().to_string(),
427                        ));
428                    }
429                    ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
430                }
431            }
432            tables.push((table_name, columns));
433        }
434    }
435
436    assert_debug_snapshot!(tables, @r#"
437    [
438        (
439            "users",
440            [
441                (
442                    "id",
443                    "int8",
444                ),
445                (
446                    "name",
447                    "varchar(255)",
448                ),
449                (
450                    "email",
451                    "text",
452                ),
453                (
454                    "created_at",
455                    "timestamp",
456                ),
457            ],
458        ),
459        (
460            "posts",
461            [
462                (
463                    "id",
464                    "serial",
465                ),
466                (
467                    "title",
468                    "varchar(500)",
469                ),
470                (
471                    "content",
472                    "text",
473                ),
474                (
475                    "user_id",
476                    "int8",
477                ),
478            ],
479        ),
480    ]
481    "#)
482}