Skip to main content

squawk_syntax/
lib.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/lib.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27pub mod ast;
28mod generated;
29mod parsing;
30mod ptr;
31pub mod quote;
32pub mod syntax_error;
33mod syntax_node;
34mod token_text;
35pub mod unescape;
36mod validation;
37
38#[cfg(test)]
39mod test;
40
41use std::{marker::PhantomData, sync::Arc};
42
43pub use squawk_parser::SyntaxKind;
44
45use ast::AstNode;
46pub use ptr::{AstPtr, SyntaxNodePtr};
47use rowan::GreenNode;
48use syntax_error::SyntaxError;
49pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
50pub use token_text::TokenText;
51
52/// `Parse` is the result of the parsing: a syntax tree and a collection of
53/// errors.
54///
55/// Note that we always produce a syntax tree, even for completely invalid
56/// files.
57#[derive(Debug, PartialEq, Eq)]
58pub struct Parse<T> {
59    green: GreenNode,
60    errors: Option<Arc<[SyntaxError]>>,
61    _ty: PhantomData<fn() -> T>,
62}
63
64impl<T> Clone for Parse<T> {
65    fn clone(&self) -> Parse<T> {
66        Parse {
67            green: self.green.clone(),
68            errors: self.errors.clone(),
69            _ty: PhantomData,
70        }
71    }
72}
73
74impl<T> Parse<T> {
75    fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
76        Parse {
77            green,
78            errors: if errors.is_empty() {
79                None
80            } else {
81                Some(errors.into())
82            },
83            _ty: PhantomData,
84        }
85    }
86
87    pub fn syntax_node(&self) -> SyntaxNode {
88        SyntaxNode::new_root(self.green.clone())
89    }
90
91    pub fn errors(&self) -> Vec<SyntaxError> {
92        let mut errors = if let Some(e) = self.errors.as_deref() {
93            e.to_vec()
94        } else {
95            vec![]
96        };
97        validation::validate(&self.syntax_node(), &mut errors);
98        errors.sort_by_key(|error| error.range().start());
99        errors
100    }
101}
102
103impl<T: AstNode> Parse<T> {
104    /// Converts this parse result into a parse result for an untyped syntax tree.
105    pub fn to_syntax(self) -> Parse<SyntaxNode> {
106        Parse {
107            green: self.green,
108            errors: self.errors,
109            _ty: PhantomData,
110        }
111    }
112
113    /// Gets the parsed syntax tree as a typed ast node.
114    ///
115    /// # Panics
116    ///
117    /// Panics if the root node cannot be casted into the typed ast node
118    /// (e.g. if it's an `ERROR` node).
119    pub fn tree(&self) -> T {
120        T::cast(self.syntax_node()).unwrap()
121    }
122
123    /// Converts from `Parse<T>` to [`Result<T, Vec<SyntaxError>>`].
124    pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
125        match self.errors() {
126            errors if !errors.is_empty() => Err(errors),
127            _ => Ok(self.tree()),
128        }
129    }
130}
131
132impl Parse<SyntaxNode> {
133    pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
134        if N::cast(self.syntax_node()).is_some() {
135            Some(Parse {
136                green: self.green,
137                errors: self.errors,
138                _ty: PhantomData,
139            })
140        } else {
141            None
142        }
143    }
144}
145
146/// `SourceFile` represents a parse tree for a single SQL file.
147pub use crate::ast::SourceFile;
148
149impl SourceFile {
150    pub fn parse(text: &str) -> Parse<SourceFile> {
151        let (green, errors) = parsing::parse_text(text);
152        let root = SyntaxNode::new_root(green.clone());
153
154        assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
155        Parse::new(green, errors)
156    }
157}
158
159/// Matches a `SyntaxNode` against an `ast` type.
160///
161/// # Example:
162///
163/// ```ignore
164/// match_ast! {
165///     match node {
166///         ast::CallExpr(it) => { ... },
167///         ast::MethodCallExpr(it) => { ... },
168///         ast::MacroCall(it) => { ... },
169///         _ => None,
170///     }
171/// }
172/// ```
173#[macro_export]
174macro_rules! match_ast {
175    (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
176
177    (match ($node:expr) {
178        $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
179        _ => $catch_all:expr $(,)?
180    }) => {{
181        $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
182        { $catch_all }
183    }};
184}
185
186/// This test does not assert anything and instead just shows off the crate's
187/// API.
188#[test]
189fn api_walkthrough() {
190    use ast::SourceFile;
191    use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
192    use std::fmt::Write;
193
194    let source_code = "
195        create function foo(p int8)
196        returns int
197        as 'select 1 + 1'
198        language sql;
199    ";
200    // `SourceFile` is the main entry point.
201    //
202    // The `parse` method returns a `Parse` -- a pair of syntax tree and a list
203    // of errors. That is, syntax tree is constructed even in presence of errors.
204    let parse = SourceFile::parse(source_code);
205    assert!(parse.errors().is_empty());
206
207    // The `tree` method returns an owned syntax node of type `SourceFile`.
208    // Owned nodes are cheap: inside, they are `Rc` handles to the underling data.
209    let file: SourceFile = parse.tree();
210
211    // `SourceFile` is the root of the syntax tree. We can iterate file's items.
212    // Let's fetch the `foo` function.
213    let mut func = None;
214    for stmt in file.stmts() {
215        match stmt {
216            ast::Stmt::CreateFunction(f) => func = Some(f),
217            _ => unreachable!(),
218        }
219    }
220    let func: ast::CreateFunction = func.unwrap();
221
222    // Each AST node has a bunch of getters for children. All getters return
223    // `Option`s though, to account for incomplete code. Some getters are common
224    // for several kinds of node. In this case, a trait like `ast::NameOwner`
225    // usually exists. By convention, all ast types should be used with `ast::`
226    // qualifier.
227    let path: Option<ast::Path> = func.path();
228    let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
229    assert_eq!(name.text(), "foo");
230
231    // return
232    let ret_type: Option<ast::RetType> = func.ret_type();
233    let r_ty = &ret_type.unwrap().ty().unwrap();
234    let type_: &ast::PathType = match &r_ty {
235        ast::Type::PathType(r) => r,
236        _ => unreachable!(),
237    };
238    let type_path: ast::Path = type_.path().unwrap();
239    assert_eq!(type_path.syntax().to_string(), "int");
240
241    // params
242    let param_list: ast::ParamList = func.param_list().unwrap();
243    let param: ast::Param = param_list.params().next().unwrap();
244
245    let param_name: ast::Name = param.name().unwrap();
246    assert_eq!(param_name.syntax().to_string(), "p");
247
248    let param_ty: ast::Type = param.ty().unwrap();
249    assert_eq!(param_ty.syntax().to_string(), "int8");
250
251    let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
252
253    // Enums are used to group related ast nodes together, and can be used for
254    // matching. However, because there are no public fields, it's possible to
255    // match only the top level enum: that is the price we pay for increased API
256    // flexibility
257    let func_option = func_option_list.options().next().unwrap();
258    let option: &ast::AsFuncOption = match &func_option {
259        ast::FuncOption::AsFuncOption(o) => o,
260        _ => unreachable!(),
261    };
262    let definition: ast::Literal = option.definition().unwrap();
263    assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
264
265    // Besides the "typed" AST API, there's an untyped CST one as well.
266    // To switch from AST to CST, call `.syntax()` method:
267    let func_option_syntax = func_option.syntax();
268
269    // Note how `func_option_syntax` and `option` are in fact the same node underneath:
270    assert!(func_option_syntax == option.syntax());
271
272    // To go from CST to AST, `AstNode::cast` function is used:
273    let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
274        Some(e) => e,
275        None => unreachable!(),
276    };
277
278    // The two properties each syntax node has is a `SyntaxKind`:
279    assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
280
281    // And text range:
282    assert_eq!(
283        func_option_syntax.text_range(),
284        TextRange::new(65.into(), 82.into())
285    );
286
287    // You can get node's text as a `SyntaxText` object, which will traverse the
288    // tree collecting token's text:
289    let text: SyntaxText = func_option_syntax.text();
290    assert_eq!(text.to_string(), "as 'select 1 + 1'");
291
292    // There's a bunch of traversal methods on `SyntaxNode`:
293    assert_eq!(
294        func_option_syntax.parent().as_ref(),
295        Some(func_option_list.syntax())
296    );
297    assert_eq!(
298        param_list
299            .syntax()
300            .first_child_or_token()
301            .map(|it| it.kind()),
302        Some(SyntaxKind::L_PAREN)
303    );
304    assert_eq!(
305        func_option_syntax
306            .next_sibling_or_token()
307            .map(|it| it.kind()),
308        Some(SyntaxKind::WHITESPACE)
309    );
310
311    // As well as some iterator helpers:
312    let f = func_option_syntax
313        .ancestors()
314        .find_map(ast::CreateFunction::cast);
315    assert_eq!(f, Some(func));
316    assert!(
317        param
318            .syntax()
319            .siblings_with_tokens(Direction::Next)
320            .any(|it| it.kind() == SyntaxKind::R_PAREN)
321    );
322    assert_eq!(
323        func_option_syntax.descendants_with_tokens().count(),
324        5, // 5 tokens `1`, ` `, `+`, ` `, `1`
325           // 2 child literal expressions: `1`, `1`
326           // 1 the node itself: `1 + 1`
327    );
328
329    // There's also a `preorder` method with a more fine-grained iteration control:
330    let mut buf = String::new();
331    let mut indent = 0;
332    for event in func_option_syntax.preorder_with_tokens() {
333        match event {
334            WalkEvent::Enter(node) => {
335                let text = match &node {
336                    NodeOrToken::Node(it) => it.text().to_string(),
337                    NodeOrToken::Token(it) => it.text().to_owned(),
338                };
339                buf.write_fmt(format_args!(
340                    "{:indent$}{:?} {:?}\n",
341                    " ",
342                    text,
343                    node.kind(),
344                    indent = indent
345                ))
346                .unwrap();
347                indent += 2;
348            }
349            WalkEvent::Leave(_) => indent -= 2,
350        }
351    }
352    assert_eq!(indent, 0);
353    assert_eq!(
354        buf.trim(),
355        r#"
356"as 'select 1 + 1'" AS_FUNC_OPTION
357  "as" AS_KW
358  " " WHITESPACE
359  "'select 1 + 1'" LITERAL
360    "'select 1 + 1'" STRING
361    "#
362        .trim()
363    );
364
365    // To recursively process the tree, there are three approaches:
366    // 1. explicitly call getter methods on AST nodes.
367    // 2. use descendants and `AstNode::cast`.
368    // 3. use descendants and `match_ast!`.
369    //
370    // Here's how the first one looks like:
371    let exprs_cast: Vec<String> = file
372        .syntax()
373        .descendants()
374        .filter_map(ast::FuncOption::cast)
375        .map(|expr| expr.syntax().text().to_string())
376        .collect();
377
378    // An alternative is to use a macro.
379    let mut exprs_visit = Vec::new();
380    for node in file.syntax().descendants() {
381        match_ast! {
382            match node {
383                ast::FuncOption(it) => {
384                    let res = it.syntax().text().to_string();
385                    exprs_visit.push(res);
386                },
387                _ => (),
388            }
389        }
390    }
391    assert_eq!(exprs_cast, exprs_visit);
392}
393
394#[test]
395fn create_table() {
396    use insta::assert_debug_snapshot;
397
398    let source_code = "
399        create table users (
400            id int8 primary key,
401            name varchar(255) not null,
402            email text,
403            created_at timestamp default now()
404        );
405        
406        create table posts (
407            id serial primary key,
408            title varchar(500),
409            content text,
410            user_id int8 references users(id)
411        );
412    ";
413
414    let parse = SourceFile::parse(source_code);
415    assert!(parse.errors().is_empty());
416    let file: SourceFile = parse.tree();
417
418    let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
419
420    for stmt in file.stmts() {
421        if let ast::Stmt::CreateTable(create_table) = stmt {
422            let table_name = create_table.path().unwrap().syntax().to_string();
423            let mut columns = vec![];
424            for arg in create_table.table_arg_list().unwrap().args() {
425                match arg {
426                    ast::TableArg::Column(column) => {
427                        let column_name = column.name().unwrap();
428                        let column_type = column.ty().unwrap();
429                        columns.push((
430                            column_name.syntax().to_string(),
431                            column_type.syntax().to_string(),
432                        ));
433                    }
434                    ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
435                }
436            }
437            tables.push((table_name, columns));
438        }
439    }
440
441    assert_debug_snapshot!(tables, @r#"
442    [
443        (
444            "users",
445            [
446                (
447                    "id",
448                    "int8",
449                ),
450                (
451                    "name",
452                    "varchar(255)",
453                ),
454                (
455                    "email",
456                    "text",
457                ),
458                (
459                    "created_at",
460                    "timestamp",
461                ),
462            ],
463        ),
464        (
465            "posts",
466            [
467                (
468                    "id",
469                    "serial",
470                ),
471                (
472                    "title",
473                    "varchar(500)",
474                ),
475                (
476                    "content",
477                    "text",
478                ),
479                (
480                    "user_id",
481                    "int8",
482                ),
483            ],
484        ),
485    ]
486    "#)
487}
488
489#[test]
490fn bin_expr() {
491    use insta::assert_debug_snapshot;
492
493    let source_code = "select 1 is not null;";
494    let parse = SourceFile::parse(source_code);
495    assert!(parse.errors().is_empty());
496    let file: SourceFile = parse.tree();
497
498    let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
499        unreachable!()
500    };
501
502    let target_list = select.select_clause().unwrap().target_list().unwrap();
503    let target = target_list.targets().next().unwrap();
504    let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
505        unreachable!()
506    };
507
508    let lhs = bin_expr.lhs();
509    let op = bin_expr.op();
510    let rhs = bin_expr.rhs();
511
512    assert_debug_snapshot!(lhs, @r#"
513    Some(
514        Literal(
515            Literal {
516                syntax: LITERAL@7..8
517                  INT_NUMBER@7..8 "1"
518                ,
519            },
520        ),
521    )
522    "#);
523    assert_debug_snapshot!(op, @r#"
524    Some(
525        IsNot(
526            IsNot {
527                syntax: IS_NOT@9..15
528                  IS_KW@9..11 "is"
529                  WHITESPACE@11..12 " "
530                  NOT_KW@12..15 "not"
531                ,
532            },
533        ),
534    )
535    "#);
536    assert_debug_snapshot!(rhs, @r#"
537    Some(
538        Literal(
539            Literal {
540                syntax: LITERAL@16..20
541                  NULL_KW@16..20 "null"
542                ,
543            },
544        ),
545    )
546    "#);
547}