Skip to main content

squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use rowan::Direction;
36
37use crate::ast;
38use crate::ast::AstNode;
39use crate::unescape::{escape_unicode_esc_str, uescape_char};
40use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
41
42use super::support;
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum LitKind {
46    BitString(SyntaxToken),
47    ByteString(SyntaxToken),
48    Default(SyntaxToken),
49    DollarQuotedString(SyntaxToken),
50    EscString(SyntaxToken),
51    False(SyntaxToken),
52    IntNumber(SyntaxToken),
53    Null(SyntaxToken),
54    NumericNumber(SyntaxToken),
55    PositionalParam(SyntaxToken),
56    String(SyntaxToken),
57    True(SyntaxToken),
58    UnicodeEscString(SyntaxToken),
59}
60
61impl ast::Literal {
62    pub fn kind(&self) -> Option<LitKind> {
63        let token = self.syntax().first_child_or_token()?.into_token()?;
64        let kind = match token.kind() {
65            SyntaxKind::BIT_STRING => LitKind::BitString(token),
66            SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
67            SyntaxKind::DEFAULT_KW => LitKind::Default(token),
68            SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
69            SyntaxKind::ESC_STRING => LitKind::EscString(token),
70            SyntaxKind::FALSE_KW => LitKind::False(token),
71            SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
72            SyntaxKind::NULL_KW => LitKind::Null(token),
73            SyntaxKind::NUMERIC_NUMBER => LitKind::NumericNumber(token),
74            SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
75            SyntaxKind::STRING => LitKind::String(token),
76            SyntaxKind::TRUE_KW => LitKind::True(token),
77            SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
78            _ => return None,
79        };
80        Some(kind)
81    }
82}
83
84impl ast::Constraint {
85    #[inline]
86    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
87        support::child(self.syntax())
88    }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum BinOp {
93    And(SyntaxToken),
94    AtTimeZone(ast::AtTimeZone),
95    Caret(SyntaxToken),
96    Collate(SyntaxToken),
97    ColonColon(ast::ColonColon),
98    ColonEq(SyntaxToken),
99    CustomOp(ast::CustomOp),
100    Eq(SyntaxToken),
101    Escape(SyntaxToken),
102    FatArrow(SyntaxToken),
103    Gteq(SyntaxToken),
104    Ilike(SyntaxToken),
105    In(SyntaxToken),
106    Is(SyntaxToken),
107    IsDistinctFrom(ast::IsDistinctFrom),
108    IsNot(ast::IsNot),
109    IsNotDistinctFrom(ast::IsNotDistinctFrom),
110    LAngle(SyntaxToken),
111    Like(SyntaxToken),
112    Lteq(SyntaxToken),
113    Minus(SyntaxToken),
114    Neq(SyntaxToken),
115    Neqb(SyntaxToken),
116    NotIlike(ast::NotIlike),
117    NotIn(ast::NotIn),
118    NotLike(ast::NotLike),
119    NotSimilarTo(ast::NotSimilarTo),
120    OperatorCall(ast::OperatorCall),
121    Or(SyntaxToken),
122    Overlaps(SyntaxToken),
123    Percent(SyntaxToken),
124    Plus(SyntaxToken),
125    RAngle(SyntaxToken),
126    SimilarTo(ast::SimilarTo),
127    Slash(SyntaxToken),
128    Star(SyntaxToken),
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub enum PostfixOp {
133    AtLocal(ast::AtLocal),
134    IsJson(ast::IsJson),
135    IsJsonArray(ast::IsJsonArray),
136    IsJsonObject(ast::IsJsonObject),
137    IsJsonScalar(ast::IsJsonScalar),
138    IsJsonValue(ast::IsJsonValue),
139    IsNormalized(ast::IsNormalized),
140    IsNotJson(ast::IsNotJson),
141    IsNotJsonArray(ast::IsNotJsonArray),
142    IsNotJsonObject(ast::IsNotJsonObject),
143    IsNotJsonScalar(ast::IsNotJsonScalar),
144    IsNotJsonValue(ast::IsNotJsonValue),
145    IsNotNormalized(ast::IsNotNormalized),
146    IsNull(SyntaxToken),
147    NotNull(SyntaxToken),
148}
149
150impl ast::BinExpr {
151    #[inline]
152    pub fn lhs(&self) -> Option<ast::Expr> {
153        support::children(self.syntax()).next()
154    }
155
156    #[inline]
157    pub fn rhs(&self) -> Option<ast::Expr> {
158        support::children(self.syntax()).nth(1)
159    }
160
161    pub fn op(&self) -> Option<BinOp> {
162        let lhs = self.lhs()?;
163        for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
164            match child {
165                NodeOrToken::Token(token) => {
166                    let op = match token.kind() {
167                        SyntaxKind::AND_KW => BinOp::And(token),
168                        SyntaxKind::CARET => BinOp::Caret(token),
169                        SyntaxKind::COLLATE_KW => BinOp::Collate(token),
170                        SyntaxKind::COLON_EQ => BinOp::ColonEq(token),
171                        SyntaxKind::EQ => BinOp::Eq(token),
172                        SyntaxKind::ESCAPE_KW => BinOp::Escape(token),
173                        SyntaxKind::FAT_ARROW => BinOp::FatArrow(token),
174                        SyntaxKind::GTEQ => BinOp::Gteq(token),
175                        SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
176                        SyntaxKind::IN_KW => BinOp::In(token),
177                        SyntaxKind::IS_KW => BinOp::Is(token),
178                        SyntaxKind::L_ANGLE => BinOp::LAngle(token),
179                        SyntaxKind::LIKE_KW => BinOp::Like(token),
180                        SyntaxKind::LTEQ => BinOp::Lteq(token),
181                        SyntaxKind::MINUS => BinOp::Minus(token),
182                        SyntaxKind::NEQ => BinOp::Neq(token),
183                        SyntaxKind::NEQB => BinOp::Neqb(token),
184                        SyntaxKind::OR_KW => BinOp::Or(token),
185                        SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
186                        SyntaxKind::PERCENT => BinOp::Percent(token),
187                        SyntaxKind::PLUS => BinOp::Plus(token),
188                        SyntaxKind::R_ANGLE => BinOp::RAngle(token),
189                        SyntaxKind::SLASH => BinOp::Slash(token),
190                        SyntaxKind::STAR => BinOp::Star(token),
191                        _ => continue,
192                    };
193                    return Some(op);
194                }
195                NodeOrToken::Node(node) => {
196                    let op = match node.kind() {
197                        SyntaxKind::AT_TIME_ZONE => {
198                            BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
199                        }
200                        SyntaxKind::COLON_COLON => {
201                            BinOp::ColonColon(ast::ColonColon { syntax: node })
202                        }
203                        SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
204                        SyntaxKind::IS_DISTINCT_FROM => {
205                            BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
206                        }
207                        SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
208                        SyntaxKind::IS_NOT_DISTINCT_FROM => {
209                            BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
210                        }
211                        SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
212                        SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
213                        SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
214                        SyntaxKind::NOT_SIMILAR_TO => {
215                            BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
216                        }
217                        SyntaxKind::OPERATOR_CALL => {
218                            BinOp::OperatorCall(ast::OperatorCall { syntax: node })
219                        }
220                        SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
221                        _ => continue,
222                    };
223                    return Some(op);
224                }
225            }
226        }
227        None
228    }
229}
230
231impl ast::PostfixExpr {
232    pub fn op(&self) -> Option<PostfixOp> {
233        let lhs = self.expr()?;
234
235        let siblings = lhs.syntax().siblings_with_tokens(Direction::Next).skip(1);
236        for child in siblings {
237            match child {
238                NodeOrToken::Token(token) => {
239                    let op = match token.kind() {
240                        SyntaxKind::ISNULL_KW => PostfixOp::IsNull(token),
241                        SyntaxKind::NOTNULL_KW => PostfixOp::NotNull(token),
242                        _ => continue,
243                    };
244                    return Some(op);
245                }
246                NodeOrToken::Node(node) => {
247                    let op = match node.kind() {
248                        SyntaxKind::AT_LOCAL => PostfixOp::AtLocal(ast::AtLocal { syntax: node }),
249                        SyntaxKind::IS_JSON => PostfixOp::IsJson(ast::IsJson { syntax: node }),
250                        SyntaxKind::IS_JSON_ARRAY => {
251                            PostfixOp::IsJsonArray(ast::IsJsonArray { syntax: node })
252                        }
253                        SyntaxKind::IS_JSON_OBJECT => {
254                            PostfixOp::IsJsonObject(ast::IsJsonObject { syntax: node })
255                        }
256                        SyntaxKind::IS_JSON_SCALAR => {
257                            PostfixOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
258                        }
259                        SyntaxKind::IS_JSON_VALUE => {
260                            PostfixOp::IsJsonValue(ast::IsJsonValue { syntax: node })
261                        }
262                        SyntaxKind::IS_NORMALIZED => {
263                            PostfixOp::IsNormalized(ast::IsNormalized { syntax: node })
264                        }
265                        SyntaxKind::IS_NOT_JSON => {
266                            PostfixOp::IsNotJson(ast::IsNotJson { syntax: node })
267                        }
268                        SyntaxKind::IS_NOT_JSON_ARRAY => {
269                            PostfixOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
270                        }
271                        SyntaxKind::IS_NOT_JSON_OBJECT => {
272                            PostfixOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
273                        }
274                        SyntaxKind::IS_NOT_JSON_SCALAR => {
275                            PostfixOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
276                        }
277                        SyntaxKind::IS_NOT_JSON_VALUE => {
278                            PostfixOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
279                        }
280                        SyntaxKind::IS_NOT_NORMALIZED => {
281                            PostfixOp::IsNotNormalized(ast::IsNotNormalized { syntax: node })
282                        }
283                        _ => continue,
284                    };
285                    return Some(op);
286                }
287            }
288        }
289
290        None
291    }
292}
293
294impl ast::FieldExpr {
295    // We have NameRef as a variant of Expr which complicates things (and it
296    // might not be worth it).
297    // Rust analyzer doesn't do this so it doesn't have to special case this.
298    #[inline]
299    pub fn base(&self) -> Option<ast::Expr> {
300        support::children(self.syntax()).next()
301    }
302    #[inline]
303    pub fn field(&self) -> Option<ast::NameRef> {
304        support::children(self.syntax()).last()
305    }
306}
307
308impl ast::IndexExpr {
309    #[inline]
310    pub fn base(&self) -> Option<ast::Expr> {
311        support::children(&self.syntax).next()
312    }
313    #[inline]
314    pub fn index(&self) -> Option<ast::Expr> {
315        support::children(&self.syntax).nth(1)
316    }
317}
318
319impl ast::SliceExpr {
320    #[inline]
321    pub fn base(&self) -> Option<ast::Expr> {
322        support::children(&self.syntax).next()
323    }
324
325    #[inline]
326    pub fn start(&self) -> Option<ast::Expr> {
327        // With `select x[1:]`, we have two exprs, `x` and `1`.
328        // We skip over the first one, and then we want the second one, but we
329        // want to make sure we don't choose the end expr if instead we had:
330        // `select x[:1]`
331        let colon = self.colon_token()?;
332        support::children(&self.syntax)
333            .skip(1)
334            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
335    }
336
337    #[inline]
338    pub fn end(&self) -> Option<ast::Expr> {
339        // We want to make sure we get the last expr after the `:` which is the
340        // end of the slice, i.e., `2` in: `select x[:2]`
341        let colon = self.colon_token()?;
342        support::children(&self.syntax)
343            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
344    }
345}
346
347impl ast::RenameColumn {
348    #[inline]
349    pub fn from(&self) -> Option<ast::NameRef> {
350        support::children(&self.syntax).nth(0)
351    }
352    #[inline]
353    pub fn to(&self) -> Option<ast::NameRef> {
354        support::children(&self.syntax).nth(1)
355    }
356}
357
358impl ast::ForeignKeyConstraint {
359    #[inline]
360    pub fn from_columns(&self) -> Option<ast::ColumnList> {
361        support::children(&self.syntax).nth(0)
362    }
363    #[inline]
364    pub fn to_columns(&self) -> Option<ast::ColumnList> {
365        support::children(&self.syntax).nth(1)
366    }
367}
368
369impl ast::BetweenExpr {
370    #[inline]
371    pub fn target(&self) -> Option<ast::Expr> {
372        support::children(&self.syntax).nth(0)
373    }
374    #[inline]
375    pub fn start(&self) -> Option<ast::Expr> {
376        support::children(&self.syntax).nth(1)
377    }
378    #[inline]
379    pub fn end(&self) -> Option<ast::Expr> {
380        support::children(&self.syntax).nth(2)
381    }
382}
383
384impl ast::FrameBetween {
385    #[inline]
386    pub fn start(&self) -> Option<ast::FrameBound> {
387        support::children(&self.syntax).nth(0)
388    }
389    #[inline]
390    pub fn end(&self) -> Option<ast::FrameBound> {
391        support::children(&self.syntax).nth(1)
392    }
393}
394
395impl ast::WhenClause {
396    #[inline]
397    pub fn condition(&self) -> Option<ast::Expr> {
398        support::children(&self.syntax).next()
399    }
400    #[inline]
401    pub fn then(&self) -> Option<ast::Expr> {
402        support::children(&self.syntax).nth(1)
403    }
404}
405
406impl ast::CompoundSelect {
407    #[inline]
408    pub fn lhs(&self) -> Option<ast::SelectVariant> {
409        support::children(&self.syntax).next()
410    }
411    #[inline]
412    pub fn rhs(&self) -> Option<ast::SelectVariant> {
413        support::children(&self.syntax).nth(1)
414    }
415}
416
417impl ast::NameRef {
418    #[inline]
419    pub fn text(&self) -> String {
420        normalize_name_node(self.syntax())
421    }
422
423    #[inline]
424    pub fn is_quoted(&self) -> bool {
425        is_quoted(self.syntax())
426    }
427}
428
429impl ast::Name {
430    #[inline]
431    pub fn text(&self) -> String {
432        normalize_name_node(self.syntax())
433    }
434
435    #[inline]
436    pub fn is_quoted(&self) -> bool {
437        is_quoted(self.syntax())
438    }
439}
440
441fn is_quoted(node: &SyntaxNode) -> bool {
442    let text = node.text();
443    let first = text.char_at(0.into());
444    let second = text.char_at(1.into());
445    matches!(
446        (first, second),
447        (Some('u' | 'U'), Some('"')) | (Some('"'), Some(_))
448    )
449}
450
451// TODO: return a NewType wrapper around String?
452fn normalize_name_node(node: &SyntaxNode) -> String {
453    let mut tokens = node
454        .children_with_tokens()
455        .filter_map(|el| el.into_token())
456        .filter(|t| !t.kind().is_trivia());
457
458    let Some(ident_token) = tokens.next() else {
459        return String::new();
460    };
461    let raw = ident_token.text();
462
463    let unicode_inner = raw
464        .strip_prefix(['u', 'U'])
465        .and_then(|s| s.strip_prefix("&\""))
466        .and_then(|s| s.strip_suffix('"'));
467
468    if let Some(inner) = unicode_inner {
469        let mut escape_char = '\\';
470        if let Some(uesc) = tokens.next()
471            && uesc.kind() == SyntaxKind::UESCAPE_KW
472            && let Some(token) = tokens.next()
473            && let Some(ch) = uescape_char(token.text())
474        {
475            escape_char = ch;
476        }
477
478        let inner = inner.replace(r#""""#, "\"");
479        let mut result = String::with_capacity(inner.len());
480        escape_unicode_esc_str(&inner, escape_char, |_range, r| {
481            if let Ok(ch) = r {
482                result.push(ch);
483            }
484        });
485        return result;
486    }
487
488    raw.strip_prefix('"')
489        .and_then(|t| t.strip_suffix('"'))
490        .map(|x| x.replace(r#""""#, "\""))
491        .unwrap_or_else(|| raw.to_ascii_lowercase())
492}
493
494impl ast::CharType {
495    #[inline]
496    pub fn text(&self) -> TokenText<'_> {
497        text_of_first_token(self.syntax())
498    }
499}
500
501fn is_falsey_option(text: &str) -> bool {
502    text == "0" || text.eq_ignore_ascii_case("false") || text.eq_ignore_ascii_case("off")
503}
504
505impl ast::Vacuum {
506    pub fn is_full(&self) -> bool {
507        self.full_token().is_some()
508            // TODO: we need a better way of handling option lists
509            || self.vacuum_option_list().is_some_and(|opt_list| {
510                opt_list.vacuum_options().any(|opt| {
511                    let mut tokens = opt
512                        .syntax()
513                        .descendants_with_tokens()
514                        .filter_map(|child| child.into_token())
515                        .filter(|token| !token.kind().is_trivia());
516
517                    tokens
518                        .next()
519                        .is_some_and(|token| token.text().eq_ignore_ascii_case("full"))
520                        && tokens
521                            .next()
522                            .is_none_or(|token| !is_falsey_option(token.text()))
523                })
524            })
525    }
526}
527
528impl ast::OpSig {
529    #[inline]
530    pub fn lhs(&self) -> Option<ast::Type> {
531        support::children(self.syntax()).next()
532    }
533
534    #[inline]
535    pub fn rhs(&self) -> Option<ast::Type> {
536        support::children(self.syntax()).nth(1)
537    }
538}
539
540impl ast::CastSig {
541    #[inline]
542    pub fn lhs(&self) -> Option<ast::Type> {
543        support::children(self.syntax()).next()
544    }
545
546    #[inline]
547    pub fn rhs(&self) -> Option<ast::Type> {
548        support::children(self.syntax()).nth(1)
549    }
550}
551
552pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
553    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
554        green_ref
555            .children()
556            .next()
557            .and_then(NodeOrToken::into_token)
558            .unwrap()
559    }
560
561    match node.green() {
562        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
563        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
564    }
565}
566
567impl ast::WithQuery {
568    #[inline]
569    pub fn with_clause(&self) -> Option<ast::WithClause> {
570        support::child(self.syntax())
571    }
572}
573
574impl ast::CreateTableAsQuery {
575    #[inline]
576    pub fn select_variant(&self) -> Option<ast::SelectVariant> {
577        match self {
578            ast::CreateTableAsQuery::Execute(_) => None,
579            ast::CreateTableAsQuery::SelectVariant(select_variant) => Some(select_variant.clone()),
580        }
581    }
582}
583
584impl ast::SelectVariant {
585    #[inline]
586    pub fn target_list(&self) -> Option<ast::TargetList> {
587        match self {
588            ast::SelectVariant::Select(select) => {
589                return select.select_clause()?.target_list();
590            }
591            ast::SelectVariant::SelectInto(select_into) => {
592                return select_into.select_clause()?.target_list();
593            }
594            ast::SelectVariant::ParenSelect(paren_select) => {
595                return paren_select.select()?.target_list();
596            }
597            _ => return None,
598        }
599    }
600}
601
602impl ast::HasParamList for ast::FunctionSig {}
603impl ast::HasParamList for ast::Aggregate {}
604
605impl ast::NameLike for ast::Name {
606    #[inline]
607    fn text(&self) -> String {
608        self.text()
609    }
610}
611impl ast::NameLike for ast::NameRef {
612    #[inline]
613    fn text(&self) -> String {
614        self.text()
615    }
616}
617
618impl ast::HasWithClause for ast::Select {}
619impl ast::HasWithClause for ast::SelectInto {}
620impl ast::HasWithClause for ast::Insert {}
621impl ast::HasWithClause for ast::Update {}
622impl ast::HasWithClause for ast::Delete {}
623
624impl ast::HasCreateTable for ast::CreateTable {}
625impl ast::HasCreateTable for ast::CreateForeignTable {}
626impl ast::HasCreateTable for ast::CreateTableLike {}
627
628#[test]
629fn name() {
630    assert_snapshot!(extract_name("select 1 foo"), @"foo");
631    assert_snapshot!(extract_name("select 1 FOO"), @"foo");
632    assert_snapshot!(extract_name(r#"select 1 "foo""#), @"foo");
633    assert_snapshot!(extract_name(r#"select 1 "Foo""#), @"Foo");
634    assert_snapshot!(extract_name(r#"select 1 "FOO""#), @"FOO");
635    assert_snapshot!(extract_name(r#"select 1 U&"\0066\006f\006f""#), @"foo");
636    assert_snapshot!(extract_name(r#"select 1 U&"@0066@006f@006f" uescape '@'"#), @"foo");
637
638    fn extract_name(source_code: &str) -> String {
639        let parse = SourceFile::parse(source_code);
640        assert!(parse.errors().is_empty());
641        let stmt = parse.tree().stmts().next().unwrap();
642        let ast::Stmt::Select(select) = stmt else {
643            unreachable!()
644        };
645        let name = select
646            .select_clause()
647            .unwrap()
648            .target_list()
649            .unwrap()
650            .targets()
651            .next()
652            .unwrap()
653            .as_name()
654            .unwrap()
655            .name()
656            .unwrap();
657        name.text().to_string()
658    }
659}
660
661#[test]
662fn name_ref() {
663    assert_snapshot!(extract_name_ref("select foo"), @"foo");
664    assert_snapshot!(extract_name_ref("select FOO"), @"foo");
665    assert_snapshot!(extract_name_ref(r#"select "foo""#), @"foo");
666    assert_snapshot!(extract_name_ref(r#"select "Foo""#), @"Foo");
667    assert_snapshot!(extract_name_ref(r#"select "FOO""#), @"FOO");
668    assert_snapshot!(extract_name_ref(r#"select U&"\0066\006f\006f""#), @"foo");
669    assert_snapshot!(extract_name_ref(r#"select U&"@0066@006f@006f" uescape '@'"#), @"foo");
670
671    fn extract_name_ref(source_code: &str) -> String {
672        let parse = SourceFile::parse(source_code);
673        assert!(parse.errors().is_empty());
674        let stmt = parse.tree().stmts().next().unwrap();
675        let ast::Stmt::Select(select) = stmt else {
676            unreachable!()
677        };
678        let select_clause = select.select_clause().unwrap();
679        let target = select_clause
680            .target_list()
681            .unwrap()
682            .targets()
683            .next()
684            .unwrap();
685        let ast::Expr::NameRef(name_ref) = target.expr().unwrap() else {
686            unreachable!()
687        };
688        name_ref.text().to_string()
689    }
690}
691
692#[test]
693fn unicode_quoted_name_keeps_doubled_single_quotes() {
694    let parse = SourceFile::parse(r#"select 1 U&"a''b""#);
695    assert!(parse.errors().is_empty());
696    let stmt = parse.tree().stmts().next().unwrap();
697    let ast::Stmt::Select(select) = stmt else {
698        unreachable!()
699    };
700    let name = select
701        .select_clause()
702        .unwrap()
703        .target_list()
704        .unwrap()
705        .targets()
706        .next()
707        .unwrap()
708        .as_name()
709        .unwrap()
710        .name()
711        .unwrap();
712
713    assert_snapshot!(name.text().to_string(), @"a''b");
714}
715
716#[test]
717fn index_expr() {
718    let source_code = "
719        select foo[bar];
720    ";
721    let parse = SourceFile::parse(source_code);
722    assert!(parse.errors().is_empty());
723    let stmt = parse.tree().stmts().next().unwrap();
724    let ast::Stmt::Select(select) = stmt else {
725        unreachable!()
726    };
727    let select_clause = select.select_clause().unwrap();
728    let target = select_clause
729        .target_list()
730        .unwrap()
731        .targets()
732        .next()
733        .unwrap();
734    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
735        unreachable!()
736    };
737    let base = index_expr.base().unwrap();
738    let index = index_expr.index().unwrap();
739    assert_eq!(base.syntax().text(), "foo");
740    assert_eq!(index.syntax().text(), "bar");
741}
742
743#[test]
744fn slice_expr() {
745    use insta::assert_snapshot;
746    let source_code = "
747        select x[1:2], x[2:], x[:3], x[:];
748    ";
749    let parse = SourceFile::parse(source_code);
750    assert!(parse.errors().is_empty());
751    let stmt = parse.tree().stmts().next().unwrap();
752    let ast::Stmt::Select(select) = stmt else {
753        unreachable!()
754    };
755    let select_clause = select.select_clause().unwrap();
756    let mut targets = select_clause.target_list().unwrap().targets();
757
758    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
759        unreachable!()
760    };
761    assert_snapshot!(slice.syntax(), @"x[1:2]");
762    assert_eq!(slice.base().unwrap().syntax().text(), "x");
763    assert_eq!(slice.start().unwrap().syntax().text(), "1");
764    assert_eq!(slice.end().unwrap().syntax().text(), "2");
765
766    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
767        unreachable!()
768    };
769    assert_snapshot!(slice.syntax(), @"x[2:]");
770    assert_eq!(slice.base().unwrap().syntax().text(), "x");
771    assert_eq!(slice.start().unwrap().syntax().text(), "2");
772    assert!(slice.end().is_none());
773
774    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
775        unreachable!()
776    };
777    assert_snapshot!(slice.syntax(), @"x[:3]");
778    assert_eq!(slice.base().unwrap().syntax().text(), "x");
779    assert!(slice.start().is_none());
780    assert_eq!(slice.end().unwrap().syntax().text(), "3");
781
782    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
783        unreachable!()
784    };
785    assert_snapshot!(slice.syntax(), @"x[:]");
786    assert_eq!(slice.base().unwrap().syntax().text(), "x");
787    assert!(slice.start().is_none());
788    assert!(slice.end().is_none());
789}
790
791#[test]
792fn field_expr() {
793    let source_code = "
794        select foo.bar;
795    ";
796    let parse = SourceFile::parse(source_code);
797    assert!(parse.errors().is_empty());
798    let stmt = parse.tree().stmts().next().unwrap();
799    let ast::Stmt::Select(select) = stmt else {
800        unreachable!()
801    };
802    let select_clause = select.select_clause().unwrap();
803    let target = select_clause
804        .target_list()
805        .unwrap()
806        .targets()
807        .next()
808        .unwrap();
809    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
810        unreachable!()
811    };
812    let base = field_expr.base().unwrap();
813    let field = field_expr.field().unwrap();
814    assert_eq!(base.syntax().text(), "foo");
815    assert_eq!(field.syntax().text(), "bar");
816}
817
818#[test]
819fn between_expr() {
820    let source_code = "
821        select 2 between 1 and 3;
822    ";
823    let parse = SourceFile::parse(source_code);
824    assert!(parse.errors().is_empty());
825    let stmt = parse.tree().stmts().next().unwrap();
826    let ast::Stmt::Select(select) = stmt else {
827        unreachable!()
828    };
829    let select_clause = select.select_clause().unwrap();
830    let target = select_clause
831        .target_list()
832        .unwrap()
833        .targets()
834        .next()
835        .unwrap();
836    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
837        unreachable!()
838    };
839    let target = between_expr.target().unwrap();
840    let start = between_expr.start().unwrap();
841    let end = between_expr.end().unwrap();
842    assert_eq!(target.syntax().text(), "2");
843    assert_eq!(start.syntax().text(), "1");
844    assert_eq!(end.syntax().text(), "3");
845}
846
847#[test]
848fn cast_expr() {
849    use insta::assert_snapshot;
850
851    let cast = extract_expr("select cast('123' as int)");
852    assert!(cast.expr().is_some());
853    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
854    assert!(cast.ty().is_some());
855    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
856
857    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
858    assert!(cast.expr().is_some());
859    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
860    assert!(cast.ty().is_some());
861    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
862
863    let cast = extract_expr("select int '123'");
864    assert!(cast.expr().is_some());
865    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
866    assert!(cast.ty().is_some());
867    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
868
869    let cast = extract_expr("select pg_catalog.int4 '123'");
870    assert!(cast.expr().is_some());
871    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
872    assert!(cast.ty().is_some());
873    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
874
875    let cast = extract_expr("select '123'::int");
876    assert!(cast.expr().is_some());
877    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
878    assert!(cast.ty().is_some());
879    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
880
881    let cast = extract_expr("select '123'::int4");
882    assert!(cast.expr().is_some());
883    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
884    assert!(cast.ty().is_some());
885    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
886
887    let cast = extract_expr("select '123'::pg_catalog.int4");
888    assert!(cast.expr().is_some());
889    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
890    assert!(cast.ty().is_some());
891    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
892
893    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
894    assert!(cast.expr().is_some());
895    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
896    assert!(cast.ty().is_some());
897    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
898
899    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
900    assert!(cast.expr().is_some());
901    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
902    assert!(cast.ty().is_some());
903    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
904
905    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
906    assert!(cast.expr().is_some());
907    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
908    assert!(cast.ty().is_some());
909    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
910
911    let cast = extract_expr("select interval '1' month");
912    assert!(cast.expr().is_some());
913    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
914    assert!(cast.ty().is_some());
915    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
916
917    fn extract_expr(sql: &str) -> ast::CastExpr {
918        let parse = SourceFile::parse(sql);
919        assert!(parse.errors().is_empty());
920        let node = parse
921            .tree()
922            .stmts()
923            .map(|x| match x {
924                ast::Stmt::Select(select) => select
925                    .select_clause()
926                    .unwrap()
927                    .target_list()
928                    .unwrap()
929                    .targets()
930                    .next()
931                    .unwrap()
932                    .expr()
933                    .unwrap()
934                    .clone(),
935                _ => unreachable!(),
936            })
937            .next()
938            .unwrap();
939        match node {
940            ast::Expr::CastExpr(cast) => cast,
941            _ => unreachable!(),
942        }
943    }
944}
945
946#[test]
947fn op_sig() {
948    let source_code = "
949      alter operator p.+ (int4, int8) 
950        owner to u;
951    ";
952    let parse = SourceFile::parse(source_code);
953    assert!(parse.errors().is_empty());
954    let stmt = parse.tree().stmts().next().unwrap();
955    let ast::Stmt::AlterOperator(alter_op) = stmt else {
956        unreachable!()
957    };
958    let op_sig = alter_op.op_sig().unwrap();
959    let lhs = op_sig.lhs().unwrap();
960    let rhs = op_sig.rhs().unwrap();
961    assert_snapshot!(lhs.syntax().text(), @"int4");
962    assert_snapshot!(rhs.syntax().text(), @"int8");
963}
964
965#[test]
966fn cast_sig() {
967    let source_code = "
968      drop cast (text as int);
969    ";
970    let parse = SourceFile::parse(source_code);
971    assert!(parse.errors().is_empty());
972    let stmt = parse.tree().stmts().next().unwrap();
973    let ast::Stmt::DropCast(alter_op) = stmt else {
974        unreachable!()
975    };
976    let cast_sig = alter_op.cast_sig().unwrap();
977    let lhs = cast_sig.lhs().unwrap();
978    let rhs = cast_sig.rhs().unwrap();
979    assert_snapshot!(lhs.syntax().text(), @"text");
980    assert_snapshot!(rhs.syntax().text(), @"int");
981}
982
983#[cfg(test)]
984fn extract_vacuum(sql: &str) -> ast::Vacuum {
985    let parse = SourceFile::parse(sql);
986    assert!(parse.errors().is_empty());
987    let stmt = parse.tree().stmts().next().unwrap();
988    let ast::Stmt::Vacuum(vacuum) = stmt else {
989        unreachable!()
990    };
991    vacuum
992}
993
994#[test]
995fn vacuum_full_is_full() {
996    assert!(extract_vacuum("VACUUM FULL foo;").is_full());
997}
998
999#[test]
1000fn vacuum_option_list_full_is_full() {
1001    assert!(extract_vacuum("VACUUM (FULL) foo;").is_full());
1002}
1003
1004#[test]
1005fn vacuum_full_true_is_full() {
1006    assert!(extract_vacuum("VACUUM (FULL TRUE) foo;").is_full());
1007}
1008
1009#[test]
1010fn vacuum_full_on_is_full() {
1011    assert!(extract_vacuum("VACUUM (FULL ON) foo;").is_full());
1012}
1013
1014#[test]
1015fn vacuum_full_1_is_full() {
1016    assert!(extract_vacuum("VACUUM (FULL 1) foo;").is_full());
1017}
1018
1019#[test]
1020fn vacuum_no_full_is_not_full() {
1021    assert!(!extract_vacuum("VACUUM foo;").is_full());
1022}
1023
1024#[test]
1025fn vacuum_other_option_is_not_full() {
1026    assert!(!extract_vacuum("VACUUM (FREEZE) foo;").is_full());
1027}
1028
1029#[test]
1030fn vacuum_full_false_is_not_full() {
1031    assert!(!extract_vacuum("VACUUM (FULL FALSE) foo;").is_full());
1032}
1033
1034#[test]
1035fn vacuum_full_off_is_not_full() {
1036    assert!(!extract_vacuum("VACUUM (FULL OFF) foo;").is_full());
1037}
1038
1039#[test]
1040fn vacuum_full_0_is_not_full() {
1041    assert!(!extract_vacuum("VACUUM (FULL 0) foo;").is_full());
1042}