Skip to main content

squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use rowan::Direction;
36
37use crate::ast;
38use crate::ast::AstNode;
39use crate::unescape::{escape_unicode_esc_str, uescape_char};
40use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
41
42use super::support;
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum LitKind {
46    BitString(SyntaxToken),
47    ByteString(SyntaxToken),
48    Default(SyntaxToken),
49    DollarQuotedString(SyntaxToken),
50    EscString(SyntaxToken),
51    False(SyntaxToken),
52    IntNumber(SyntaxToken),
53    Null(SyntaxToken),
54    NumericNumber(SyntaxToken),
55    PositionalParam(SyntaxToken),
56    String(SyntaxToken),
57    True(SyntaxToken),
58    UnicodeEscString(SyntaxToken),
59}
60
61impl ast::Literal {
62    pub fn kind(&self) -> Option<LitKind> {
63        let token = self.syntax().first_child_or_token()?.into_token()?;
64        let kind = match token.kind() {
65            SyntaxKind::BIT_STRING => LitKind::BitString(token),
66            SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
67            SyntaxKind::DEFAULT_KW => LitKind::Default(token),
68            SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
69            SyntaxKind::ESC_STRING => LitKind::EscString(token),
70            SyntaxKind::FALSE_KW => LitKind::False(token),
71            SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
72            SyntaxKind::NULL_KW => LitKind::Null(token),
73            SyntaxKind::NUMERIC_NUMBER => LitKind::NumericNumber(token),
74            SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
75            SyntaxKind::STRING => LitKind::String(token),
76            SyntaxKind::TRUE_KW => LitKind::True(token),
77            SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
78            _ => return None,
79        };
80        Some(kind)
81    }
82}
83
84impl ast::Constraint {
85    #[inline]
86    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
87        support::child(self.syntax())
88    }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum BinOp {
93    And(SyntaxToken),
94    AtTimeZone(ast::AtTimeZone),
95    Caret(SyntaxToken),
96    Collate(SyntaxToken),
97    ColonColon(ast::ColonColon),
98    ColonEq(SyntaxToken),
99    CustomOp(ast::CustomOp),
100    Eq(SyntaxToken),
101    FatArrow(SyntaxToken),
102    Gteq(SyntaxToken),
103    Ilike(SyntaxToken),
104    In(SyntaxToken),
105    Is(SyntaxToken),
106    IsDistinctFrom(ast::IsDistinctFrom),
107    IsNot(ast::IsNot),
108    IsNotDistinctFrom(ast::IsNotDistinctFrom),
109    LAngle(SyntaxToken),
110    Like(SyntaxToken),
111    Lteq(SyntaxToken),
112    Minus(SyntaxToken),
113    Neq(SyntaxToken),
114    Neqb(SyntaxToken),
115    NotIlike(ast::NotIlike),
116    NotIn(ast::NotIn),
117    NotLike(ast::NotLike),
118    NotSimilarTo(ast::NotSimilarTo),
119    OperatorCall(ast::OperatorCall),
120    Or(SyntaxToken),
121    Overlaps(SyntaxToken),
122    Percent(SyntaxToken),
123    Plus(SyntaxToken),
124    RAngle(SyntaxToken),
125    SimilarTo(ast::SimilarTo),
126    Slash(SyntaxToken),
127    Star(SyntaxToken),
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub enum PostfixOp {
132    AtLocal(SyntaxToken),
133    IsJson(ast::IsJson),
134    IsJsonArray(ast::IsJsonArray),
135    IsJsonObject(ast::IsJsonObject),
136    IsJsonScalar(ast::IsJsonScalar),
137    IsJsonValue(ast::IsJsonValue),
138    IsNormalized(ast::IsNormalized),
139    IsNotJson(ast::IsNotJson),
140    IsNotJsonArray(ast::IsNotJsonArray),
141    IsNotJsonObject(ast::IsNotJsonObject),
142    IsNotJsonScalar(ast::IsNotJsonScalar),
143    IsNotJsonValue(ast::IsNotJsonValue),
144    IsNotNormalized(ast::IsNotNormalized),
145    IsNull(SyntaxToken),
146    NotNull(SyntaxToken),
147}
148
149impl ast::BinExpr {
150    #[inline]
151    pub fn lhs(&self) -> Option<ast::Expr> {
152        support::children(self.syntax()).next()
153    }
154
155    #[inline]
156    pub fn rhs(&self) -> Option<ast::Expr> {
157        support::children(self.syntax()).nth(1)
158    }
159
160    pub fn op(&self) -> Option<BinOp> {
161        let lhs = self.lhs()?;
162        for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
163            match child {
164                NodeOrToken::Token(token) => {
165                    let op = match token.kind() {
166                        SyntaxKind::AND_KW => BinOp::And(token),
167                        SyntaxKind::CARET => BinOp::Caret(token),
168                        SyntaxKind::COLLATE_KW => BinOp::Collate(token),
169                        SyntaxKind::COLON_EQ => BinOp::ColonEq(token),
170                        SyntaxKind::EQ => BinOp::Eq(token),
171                        SyntaxKind::FAT_ARROW => BinOp::FatArrow(token),
172                        SyntaxKind::GTEQ => BinOp::Gteq(token),
173                        SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
174                        SyntaxKind::IN_KW => BinOp::In(token),
175                        SyntaxKind::IS_KW => BinOp::Is(token),
176                        SyntaxKind::L_ANGLE => BinOp::LAngle(token),
177                        SyntaxKind::LIKE_KW => BinOp::Like(token),
178                        SyntaxKind::LTEQ => BinOp::Lteq(token),
179                        SyntaxKind::MINUS => BinOp::Minus(token),
180                        SyntaxKind::NEQ => BinOp::Neq(token),
181                        SyntaxKind::NEQB => BinOp::Neqb(token),
182                        SyntaxKind::OR_KW => BinOp::Or(token),
183                        SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
184                        SyntaxKind::PERCENT => BinOp::Percent(token),
185                        SyntaxKind::PLUS => BinOp::Plus(token),
186                        SyntaxKind::R_ANGLE => BinOp::RAngle(token),
187                        SyntaxKind::SLASH => BinOp::Slash(token),
188                        SyntaxKind::STAR => BinOp::Star(token),
189                        _ => continue,
190                    };
191                    return Some(op);
192                }
193                NodeOrToken::Node(node) => {
194                    let op = match node.kind() {
195                        SyntaxKind::AT_TIME_ZONE => {
196                            BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
197                        }
198                        SyntaxKind::COLON_COLON => {
199                            BinOp::ColonColon(ast::ColonColon { syntax: node })
200                        }
201                        SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
202                        SyntaxKind::IS_DISTINCT_FROM => {
203                            BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
204                        }
205                        SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
206                        SyntaxKind::IS_NOT_DISTINCT_FROM => {
207                            BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
208                        }
209                        SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
210                        SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
211                        SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
212                        SyntaxKind::NOT_SIMILAR_TO => {
213                            BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
214                        }
215                        SyntaxKind::OPERATOR_CALL => {
216                            BinOp::OperatorCall(ast::OperatorCall { syntax: node })
217                        }
218                        SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
219                        _ => continue,
220                    };
221                    return Some(op);
222                }
223            }
224        }
225        None
226    }
227}
228
229impl ast::PostfixExpr {
230    pub fn op(&self) -> Option<PostfixOp> {
231        let lhs = self.expr()?;
232
233        let siblings = lhs.syntax().siblings_with_tokens(Direction::Next).skip(1);
234        for child in siblings {
235            match child {
236                NodeOrToken::Token(token) => {
237                    let op = match token.kind() {
238                        SyntaxKind::AT_KW => PostfixOp::AtLocal(token),
239                        SyntaxKind::ISNULL_KW => PostfixOp::IsNull(token),
240                        SyntaxKind::NOTNULL_KW => PostfixOp::NotNull(token),
241                        _ => continue,
242                    };
243                    return Some(op);
244                }
245                NodeOrToken::Node(node) => {
246                    let op = match node.kind() {
247                        SyntaxKind::IS_JSON => PostfixOp::IsJson(ast::IsJson { syntax: node }),
248                        SyntaxKind::IS_JSON_ARRAY => {
249                            PostfixOp::IsJsonArray(ast::IsJsonArray { syntax: node })
250                        }
251                        SyntaxKind::IS_JSON_OBJECT => {
252                            PostfixOp::IsJsonObject(ast::IsJsonObject { syntax: node })
253                        }
254                        SyntaxKind::IS_JSON_SCALAR => {
255                            PostfixOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
256                        }
257                        SyntaxKind::IS_JSON_VALUE => {
258                            PostfixOp::IsJsonValue(ast::IsJsonValue { syntax: node })
259                        }
260                        SyntaxKind::IS_NORMALIZED => {
261                            PostfixOp::IsNormalized(ast::IsNormalized { syntax: node })
262                        }
263                        SyntaxKind::IS_NOT_JSON => {
264                            PostfixOp::IsNotJson(ast::IsNotJson { syntax: node })
265                        }
266                        SyntaxKind::IS_NOT_JSON_ARRAY => {
267                            PostfixOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
268                        }
269                        SyntaxKind::IS_NOT_JSON_OBJECT => {
270                            PostfixOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
271                        }
272                        SyntaxKind::IS_NOT_JSON_SCALAR => {
273                            PostfixOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
274                        }
275                        SyntaxKind::IS_NOT_JSON_VALUE => {
276                            PostfixOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
277                        }
278                        SyntaxKind::IS_NOT_NORMALIZED => {
279                            PostfixOp::IsNotNormalized(ast::IsNotNormalized { syntax: node })
280                        }
281                        _ => continue,
282                    };
283                    return Some(op);
284                }
285            }
286        }
287
288        None
289    }
290}
291
292impl ast::FieldExpr {
293    // We have NameRef as a variant of Expr which complicates things (and it
294    // might not be worth it).
295    // Rust analyzer doesn't do this so it doesn't have to special case this.
296    #[inline]
297    pub fn base(&self) -> Option<ast::Expr> {
298        support::children(self.syntax()).next()
299    }
300    #[inline]
301    pub fn field(&self) -> Option<ast::NameRef> {
302        support::children(self.syntax()).last()
303    }
304}
305
306impl ast::IndexExpr {
307    #[inline]
308    pub fn base(&self) -> Option<ast::Expr> {
309        support::children(&self.syntax).next()
310    }
311    #[inline]
312    pub fn index(&self) -> Option<ast::Expr> {
313        support::children(&self.syntax).nth(1)
314    }
315}
316
317impl ast::SliceExpr {
318    #[inline]
319    pub fn base(&self) -> Option<ast::Expr> {
320        support::children(&self.syntax).next()
321    }
322
323    #[inline]
324    pub fn start(&self) -> Option<ast::Expr> {
325        // With `select x[1:]`, we have two exprs, `x` and `1`.
326        // We skip over the first one, and then we want the second one, but we
327        // want to make sure we don't choose the end expr if instead we had:
328        // `select x[:1]`
329        let colon = self.colon_token()?;
330        support::children(&self.syntax)
331            .skip(1)
332            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
333    }
334
335    #[inline]
336    pub fn end(&self) -> Option<ast::Expr> {
337        // We want to make sure we get the last expr after the `:` which is the
338        // end of the slice, i.e., `2` in: `select x[:2]`
339        let colon = self.colon_token()?;
340        support::children(&self.syntax)
341            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
342    }
343}
344
345impl ast::RenameColumn {
346    #[inline]
347    pub fn from(&self) -> Option<ast::NameRef> {
348        support::children(&self.syntax).nth(0)
349    }
350    #[inline]
351    pub fn to(&self) -> Option<ast::NameRef> {
352        support::children(&self.syntax).nth(1)
353    }
354}
355
356impl ast::ForeignKeyConstraint {
357    #[inline]
358    pub fn from_columns(&self) -> Option<ast::ColumnList> {
359        support::children(&self.syntax).nth(0)
360    }
361    #[inline]
362    pub fn to_columns(&self) -> Option<ast::ColumnList> {
363        support::children(&self.syntax).nth(1)
364    }
365}
366
367impl ast::BetweenExpr {
368    #[inline]
369    pub fn target(&self) -> Option<ast::Expr> {
370        support::children(&self.syntax).nth(0)
371    }
372    #[inline]
373    pub fn start(&self) -> Option<ast::Expr> {
374        support::children(&self.syntax).nth(1)
375    }
376    #[inline]
377    pub fn end(&self) -> Option<ast::Expr> {
378        support::children(&self.syntax).nth(2)
379    }
380}
381
382impl ast::FrameBetween {
383    #[inline]
384    pub fn start(&self) -> Option<ast::FrameBound> {
385        support::children(&self.syntax).nth(0)
386    }
387    #[inline]
388    pub fn end(&self) -> Option<ast::FrameBound> {
389        support::children(&self.syntax).nth(1)
390    }
391}
392
393impl ast::WhenClause {
394    #[inline]
395    pub fn condition(&self) -> Option<ast::Expr> {
396        support::children(&self.syntax).next()
397    }
398    #[inline]
399    pub fn then(&self) -> Option<ast::Expr> {
400        support::children(&self.syntax).nth(1)
401    }
402}
403
404impl ast::CompoundSelect {
405    #[inline]
406    pub fn lhs(&self) -> Option<ast::SelectVariant> {
407        support::children(&self.syntax).next()
408    }
409    #[inline]
410    pub fn rhs(&self) -> Option<ast::SelectVariant> {
411        support::children(&self.syntax).nth(1)
412    }
413}
414
415impl ast::NameRef {
416    #[inline]
417    pub fn text(&self) -> String {
418        normalize_name_node(self.syntax())
419    }
420
421    #[inline]
422    pub fn is_quoted(&self) -> bool {
423        is_quoted(self.syntax())
424    }
425}
426
427impl ast::Name {
428    #[inline]
429    pub fn text(&self) -> String {
430        normalize_name_node(self.syntax())
431    }
432
433    #[inline]
434    pub fn is_quoted(&self) -> bool {
435        is_quoted(self.syntax())
436    }
437}
438
439fn is_quoted(node: &SyntaxNode) -> bool {
440    let text = node.text();
441    let first = text.char_at(0.into());
442    let second = text.char_at(1.into());
443    matches!(
444        (first, second),
445        (Some('u' | 'U'), Some('"')) | (Some('"'), Some(_))
446    )
447}
448
449// TODO: return a NewType wrapper around String?
450fn normalize_name_node(node: &SyntaxNode) -> String {
451    let mut tokens = node
452        .children_with_tokens()
453        .filter_map(|el| el.into_token())
454        .filter(|t| !t.kind().is_trivia());
455
456    let Some(ident_token) = tokens.next() else {
457        return String::new();
458    };
459    let raw = ident_token.text();
460
461    let unicode_inner = raw
462        .strip_prefix(['u', 'U'])
463        .and_then(|s| s.strip_prefix("&\""))
464        .and_then(|s| s.strip_suffix('"'));
465
466    if let Some(inner) = unicode_inner {
467        let mut escape_char = '\\';
468        if let Some(uesc) = tokens.next()
469            && uesc.kind() == SyntaxKind::UESCAPE_KW
470            && let Some(token) = tokens.next()
471            && let Some(ch) = uescape_char(token.text())
472        {
473            escape_char = ch;
474        }
475
476        let inner = inner.replace(r#""""#, "\"");
477        let mut result = String::with_capacity(inner.len());
478        escape_unicode_esc_str(&inner, escape_char, |_range, r| {
479            if let Ok(ch) = r {
480                result.push(ch);
481            }
482        });
483        return result;
484    }
485
486    raw.strip_prefix('"')
487        .and_then(|t| t.strip_suffix('"'))
488        .map(|x| x.replace(r#""""#, "\""))
489        .unwrap_or_else(|| raw.to_ascii_lowercase())
490}
491
492impl ast::CharType {
493    #[inline]
494    pub fn text(&self) -> TokenText<'_> {
495        text_of_first_token(self.syntax())
496    }
497}
498
499fn is_falsey_option(text: &str) -> bool {
500    text == "0" || text.eq_ignore_ascii_case("false") || text.eq_ignore_ascii_case("off")
501}
502
503impl ast::Vacuum {
504    pub fn is_full(&self) -> bool {
505        self.full_token().is_some()
506            // TODO: we need a better way of handling option lists
507            || self.vacuum_option_list().is_some_and(|opt_list| {
508                opt_list.vacuum_options().any(|opt| {
509                    let mut tokens = opt
510                        .syntax()
511                        .descendants_with_tokens()
512                        .filter_map(|child| child.into_token())
513                        .filter(|token| !token.kind().is_trivia());
514
515                    tokens
516                        .next()
517                        .is_some_and(|token| token.text().eq_ignore_ascii_case("full"))
518                        && tokens
519                            .next()
520                            .is_none_or(|token| !is_falsey_option(token.text()))
521                })
522            })
523    }
524}
525
526impl ast::OpSig {
527    #[inline]
528    pub fn lhs(&self) -> Option<ast::Type> {
529        support::children(self.syntax()).next()
530    }
531
532    #[inline]
533    pub fn rhs(&self) -> Option<ast::Type> {
534        support::children(self.syntax()).nth(1)
535    }
536}
537
538impl ast::CastSig {
539    #[inline]
540    pub fn lhs(&self) -> Option<ast::Type> {
541        support::children(self.syntax()).next()
542    }
543
544    #[inline]
545    pub fn rhs(&self) -> Option<ast::Type> {
546        support::children(self.syntax()).nth(1)
547    }
548}
549
550pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
551    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
552        green_ref
553            .children()
554            .next()
555            .and_then(NodeOrToken::into_token)
556            .unwrap()
557    }
558
559    match node.green() {
560        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
561        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
562    }
563}
564
565impl ast::WithQuery {
566    #[inline]
567    pub fn with_clause(&self) -> Option<ast::WithClause> {
568        support::child(self.syntax())
569    }
570}
571
572impl ast::SelectVariant {
573    #[inline]
574    pub fn target_list(&self) -> Option<ast::TargetList> {
575        match self {
576            ast::SelectVariant::Select(select) => {
577                return select.select_clause()?.target_list();
578            }
579            ast::SelectVariant::SelectInto(select_into) => {
580                return select_into.select_clause()?.target_list();
581            }
582            ast::SelectVariant::ParenSelect(paren_select) => {
583                return paren_select.select()?.target_list();
584            }
585            _ => return None,
586        }
587    }
588}
589
590impl ast::HasParamList for ast::FunctionSig {}
591impl ast::HasParamList for ast::Aggregate {}
592
593impl ast::NameLike for ast::Name {
594    #[inline]
595    fn text(&self) -> String {
596        self.text()
597    }
598}
599impl ast::NameLike for ast::NameRef {
600    #[inline]
601    fn text(&self) -> String {
602        self.text()
603    }
604}
605
606impl ast::HasWithClause for ast::Select {}
607impl ast::HasWithClause for ast::SelectInto {}
608impl ast::HasWithClause for ast::Insert {}
609impl ast::HasWithClause for ast::Update {}
610impl ast::HasWithClause for ast::Delete {}
611
612impl ast::HasCreateTable for ast::CreateTable {}
613impl ast::HasCreateTable for ast::CreateForeignTable {}
614impl ast::HasCreateTable for ast::CreateTableLike {}
615
616#[test]
617fn name() {
618    assert_snapshot!(extract_name("select 1 foo"), @"foo");
619    assert_snapshot!(extract_name("select 1 FOO"), @"foo");
620    assert_snapshot!(extract_name(r#"select 1 "foo""#), @"foo");
621    assert_snapshot!(extract_name(r#"select 1 "Foo""#), @"Foo");
622    assert_snapshot!(extract_name(r#"select 1 "FOO""#), @"FOO");
623    assert_snapshot!(extract_name(r#"select 1 U&"\0066\006f\006f""#), @"foo");
624    assert_snapshot!(extract_name(r#"select 1 U&"@0066@006f@006f" uescape '@'"#), @"foo");
625
626    fn extract_name(source_code: &str) -> String {
627        let parse = SourceFile::parse(source_code);
628        assert!(parse.errors().is_empty());
629        let stmt = parse.tree().stmts().next().unwrap();
630        let ast::Stmt::Select(select) = stmt else {
631            unreachable!()
632        };
633        let name = select
634            .select_clause()
635            .unwrap()
636            .target_list()
637            .unwrap()
638            .targets()
639            .next()
640            .unwrap()
641            .as_name()
642            .unwrap()
643            .name()
644            .unwrap();
645        name.text().to_string()
646    }
647}
648
649#[test]
650fn name_ref() {
651    assert_snapshot!(extract_name_ref("select foo"), @"foo");
652    assert_snapshot!(extract_name_ref("select FOO"), @"foo");
653    assert_snapshot!(extract_name_ref(r#"select "foo""#), @"foo");
654    assert_snapshot!(extract_name_ref(r#"select "Foo""#), @"Foo");
655    assert_snapshot!(extract_name_ref(r#"select "FOO""#), @"FOO");
656    assert_snapshot!(extract_name_ref(r#"select U&"\0066\006f\006f""#), @"foo");
657    assert_snapshot!(extract_name_ref(r#"select U&"@0066@006f@006f" uescape '@'"#), @"foo");
658
659    fn extract_name_ref(source_code: &str) -> String {
660        let parse = SourceFile::parse(source_code);
661        assert!(parse.errors().is_empty());
662        let stmt = parse.tree().stmts().next().unwrap();
663        let ast::Stmt::Select(select) = stmt else {
664            unreachable!()
665        };
666        let select_clause = select.select_clause().unwrap();
667        let target = select_clause
668            .target_list()
669            .unwrap()
670            .targets()
671            .next()
672            .unwrap();
673        let ast::Expr::NameRef(name_ref) = target.expr().unwrap() else {
674            unreachable!()
675        };
676        name_ref.text().to_string()
677    }
678}
679
680#[test]
681fn unicode_quoted_name_keeps_doubled_single_quotes() {
682    let parse = SourceFile::parse(r#"select 1 U&"a''b""#);
683    assert!(parse.errors().is_empty());
684    let stmt = parse.tree().stmts().next().unwrap();
685    let ast::Stmt::Select(select) = stmt else {
686        unreachable!()
687    };
688    let name = select
689        .select_clause()
690        .unwrap()
691        .target_list()
692        .unwrap()
693        .targets()
694        .next()
695        .unwrap()
696        .as_name()
697        .unwrap()
698        .name()
699        .unwrap();
700
701    assert_snapshot!(name.text().to_string(), @"a''b");
702}
703
704#[test]
705fn index_expr() {
706    let source_code = "
707        select foo[bar];
708    ";
709    let parse = SourceFile::parse(source_code);
710    assert!(parse.errors().is_empty());
711    let stmt = parse.tree().stmts().next().unwrap();
712    let ast::Stmt::Select(select) = stmt else {
713        unreachable!()
714    };
715    let select_clause = select.select_clause().unwrap();
716    let target = select_clause
717        .target_list()
718        .unwrap()
719        .targets()
720        .next()
721        .unwrap();
722    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
723        unreachable!()
724    };
725    let base = index_expr.base().unwrap();
726    let index = index_expr.index().unwrap();
727    assert_eq!(base.syntax().text(), "foo");
728    assert_eq!(index.syntax().text(), "bar");
729}
730
731#[test]
732fn slice_expr() {
733    use insta::assert_snapshot;
734    let source_code = "
735        select x[1:2], x[2:], x[:3], x[:];
736    ";
737    let parse = SourceFile::parse(source_code);
738    assert!(parse.errors().is_empty());
739    let stmt = parse.tree().stmts().next().unwrap();
740    let ast::Stmt::Select(select) = stmt else {
741        unreachable!()
742    };
743    let select_clause = select.select_clause().unwrap();
744    let mut targets = select_clause.target_list().unwrap().targets();
745
746    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
747        unreachable!()
748    };
749    assert_snapshot!(slice.syntax(), @"x[1:2]");
750    assert_eq!(slice.base().unwrap().syntax().text(), "x");
751    assert_eq!(slice.start().unwrap().syntax().text(), "1");
752    assert_eq!(slice.end().unwrap().syntax().text(), "2");
753
754    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
755        unreachable!()
756    };
757    assert_snapshot!(slice.syntax(), @"x[2:]");
758    assert_eq!(slice.base().unwrap().syntax().text(), "x");
759    assert_eq!(slice.start().unwrap().syntax().text(), "2");
760    assert!(slice.end().is_none());
761
762    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
763        unreachable!()
764    };
765    assert_snapshot!(slice.syntax(), @"x[:3]");
766    assert_eq!(slice.base().unwrap().syntax().text(), "x");
767    assert!(slice.start().is_none());
768    assert_eq!(slice.end().unwrap().syntax().text(), "3");
769
770    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
771        unreachable!()
772    };
773    assert_snapshot!(slice.syntax(), @"x[:]");
774    assert_eq!(slice.base().unwrap().syntax().text(), "x");
775    assert!(slice.start().is_none());
776    assert!(slice.end().is_none());
777}
778
779#[test]
780fn field_expr() {
781    let source_code = "
782        select foo.bar;
783    ";
784    let parse = SourceFile::parse(source_code);
785    assert!(parse.errors().is_empty());
786    let stmt = parse.tree().stmts().next().unwrap();
787    let ast::Stmt::Select(select) = stmt else {
788        unreachable!()
789    };
790    let select_clause = select.select_clause().unwrap();
791    let target = select_clause
792        .target_list()
793        .unwrap()
794        .targets()
795        .next()
796        .unwrap();
797    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
798        unreachable!()
799    };
800    let base = field_expr.base().unwrap();
801    let field = field_expr.field().unwrap();
802    assert_eq!(base.syntax().text(), "foo");
803    assert_eq!(field.syntax().text(), "bar");
804}
805
806#[test]
807fn between_expr() {
808    let source_code = "
809        select 2 between 1 and 3;
810    ";
811    let parse = SourceFile::parse(source_code);
812    assert!(parse.errors().is_empty());
813    let stmt = parse.tree().stmts().next().unwrap();
814    let ast::Stmt::Select(select) = stmt else {
815        unreachable!()
816    };
817    let select_clause = select.select_clause().unwrap();
818    let target = select_clause
819        .target_list()
820        .unwrap()
821        .targets()
822        .next()
823        .unwrap();
824    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
825        unreachable!()
826    };
827    let target = between_expr.target().unwrap();
828    let start = between_expr.start().unwrap();
829    let end = between_expr.end().unwrap();
830    assert_eq!(target.syntax().text(), "2");
831    assert_eq!(start.syntax().text(), "1");
832    assert_eq!(end.syntax().text(), "3");
833}
834
835#[test]
836fn cast_expr() {
837    use insta::assert_snapshot;
838
839    let cast = extract_expr("select cast('123' as int)");
840    assert!(cast.expr().is_some());
841    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
842    assert!(cast.ty().is_some());
843    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
844
845    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
846    assert!(cast.expr().is_some());
847    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
848    assert!(cast.ty().is_some());
849    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
850
851    let cast = extract_expr("select int '123'");
852    assert!(cast.expr().is_some());
853    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
854    assert!(cast.ty().is_some());
855    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
856
857    let cast = extract_expr("select pg_catalog.int4 '123'");
858    assert!(cast.expr().is_some());
859    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
860    assert!(cast.ty().is_some());
861    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
862
863    let cast = extract_expr("select '123'::int");
864    assert!(cast.expr().is_some());
865    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
866    assert!(cast.ty().is_some());
867    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
868
869    let cast = extract_expr("select '123'::int4");
870    assert!(cast.expr().is_some());
871    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
872    assert!(cast.ty().is_some());
873    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
874
875    let cast = extract_expr("select '123'::pg_catalog.int4");
876    assert!(cast.expr().is_some());
877    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
878    assert!(cast.ty().is_some());
879    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
880
881    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
882    assert!(cast.expr().is_some());
883    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
884    assert!(cast.ty().is_some());
885    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
886
887    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
888    assert!(cast.expr().is_some());
889    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
890    assert!(cast.ty().is_some());
891    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
892
893    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
894    assert!(cast.expr().is_some());
895    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
896    assert!(cast.ty().is_some());
897    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
898
899    let cast = extract_expr("select interval '1' month");
900    assert!(cast.expr().is_some());
901    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
902    assert!(cast.ty().is_some());
903    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
904
905    fn extract_expr(sql: &str) -> ast::CastExpr {
906        let parse = SourceFile::parse(sql);
907        assert!(parse.errors().is_empty());
908        let node = parse
909            .tree()
910            .stmts()
911            .map(|x| match x {
912                ast::Stmt::Select(select) => select
913                    .select_clause()
914                    .unwrap()
915                    .target_list()
916                    .unwrap()
917                    .targets()
918                    .next()
919                    .unwrap()
920                    .expr()
921                    .unwrap()
922                    .clone(),
923                _ => unreachable!(),
924            })
925            .next()
926            .unwrap();
927        match node {
928            ast::Expr::CastExpr(cast) => cast,
929            _ => unreachable!(),
930        }
931    }
932}
933
934#[test]
935fn op_sig() {
936    let source_code = "
937      alter operator p.+ (int4, int8) 
938        owner to u;
939    ";
940    let parse = SourceFile::parse(source_code);
941    assert!(parse.errors().is_empty());
942    let stmt = parse.tree().stmts().next().unwrap();
943    let ast::Stmt::AlterOperator(alter_op) = stmt else {
944        unreachable!()
945    };
946    let op_sig = alter_op.op_sig().unwrap();
947    let lhs = op_sig.lhs().unwrap();
948    let rhs = op_sig.rhs().unwrap();
949    assert_snapshot!(lhs.syntax().text(), @"int4");
950    assert_snapshot!(rhs.syntax().text(), @"int8");
951}
952
953#[test]
954fn cast_sig() {
955    let source_code = "
956      drop cast (text as int);
957    ";
958    let parse = SourceFile::parse(source_code);
959    assert!(parse.errors().is_empty());
960    let stmt = parse.tree().stmts().next().unwrap();
961    let ast::Stmt::DropCast(alter_op) = stmt else {
962        unreachable!()
963    };
964    let cast_sig = alter_op.cast_sig().unwrap();
965    let lhs = cast_sig.lhs().unwrap();
966    let rhs = cast_sig.rhs().unwrap();
967    assert_snapshot!(lhs.syntax().text(), @"text");
968    assert_snapshot!(rhs.syntax().text(), @"int");
969}
970
971#[cfg(test)]
972fn extract_vacuum(sql: &str) -> ast::Vacuum {
973    let parse = SourceFile::parse(sql);
974    assert!(parse.errors().is_empty());
975    let stmt = parse.tree().stmts().next().unwrap();
976    let ast::Stmt::Vacuum(vacuum) = stmt else {
977        unreachable!()
978    };
979    vacuum
980}
981
982#[test]
983fn vacuum_full_is_full() {
984    assert!(extract_vacuum("VACUUM FULL foo;").is_full());
985}
986
987#[test]
988fn vacuum_option_list_full_is_full() {
989    assert!(extract_vacuum("VACUUM (FULL) foo;").is_full());
990}
991
992#[test]
993fn vacuum_full_true_is_full() {
994    assert!(extract_vacuum("VACUUM (FULL TRUE) foo;").is_full());
995}
996
997#[test]
998fn vacuum_full_on_is_full() {
999    assert!(extract_vacuum("VACUUM (FULL ON) foo;").is_full());
1000}
1001
1002#[test]
1003fn vacuum_full_1_is_full() {
1004    assert!(extract_vacuum("VACUUM (FULL 1) foo;").is_full());
1005}
1006
1007#[test]
1008fn vacuum_no_full_is_not_full() {
1009    assert!(!extract_vacuum("VACUUM foo;").is_full());
1010}
1011
1012#[test]
1013fn vacuum_other_option_is_not_full() {
1014    assert!(!extract_vacuum("VACUUM (FREEZE) foo;").is_full());
1015}
1016
1017#[test]
1018fn vacuum_full_false_is_not_full() {
1019    assert!(!extract_vacuum("VACUUM (FULL FALSE) foo;").is_full());
1020}
1021
1022#[test]
1023fn vacuum_full_off_is_not_full() {
1024    assert!(!extract_vacuum("VACUUM (FULL OFF) foo;").is_full());
1025}
1026
1027#[test]
1028fn vacuum_full_0_is_not_full() {
1029    assert!(!extract_vacuum("VACUUM (FULL 0) foo;").is_full());
1030}