Skip to main content

squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use rowan::Direction;
36
37use crate::ast;
38use crate::ast::AstNode;
39use crate::unescape::{escape_unicode_esc_str, uescape_char};
40use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
41
42use super::support;
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum LitKind {
46    BitString(SyntaxToken),
47    ByteString(SyntaxToken),
48    DollarQuotedString(SyntaxToken),
49    EscString(SyntaxToken),
50    FloatNumber(SyntaxToken),
51    IntNumber(SyntaxToken),
52    Null(SyntaxToken),
53    PositionalParam(SyntaxToken),
54    String(SyntaxToken),
55    UnicodeEscString(SyntaxToken),
56}
57
58impl ast::Literal {
59    pub fn kind(&self) -> Option<LitKind> {
60        let token = self.syntax().first_child_or_token()?.into_token()?;
61        let kind = match token.kind() {
62            SyntaxKind::STRING => LitKind::String(token),
63            SyntaxKind::NULL_KW => LitKind::Null(token),
64            SyntaxKind::FLOAT_NUMBER => LitKind::FloatNumber(token),
65            SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
66            SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
67            SyntaxKind::BIT_STRING => LitKind::BitString(token),
68            SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
69            SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
70            SyntaxKind::ESC_STRING => LitKind::EscString(token),
71            SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
72            _ => return None,
73        };
74        Some(kind)
75    }
76}
77
78impl ast::Constraint {
79    #[inline]
80    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
81        support::child(self.syntax())
82    }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub enum BinOp {
87    And(SyntaxToken),
88    AtTimeZone(ast::AtTimeZone),
89    Caret(SyntaxToken),
90    Collate(SyntaxToken),
91    ColonColon(ast::ColonColon),
92    ColonEq(SyntaxToken),
93    CustomOp(ast::CustomOp),
94    Eq(SyntaxToken),
95    FatArrow(SyntaxToken),
96    Gteq(SyntaxToken),
97    Ilike(SyntaxToken),
98    In(SyntaxToken),
99    Is(SyntaxToken),
100    IsDistinctFrom(ast::IsDistinctFrom),
101    IsNot(ast::IsNot),
102    IsNotDistinctFrom(ast::IsNotDistinctFrom),
103    LAngle(SyntaxToken),
104    Like(SyntaxToken),
105    Lteq(SyntaxToken),
106    Minus(SyntaxToken),
107    Neq(SyntaxToken),
108    Neqb(SyntaxToken),
109    NotIlike(ast::NotIlike),
110    NotIn(ast::NotIn),
111    NotLike(ast::NotLike),
112    NotSimilarTo(ast::NotSimilarTo),
113    OperatorCall(ast::OperatorCall),
114    Or(SyntaxToken),
115    Overlaps(SyntaxToken),
116    Percent(SyntaxToken),
117    Plus(SyntaxToken),
118    RAngle(SyntaxToken),
119    SimilarTo(ast::SimilarTo),
120    Slash(SyntaxToken),
121    Star(SyntaxToken),
122}
123
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum PostfixOp {
126    AtLocal(SyntaxToken),
127    IsJson(ast::IsJson),
128    IsJsonArray(ast::IsJsonArray),
129    IsJsonObject(ast::IsJsonObject),
130    IsJsonScalar(ast::IsJsonScalar),
131    IsJsonValue(ast::IsJsonValue),
132    IsNormalized(ast::IsNormalized),
133    IsNotJson(ast::IsNotJson),
134    IsNotJsonArray(ast::IsNotJsonArray),
135    IsNotJsonObject(ast::IsNotJsonObject),
136    IsNotJsonScalar(ast::IsNotJsonScalar),
137    IsNotJsonValue(ast::IsNotJsonValue),
138    IsNotNormalized(ast::IsNotNormalized),
139    IsNull(SyntaxToken),
140    NotNull(SyntaxToken),
141}
142
143impl ast::BinExpr {
144    #[inline]
145    pub fn lhs(&self) -> Option<ast::Expr> {
146        support::children(self.syntax()).next()
147    }
148
149    #[inline]
150    pub fn rhs(&self) -> Option<ast::Expr> {
151        support::children(self.syntax()).nth(1)
152    }
153
154    pub fn op(&self) -> Option<BinOp> {
155        let lhs = self.lhs()?;
156        for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
157            match child {
158                NodeOrToken::Token(token) => {
159                    let op = match token.kind() {
160                        SyntaxKind::AND_KW => BinOp::And(token),
161                        SyntaxKind::CARET => BinOp::Caret(token),
162                        SyntaxKind::COLLATE_KW => BinOp::Collate(token),
163                        SyntaxKind::COLON_EQ => BinOp::ColonEq(token),
164                        SyntaxKind::EQ => BinOp::Eq(token),
165                        SyntaxKind::FAT_ARROW => BinOp::FatArrow(token),
166                        SyntaxKind::GTEQ => BinOp::Gteq(token),
167                        SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
168                        SyntaxKind::IN_KW => BinOp::In(token),
169                        SyntaxKind::IS_KW => BinOp::Is(token),
170                        SyntaxKind::L_ANGLE => BinOp::LAngle(token),
171                        SyntaxKind::LIKE_KW => BinOp::Like(token),
172                        SyntaxKind::LTEQ => BinOp::Lteq(token),
173                        SyntaxKind::MINUS => BinOp::Minus(token),
174                        SyntaxKind::NEQ => BinOp::Neq(token),
175                        SyntaxKind::NEQB => BinOp::Neqb(token),
176                        SyntaxKind::OR_KW => BinOp::Or(token),
177                        SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
178                        SyntaxKind::PERCENT => BinOp::Percent(token),
179                        SyntaxKind::PLUS => BinOp::Plus(token),
180                        SyntaxKind::R_ANGLE => BinOp::RAngle(token),
181                        SyntaxKind::SLASH => BinOp::Slash(token),
182                        SyntaxKind::STAR => BinOp::Star(token),
183                        _ => continue,
184                    };
185                    return Some(op);
186                }
187                NodeOrToken::Node(node) => {
188                    let op = match node.kind() {
189                        SyntaxKind::AT_TIME_ZONE => {
190                            BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
191                        }
192                        SyntaxKind::COLON_COLON => {
193                            BinOp::ColonColon(ast::ColonColon { syntax: node })
194                        }
195                        SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
196                        SyntaxKind::IS_DISTINCT_FROM => {
197                            BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
198                        }
199                        SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
200                        SyntaxKind::IS_NOT_DISTINCT_FROM => {
201                            BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
202                        }
203                        SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
204                        SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
205                        SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
206                        SyntaxKind::NOT_SIMILAR_TO => {
207                            BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
208                        }
209                        SyntaxKind::OPERATOR_CALL => {
210                            BinOp::OperatorCall(ast::OperatorCall { syntax: node })
211                        }
212                        SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
213                        _ => continue,
214                    };
215                    return Some(op);
216                }
217            }
218        }
219        None
220    }
221}
222
223impl ast::PostfixExpr {
224    pub fn op(&self) -> Option<PostfixOp> {
225        let lhs = self.expr()?;
226
227        let siblings = lhs.syntax().siblings_with_tokens(Direction::Next).skip(1);
228        for child in siblings {
229            match child {
230                NodeOrToken::Token(token) => {
231                    let op = match token.kind() {
232                        SyntaxKind::AT_KW => PostfixOp::AtLocal(token),
233                        SyntaxKind::ISNULL_KW => PostfixOp::IsNull(token),
234                        SyntaxKind::NOTNULL_KW => PostfixOp::NotNull(token),
235                        _ => continue,
236                    };
237                    return Some(op);
238                }
239                NodeOrToken::Node(node) => {
240                    let op = match node.kind() {
241                        SyntaxKind::IS_JSON => PostfixOp::IsJson(ast::IsJson { syntax: node }),
242                        SyntaxKind::IS_JSON_ARRAY => {
243                            PostfixOp::IsJsonArray(ast::IsJsonArray { syntax: node })
244                        }
245                        SyntaxKind::IS_JSON_OBJECT => {
246                            PostfixOp::IsJsonObject(ast::IsJsonObject { syntax: node })
247                        }
248                        SyntaxKind::IS_JSON_SCALAR => {
249                            PostfixOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
250                        }
251                        SyntaxKind::IS_JSON_VALUE => {
252                            PostfixOp::IsJsonValue(ast::IsJsonValue { syntax: node })
253                        }
254                        SyntaxKind::IS_NORMALIZED => {
255                            PostfixOp::IsNormalized(ast::IsNormalized { syntax: node })
256                        }
257                        SyntaxKind::IS_NOT_JSON => {
258                            PostfixOp::IsNotJson(ast::IsNotJson { syntax: node })
259                        }
260                        SyntaxKind::IS_NOT_JSON_ARRAY => {
261                            PostfixOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
262                        }
263                        SyntaxKind::IS_NOT_JSON_OBJECT => {
264                            PostfixOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
265                        }
266                        SyntaxKind::IS_NOT_JSON_SCALAR => {
267                            PostfixOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
268                        }
269                        SyntaxKind::IS_NOT_JSON_VALUE => {
270                            PostfixOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
271                        }
272                        SyntaxKind::IS_NOT_NORMALIZED => {
273                            PostfixOp::IsNotNormalized(ast::IsNotNormalized { syntax: node })
274                        }
275                        _ => continue,
276                    };
277                    return Some(op);
278                }
279            }
280        }
281
282        None
283    }
284}
285
286impl ast::FieldExpr {
287    // We have NameRef as a variant of Expr which complicates things (and it
288    // might not be worth it).
289    // Rust analyzer doesn't do this so it doesn't have to special case this.
290    #[inline]
291    pub fn base(&self) -> Option<ast::Expr> {
292        support::children(self.syntax()).next()
293    }
294    #[inline]
295    pub fn field(&self) -> Option<ast::NameRef> {
296        support::children(self.syntax()).last()
297    }
298}
299
300impl ast::IndexExpr {
301    #[inline]
302    pub fn base(&self) -> Option<ast::Expr> {
303        support::children(&self.syntax).next()
304    }
305    #[inline]
306    pub fn index(&self) -> Option<ast::Expr> {
307        support::children(&self.syntax).nth(1)
308    }
309}
310
311impl ast::SliceExpr {
312    #[inline]
313    pub fn base(&self) -> Option<ast::Expr> {
314        support::children(&self.syntax).next()
315    }
316
317    #[inline]
318    pub fn start(&self) -> Option<ast::Expr> {
319        // With `select x[1:]`, we have two exprs, `x` and `1`.
320        // We skip over the first one, and then we want the second one, but we
321        // want to make sure we don't choose the end expr if instead we had:
322        // `select x[:1]`
323        let colon = self.colon_token()?;
324        support::children(&self.syntax)
325            .skip(1)
326            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
327    }
328
329    #[inline]
330    pub fn end(&self) -> Option<ast::Expr> {
331        // We want to make sure we get the last expr after the `:` which is the
332        // end of the slice, i.e., `2` in: `select x[:2]`
333        let colon = self.colon_token()?;
334        support::children(&self.syntax)
335            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
336    }
337}
338
339impl ast::RenameColumn {
340    #[inline]
341    pub fn from(&self) -> Option<ast::NameRef> {
342        support::children(&self.syntax).nth(0)
343    }
344    #[inline]
345    pub fn to(&self) -> Option<ast::NameRef> {
346        support::children(&self.syntax).nth(1)
347    }
348}
349
350impl ast::ForeignKeyConstraint {
351    #[inline]
352    pub fn from_columns(&self) -> Option<ast::ColumnList> {
353        support::children(&self.syntax).nth(0)
354    }
355    #[inline]
356    pub fn to_columns(&self) -> Option<ast::ColumnList> {
357        support::children(&self.syntax).nth(1)
358    }
359}
360
361impl ast::BetweenExpr {
362    #[inline]
363    pub fn target(&self) -> Option<ast::Expr> {
364        support::children(&self.syntax).nth(0)
365    }
366    #[inline]
367    pub fn start(&self) -> Option<ast::Expr> {
368        support::children(&self.syntax).nth(1)
369    }
370    #[inline]
371    pub fn end(&self) -> Option<ast::Expr> {
372        support::children(&self.syntax).nth(2)
373    }
374}
375
376impl ast::WhenClause {
377    #[inline]
378    pub fn condition(&self) -> Option<ast::Expr> {
379        support::children(&self.syntax).next()
380    }
381    #[inline]
382    pub fn then(&self) -> Option<ast::Expr> {
383        support::children(&self.syntax).nth(1)
384    }
385}
386
387impl ast::CompoundSelect {
388    #[inline]
389    pub fn lhs(&self) -> Option<ast::SelectVariant> {
390        support::children(&self.syntax).next()
391    }
392    #[inline]
393    pub fn rhs(&self) -> Option<ast::SelectVariant> {
394        support::children(&self.syntax).nth(1)
395    }
396}
397
398impl ast::NameRef {
399    #[inline]
400    pub fn text(&self) -> String {
401        normalize_name_node(self.syntax())
402    }
403
404    #[inline]
405    pub fn is_quoted(&self) -> bool {
406        is_quoted(self.syntax())
407    }
408}
409
410impl ast::Name {
411    #[inline]
412    pub fn text(&self) -> String {
413        normalize_name_node(self.syntax())
414    }
415
416    #[inline]
417    pub fn is_quoted(&self) -> bool {
418        is_quoted(self.syntax())
419    }
420}
421
422fn is_quoted(node: &SyntaxNode) -> bool {
423    let text = node.text();
424    let first = text.char_at(0.into());
425    let second = text.char_at(1.into());
426    matches!(
427        (first, second),
428        (Some('u' | 'U'), Some('"')) | (Some('"'), Some(_))
429    )
430}
431
432// TODO: return a NewType wrapper around String?
433fn normalize_name_node(node: &SyntaxNode) -> String {
434    let mut tokens = node
435        .children_with_tokens()
436        .filter_map(|el| el.into_token())
437        .filter(|t| !t.kind().is_trivia());
438
439    let Some(ident_token) = tokens.next() else {
440        return String::new();
441    };
442    let raw = ident_token.text();
443
444    let unicode_inner = raw
445        .strip_prefix(['u', 'U'])
446        .and_then(|s| s.strip_prefix("&\""))
447        .and_then(|s| s.strip_suffix('"'));
448
449    if let Some(inner) = unicode_inner {
450        let mut escape_char = '\\';
451        if let Some(uesc) = tokens.next()
452            && uesc.kind() == SyntaxKind::UESCAPE_KW
453            && let Some(token) = tokens.next()
454            && let Some(ch) = uescape_char(token.text())
455        {
456            escape_char = ch;
457        }
458
459        let inner = inner.replace(r#""""#, "\"");
460        let mut result = String::with_capacity(inner.len());
461        escape_unicode_esc_str(&inner, escape_char, |_range, r| {
462            if let Ok(ch) = r {
463                result.push(ch);
464            }
465        });
466        return result;
467    }
468
469    raw.strip_prefix('"')
470        .and_then(|t| t.strip_suffix('"'))
471        .map(|x| x.replace(r#""""#, "\""))
472        .unwrap_or_else(|| raw.to_ascii_lowercase())
473}
474
475impl ast::CharType {
476    #[inline]
477    pub fn text(&self) -> TokenText<'_> {
478        text_of_first_token(self.syntax())
479    }
480}
481
482fn is_falsey_option(text: &str) -> bool {
483    text == "0" || text.eq_ignore_ascii_case("false") || text.eq_ignore_ascii_case("off")
484}
485
486impl ast::Vacuum {
487    pub fn is_full(&self) -> bool {
488        self.full_token().is_some()
489            // TODO: we need a better way of handling option lists
490            || self.vacuum_option_list().is_some_and(|opt_list| {
491                opt_list.vacuum_options().any(|opt| {
492                    let mut tokens = opt
493                        .syntax()
494                        .descendants_with_tokens()
495                        .filter_map(|child| child.into_token())
496                        .filter(|token| !token.kind().is_trivia());
497
498                    tokens
499                        .next()
500                        .is_some_and(|token| token.text().eq_ignore_ascii_case("full"))
501                        && tokens
502                            .next()
503                            .is_none_or(|token| !is_falsey_option(token.text()))
504                })
505            })
506    }
507}
508
509impl ast::OpSig {
510    #[inline]
511    pub fn lhs(&self) -> Option<ast::Type> {
512        support::children(self.syntax()).next()
513    }
514
515    #[inline]
516    pub fn rhs(&self) -> Option<ast::Type> {
517        support::children(self.syntax()).nth(1)
518    }
519}
520
521impl ast::CastSig {
522    #[inline]
523    pub fn lhs(&self) -> Option<ast::Type> {
524        support::children(self.syntax()).next()
525    }
526
527    #[inline]
528    pub fn rhs(&self) -> Option<ast::Type> {
529        support::children(self.syntax()).nth(1)
530    }
531}
532
533pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
534    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
535        green_ref
536            .children()
537            .next()
538            .and_then(NodeOrToken::into_token)
539            .unwrap()
540    }
541
542    match node.green() {
543        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
544        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
545    }
546}
547
548impl ast::WithQuery {
549    #[inline]
550    pub fn with_clause(&self) -> Option<ast::WithClause> {
551        support::child(self.syntax())
552    }
553}
554
555impl ast::SelectVariant {
556    #[inline]
557    pub fn target_list(&self) -> Option<ast::TargetList> {
558        match self {
559            ast::SelectVariant::Select(select) => {
560                return select.select_clause()?.target_list();
561            }
562            ast::SelectVariant::SelectInto(select_into) => {
563                return select_into.select_clause()?.target_list();
564            }
565            ast::SelectVariant::ParenSelect(paren_select) => {
566                return paren_select.select()?.target_list();
567            }
568            _ => return None,
569        }
570    }
571}
572
573impl ast::HasParamList for ast::FunctionSig {}
574impl ast::HasParamList for ast::Aggregate {}
575
576impl ast::NameLike for ast::Name {
577    #[inline]
578    fn text(&self) -> String {
579        self.text()
580    }
581}
582impl ast::NameLike for ast::NameRef {
583    #[inline]
584    fn text(&self) -> String {
585        self.text()
586    }
587}
588
589impl ast::HasWithClause for ast::Select {}
590impl ast::HasWithClause for ast::SelectInto {}
591impl ast::HasWithClause for ast::Insert {}
592impl ast::HasWithClause for ast::Update {}
593impl ast::HasWithClause for ast::Delete {}
594
595impl ast::HasCreateTable for ast::CreateTable {}
596impl ast::HasCreateTable for ast::CreateForeignTable {}
597impl ast::HasCreateTable for ast::CreateTableLike {}
598
599#[test]
600fn name() {
601    assert_snapshot!(extract_name("select 1 foo"), @"foo");
602    assert_snapshot!(extract_name("select 1 FOO"), @"foo");
603    assert_snapshot!(extract_name(r#"select 1 "foo""#), @"foo");
604    assert_snapshot!(extract_name(r#"select 1 "Foo""#), @"Foo");
605    assert_snapshot!(extract_name(r#"select 1 "FOO""#), @"FOO");
606    assert_snapshot!(extract_name(r#"select 1 U&"\0066\006f\006f""#), @"foo");
607    assert_snapshot!(extract_name(r#"select 1 U&"@0066@006f@006f" uescape '@'"#), @"foo");
608
609    fn extract_name(source_code: &str) -> String {
610        let parse = SourceFile::parse(source_code);
611        assert!(parse.errors().is_empty());
612        let stmt = parse.tree().stmts().next().unwrap();
613        let ast::Stmt::Select(select) = stmt else {
614            unreachable!()
615        };
616        let name = select
617            .select_clause()
618            .unwrap()
619            .target_list()
620            .unwrap()
621            .targets()
622            .next()
623            .unwrap()
624            .as_name()
625            .unwrap()
626            .name()
627            .unwrap();
628        name.text().to_string()
629    }
630}
631
632#[test]
633fn name_ref() {
634    assert_snapshot!(extract_name_ref("select foo"), @"foo");
635    assert_snapshot!(extract_name_ref("select FOO"), @"foo");
636    assert_snapshot!(extract_name_ref(r#"select "foo""#), @"foo");
637    assert_snapshot!(extract_name_ref(r#"select "Foo""#), @"Foo");
638    assert_snapshot!(extract_name_ref(r#"select "FOO""#), @"FOO");
639    assert_snapshot!(extract_name_ref(r#"select U&"\0066\006f\006f""#), @"foo");
640    assert_snapshot!(extract_name_ref(r#"select U&"@0066@006f@006f" uescape '@'"#), @"foo");
641
642    fn extract_name_ref(source_code: &str) -> String {
643        let parse = SourceFile::parse(source_code);
644        assert!(parse.errors().is_empty());
645        let stmt = parse.tree().stmts().next().unwrap();
646        let ast::Stmt::Select(select) = stmt else {
647            unreachable!()
648        };
649        let select_clause = select.select_clause().unwrap();
650        let target = select_clause
651            .target_list()
652            .unwrap()
653            .targets()
654            .next()
655            .unwrap();
656        let ast::Expr::NameRef(name_ref) = target.expr().unwrap() else {
657            unreachable!()
658        };
659        name_ref.text().to_string()
660    }
661}
662
663#[test]
664fn index_expr() {
665    let source_code = "
666        select foo[bar];
667    ";
668    let parse = SourceFile::parse(source_code);
669    assert!(parse.errors().is_empty());
670    let stmt = parse.tree().stmts().next().unwrap();
671    let ast::Stmt::Select(select) = stmt else {
672        unreachable!()
673    };
674    let select_clause = select.select_clause().unwrap();
675    let target = select_clause
676        .target_list()
677        .unwrap()
678        .targets()
679        .next()
680        .unwrap();
681    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
682        unreachable!()
683    };
684    let base = index_expr.base().unwrap();
685    let index = index_expr.index().unwrap();
686    assert_eq!(base.syntax().text(), "foo");
687    assert_eq!(index.syntax().text(), "bar");
688}
689
690#[test]
691fn slice_expr() {
692    use insta::assert_snapshot;
693    let source_code = "
694        select x[1:2], x[2:], x[:3], x[:];
695    ";
696    let parse = SourceFile::parse(source_code);
697    assert!(parse.errors().is_empty());
698    let stmt = parse.tree().stmts().next().unwrap();
699    let ast::Stmt::Select(select) = stmt else {
700        unreachable!()
701    };
702    let select_clause = select.select_clause().unwrap();
703    let mut targets = select_clause.target_list().unwrap().targets();
704
705    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
706        unreachable!()
707    };
708    assert_snapshot!(slice.syntax(), @"x[1:2]");
709    assert_eq!(slice.base().unwrap().syntax().text(), "x");
710    assert_eq!(slice.start().unwrap().syntax().text(), "1");
711    assert_eq!(slice.end().unwrap().syntax().text(), "2");
712
713    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
714        unreachable!()
715    };
716    assert_snapshot!(slice.syntax(), @"x[2:]");
717    assert_eq!(slice.base().unwrap().syntax().text(), "x");
718    assert_eq!(slice.start().unwrap().syntax().text(), "2");
719    assert!(slice.end().is_none());
720
721    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
722        unreachable!()
723    };
724    assert_snapshot!(slice.syntax(), @"x[:3]");
725    assert_eq!(slice.base().unwrap().syntax().text(), "x");
726    assert!(slice.start().is_none());
727    assert_eq!(slice.end().unwrap().syntax().text(), "3");
728
729    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
730        unreachable!()
731    };
732    assert_snapshot!(slice.syntax(), @"x[:]");
733    assert_eq!(slice.base().unwrap().syntax().text(), "x");
734    assert!(slice.start().is_none());
735    assert!(slice.end().is_none());
736}
737
738#[test]
739fn field_expr() {
740    let source_code = "
741        select foo.bar;
742    ";
743    let parse = SourceFile::parse(source_code);
744    assert!(parse.errors().is_empty());
745    let stmt = parse.tree().stmts().next().unwrap();
746    let ast::Stmt::Select(select) = stmt else {
747        unreachable!()
748    };
749    let select_clause = select.select_clause().unwrap();
750    let target = select_clause
751        .target_list()
752        .unwrap()
753        .targets()
754        .next()
755        .unwrap();
756    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
757        unreachable!()
758    };
759    let base = field_expr.base().unwrap();
760    let field = field_expr.field().unwrap();
761    assert_eq!(base.syntax().text(), "foo");
762    assert_eq!(field.syntax().text(), "bar");
763}
764
765#[test]
766fn between_expr() {
767    let source_code = "
768        select 2 between 1 and 3;
769    ";
770    let parse = SourceFile::parse(source_code);
771    assert!(parse.errors().is_empty());
772    let stmt = parse.tree().stmts().next().unwrap();
773    let ast::Stmt::Select(select) = stmt else {
774        unreachable!()
775    };
776    let select_clause = select.select_clause().unwrap();
777    let target = select_clause
778        .target_list()
779        .unwrap()
780        .targets()
781        .next()
782        .unwrap();
783    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
784        unreachable!()
785    };
786    let target = between_expr.target().unwrap();
787    let start = between_expr.start().unwrap();
788    let end = between_expr.end().unwrap();
789    assert_eq!(target.syntax().text(), "2");
790    assert_eq!(start.syntax().text(), "1");
791    assert_eq!(end.syntax().text(), "3");
792}
793
794#[test]
795fn cast_expr() {
796    use insta::assert_snapshot;
797
798    let cast = extract_expr("select cast('123' as int)");
799    assert!(cast.expr().is_some());
800    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
801    assert!(cast.ty().is_some());
802    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
803
804    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
805    assert!(cast.expr().is_some());
806    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
807    assert!(cast.ty().is_some());
808    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
809
810    let cast = extract_expr("select int '123'");
811    assert!(cast.expr().is_some());
812    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
813    assert!(cast.ty().is_some());
814    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
815
816    let cast = extract_expr("select pg_catalog.int4 '123'");
817    assert!(cast.expr().is_some());
818    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
819    assert!(cast.ty().is_some());
820    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
821
822    let cast = extract_expr("select '123'::int");
823    assert!(cast.expr().is_some());
824    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
825    assert!(cast.ty().is_some());
826    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
827
828    let cast = extract_expr("select '123'::int4");
829    assert!(cast.expr().is_some());
830    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
831    assert!(cast.ty().is_some());
832    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
833
834    let cast = extract_expr("select '123'::pg_catalog.int4");
835    assert!(cast.expr().is_some());
836    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
837    assert!(cast.ty().is_some());
838    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
839
840    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
841    assert!(cast.expr().is_some());
842    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
843    assert!(cast.ty().is_some());
844    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
845
846    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
847    assert!(cast.expr().is_some());
848    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
849    assert!(cast.ty().is_some());
850    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
851
852    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
853    assert!(cast.expr().is_some());
854    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
855    assert!(cast.ty().is_some());
856    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
857
858    let cast = extract_expr("select interval '1' month");
859    assert!(cast.expr().is_some());
860    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
861    assert!(cast.ty().is_some());
862    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
863
864    fn extract_expr(sql: &str) -> ast::CastExpr {
865        let parse = SourceFile::parse(sql);
866        assert!(parse.errors().is_empty());
867        let node = parse
868            .tree()
869            .stmts()
870            .map(|x| match x {
871                ast::Stmt::Select(select) => select
872                    .select_clause()
873                    .unwrap()
874                    .target_list()
875                    .unwrap()
876                    .targets()
877                    .next()
878                    .unwrap()
879                    .expr()
880                    .unwrap()
881                    .clone(),
882                _ => unreachable!(),
883            })
884            .next()
885            .unwrap();
886        match node {
887            ast::Expr::CastExpr(cast) => cast,
888            _ => unreachable!(),
889        }
890    }
891}
892
893#[test]
894fn op_sig() {
895    let source_code = "
896      alter operator p.+ (int4, int8) 
897        owner to u;
898    ";
899    let parse = SourceFile::parse(source_code);
900    assert!(parse.errors().is_empty());
901    let stmt = parse.tree().stmts().next().unwrap();
902    let ast::Stmt::AlterOperator(alter_op) = stmt else {
903        unreachable!()
904    };
905    let op_sig = alter_op.op_sig().unwrap();
906    let lhs = op_sig.lhs().unwrap();
907    let rhs = op_sig.rhs().unwrap();
908    assert_snapshot!(lhs.syntax().text(), @"int4");
909    assert_snapshot!(rhs.syntax().text(), @"int8");
910}
911
912#[test]
913fn cast_sig() {
914    let source_code = "
915      drop cast (text as int);
916    ";
917    let parse = SourceFile::parse(source_code);
918    assert!(parse.errors().is_empty());
919    let stmt = parse.tree().stmts().next().unwrap();
920    let ast::Stmt::DropCast(alter_op) = stmt else {
921        unreachable!()
922    };
923    let cast_sig = alter_op.cast_sig().unwrap();
924    let lhs = cast_sig.lhs().unwrap();
925    let rhs = cast_sig.rhs().unwrap();
926    assert_snapshot!(lhs.syntax().text(), @"text");
927    assert_snapshot!(rhs.syntax().text(), @"int");
928}
929
930#[cfg(test)]
931fn extract_vacuum(sql: &str) -> ast::Vacuum {
932    let parse = SourceFile::parse(sql);
933    assert!(parse.errors().is_empty());
934    let stmt = parse.tree().stmts().next().unwrap();
935    let ast::Stmt::Vacuum(vacuum) = stmt else {
936        unreachable!()
937    };
938    vacuum
939}
940
941#[test]
942fn vacuum_full_is_full() {
943    assert!(extract_vacuum("VACUUM FULL foo;").is_full());
944}
945
946#[test]
947fn vacuum_option_list_full_is_full() {
948    assert!(extract_vacuum("VACUUM (FULL) foo;").is_full());
949}
950
951#[test]
952fn vacuum_full_true_is_full() {
953    assert!(extract_vacuum("VACUUM (FULL TRUE) foo;").is_full());
954}
955
956#[test]
957fn vacuum_full_on_is_full() {
958    assert!(extract_vacuum("VACUUM (FULL ON) foo;").is_full());
959}
960
961#[test]
962fn vacuum_full_1_is_full() {
963    assert!(extract_vacuum("VACUUM (FULL 1) foo;").is_full());
964}
965
966#[test]
967fn vacuum_no_full_is_not_full() {
968    assert!(!extract_vacuum("VACUUM foo;").is_full());
969}
970
971#[test]
972fn vacuum_other_option_is_not_full() {
973    assert!(!extract_vacuum("VACUUM (FREEZE) foo;").is_full());
974}
975
976#[test]
977fn vacuum_full_false_is_not_full() {
978    assert!(!extract_vacuum("VACUUM (FULL FALSE) foo;").is_full());
979}
980
981#[test]
982fn vacuum_full_off_is_not_full() {
983    assert!(!extract_vacuum("VACUUM (FULL OFF) foo;").is_full());
984}
985
986#[test]
987fn vacuum_full_0_is_not_full() {
988    assert!(!extract_vacuum("VACUUM (FULL 0) foo;").is_full());
989}