Skip to main content

squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use rowan::Direction;
36
37use crate::ast;
38use crate::ast::AstNode;
39use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
40
41use super::support;
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum LitKind {
45    BitString(SyntaxToken),
46    ByteString(SyntaxToken),
47    DollarQuotedString(SyntaxToken),
48    EscString(SyntaxToken),
49    FloatNumber(SyntaxToken),
50    IntNumber(SyntaxToken),
51    Null(SyntaxToken),
52    PositionalParam(SyntaxToken),
53    String(SyntaxToken),
54    UnicodeEscString(SyntaxToken),
55}
56
57impl ast::Literal {
58    pub fn kind(&self) -> Option<LitKind> {
59        let token = self.syntax().first_child_or_token()?.into_token()?;
60        let kind = match token.kind() {
61            SyntaxKind::STRING => LitKind::String(token),
62            SyntaxKind::NULL_KW => LitKind::Null(token),
63            SyntaxKind::FLOAT_NUMBER => LitKind::FloatNumber(token),
64            SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
65            SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
66            SyntaxKind::BIT_STRING => LitKind::BitString(token),
67            SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
68            SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
69            SyntaxKind::ESC_STRING => LitKind::EscString(token),
70            SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
71            _ => return None,
72        };
73        Some(kind)
74    }
75}
76
77impl ast::Constraint {
78    #[inline]
79    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
80        support::child(self.syntax())
81    }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum BinOp {
86    And(SyntaxToken),
87    AtTimeZone(ast::AtTimeZone),
88    Caret(SyntaxToken),
89    Collate(SyntaxToken),
90    ColonColon(ast::ColonColon),
91    ColonEq(SyntaxToken),
92    CustomOp(ast::CustomOp),
93    Eq(SyntaxToken),
94    FatArrow(SyntaxToken),
95    Gteq(SyntaxToken),
96    Ilike(SyntaxToken),
97    In(SyntaxToken),
98    Is(SyntaxToken),
99    IsDistinctFrom(ast::IsDistinctFrom),
100    IsNot(ast::IsNot),
101    IsNotDistinctFrom(ast::IsNotDistinctFrom),
102    LAngle(SyntaxToken),
103    Like(SyntaxToken),
104    Lteq(SyntaxToken),
105    Minus(SyntaxToken),
106    Neq(SyntaxToken),
107    Neqb(SyntaxToken),
108    NotIlike(ast::NotIlike),
109    NotIn(ast::NotIn),
110    NotLike(ast::NotLike),
111    NotSimilarTo(ast::NotSimilarTo),
112    OperatorCall(ast::OperatorCall),
113    Or(SyntaxToken),
114    Overlaps(SyntaxToken),
115    Percent(SyntaxToken),
116    Plus(SyntaxToken),
117    RAngle(SyntaxToken),
118    SimilarTo(ast::SimilarTo),
119    Slash(SyntaxToken),
120    Star(SyntaxToken),
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub enum PostfixOp {
125    AtLocal(SyntaxToken),
126    IsJson(ast::IsJson),
127    IsJsonArray(ast::IsJsonArray),
128    IsJsonObject(ast::IsJsonObject),
129    IsJsonScalar(ast::IsJsonScalar),
130    IsJsonValue(ast::IsJsonValue),
131    IsNormalized(ast::IsNormalized),
132    IsNotJson(ast::IsNotJson),
133    IsNotJsonArray(ast::IsNotJsonArray),
134    IsNotJsonObject(ast::IsNotJsonObject),
135    IsNotJsonScalar(ast::IsNotJsonScalar),
136    IsNotJsonValue(ast::IsNotJsonValue),
137    IsNotNormalized(ast::IsNotNormalized),
138    IsNull(SyntaxToken),
139    NotNull(SyntaxToken),
140}
141
142impl ast::BinExpr {
143    #[inline]
144    pub fn lhs(&self) -> Option<ast::Expr> {
145        support::children(self.syntax()).next()
146    }
147
148    #[inline]
149    pub fn rhs(&self) -> Option<ast::Expr> {
150        support::children(self.syntax()).nth(1)
151    }
152
153    pub fn op(&self) -> Option<BinOp> {
154        let lhs = self.lhs()?;
155        for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
156            match child {
157                NodeOrToken::Token(token) => {
158                    let op = match token.kind() {
159                        SyntaxKind::AND_KW => BinOp::And(token),
160                        SyntaxKind::CARET => BinOp::Caret(token),
161                        SyntaxKind::COLLATE_KW => BinOp::Collate(token),
162                        SyntaxKind::COLON_EQ => BinOp::ColonEq(token),
163                        SyntaxKind::EQ => BinOp::Eq(token),
164                        SyntaxKind::FAT_ARROW => BinOp::FatArrow(token),
165                        SyntaxKind::GTEQ => BinOp::Gteq(token),
166                        SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
167                        SyntaxKind::IN_KW => BinOp::In(token),
168                        SyntaxKind::IS_KW => BinOp::Is(token),
169                        SyntaxKind::L_ANGLE => BinOp::LAngle(token),
170                        SyntaxKind::LIKE_KW => BinOp::Like(token),
171                        SyntaxKind::LTEQ => BinOp::Lteq(token),
172                        SyntaxKind::MINUS => BinOp::Minus(token),
173                        SyntaxKind::NEQ => BinOp::Neq(token),
174                        SyntaxKind::NEQB => BinOp::Neqb(token),
175                        SyntaxKind::OR_KW => BinOp::Or(token),
176                        SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
177                        SyntaxKind::PERCENT => BinOp::Percent(token),
178                        SyntaxKind::PLUS => BinOp::Plus(token),
179                        SyntaxKind::R_ANGLE => BinOp::RAngle(token),
180                        SyntaxKind::SLASH => BinOp::Slash(token),
181                        SyntaxKind::STAR => BinOp::Star(token),
182                        _ => continue,
183                    };
184                    return Some(op);
185                }
186                NodeOrToken::Node(node) => {
187                    let op = match node.kind() {
188                        SyntaxKind::AT_TIME_ZONE => {
189                            BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
190                        }
191                        SyntaxKind::COLON_COLON => {
192                            BinOp::ColonColon(ast::ColonColon { syntax: node })
193                        }
194                        SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
195                        SyntaxKind::IS_DISTINCT_FROM => {
196                            BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
197                        }
198                        SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
199                        SyntaxKind::IS_NOT_DISTINCT_FROM => {
200                            BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
201                        }
202                        SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
203                        SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
204                        SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
205                        SyntaxKind::NOT_SIMILAR_TO => {
206                            BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
207                        }
208                        SyntaxKind::OPERATOR_CALL => {
209                            BinOp::OperatorCall(ast::OperatorCall { syntax: node })
210                        }
211                        SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
212                        _ => continue,
213                    };
214                    return Some(op);
215                }
216            }
217        }
218        None
219    }
220}
221
222impl ast::PostfixExpr {
223    pub fn op(&self) -> Option<PostfixOp> {
224        let lhs = self.expr()?;
225
226        let siblings = lhs.syntax().siblings_with_tokens(Direction::Next).skip(1);
227        for child in siblings {
228            match child {
229                NodeOrToken::Token(token) => {
230                    let op = match token.kind() {
231                        SyntaxKind::AT_KW => PostfixOp::AtLocal(token),
232                        SyntaxKind::ISNULL_KW => PostfixOp::IsNull(token),
233                        SyntaxKind::NOTNULL_KW => PostfixOp::NotNull(token),
234                        _ => continue,
235                    };
236                    return Some(op);
237                }
238                NodeOrToken::Node(node) => {
239                    let op = match node.kind() {
240                        SyntaxKind::IS_JSON => PostfixOp::IsJson(ast::IsJson { syntax: node }),
241                        SyntaxKind::IS_JSON_ARRAY => {
242                            PostfixOp::IsJsonArray(ast::IsJsonArray { syntax: node })
243                        }
244                        SyntaxKind::IS_JSON_OBJECT => {
245                            PostfixOp::IsJsonObject(ast::IsJsonObject { syntax: node })
246                        }
247                        SyntaxKind::IS_JSON_SCALAR => {
248                            PostfixOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
249                        }
250                        SyntaxKind::IS_JSON_VALUE => {
251                            PostfixOp::IsJsonValue(ast::IsJsonValue { syntax: node })
252                        }
253                        SyntaxKind::IS_NORMALIZED => {
254                            PostfixOp::IsNormalized(ast::IsNormalized { syntax: node })
255                        }
256                        SyntaxKind::IS_NOT_JSON => {
257                            PostfixOp::IsNotJson(ast::IsNotJson { syntax: node })
258                        }
259                        SyntaxKind::IS_NOT_JSON_ARRAY => {
260                            PostfixOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
261                        }
262                        SyntaxKind::IS_NOT_JSON_OBJECT => {
263                            PostfixOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
264                        }
265                        SyntaxKind::IS_NOT_JSON_SCALAR => {
266                            PostfixOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
267                        }
268                        SyntaxKind::IS_NOT_JSON_VALUE => {
269                            PostfixOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
270                        }
271                        SyntaxKind::IS_NOT_NORMALIZED => {
272                            PostfixOp::IsNotNormalized(ast::IsNotNormalized { syntax: node })
273                        }
274                        _ => continue,
275                    };
276                    return Some(op);
277                }
278            }
279        }
280
281        None
282    }
283}
284
285impl ast::FieldExpr {
286    // We have NameRef as a variant of Expr which complicates things (and it
287    // might not be worth it).
288    // Rust analyzer doesn't do this so it doesn't have to special case this.
289    #[inline]
290    pub fn base(&self) -> Option<ast::Expr> {
291        support::children(self.syntax()).next()
292    }
293    #[inline]
294    pub fn field(&self) -> Option<ast::NameRef> {
295        support::children(self.syntax()).last()
296    }
297}
298
299impl ast::IndexExpr {
300    #[inline]
301    pub fn base(&self) -> Option<ast::Expr> {
302        support::children(&self.syntax).next()
303    }
304    #[inline]
305    pub fn index(&self) -> Option<ast::Expr> {
306        support::children(&self.syntax).nth(1)
307    }
308}
309
310impl ast::SliceExpr {
311    #[inline]
312    pub fn base(&self) -> Option<ast::Expr> {
313        support::children(&self.syntax).next()
314    }
315
316    #[inline]
317    pub fn start(&self) -> Option<ast::Expr> {
318        // With `select x[1:]`, we have two exprs, `x` and `1`.
319        // We skip over the first one, and then we want the second one, but we
320        // want to make sure we don't choose the end expr if instead we had:
321        // `select x[:1]`
322        let colon = self.colon_token()?;
323        support::children(&self.syntax)
324            .skip(1)
325            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
326    }
327
328    #[inline]
329    pub fn end(&self) -> Option<ast::Expr> {
330        // We want to make sure we get the last expr after the `:` which is the
331        // end of the slice, i.e., `2` in: `select x[:2]`
332        let colon = self.colon_token()?;
333        support::children(&self.syntax)
334            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
335    }
336}
337
338impl ast::RenameColumn {
339    #[inline]
340    pub fn from(&self) -> Option<ast::NameRef> {
341        support::children(&self.syntax).nth(0)
342    }
343    #[inline]
344    pub fn to(&self) -> Option<ast::NameRef> {
345        support::children(&self.syntax).nth(1)
346    }
347}
348
349impl ast::ForeignKeyConstraint {
350    #[inline]
351    pub fn from_columns(&self) -> Option<ast::ColumnList> {
352        support::children(&self.syntax).nth(0)
353    }
354    #[inline]
355    pub fn to_columns(&self) -> Option<ast::ColumnList> {
356        support::children(&self.syntax).nth(1)
357    }
358}
359
360impl ast::BetweenExpr {
361    #[inline]
362    pub fn target(&self) -> Option<ast::Expr> {
363        support::children(&self.syntax).nth(0)
364    }
365    #[inline]
366    pub fn start(&self) -> Option<ast::Expr> {
367        support::children(&self.syntax).nth(1)
368    }
369    #[inline]
370    pub fn end(&self) -> Option<ast::Expr> {
371        support::children(&self.syntax).nth(2)
372    }
373}
374
375impl ast::WhenClause {
376    #[inline]
377    pub fn condition(&self) -> Option<ast::Expr> {
378        support::children(&self.syntax).next()
379    }
380    #[inline]
381    pub fn then(&self) -> Option<ast::Expr> {
382        support::children(&self.syntax).nth(1)
383    }
384}
385
386impl ast::CompoundSelect {
387    #[inline]
388    pub fn lhs(&self) -> Option<ast::SelectVariant> {
389        support::children(&self.syntax).next()
390    }
391    #[inline]
392    pub fn rhs(&self) -> Option<ast::SelectVariant> {
393        support::children(&self.syntax).nth(1)
394    }
395}
396
397impl ast::NameRef {
398    #[inline]
399    pub fn text(&self) -> TokenText<'_> {
400        text_of_first_token(self.syntax())
401    }
402}
403
404impl ast::Name {
405    #[inline]
406    pub fn text(&self) -> TokenText<'_> {
407        text_of_first_token(self.syntax())
408    }
409}
410
411impl ast::CharType {
412    #[inline]
413    pub fn text(&self) -> TokenText<'_> {
414        text_of_first_token(self.syntax())
415    }
416}
417
418fn is_falsey_option(text: &str) -> bool {
419    text == "0" || text.eq_ignore_ascii_case("false") || text.eq_ignore_ascii_case("off")
420}
421
422impl ast::Vacuum {
423    pub fn is_full(&self) -> bool {
424        self.full_token().is_some()
425            // TODO: we need a better way of handling option lists
426            || self.vacuum_option_list().is_some_and(|opt_list| {
427                opt_list.vacuum_options().any(|opt| {
428                    let mut tokens = opt
429                        .syntax()
430                        .descendants_with_tokens()
431                        .filter_map(|child| child.into_token())
432                        .filter(|token| !token.kind().is_trivia());
433
434                    tokens
435                        .next()
436                        .is_some_and(|token| token.text().eq_ignore_ascii_case("full"))
437                        && tokens
438                            .next()
439                            .is_none_or(|token| !is_falsey_option(token.text()))
440                })
441            })
442    }
443}
444
445impl ast::OpSig {
446    #[inline]
447    pub fn lhs(&self) -> Option<ast::Type> {
448        support::children(self.syntax()).next()
449    }
450
451    #[inline]
452    pub fn rhs(&self) -> Option<ast::Type> {
453        support::children(self.syntax()).nth(1)
454    }
455}
456
457impl ast::CastSig {
458    #[inline]
459    pub fn lhs(&self) -> Option<ast::Type> {
460        support::children(self.syntax()).next()
461    }
462
463    #[inline]
464    pub fn rhs(&self) -> Option<ast::Type> {
465        support::children(self.syntax()).nth(1)
466    }
467}
468
469pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
470    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
471        green_ref
472            .children()
473            .next()
474            .and_then(NodeOrToken::into_token)
475            .unwrap()
476    }
477
478    match node.green() {
479        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
480        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
481    }
482}
483
484impl ast::WithQuery {
485    #[inline]
486    pub fn with_clause(&self) -> Option<ast::WithClause> {
487        support::child(self.syntax())
488    }
489}
490
491impl ast::SelectVariant {
492    #[inline]
493    pub fn target_list(&self) -> Option<ast::TargetList> {
494        match self {
495            ast::SelectVariant::Select(select) => {
496                return select.select_clause()?.target_list();
497            }
498            ast::SelectVariant::SelectInto(select_into) => {
499                return select_into.select_clause()?.target_list();
500            }
501            ast::SelectVariant::ParenSelect(paren_select) => {
502                return paren_select.select()?.target_list();
503            }
504            _ => return None,
505        }
506    }
507}
508
509impl ast::HasParamList for ast::FunctionSig {}
510impl ast::HasParamList for ast::Aggregate {}
511
512impl ast::NameLike for ast::Name {}
513impl ast::NameLike for ast::NameRef {}
514
515impl ast::HasWithClause for ast::Select {}
516impl ast::HasWithClause for ast::SelectInto {}
517impl ast::HasWithClause for ast::Insert {}
518impl ast::HasWithClause for ast::Update {}
519impl ast::HasWithClause for ast::Delete {}
520
521impl ast::HasCreateTable for ast::CreateTable {}
522impl ast::HasCreateTable for ast::CreateForeignTable {}
523impl ast::HasCreateTable for ast::CreateTableLike {}
524
525#[test]
526fn index_expr() {
527    let source_code = "
528        select foo[bar];
529    ";
530    let parse = SourceFile::parse(source_code);
531    assert!(parse.errors().is_empty());
532    let file: SourceFile = parse.tree();
533    let stmt = file.stmts().next().unwrap();
534    let ast::Stmt::Select(select) = stmt else {
535        unreachable!()
536    };
537    let select_clause = select.select_clause().unwrap();
538    let target = select_clause
539        .target_list()
540        .unwrap()
541        .targets()
542        .next()
543        .unwrap();
544    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
545        unreachable!()
546    };
547    let base = index_expr.base().unwrap();
548    let index = index_expr.index().unwrap();
549    assert_eq!(base.syntax().text(), "foo");
550    assert_eq!(index.syntax().text(), "bar");
551}
552
553#[test]
554fn slice_expr() {
555    use insta::assert_snapshot;
556    let source_code = "
557        select x[1:2], x[2:], x[:3], x[:];
558    ";
559    let parse = SourceFile::parse(source_code);
560    assert!(parse.errors().is_empty());
561    let file: SourceFile = parse.tree();
562    let stmt = file.stmts().next().unwrap();
563    let ast::Stmt::Select(select) = stmt else {
564        unreachable!()
565    };
566    let select_clause = select.select_clause().unwrap();
567    let mut targets = select_clause.target_list().unwrap().targets();
568
569    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
570        unreachable!()
571    };
572    assert_snapshot!(slice.syntax(), @"x[1:2]");
573    assert_eq!(slice.base().unwrap().syntax().text(), "x");
574    assert_eq!(slice.start().unwrap().syntax().text(), "1");
575    assert_eq!(slice.end().unwrap().syntax().text(), "2");
576
577    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
578        unreachable!()
579    };
580    assert_snapshot!(slice.syntax(), @"x[2:]");
581    assert_eq!(slice.base().unwrap().syntax().text(), "x");
582    assert_eq!(slice.start().unwrap().syntax().text(), "2");
583    assert!(slice.end().is_none());
584
585    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
586        unreachable!()
587    };
588    assert_snapshot!(slice.syntax(), @"x[:3]");
589    assert_eq!(slice.base().unwrap().syntax().text(), "x");
590    assert!(slice.start().is_none());
591    assert_eq!(slice.end().unwrap().syntax().text(), "3");
592
593    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
594        unreachable!()
595    };
596    assert_snapshot!(slice.syntax(), @"x[:]");
597    assert_eq!(slice.base().unwrap().syntax().text(), "x");
598    assert!(slice.start().is_none());
599    assert!(slice.end().is_none());
600}
601
602#[test]
603fn field_expr() {
604    let source_code = "
605        select foo.bar;
606    ";
607    let parse = SourceFile::parse(source_code);
608    assert!(parse.errors().is_empty());
609    let file: SourceFile = parse.tree();
610    let stmt = file.stmts().next().unwrap();
611    let ast::Stmt::Select(select) = stmt else {
612        unreachable!()
613    };
614    let select_clause = select.select_clause().unwrap();
615    let target = select_clause
616        .target_list()
617        .unwrap()
618        .targets()
619        .next()
620        .unwrap();
621    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
622        unreachable!()
623    };
624    let base = field_expr.base().unwrap();
625    let field = field_expr.field().unwrap();
626    assert_eq!(base.syntax().text(), "foo");
627    assert_eq!(field.syntax().text(), "bar");
628}
629
630#[test]
631fn between_expr() {
632    let source_code = "
633        select 2 between 1 and 3;
634    ";
635    let parse = SourceFile::parse(source_code);
636    assert!(parse.errors().is_empty());
637    let file: SourceFile = parse.tree();
638    let stmt = file.stmts().next().unwrap();
639    let ast::Stmt::Select(select) = stmt else {
640        unreachable!()
641    };
642    let select_clause = select.select_clause().unwrap();
643    let target = select_clause
644        .target_list()
645        .unwrap()
646        .targets()
647        .next()
648        .unwrap();
649    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
650        unreachable!()
651    };
652    let target = between_expr.target().unwrap();
653    let start = between_expr.start().unwrap();
654    let end = between_expr.end().unwrap();
655    assert_eq!(target.syntax().text(), "2");
656    assert_eq!(start.syntax().text(), "1");
657    assert_eq!(end.syntax().text(), "3");
658}
659
660#[test]
661fn cast_expr() {
662    use insta::assert_snapshot;
663
664    let cast = extract_expr("select cast('123' as int)");
665    assert!(cast.expr().is_some());
666    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
667    assert!(cast.ty().is_some());
668    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
669
670    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
671    assert!(cast.expr().is_some());
672    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
673    assert!(cast.ty().is_some());
674    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
675
676    let cast = extract_expr("select int '123'");
677    assert!(cast.expr().is_some());
678    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
679    assert!(cast.ty().is_some());
680    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
681
682    let cast = extract_expr("select pg_catalog.int4 '123'");
683    assert!(cast.expr().is_some());
684    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
685    assert!(cast.ty().is_some());
686    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
687
688    let cast = extract_expr("select '123'::int");
689    assert!(cast.expr().is_some());
690    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
691    assert!(cast.ty().is_some());
692    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
693
694    let cast = extract_expr("select '123'::int4");
695    assert!(cast.expr().is_some());
696    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
697    assert!(cast.ty().is_some());
698    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
699
700    let cast = extract_expr("select '123'::pg_catalog.int4");
701    assert!(cast.expr().is_some());
702    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
703    assert!(cast.ty().is_some());
704    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
705
706    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
707    assert!(cast.expr().is_some());
708    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
709    assert!(cast.ty().is_some());
710    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
711
712    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
713    assert!(cast.expr().is_some());
714    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
715    assert!(cast.ty().is_some());
716    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
717
718    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
719    assert!(cast.expr().is_some());
720    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
721    assert!(cast.ty().is_some());
722    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
723
724    let cast = extract_expr("select interval '1' month");
725    assert!(cast.expr().is_some());
726    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
727    assert!(cast.ty().is_some());
728    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
729
730    fn extract_expr(sql: &str) -> ast::CastExpr {
731        let parse = SourceFile::parse(sql);
732        assert!(parse.errors().is_empty());
733        let file: SourceFile = parse.tree();
734        let node = file
735            .stmts()
736            .map(|x| match x {
737                ast::Stmt::Select(select) => select
738                    .select_clause()
739                    .unwrap()
740                    .target_list()
741                    .unwrap()
742                    .targets()
743                    .next()
744                    .unwrap()
745                    .expr()
746                    .unwrap()
747                    .clone(),
748                _ => unreachable!(),
749            })
750            .next()
751            .unwrap();
752        match node {
753            ast::Expr::CastExpr(cast) => cast,
754            _ => unreachable!(),
755        }
756    }
757}
758
759#[test]
760fn op_sig() {
761    let source_code = "
762      alter operator p.+ (int4, int8) 
763        owner to u;
764    ";
765    let parse = SourceFile::parse(source_code);
766    assert!(parse.errors().is_empty());
767    let file: SourceFile = parse.tree();
768    let stmt = file.stmts().next().unwrap();
769    let ast::Stmt::AlterOperator(alter_op) = stmt else {
770        unreachable!()
771    };
772    let op_sig = alter_op.op_sig().unwrap();
773    let lhs = op_sig.lhs().unwrap();
774    let rhs = op_sig.rhs().unwrap();
775    assert_snapshot!(lhs.syntax().text(), @"int4");
776    assert_snapshot!(rhs.syntax().text(), @"int8");
777}
778
779#[test]
780fn cast_sig() {
781    let source_code = "
782      drop cast (text as int);
783    ";
784    let parse = SourceFile::parse(source_code);
785    assert!(parse.errors().is_empty());
786    let file: SourceFile = parse.tree();
787    let stmt = file.stmts().next().unwrap();
788    let ast::Stmt::DropCast(alter_op) = stmt else {
789        unreachable!()
790    };
791    let cast_sig = alter_op.cast_sig().unwrap();
792    let lhs = cast_sig.lhs().unwrap();
793    let rhs = cast_sig.rhs().unwrap();
794    assert_snapshot!(lhs.syntax().text(), @"text");
795    assert_snapshot!(rhs.syntax().text(), @"int");
796}
797
798#[cfg(test)]
799fn extract_vacuum(sql: &str) -> ast::Vacuum {
800    let parse = SourceFile::parse(sql);
801    assert!(parse.errors().is_empty());
802    let file: SourceFile = parse.tree();
803    let stmt = file.stmts().next().unwrap();
804    let ast::Stmt::Vacuum(vacuum) = stmt else {
805        unreachable!()
806    };
807    vacuum
808}
809
810#[test]
811fn vacuum_full_is_full() {
812    assert!(extract_vacuum("VACUUM FULL foo;").is_full());
813}
814
815#[test]
816fn vacuum_option_list_full_is_full() {
817    assert!(extract_vacuum("VACUUM (FULL) foo;").is_full());
818}
819
820#[test]
821fn vacuum_full_true_is_full() {
822    assert!(extract_vacuum("VACUUM (FULL TRUE) foo;").is_full());
823}
824
825#[test]
826fn vacuum_full_on_is_full() {
827    assert!(extract_vacuum("VACUUM (FULL ON) foo;").is_full());
828}
829
830#[test]
831fn vacuum_full_1_is_full() {
832    assert!(extract_vacuum("VACUUM (FULL 1) foo;").is_full());
833}
834
835#[test]
836fn vacuum_no_full_is_not_full() {
837    assert!(!extract_vacuum("VACUUM foo;").is_full());
838}
839
840#[test]
841fn vacuum_other_option_is_not_full() {
842    assert!(!extract_vacuum("VACUUM (FREEZE) foo;").is_full());
843}
844
845#[test]
846fn vacuum_full_false_is_not_full() {
847    assert!(!extract_vacuum("VACUUM (FULL FALSE) foo;").is_full());
848}
849
850#[test]
851fn vacuum_full_off_is_not_full() {
852    assert!(!extract_vacuum("VACUUM (FULL OFF) foo;").is_full());
853}
854
855#[test]
856fn vacuum_full_0_is_not_full() {
857    assert!(!extract_vacuum("VACUUM (FULL 0) foo;").is_full());
858}