Skip to main content

squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use rowan::Direction;
36
37use crate::ast;
38use crate::ast::AstNode;
39use crate::unescape::{escape_unicode_esc_str, uescape_char};
40use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
41
42use super::support;
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum LitKind {
46    BitString(SyntaxToken),
47    ByteString(SyntaxToken),
48    Default(SyntaxToken),
49    DollarQuotedString(SyntaxToken),
50    EscString(SyntaxToken),
51    False(SyntaxToken),
52    IntNumber(SyntaxToken),
53    NationalString(SyntaxToken),
54    Null(SyntaxToken),
55    NumericNumber(SyntaxToken),
56    PositionalParam(SyntaxToken),
57    String(SyntaxToken),
58    True(SyntaxToken),
59    UnicodeEscString(SyntaxToken),
60}
61
62impl ast::Literal {
63    pub fn kind(&self) -> Option<LitKind> {
64        let token = self.syntax().first_child_or_token()?.into_token()?;
65        let kind = match token.kind() {
66            SyntaxKind::BIT_STRING => LitKind::BitString(token),
67            SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
68            SyntaxKind::DEFAULT_KW => LitKind::Default(token),
69            SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
70            SyntaxKind::ESC_STRING => LitKind::EscString(token),
71            SyntaxKind::FALSE_KW => LitKind::False(token),
72            SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
73            SyntaxKind::NATIONAL_STRING => LitKind::NationalString(token),
74            SyntaxKind::NULL_KW => LitKind::Null(token),
75            SyntaxKind::NUMERIC_NUMBER => LitKind::NumericNumber(token),
76            SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
77            SyntaxKind::STRING => LitKind::String(token),
78            SyntaxKind::TRUE_KW => LitKind::True(token),
79            SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
80            _ => return None,
81        };
82        Some(kind)
83    }
84}
85
86impl ast::Constraint {
87    #[inline]
88    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
89        support::child(self.syntax())
90    }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum BinOp {
95    And(SyntaxToken),
96    AtTimeZone(ast::AtTimeZone),
97    Caret(SyntaxToken),
98    Collate(SyntaxToken),
99    ColonColon(ast::ColonColon),
100    ColonEq(SyntaxToken),
101    CustomOp(ast::CustomOp),
102    Eq(SyntaxToken),
103    Escape(SyntaxToken),
104    FatArrow(SyntaxToken),
105    Gteq(SyntaxToken),
106    Ilike(SyntaxToken),
107    In(SyntaxToken),
108    Is(SyntaxToken),
109    IsDistinctFrom(ast::IsDistinctFrom),
110    IsNot(ast::IsNot),
111    IsNotDistinctFrom(ast::IsNotDistinctFrom),
112    LAngle(SyntaxToken),
113    Like(SyntaxToken),
114    Lteq(SyntaxToken),
115    Minus(SyntaxToken),
116    Neq(SyntaxToken),
117    Neqb(SyntaxToken),
118    NotIlike(ast::NotIlike),
119    NotIn(ast::NotIn),
120    NotLike(ast::NotLike),
121    NotSimilarTo(ast::NotSimilarTo),
122    OperatorCall(ast::OperatorCall),
123    Or(SyntaxToken),
124    Overlaps(SyntaxToken),
125    Percent(SyntaxToken),
126    Plus(SyntaxToken),
127    RAngle(SyntaxToken),
128    SimilarTo(ast::SimilarTo),
129    Slash(SyntaxToken),
130    Star(SyntaxToken),
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub enum PostfixOp {
135    AtLocal(ast::AtLocal),
136    IsJson(ast::IsJson),
137    IsJsonArray(ast::IsJsonArray),
138    IsJsonObject(ast::IsJsonObject),
139    IsJsonScalar(ast::IsJsonScalar),
140    IsJsonValue(ast::IsJsonValue),
141    IsNormalized(ast::IsNormalized),
142    IsNotJson(ast::IsNotJson),
143    IsNotJsonArray(ast::IsNotJsonArray),
144    IsNotJsonObject(ast::IsNotJsonObject),
145    IsNotJsonScalar(ast::IsNotJsonScalar),
146    IsNotJsonValue(ast::IsNotJsonValue),
147    IsNotNormalized(ast::IsNotNormalized),
148    IsNull(SyntaxToken),
149    NotNull(SyntaxToken),
150}
151
152impl ast::BinExpr {
153    #[inline]
154    pub fn lhs(&self) -> Option<ast::Expr> {
155        support::children(self.syntax()).next()
156    }
157
158    #[inline]
159    pub fn rhs(&self) -> Option<ast::Expr> {
160        support::children(self.syntax()).nth(1)
161    }
162
163    pub fn op(&self) -> Option<BinOp> {
164        let lhs = self.lhs()?;
165        for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
166            match child {
167                NodeOrToken::Token(token) => {
168                    let op = match token.kind() {
169                        SyntaxKind::AND_KW => BinOp::And(token),
170                        SyntaxKind::CARET => BinOp::Caret(token),
171                        SyntaxKind::COLLATE_KW => BinOp::Collate(token),
172                        SyntaxKind::COLON_EQ => BinOp::ColonEq(token),
173                        SyntaxKind::EQ => BinOp::Eq(token),
174                        SyntaxKind::ESCAPE_KW => BinOp::Escape(token),
175                        SyntaxKind::FAT_ARROW => BinOp::FatArrow(token),
176                        SyntaxKind::GTEQ => BinOp::Gteq(token),
177                        SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
178                        SyntaxKind::IN_KW => BinOp::In(token),
179                        SyntaxKind::IS_KW => BinOp::Is(token),
180                        SyntaxKind::L_ANGLE => BinOp::LAngle(token),
181                        SyntaxKind::LIKE_KW => BinOp::Like(token),
182                        SyntaxKind::LTEQ => BinOp::Lteq(token),
183                        SyntaxKind::MINUS => BinOp::Minus(token),
184                        SyntaxKind::NEQ => BinOp::Neq(token),
185                        SyntaxKind::NEQB => BinOp::Neqb(token),
186                        SyntaxKind::OR_KW => BinOp::Or(token),
187                        SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
188                        SyntaxKind::PERCENT => BinOp::Percent(token),
189                        SyntaxKind::PLUS => BinOp::Plus(token),
190                        SyntaxKind::R_ANGLE => BinOp::RAngle(token),
191                        SyntaxKind::SLASH => BinOp::Slash(token),
192                        SyntaxKind::STAR => BinOp::Star(token),
193                        _ => continue,
194                    };
195                    return Some(op);
196                }
197                NodeOrToken::Node(node) => {
198                    let op = match node.kind() {
199                        SyntaxKind::AT_TIME_ZONE => {
200                            BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
201                        }
202                        SyntaxKind::COLON_COLON => {
203                            BinOp::ColonColon(ast::ColonColon { syntax: node })
204                        }
205                        SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
206                        SyntaxKind::IS_DISTINCT_FROM => {
207                            BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
208                        }
209                        SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
210                        SyntaxKind::IS_NOT_DISTINCT_FROM => {
211                            BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
212                        }
213                        SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
214                        SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
215                        SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
216                        SyntaxKind::NOT_SIMILAR_TO => {
217                            BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
218                        }
219                        SyntaxKind::OPERATOR_CALL => {
220                            BinOp::OperatorCall(ast::OperatorCall { syntax: node })
221                        }
222                        SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
223                        _ => continue,
224                    };
225                    return Some(op);
226                }
227            }
228        }
229        None
230    }
231}
232
233impl ast::PostfixExpr {
234    pub fn op(&self) -> Option<PostfixOp> {
235        let lhs = self.expr()?;
236
237        let siblings = lhs.syntax().siblings_with_tokens(Direction::Next).skip(1);
238        for child in siblings {
239            match child {
240                NodeOrToken::Token(token) => {
241                    let op = match token.kind() {
242                        SyntaxKind::ISNULL_KW => PostfixOp::IsNull(token),
243                        SyntaxKind::NOTNULL_KW => PostfixOp::NotNull(token),
244                        _ => continue,
245                    };
246                    return Some(op);
247                }
248                NodeOrToken::Node(node) => {
249                    let op = match node.kind() {
250                        SyntaxKind::AT_LOCAL => PostfixOp::AtLocal(ast::AtLocal { syntax: node }),
251                        SyntaxKind::IS_JSON => PostfixOp::IsJson(ast::IsJson { syntax: node }),
252                        SyntaxKind::IS_JSON_ARRAY => {
253                            PostfixOp::IsJsonArray(ast::IsJsonArray { syntax: node })
254                        }
255                        SyntaxKind::IS_JSON_OBJECT => {
256                            PostfixOp::IsJsonObject(ast::IsJsonObject { syntax: node })
257                        }
258                        SyntaxKind::IS_JSON_SCALAR => {
259                            PostfixOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
260                        }
261                        SyntaxKind::IS_JSON_VALUE => {
262                            PostfixOp::IsJsonValue(ast::IsJsonValue { syntax: node })
263                        }
264                        SyntaxKind::IS_NORMALIZED => {
265                            PostfixOp::IsNormalized(ast::IsNormalized { syntax: node })
266                        }
267                        SyntaxKind::IS_NOT_JSON => {
268                            PostfixOp::IsNotJson(ast::IsNotJson { syntax: node })
269                        }
270                        SyntaxKind::IS_NOT_JSON_ARRAY => {
271                            PostfixOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
272                        }
273                        SyntaxKind::IS_NOT_JSON_OBJECT => {
274                            PostfixOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
275                        }
276                        SyntaxKind::IS_NOT_JSON_SCALAR => {
277                            PostfixOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
278                        }
279                        SyntaxKind::IS_NOT_JSON_VALUE => {
280                            PostfixOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
281                        }
282                        SyntaxKind::IS_NOT_NORMALIZED => {
283                            PostfixOp::IsNotNormalized(ast::IsNotNormalized { syntax: node })
284                        }
285                        _ => continue,
286                    };
287                    return Some(op);
288                }
289            }
290        }
291
292        None
293    }
294}
295
296impl ast::FieldExpr {
297    // We have NameRef as a variant of Expr which complicates things (and it
298    // might not be worth it).
299    // Rust analyzer doesn't do this so it doesn't have to special case this.
300    #[inline]
301    pub fn base(&self) -> Option<ast::Expr> {
302        support::children(self.syntax()).next()
303    }
304    #[inline]
305    pub fn field(&self) -> Option<ast::NameRef> {
306        support::children(self.syntax()).last()
307    }
308}
309
310impl ast::IndexExpr {
311    #[inline]
312    pub fn base(&self) -> Option<ast::Expr> {
313        support::children(&self.syntax).next()
314    }
315    #[inline]
316    pub fn index(&self) -> Option<ast::Expr> {
317        support::children(&self.syntax).nth(1)
318    }
319}
320
321impl ast::SliceExpr {
322    #[inline]
323    pub fn base(&self) -> Option<ast::Expr> {
324        support::children(&self.syntax).next()
325    }
326
327    #[inline]
328    pub fn start(&self) -> Option<ast::Expr> {
329        // With `select x[1:]`, we have two exprs, `x` and `1`.
330        // We skip over the first one, and then we want the second one, but we
331        // want to make sure we don't choose the end expr if instead we had:
332        // `select x[:1]`
333        let colon = self.colon_token()?;
334        support::children(&self.syntax)
335            .skip(1)
336            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
337    }
338
339    #[inline]
340    pub fn end(&self) -> Option<ast::Expr> {
341        // We want to make sure we get the last expr after the `:` which is the
342        // end of the slice, i.e., `2` in: `select x[:2]`
343        let colon = self.colon_token()?;
344        support::children(&self.syntax)
345            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
346    }
347}
348
349impl ast::RenameColumn {
350    #[inline]
351    pub fn from(&self) -> Option<ast::NameRef> {
352        support::children(&self.syntax).nth(0)
353    }
354    #[inline]
355    pub fn to(&self) -> Option<ast::NameRef> {
356        support::children(&self.syntax).nth(1)
357    }
358}
359
360impl ast::ForeignKeyConstraint {
361    #[inline]
362    pub fn from_columns(&self) -> Option<ast::ColumnList> {
363        support::children(&self.syntax).nth(0)
364    }
365    #[inline]
366    pub fn to_columns(&self) -> Option<ast::ColumnList> {
367        support::children(&self.syntax).nth(1)
368    }
369}
370
371impl ast::BetweenExpr {
372    #[inline]
373    pub fn target(&self) -> Option<ast::Expr> {
374        support::children(&self.syntax).nth(0)
375    }
376    #[inline]
377    pub fn start(&self) -> Option<ast::Expr> {
378        support::children(&self.syntax).nth(1)
379    }
380    #[inline]
381    pub fn end(&self) -> Option<ast::Expr> {
382        support::children(&self.syntax).nth(2)
383    }
384}
385
386impl ast::FrameBetween {
387    #[inline]
388    pub fn start(&self) -> Option<ast::FrameBound> {
389        support::children(&self.syntax).nth(0)
390    }
391    #[inline]
392    pub fn end(&self) -> Option<ast::FrameBound> {
393        support::children(&self.syntax).nth(1)
394    }
395}
396
397impl ast::WhenClause {
398    #[inline]
399    pub fn condition(&self) -> Option<ast::Expr> {
400        support::children(&self.syntax).next()
401    }
402    #[inline]
403    pub fn then(&self) -> Option<ast::Expr> {
404        support::children(&self.syntax).nth(1)
405    }
406}
407
408impl ast::CompoundSelect {
409    #[inline]
410    pub fn lhs(&self) -> Option<ast::SelectVariant> {
411        support::children(&self.syntax).next()
412    }
413    #[inline]
414    pub fn rhs(&self) -> Option<ast::SelectVariant> {
415        support::children(&self.syntax).nth(1)
416    }
417}
418
419impl ast::NameRef {
420    #[inline]
421    pub fn text(&self) -> String {
422        normalize_name_node(self.syntax())
423    }
424
425    #[inline]
426    pub fn is_quoted(&self) -> bool {
427        is_quoted(self.syntax())
428    }
429}
430
431impl ast::Name {
432    #[inline]
433    pub fn text(&self) -> String {
434        normalize_name_node(self.syntax())
435    }
436
437    #[inline]
438    pub fn is_quoted(&self) -> bool {
439        is_quoted(self.syntax())
440    }
441}
442
443fn is_quoted(node: &SyntaxNode) -> bool {
444    let text = node.text();
445    let first = text.char_at(0.into());
446    let second = text.char_at(1.into());
447    matches!(
448        (first, second),
449        (Some('u' | 'U'), Some('"')) | (Some('"'), Some(_))
450    )
451}
452
453// TODO: return a NewType wrapper around String?
454fn normalize_name_node(node: &SyntaxNode) -> String {
455    let mut tokens = node
456        .children_with_tokens()
457        .filter_map(|el| el.into_token())
458        .filter(|t| !t.kind().is_trivia());
459
460    let Some(ident_token) = tokens.next() else {
461        return String::new();
462    };
463    let raw = ident_token.text();
464
465    let unicode_inner = raw
466        .strip_prefix(['u', 'U'])
467        .and_then(|s| s.strip_prefix("&\""))
468        .and_then(|s| s.strip_suffix('"'));
469
470    if let Some(inner) = unicode_inner {
471        let mut escape_char = '\\';
472        if let Some(uesc) = tokens.next()
473            && uesc.kind() == SyntaxKind::UESCAPE_KW
474            && let Some(token) = tokens.next()
475            && let Some(ch) = uescape_char(token.text())
476        {
477            escape_char = ch;
478        }
479
480        let inner = inner.replace(r#""""#, "\"");
481        let mut result = String::with_capacity(inner.len());
482        escape_unicode_esc_str(&inner, escape_char, |_range, r| {
483            if let Ok(ch) = r {
484                result.push(ch);
485            }
486        });
487        return result;
488    }
489
490    raw.strip_prefix('"')
491        .and_then(|t| t.strip_suffix('"'))
492        .map(|x| x.replace(r#""""#, "\""))
493        .unwrap_or_else(|| raw.to_ascii_lowercase())
494}
495
496impl ast::CharType {
497    #[inline]
498    pub fn text(&self) -> TokenText<'_> {
499        text_of_first_token(self.syntax())
500    }
501}
502
503fn is_falsey_option(text: &str) -> bool {
504    text == "0" || text.eq_ignore_ascii_case("false") || text.eq_ignore_ascii_case("off")
505}
506
507impl ast::Vacuum {
508    pub fn is_full(&self) -> bool {
509        self.full_token().is_some()
510            // TODO: we need a better way of handling option lists
511            || self.vacuum_option_list().is_some_and(|opt_list| {
512                opt_list.vacuum_options().any(|opt| {
513                    let mut tokens = opt
514                        .syntax()
515                        .descendants_with_tokens()
516                        .filter_map(|child| child.into_token())
517                        .filter(|token| !token.kind().is_trivia());
518
519                    tokens
520                        .next()
521                        .is_some_and(|token| token.text().eq_ignore_ascii_case("full"))
522                        && tokens
523                            .next()
524                            .is_none_or(|token| !is_falsey_option(token.text()))
525                })
526            })
527    }
528}
529
530impl ast::OpSig {
531    #[inline]
532    pub fn lhs(&self) -> Option<ast::Type> {
533        support::children(self.syntax()).next()
534    }
535
536    #[inline]
537    pub fn rhs(&self) -> Option<ast::Type> {
538        support::children(self.syntax()).nth(1)
539    }
540}
541
542impl ast::CastSig {
543    #[inline]
544    pub fn lhs(&self) -> Option<ast::Type> {
545        support::children(self.syntax()).next()
546    }
547
548    #[inline]
549    pub fn rhs(&self) -> Option<ast::Type> {
550        support::children(self.syntax()).nth(1)
551    }
552}
553
554pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
555    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
556        green_ref
557            .children()
558            .next()
559            .and_then(NodeOrToken::into_token)
560            .unwrap()
561    }
562
563    match node.green() {
564        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
565        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
566    }
567}
568
569impl ast::WithQuery {
570    #[inline]
571    pub fn with_clause(&self) -> Option<ast::WithClause> {
572        support::child(self.syntax())
573    }
574}
575
576impl ast::CreateTableAsQuery {
577    #[inline]
578    pub fn select_variant(&self) -> Option<ast::SelectVariant> {
579        match self {
580            ast::CreateTableAsQuery::Execute(_) => None,
581            ast::CreateTableAsQuery::SelectVariant(select_variant) => Some(select_variant.clone()),
582        }
583    }
584}
585
586impl ast::SelectVariant {
587    #[inline]
588    pub fn target_list(&self) -> Option<ast::TargetList> {
589        match self {
590            ast::SelectVariant::Select(select) => {
591                return select.select_clause()?.target_list();
592            }
593            ast::SelectVariant::SelectInto(select_into) => {
594                return select_into.select_clause()?.target_list();
595            }
596            ast::SelectVariant::ParenSelect(paren_select) => {
597                return paren_select.select()?.target_list();
598            }
599            _ => return None,
600        }
601    }
602}
603
604impl ast::HasParamList for ast::FunctionSig {}
605impl ast::HasParamList for ast::Aggregate {}
606
607impl ast::NameLike for ast::Name {
608    #[inline]
609    fn text(&self) -> String {
610        self.text()
611    }
612}
613impl ast::NameLike for ast::NameRef {
614    #[inline]
615    fn text(&self) -> String {
616        self.text()
617    }
618}
619
620impl ast::HasWithClause for ast::Select {}
621impl ast::HasWithClause for ast::SelectInto {}
622impl ast::HasWithClause for ast::Insert {}
623impl ast::HasWithClause for ast::Update {}
624impl ast::HasWithClause for ast::Delete {}
625
626impl ast::HasCreateTable for ast::CreateTable {}
627impl ast::HasCreateTable for ast::CreateForeignTable {}
628impl ast::HasCreateTable for ast::CreateTableLike {}
629
630#[test]
631fn name() {
632    assert_snapshot!(extract_name("select 1 foo"), @"foo");
633    assert_snapshot!(extract_name("select 1 FOO"), @"foo");
634    assert_snapshot!(extract_name(r#"select 1 "foo""#), @"foo");
635    assert_snapshot!(extract_name(r#"select 1 "Foo""#), @"Foo");
636    assert_snapshot!(extract_name(r#"select 1 "FOO""#), @"FOO");
637    assert_snapshot!(extract_name(r#"select 1 U&"\0066\006f\006f""#), @"foo");
638    assert_snapshot!(extract_name(r#"select 1 U&"@0066@006f@006f" uescape '@'"#), @"foo");
639
640    fn extract_name(source_code: &str) -> String {
641        let parse = SourceFile::parse(source_code);
642        assert!(parse.errors().is_empty());
643        let stmt = parse.tree().stmts().next().unwrap();
644        let ast::Stmt::Select(select) = stmt else {
645            unreachable!()
646        };
647        let name = select
648            .select_clause()
649            .unwrap()
650            .target_list()
651            .unwrap()
652            .targets()
653            .next()
654            .unwrap()
655            .as_name()
656            .unwrap()
657            .name()
658            .unwrap();
659        name.text().to_string()
660    }
661}
662
663#[test]
664fn name_ref() {
665    assert_snapshot!(extract_name_ref("select foo"), @"foo");
666    assert_snapshot!(extract_name_ref("select FOO"), @"foo");
667    assert_snapshot!(extract_name_ref(r#"select "foo""#), @"foo");
668    assert_snapshot!(extract_name_ref(r#"select "Foo""#), @"Foo");
669    assert_snapshot!(extract_name_ref(r#"select "FOO""#), @"FOO");
670    assert_snapshot!(extract_name_ref(r#"select U&"\0066\006f\006f""#), @"foo");
671    assert_snapshot!(extract_name_ref(r#"select U&"@0066@006f@006f" uescape '@'"#), @"foo");
672
673    fn extract_name_ref(source_code: &str) -> String {
674        let parse = SourceFile::parse(source_code);
675        assert!(parse.errors().is_empty());
676        let stmt = parse.tree().stmts().next().unwrap();
677        let ast::Stmt::Select(select) = stmt else {
678            unreachable!()
679        };
680        let select_clause = select.select_clause().unwrap();
681        let target = select_clause
682            .target_list()
683            .unwrap()
684            .targets()
685            .next()
686            .unwrap();
687        let ast::Expr::NameRef(name_ref) = target.expr().unwrap() else {
688            unreachable!()
689        };
690        name_ref.text().to_string()
691    }
692}
693
694#[test]
695fn unicode_quoted_name_keeps_doubled_single_quotes() {
696    let parse = SourceFile::parse(r#"select 1 U&"a''b""#);
697    assert!(parse.errors().is_empty());
698    let stmt = parse.tree().stmts().next().unwrap();
699    let ast::Stmt::Select(select) = stmt else {
700        unreachable!()
701    };
702    let name = select
703        .select_clause()
704        .unwrap()
705        .target_list()
706        .unwrap()
707        .targets()
708        .next()
709        .unwrap()
710        .as_name()
711        .unwrap()
712        .name()
713        .unwrap();
714
715    assert_snapshot!(name.text().to_string(), @"a''b");
716}
717
718#[test]
719fn index_expr() {
720    let source_code = "
721        select foo[bar];
722    ";
723    let parse = SourceFile::parse(source_code);
724    assert!(parse.errors().is_empty());
725    let stmt = parse.tree().stmts().next().unwrap();
726    let ast::Stmt::Select(select) = stmt else {
727        unreachable!()
728    };
729    let select_clause = select.select_clause().unwrap();
730    let target = select_clause
731        .target_list()
732        .unwrap()
733        .targets()
734        .next()
735        .unwrap();
736    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
737        unreachable!()
738    };
739    let base = index_expr.base().unwrap();
740    let index = index_expr.index().unwrap();
741    assert_eq!(base.syntax().text(), "foo");
742    assert_eq!(index.syntax().text(), "bar");
743}
744
745#[test]
746fn slice_expr() {
747    use insta::assert_snapshot;
748    let source_code = "
749        select x[1:2], x[2:], x[:3], x[:];
750    ";
751    let parse = SourceFile::parse(source_code);
752    assert!(parse.errors().is_empty());
753    let stmt = parse.tree().stmts().next().unwrap();
754    let ast::Stmt::Select(select) = stmt else {
755        unreachable!()
756    };
757    let select_clause = select.select_clause().unwrap();
758    let mut targets = select_clause.target_list().unwrap().targets();
759
760    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
761        unreachable!()
762    };
763    assert_snapshot!(slice.syntax(), @"x[1:2]");
764    assert_eq!(slice.base().unwrap().syntax().text(), "x");
765    assert_eq!(slice.start().unwrap().syntax().text(), "1");
766    assert_eq!(slice.end().unwrap().syntax().text(), "2");
767
768    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
769        unreachable!()
770    };
771    assert_snapshot!(slice.syntax(), @"x[2:]");
772    assert_eq!(slice.base().unwrap().syntax().text(), "x");
773    assert_eq!(slice.start().unwrap().syntax().text(), "2");
774    assert!(slice.end().is_none());
775
776    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
777        unreachable!()
778    };
779    assert_snapshot!(slice.syntax(), @"x[:3]");
780    assert_eq!(slice.base().unwrap().syntax().text(), "x");
781    assert!(slice.start().is_none());
782    assert_eq!(slice.end().unwrap().syntax().text(), "3");
783
784    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
785        unreachable!()
786    };
787    assert_snapshot!(slice.syntax(), @"x[:]");
788    assert_eq!(slice.base().unwrap().syntax().text(), "x");
789    assert!(slice.start().is_none());
790    assert!(slice.end().is_none());
791}
792
793#[test]
794fn field_expr() {
795    let source_code = "
796        select foo.bar;
797    ";
798    let parse = SourceFile::parse(source_code);
799    assert!(parse.errors().is_empty());
800    let stmt = parse.tree().stmts().next().unwrap();
801    let ast::Stmt::Select(select) = stmt else {
802        unreachable!()
803    };
804    let select_clause = select.select_clause().unwrap();
805    let target = select_clause
806        .target_list()
807        .unwrap()
808        .targets()
809        .next()
810        .unwrap();
811    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
812        unreachable!()
813    };
814    let base = field_expr.base().unwrap();
815    let field = field_expr.field().unwrap();
816    assert_eq!(base.syntax().text(), "foo");
817    assert_eq!(field.syntax().text(), "bar");
818}
819
820#[test]
821fn between_expr() {
822    let source_code = "
823        select 2 between 1 and 3;
824    ";
825    let parse = SourceFile::parse(source_code);
826    assert!(parse.errors().is_empty());
827    let stmt = parse.tree().stmts().next().unwrap();
828    let ast::Stmt::Select(select) = stmt else {
829        unreachable!()
830    };
831    let select_clause = select.select_clause().unwrap();
832    let target = select_clause
833        .target_list()
834        .unwrap()
835        .targets()
836        .next()
837        .unwrap();
838    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
839        unreachable!()
840    };
841    let target = between_expr.target().unwrap();
842    let start = between_expr.start().unwrap();
843    let end = between_expr.end().unwrap();
844    assert_eq!(target.syntax().text(), "2");
845    assert_eq!(start.syntax().text(), "1");
846    assert_eq!(end.syntax().text(), "3");
847}
848
849#[test]
850fn cast_expr() {
851    use insta::assert_snapshot;
852
853    let cast = extract_expr("select cast('123' as int)");
854    assert!(cast.expr().is_some());
855    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
856    assert!(cast.ty().is_some());
857    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
858
859    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
860    assert!(cast.expr().is_some());
861    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
862    assert!(cast.ty().is_some());
863    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
864
865    let cast = extract_expr("select int '123'");
866    assert!(cast.expr().is_some());
867    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
868    assert!(cast.ty().is_some());
869    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
870
871    let cast = extract_expr("select pg_catalog.int4 '123'");
872    assert!(cast.expr().is_some());
873    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
874    assert!(cast.ty().is_some());
875    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
876
877    let cast = extract_expr("select '123'::int");
878    assert!(cast.expr().is_some());
879    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
880    assert!(cast.ty().is_some());
881    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
882
883    let cast = extract_expr("select '123'::int4");
884    assert!(cast.expr().is_some());
885    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
886    assert!(cast.ty().is_some());
887    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
888
889    let cast = extract_expr("select '123'::pg_catalog.int4");
890    assert!(cast.expr().is_some());
891    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
892    assert!(cast.ty().is_some());
893    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
894
895    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
896    assert!(cast.expr().is_some());
897    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
898    assert!(cast.ty().is_some());
899    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
900
901    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
902    assert!(cast.expr().is_some());
903    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
904    assert!(cast.ty().is_some());
905    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
906
907    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
908    assert!(cast.expr().is_some());
909    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
910    assert!(cast.ty().is_some());
911    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
912
913    let cast = extract_expr("select interval '1' month");
914    assert!(cast.expr().is_some());
915    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
916    assert!(cast.ty().is_some());
917    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
918
919    fn extract_expr(sql: &str) -> ast::CastExpr {
920        let parse = SourceFile::parse(sql);
921        assert!(parse.errors().is_empty());
922        let node = parse
923            .tree()
924            .stmts()
925            .map(|x| match x {
926                ast::Stmt::Select(select) => select
927                    .select_clause()
928                    .unwrap()
929                    .target_list()
930                    .unwrap()
931                    .targets()
932                    .next()
933                    .unwrap()
934                    .expr()
935                    .unwrap()
936                    .clone(),
937                _ => unreachable!(),
938            })
939            .next()
940            .unwrap();
941        match node {
942            ast::Expr::CastExpr(cast) => cast,
943            _ => unreachable!(),
944        }
945    }
946}
947
948#[test]
949fn op_sig() {
950    let source_code = "
951      alter operator p.+ (int4, int8) 
952        owner to u;
953    ";
954    let parse = SourceFile::parse(source_code);
955    assert!(parse.errors().is_empty());
956    let stmt = parse.tree().stmts().next().unwrap();
957    let ast::Stmt::AlterOperator(alter_op) = stmt else {
958        unreachable!()
959    };
960    let op_sig = alter_op.op_sig().unwrap();
961    let lhs = op_sig.lhs().unwrap();
962    let rhs = op_sig.rhs().unwrap();
963    assert_snapshot!(lhs.syntax().text(), @"int4");
964    assert_snapshot!(rhs.syntax().text(), @"int8");
965}
966
967#[test]
968fn cast_sig() {
969    let source_code = "
970      drop cast (text as int);
971    ";
972    let parse = SourceFile::parse(source_code);
973    assert!(parse.errors().is_empty());
974    let stmt = parse.tree().stmts().next().unwrap();
975    let ast::Stmt::DropCast(alter_op) = stmt else {
976        unreachable!()
977    };
978    let cast_sig = alter_op.cast_sig().unwrap();
979    let lhs = cast_sig.lhs().unwrap();
980    let rhs = cast_sig.rhs().unwrap();
981    assert_snapshot!(lhs.syntax().text(), @"text");
982    assert_snapshot!(rhs.syntax().text(), @"int");
983}
984
985#[cfg(test)]
986fn extract_vacuum(sql: &str) -> ast::Vacuum {
987    let parse = SourceFile::parse(sql);
988    assert!(parse.errors().is_empty());
989    let stmt = parse.tree().stmts().next().unwrap();
990    let ast::Stmt::Vacuum(vacuum) = stmt else {
991        unreachable!()
992    };
993    vacuum
994}
995
996#[test]
997fn vacuum_full_is_full() {
998    assert!(extract_vacuum("VACUUM FULL foo;").is_full());
999}
1000
1001#[test]
1002fn vacuum_option_list_full_is_full() {
1003    assert!(extract_vacuum("VACUUM (FULL) foo;").is_full());
1004}
1005
1006#[test]
1007fn vacuum_full_true_is_full() {
1008    assert!(extract_vacuum("VACUUM (FULL TRUE) foo;").is_full());
1009}
1010
1011#[test]
1012fn vacuum_full_on_is_full() {
1013    assert!(extract_vacuum("VACUUM (FULL ON) foo;").is_full());
1014}
1015
1016#[test]
1017fn vacuum_full_1_is_full() {
1018    assert!(extract_vacuum("VACUUM (FULL 1) foo;").is_full());
1019}
1020
1021#[test]
1022fn vacuum_no_full_is_not_full() {
1023    assert!(!extract_vacuum("VACUUM foo;").is_full());
1024}
1025
1026#[test]
1027fn vacuum_other_option_is_not_full() {
1028    assert!(!extract_vacuum("VACUUM (FREEZE) foo;").is_full());
1029}
1030
1031#[test]
1032fn vacuum_full_false_is_not_full() {
1033    assert!(!extract_vacuum("VACUUM (FULL FALSE) foo;").is_full());
1034}
1035
1036#[test]
1037fn vacuum_full_off_is_not_full() {
1038    assert!(!extract_vacuum("VACUUM (FULL OFF) foo;").is_full());
1039}
1040
1041#[test]
1042fn vacuum_full_0_is_not_full() {
1043    assert!(!extract_vacuum("VACUUM (FULL 0) foo;").is_full());
1044}