Skip to main content

squawk_syntax/
lib.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/lib.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27pub mod ast;
28mod generated;
29pub mod identifier;
30mod parsing;
31mod ptr;
32pub mod quote;
33pub mod syntax_error;
34mod syntax_node;
35mod token_text;
36mod unescape;
37mod validation;
38
39#[cfg(test)]
40mod test;
41
42use std::{marker::PhantomData, sync::Arc};
43
44pub use squawk_parser::SyntaxKind;
45
46use ast::AstNode;
47pub use ptr::{AstPtr, SyntaxNodePtr};
48use rowan::GreenNode;
49use syntax_error::SyntaxError;
50pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
51pub use token_text::TokenText;
52
53/// `Parse` is the result of the parsing: a syntax tree and a collection of
54/// errors.
55///
56/// Note that we always produce a syntax tree, even for completely invalid
57/// files.
58#[derive(Debug, PartialEq, Eq)]
59pub struct Parse<T> {
60    green: GreenNode,
61    errors: Option<Arc<[SyntaxError]>>,
62    _ty: PhantomData<fn() -> T>,
63}
64
65impl<T> Clone for Parse<T> {
66    fn clone(&self) -> Parse<T> {
67        Parse {
68            green: self.green.clone(),
69            errors: self.errors.clone(),
70            _ty: PhantomData,
71        }
72    }
73}
74
75impl<T> Parse<T> {
76    fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
77        Parse {
78            green,
79            errors: if errors.is_empty() {
80                None
81            } else {
82                Some(errors.into())
83            },
84            _ty: PhantomData,
85        }
86    }
87
88    pub fn syntax_node(&self) -> SyntaxNode {
89        SyntaxNode::new_root(self.green.clone())
90    }
91
92    pub fn errors(&self) -> Vec<SyntaxError> {
93        let mut errors = if let Some(e) = self.errors.as_deref() {
94            e.to_vec()
95        } else {
96            vec![]
97        };
98        validation::validate(&self.syntax_node(), &mut errors);
99        errors.sort_by_key(|error| error.range().start());
100        errors
101    }
102}
103
104impl<T: AstNode> Parse<T> {
105    /// Converts this parse result into a parse result for an untyped syntax tree.
106    pub fn to_syntax(self) -> Parse<SyntaxNode> {
107        Parse {
108            green: self.green,
109            errors: self.errors,
110            _ty: PhantomData,
111        }
112    }
113
114    /// Gets the parsed syntax tree as a typed ast node.
115    ///
116    /// # Panics
117    ///
118    /// Panics if the root node cannot be casted into the typed ast node
119    /// (e.g. if it's an `ERROR` node).
120    pub fn tree(&self) -> T {
121        T::cast(self.syntax_node()).unwrap()
122    }
123
124    /// Converts from `Parse<T>` to [`Result<T, Vec<SyntaxError>>`].
125    pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
126        match self.errors() {
127            errors if !errors.is_empty() => Err(errors),
128            _ => Ok(self.tree()),
129        }
130    }
131}
132
133impl Parse<SyntaxNode> {
134    pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
135        if N::cast(self.syntax_node()).is_some() {
136            Some(Parse {
137                green: self.green,
138                errors: self.errors,
139                _ty: PhantomData,
140            })
141        } else {
142            None
143        }
144    }
145}
146
147/// `SourceFile` represents a parse tree for a single SQL file.
148pub use crate::ast::SourceFile;
149
150impl SourceFile {
151    pub fn parse(text: &str) -> Parse<SourceFile> {
152        let (green, errors) = parsing::parse_text(text);
153        let root = SyntaxNode::new_root(green.clone());
154
155        assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
156        Parse::new(green, errors)
157    }
158}
159
160/// Matches a `SyntaxNode` against an `ast` type.
161///
162/// # Example:
163///
164/// ```ignore
165/// match_ast! {
166///     match node {
167///         ast::CallExpr(it) => { ... },
168///         ast::MethodCallExpr(it) => { ... },
169///         ast::MacroCall(it) => { ... },
170///         _ => None,
171///     }
172/// }
173/// ```
174#[macro_export]
175macro_rules! match_ast {
176    (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
177
178    (match ($node:expr) {
179        $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
180        _ => $catch_all:expr $(,)?
181    }) => {{
182        $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
183        { $catch_all }
184    }};
185}
186
187/// This test does not assert anything and instead just shows off the crate's
188/// API.
189#[test]
190fn api_walkthrough() {
191    use ast::SourceFile;
192    use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
193    use std::fmt::Write;
194
195    let source_code = "
196        create function foo(p int8)
197        returns int
198        as 'select 1 + 1'
199        language sql;
200    ";
201    // `SourceFile` is the main entry point.
202    //
203    // The `parse` method returns a `Parse` -- a pair of syntax tree and a list
204    // of errors. That is, syntax tree is constructed even in presence of errors.
205    let parse = SourceFile::parse(source_code);
206    assert!(parse.errors().is_empty());
207
208    // The `tree` method returns an owned syntax node of type `SourceFile`.
209    // Owned nodes are cheap: inside, they are `Rc` handles to the underling data.
210    let file: SourceFile = parse.tree();
211
212    // `SourceFile` is the root of the syntax tree. We can iterate file's items.
213    // Let's fetch the `foo` function.
214    let mut func = None;
215    for stmt in file.stmts() {
216        match stmt {
217            ast::Stmt::CreateFunction(f) => func = Some(f),
218            _ => unreachable!(),
219        }
220    }
221    let func: ast::CreateFunction = func.unwrap();
222
223    // Each AST node has a bunch of getters for children. All getters return
224    // `Option`s though, to account for incomplete code. Some getters are common
225    // for several kinds of node. In this case, a trait like `ast::NameOwner`
226    // usually exists. By convention, all ast types should be used with `ast::`
227    // qualifier.
228    let path: Option<ast::Path> = func.path();
229    let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
230    assert_eq!(name.text(), "foo");
231
232    // return
233    let ret_type: Option<ast::RetType> = func.ret_type();
234    let r_ty = &ret_type.unwrap().ty().unwrap();
235    let type_: &ast::PathType = match &r_ty {
236        ast::Type::PathType(r) => r,
237        _ => unreachable!(),
238    };
239    let type_path: ast::Path = type_.path().unwrap();
240    assert_eq!(type_path.syntax().to_string(), "int");
241
242    // params
243    let param_list: ast::ParamList = func.param_list().unwrap();
244    let param: ast::Param = param_list.params().next().unwrap();
245
246    let param_name: ast::Name = param.name().unwrap();
247    assert_eq!(param_name.syntax().to_string(), "p");
248
249    let param_ty: ast::Type = param.ty().unwrap();
250    assert_eq!(param_ty.syntax().to_string(), "int8");
251
252    let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
253
254    // Enums are used to group related ast nodes together, and can be used for
255    // matching. However, because there are no public fields, it's possible to
256    // match only the top level enum: that is the price we pay for increased API
257    // flexibility
258    let func_option = func_option_list.options().next().unwrap();
259    let option: &ast::AsFuncOption = match &func_option {
260        ast::FuncOption::AsFuncOption(o) => o,
261        _ => unreachable!(),
262    };
263    let definition: ast::Literal = option.definition().unwrap();
264    assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
265
266    // Besides the "typed" AST API, there's an untyped CST one as well.
267    // To switch from AST to CST, call `.syntax()` method:
268    let func_option_syntax = func_option.syntax();
269
270    // Note how `func_option_syntax` and `option` are in fact the same node underneath:
271    assert!(func_option_syntax == option.syntax());
272
273    // To go from CST to AST, `AstNode::cast` function is used:
274    let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
275        Some(e) => e,
276        None => unreachable!(),
277    };
278
279    // The two properties each syntax node has is a `SyntaxKind`:
280    assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
281
282    // And text range:
283    assert_eq!(
284        func_option_syntax.text_range(),
285        TextRange::new(65.into(), 82.into())
286    );
287
288    // You can get node's text as a `SyntaxText` object, which will traverse the
289    // tree collecting token's text:
290    let text: SyntaxText = func_option_syntax.text();
291    assert_eq!(text.to_string(), "as 'select 1 + 1'");
292
293    // There's a bunch of traversal methods on `SyntaxNode`:
294    assert_eq!(
295        func_option_syntax.parent().as_ref(),
296        Some(func_option_list.syntax())
297    );
298    assert_eq!(
299        param_list
300            .syntax()
301            .first_child_or_token()
302            .map(|it| it.kind()),
303        Some(SyntaxKind::L_PAREN)
304    );
305    assert_eq!(
306        func_option_syntax
307            .next_sibling_or_token()
308            .map(|it| it.kind()),
309        Some(SyntaxKind::WHITESPACE)
310    );
311
312    // As well as some iterator helpers:
313    let f = func_option_syntax
314        .ancestors()
315        .find_map(ast::CreateFunction::cast);
316    assert_eq!(f, Some(func));
317    assert!(
318        param
319            .syntax()
320            .siblings_with_tokens(Direction::Next)
321            .any(|it| it.kind() == SyntaxKind::R_PAREN)
322    );
323    assert_eq!(
324        func_option_syntax.descendants_with_tokens().count(),
325        5, // 5 tokens `1`, ` `, `+`, ` `, `1`
326           // 2 child literal expressions: `1`, `1`
327           // 1 the node itself: `1 + 1`
328    );
329
330    // There's also a `preorder` method with a more fine-grained iteration control:
331    let mut buf = String::new();
332    let mut indent = 0;
333    for event in func_option_syntax.preorder_with_tokens() {
334        match event {
335            WalkEvent::Enter(node) => {
336                let text = match &node {
337                    NodeOrToken::Node(it) => it.text().to_string(),
338                    NodeOrToken::Token(it) => it.text().to_owned(),
339                };
340                buf.write_fmt(format_args!(
341                    "{:indent$}{:?} {:?}\n",
342                    " ",
343                    text,
344                    node.kind(),
345                    indent = indent
346                ))
347                .unwrap();
348                indent += 2;
349            }
350            WalkEvent::Leave(_) => indent -= 2,
351        }
352    }
353    assert_eq!(indent, 0);
354    assert_eq!(
355        buf.trim(),
356        r#"
357"as 'select 1 + 1'" AS_FUNC_OPTION
358  "as" AS_KW
359  " " WHITESPACE
360  "'select 1 + 1'" LITERAL
361    "'select 1 + 1'" STRING
362    "#
363        .trim()
364    );
365
366    // To recursively process the tree, there are three approaches:
367    // 1. explicitly call getter methods on AST nodes.
368    // 2. use descendants and `AstNode::cast`.
369    // 3. use descendants and `match_ast!`.
370    //
371    // Here's how the first one looks like:
372    let exprs_cast: Vec<String> = file
373        .syntax()
374        .descendants()
375        .filter_map(ast::FuncOption::cast)
376        .map(|expr| expr.syntax().text().to_string())
377        .collect();
378
379    // An alternative is to use a macro.
380    let mut exprs_visit = Vec::new();
381    for node in file.syntax().descendants() {
382        match_ast! {
383            match node {
384                ast::FuncOption(it) => {
385                    let res = it.syntax().text().to_string();
386                    exprs_visit.push(res);
387                },
388                _ => (),
389            }
390        }
391    }
392    assert_eq!(exprs_cast, exprs_visit);
393}
394
395#[test]
396fn create_table() {
397    use insta::assert_debug_snapshot;
398
399    let source_code = "
400        create table users (
401            id int8 primary key,
402            name varchar(255) not null,
403            email text,
404            created_at timestamp default now()
405        );
406        
407        create table posts (
408            id serial primary key,
409            title varchar(500),
410            content text,
411            user_id int8 references users(id)
412        );
413    ";
414
415    let parse = SourceFile::parse(source_code);
416    assert!(parse.errors().is_empty());
417    let file: SourceFile = parse.tree();
418
419    let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
420
421    for stmt in file.stmts() {
422        if let ast::Stmt::CreateTable(create_table) = stmt {
423            let table_name = create_table.path().unwrap().syntax().to_string();
424            let mut columns = vec![];
425            for arg in create_table.table_arg_list().unwrap().args() {
426                match arg {
427                    ast::TableArg::Column(column) => {
428                        let column_name = column.name().unwrap();
429                        let column_type = column.ty().unwrap();
430                        columns.push((
431                            column_name.syntax().to_string(),
432                            column_type.syntax().to_string(),
433                        ));
434                    }
435                    ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
436                }
437            }
438            tables.push((table_name, columns));
439        }
440    }
441
442    assert_debug_snapshot!(tables, @r#"
443    [
444        (
445            "users",
446            [
447                (
448                    "id",
449                    "int8",
450                ),
451                (
452                    "name",
453                    "varchar(255)",
454                ),
455                (
456                    "email",
457                    "text",
458                ),
459                (
460                    "created_at",
461                    "timestamp",
462                ),
463            ],
464        ),
465        (
466            "posts",
467            [
468                (
469                    "id",
470                    "serial",
471                ),
472                (
473                    "title",
474                    "varchar(500)",
475                ),
476                (
477                    "content",
478                    "text",
479                ),
480                (
481                    "user_id",
482                    "int8",
483                ),
484            ],
485        ),
486    ]
487    "#)
488}
489
490#[test]
491fn bin_expr() {
492    use insta::assert_debug_snapshot;
493
494    let source_code = "select 1 is not null;";
495    let parse = SourceFile::parse(source_code);
496    assert!(parse.errors().is_empty());
497    let file: SourceFile = parse.tree();
498
499    let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
500        unreachable!()
501    };
502
503    let target_list = select.select_clause().unwrap().target_list().unwrap();
504    let target = target_list.targets().next().unwrap();
505    let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
506        unreachable!()
507    };
508
509    let lhs = bin_expr.lhs();
510    let op = bin_expr.op();
511    let rhs = bin_expr.rhs();
512
513    assert_debug_snapshot!(lhs, @r#"
514    Some(
515        Literal(
516            Literal {
517                syntax: LITERAL@7..8
518                  INT_NUMBER@7..8 "1"
519                ,
520            },
521        ),
522    )
523    "#);
524    assert_debug_snapshot!(op, @r#"
525    Some(
526        IsNot(
527            IsNot {
528                syntax: IS_NOT@9..15
529                  IS_KW@9..11 "is"
530                  WHITESPACE@11..12 " "
531                  NOT_KW@12..15 "not"
532                ,
533            },
534        ),
535    )
536    "#);
537    assert_debug_snapshot!(rhs, @r#"
538    Some(
539        Literal(
540            Literal {
541                syntax: LITERAL@16..20
542                  NULL_KW@16..20 "null"
543                ,
544            },
545        ),
546    )
547    "#);
548}