Skip to main content

squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use rowan::Direction;
36
37use crate::ast;
38use crate::ast::AstNode;
39use crate::{SyntaxKind, SyntaxNode, SyntaxToken, TokenText};
40
41use super::support;
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum LitKind {
45    BitString(SyntaxToken),
46    ByteString(SyntaxToken),
47    DollarQuotedString(SyntaxToken),
48    EscString(SyntaxToken),
49    FloatNumber(SyntaxToken),
50    IntNumber(SyntaxToken),
51    Null(SyntaxToken),
52    PositionalParam(SyntaxToken),
53    String(SyntaxToken),
54    UnicodeEscString(SyntaxToken),
55}
56
57impl ast::Literal {
58    pub fn kind(&self) -> Option<LitKind> {
59        let token = self.syntax().first_child_or_token()?.into_token()?;
60        let kind = match token.kind() {
61            SyntaxKind::STRING => LitKind::String(token),
62            SyntaxKind::NULL_KW => LitKind::Null(token),
63            SyntaxKind::FLOAT_NUMBER => LitKind::FloatNumber(token),
64            SyntaxKind::INT_NUMBER => LitKind::IntNumber(token),
65            SyntaxKind::BYTE_STRING => LitKind::ByteString(token),
66            SyntaxKind::BIT_STRING => LitKind::BitString(token),
67            SyntaxKind::DOLLAR_QUOTED_STRING => LitKind::DollarQuotedString(token),
68            SyntaxKind::UNICODE_ESC_STRING => LitKind::UnicodeEscString(token),
69            SyntaxKind::ESC_STRING => LitKind::EscString(token),
70            SyntaxKind::POSITIONAL_PARAM => LitKind::PositionalParam(token),
71            _ => return None,
72        };
73        Some(kind)
74    }
75}
76
77impl ast::Constraint {
78    #[inline]
79    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
80        support::child(self.syntax())
81    }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum BinOp {
86    And(SyntaxToken),
87    AtTimeZone(ast::AtTimeZone),
88    Caret(SyntaxToken),
89    Collate(SyntaxToken),
90    ColonColon(ast::ColonColon),
91    ColonEq(ast::ColonEq),
92    CustomOp(ast::CustomOp),
93    Eq(SyntaxToken),
94    FatArrow(ast::FatArrow),
95    Gteq(ast::Gteq),
96    Ilike(SyntaxToken),
97    In(SyntaxToken),
98    Is(SyntaxToken),
99    IsDistinctFrom(ast::IsDistinctFrom),
100    IsJson(ast::IsJson),
101    IsJsonArray(ast::IsJsonArray),
102    IsJsonObject(ast::IsJsonObject),
103    IsJsonScalar(ast::IsJsonScalar),
104    IsJsonValue(ast::IsJsonValue),
105    IsNot(ast::IsNot),
106    IsNotDistinctFrom(ast::IsNotDistinctFrom),
107    IsNotJson(ast::IsNotJson),
108    IsNotJsonArray(ast::IsNotJsonArray),
109    IsNotJsonObject(ast::IsNotJsonObject),
110    IsNotJsonScalar(ast::IsNotJsonScalar),
111    IsNotJsonValue(ast::IsNotJsonValue),
112    LAngle(SyntaxToken),
113    Like(SyntaxToken),
114    Lteq(ast::Lteq),
115    Minus(SyntaxToken),
116    Neq(ast::Neq),
117    Neqb(ast::Neqb),
118    NotIlike(ast::NotIlike),
119    NotIn(ast::NotIn),
120    NotLike(ast::NotLike),
121    NotSimilarTo(ast::NotSimilarTo),
122    OperatorCall(ast::OperatorCall),
123    Or(SyntaxToken),
124    Overlaps(SyntaxToken),
125    Percent(SyntaxToken),
126    Plus(SyntaxToken),
127    RAngle(SyntaxToken),
128    SimilarTo(ast::SimilarTo),
129    Slash(SyntaxToken),
130    Star(SyntaxToken),
131}
132
133impl ast::BinExpr {
134    #[inline]
135    pub fn lhs(&self) -> Option<ast::Expr> {
136        support::children(self.syntax()).next()
137    }
138
139    #[inline]
140    pub fn rhs(&self) -> Option<ast::Expr> {
141        support::children(self.syntax()).nth(1)
142    }
143
144    pub fn op(&self) -> Option<BinOp> {
145        let lhs = self.lhs()?;
146        for child in lhs.syntax().siblings_with_tokens(Direction::Next).skip(1) {
147            match child {
148                NodeOrToken::Token(token) => {
149                    let op = match token.kind() {
150                        SyntaxKind::AND_KW => BinOp::And(token),
151                        SyntaxKind::CARET => BinOp::Caret(token),
152                        SyntaxKind::COLLATE_KW => BinOp::Collate(token),
153                        SyntaxKind::EQ => BinOp::Eq(token),
154                        SyntaxKind::ILIKE_KW => BinOp::Ilike(token),
155                        SyntaxKind::IN_KW => BinOp::In(token),
156                        SyntaxKind::IS_KW => BinOp::Is(token),
157                        SyntaxKind::L_ANGLE => BinOp::LAngle(token),
158                        SyntaxKind::LIKE_KW => BinOp::Like(token),
159                        SyntaxKind::MINUS => BinOp::Minus(token),
160                        SyntaxKind::OR_KW => BinOp::Or(token),
161                        SyntaxKind::OVERLAPS_KW => BinOp::Overlaps(token),
162                        SyntaxKind::PERCENT => BinOp::Percent(token),
163                        SyntaxKind::PLUS => BinOp::Plus(token),
164                        SyntaxKind::R_ANGLE => BinOp::RAngle(token),
165                        SyntaxKind::SLASH => BinOp::Slash(token),
166                        SyntaxKind::STAR => BinOp::Star(token),
167                        _ => continue,
168                    };
169                    return Some(op);
170                }
171                NodeOrToken::Node(node) => {
172                    let op = match node.kind() {
173                        SyntaxKind::AT_TIME_ZONE => {
174                            BinOp::AtTimeZone(ast::AtTimeZone { syntax: node })
175                        }
176                        SyntaxKind::COLON_COLON => {
177                            BinOp::ColonColon(ast::ColonColon { syntax: node })
178                        }
179                        SyntaxKind::COLON_EQ => BinOp::ColonEq(ast::ColonEq { syntax: node }),
180                        SyntaxKind::CUSTOM_OP => BinOp::CustomOp(ast::CustomOp { syntax: node }),
181                        SyntaxKind::FAT_ARROW => BinOp::FatArrow(ast::FatArrow { syntax: node }),
182                        SyntaxKind::GTEQ => BinOp::Gteq(ast::Gteq { syntax: node }),
183                        SyntaxKind::IS_DISTINCT_FROM => {
184                            BinOp::IsDistinctFrom(ast::IsDistinctFrom { syntax: node })
185                        }
186                        SyntaxKind::IS_JSON => BinOp::IsJson(ast::IsJson { syntax: node }),
187                        SyntaxKind::IS_JSON_ARRAY => {
188                            BinOp::IsJsonArray(ast::IsJsonArray { syntax: node })
189                        }
190                        SyntaxKind::IS_JSON_OBJECT => {
191                            BinOp::IsJsonObject(ast::IsJsonObject { syntax: node })
192                        }
193                        SyntaxKind::IS_JSON_SCALAR => {
194                            BinOp::IsJsonScalar(ast::IsJsonScalar { syntax: node })
195                        }
196                        SyntaxKind::IS_JSON_VALUE => {
197                            BinOp::IsJsonValue(ast::IsJsonValue { syntax: node })
198                        }
199                        SyntaxKind::IS_NOT => BinOp::IsNot(ast::IsNot { syntax: node }),
200                        SyntaxKind::IS_NOT_DISTINCT_FROM => {
201                            BinOp::IsNotDistinctFrom(ast::IsNotDistinctFrom { syntax: node })
202                        }
203                        SyntaxKind::IS_NOT_JSON => {
204                            BinOp::IsNotJson(ast::IsNotJson { syntax: node })
205                        }
206                        SyntaxKind::IS_NOT_JSON_ARRAY => {
207                            BinOp::IsNotJsonArray(ast::IsNotJsonArray { syntax: node })
208                        }
209                        SyntaxKind::IS_NOT_JSON_OBJECT => {
210                            BinOp::IsNotJsonObject(ast::IsNotJsonObject { syntax: node })
211                        }
212                        SyntaxKind::IS_NOT_JSON_SCALAR => {
213                            BinOp::IsNotJsonScalar(ast::IsNotJsonScalar { syntax: node })
214                        }
215                        SyntaxKind::IS_NOT_JSON_VALUE => {
216                            BinOp::IsNotJsonValue(ast::IsNotJsonValue { syntax: node })
217                        }
218                        SyntaxKind::LTEQ => BinOp::Lteq(ast::Lteq { syntax: node }),
219                        SyntaxKind::NEQ => BinOp::Neq(ast::Neq { syntax: node }),
220                        SyntaxKind::NEQB => BinOp::Neqb(ast::Neqb { syntax: node }),
221                        SyntaxKind::NOT_ILIKE => BinOp::NotIlike(ast::NotIlike { syntax: node }),
222                        SyntaxKind::NOT_IN => BinOp::NotIn(ast::NotIn { syntax: node }),
223                        SyntaxKind::NOT_LIKE => BinOp::NotLike(ast::NotLike { syntax: node }),
224                        SyntaxKind::NOT_SIMILAR_TO => {
225                            BinOp::NotSimilarTo(ast::NotSimilarTo { syntax: node })
226                        }
227                        SyntaxKind::OPERATOR_CALL => {
228                            BinOp::OperatorCall(ast::OperatorCall { syntax: node })
229                        }
230                        SyntaxKind::SIMILAR_TO => BinOp::SimilarTo(ast::SimilarTo { syntax: node }),
231                        _ => continue,
232                    };
233                    return Some(op);
234                }
235            }
236        }
237        None
238    }
239}
240
241impl ast::FieldExpr {
242    // We have NameRef as a variant of Expr which complicates things (and it
243    // might not be worth it).
244    // Rust analyzer doesn't do this so it doesn't have to special case this.
245    #[inline]
246    pub fn base(&self) -> Option<ast::Expr> {
247        support::children(self.syntax()).next()
248    }
249    #[inline]
250    pub fn field(&self) -> Option<ast::NameRef> {
251        support::children(self.syntax()).last()
252    }
253}
254
255impl ast::IndexExpr {
256    #[inline]
257    pub fn base(&self) -> Option<ast::Expr> {
258        support::children(&self.syntax).next()
259    }
260    #[inline]
261    pub fn index(&self) -> Option<ast::Expr> {
262        support::children(&self.syntax).nth(1)
263    }
264}
265
266impl ast::SliceExpr {
267    #[inline]
268    pub fn base(&self) -> Option<ast::Expr> {
269        support::children(&self.syntax).next()
270    }
271
272    #[inline]
273    pub fn start(&self) -> Option<ast::Expr> {
274        // With `select x[1:]`, we have two exprs, `x` and `1`.
275        // We skip over the first one, and then we want the second one, but we
276        // want to make sure we don't choose the end expr if instead we had:
277        // `select x[:1]`
278        let colon = self.colon_token()?;
279        support::children(&self.syntax)
280            .skip(1)
281            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
282    }
283
284    #[inline]
285    pub fn end(&self) -> Option<ast::Expr> {
286        // We want to make sure we get the last expr after the `:` which is the
287        // end of the slice, i.e., `2` in: `select x[:2]`
288        let colon = self.colon_token()?;
289        support::children(&self.syntax)
290            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
291    }
292}
293
294impl ast::RenameColumn {
295    #[inline]
296    pub fn from(&self) -> Option<ast::NameRef> {
297        support::children(&self.syntax).nth(0)
298    }
299    #[inline]
300    pub fn to(&self) -> Option<ast::NameRef> {
301        support::children(&self.syntax).nth(1)
302    }
303}
304
305impl ast::ForeignKeyConstraint {
306    #[inline]
307    pub fn from_columns(&self) -> Option<ast::ColumnList> {
308        support::children(&self.syntax).nth(0)
309    }
310    #[inline]
311    pub fn to_columns(&self) -> Option<ast::ColumnList> {
312        support::children(&self.syntax).nth(1)
313    }
314}
315
316impl ast::BetweenExpr {
317    #[inline]
318    pub fn target(&self) -> Option<ast::Expr> {
319        support::children(&self.syntax).nth(0)
320    }
321    #[inline]
322    pub fn start(&self) -> Option<ast::Expr> {
323        support::children(&self.syntax).nth(1)
324    }
325    #[inline]
326    pub fn end(&self) -> Option<ast::Expr> {
327        support::children(&self.syntax).nth(2)
328    }
329}
330
331impl ast::WhenClause {
332    #[inline]
333    pub fn condition(&self) -> Option<ast::Expr> {
334        support::children(&self.syntax).next()
335    }
336    #[inline]
337    pub fn then(&self) -> Option<ast::Expr> {
338        support::children(&self.syntax).nth(1)
339    }
340}
341
342impl ast::CompoundSelect {
343    #[inline]
344    pub fn lhs(&self) -> Option<ast::SelectVariant> {
345        support::children(&self.syntax).next()
346    }
347    #[inline]
348    pub fn rhs(&self) -> Option<ast::SelectVariant> {
349        support::children(&self.syntax).nth(1)
350    }
351}
352
353impl ast::NameRef {
354    #[inline]
355    pub fn text(&self) -> TokenText<'_> {
356        text_of_first_token(self.syntax())
357    }
358}
359
360impl ast::Name {
361    #[inline]
362    pub fn text(&self) -> TokenText<'_> {
363        text_of_first_token(self.syntax())
364    }
365}
366
367impl ast::CharType {
368    #[inline]
369    pub fn text(&self) -> TokenText<'_> {
370        text_of_first_token(self.syntax())
371    }
372}
373
374impl ast::OpSig {
375    #[inline]
376    pub fn lhs(&self) -> Option<ast::Type> {
377        support::children(self.syntax()).next()
378    }
379
380    #[inline]
381    pub fn rhs(&self) -> Option<ast::Type> {
382        support::children(self.syntax()).nth(1)
383    }
384}
385
386impl ast::CastSig {
387    #[inline]
388    pub fn lhs(&self) -> Option<ast::Type> {
389        support::children(self.syntax()).next()
390    }
391
392    #[inline]
393    pub fn rhs(&self) -> Option<ast::Type> {
394        support::children(self.syntax()).nth(1)
395    }
396}
397
398pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
399    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
400        green_ref
401            .children()
402            .next()
403            .and_then(NodeOrToken::into_token)
404            .unwrap()
405    }
406
407    match node.green() {
408        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
409        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
410    }
411}
412
413impl ast::WithQuery {
414    #[inline]
415    pub fn with_clause(&self) -> Option<ast::WithClause> {
416        support::child(self.syntax())
417    }
418}
419
420impl ast::HasParamList for ast::FunctionSig {}
421impl ast::HasParamList for ast::Aggregate {}
422
423impl ast::NameLike for ast::Name {}
424impl ast::NameLike for ast::NameRef {}
425
426impl ast::HasWithClause for ast::Select {}
427impl ast::HasWithClause for ast::SelectInto {}
428impl ast::HasWithClause for ast::Insert {}
429impl ast::HasWithClause for ast::Update {}
430impl ast::HasWithClause for ast::Delete {}
431
432impl ast::HasCreateTable for ast::CreateTable {}
433impl ast::HasCreateTable for ast::CreateForeignTable {}
434impl ast::HasCreateTable for ast::CreateTableLike {}
435
436#[test]
437fn index_expr() {
438    let source_code = "
439        select foo[bar];
440    ";
441    let parse = SourceFile::parse(source_code);
442    assert!(parse.errors().is_empty());
443    let file: SourceFile = parse.tree();
444    let stmt = file.stmts().next().unwrap();
445    let ast::Stmt::Select(select) = stmt else {
446        unreachable!()
447    };
448    let select_clause = select.select_clause().unwrap();
449    let target = select_clause
450        .target_list()
451        .unwrap()
452        .targets()
453        .next()
454        .unwrap();
455    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
456        unreachable!()
457    };
458    let base = index_expr.base().unwrap();
459    let index = index_expr.index().unwrap();
460    assert_eq!(base.syntax().text(), "foo");
461    assert_eq!(index.syntax().text(), "bar");
462}
463
464#[test]
465fn slice_expr() {
466    use insta::assert_snapshot;
467    let source_code = "
468        select x[1:2], x[2:], x[:3], x[:];
469    ";
470    let parse = SourceFile::parse(source_code);
471    assert!(parse.errors().is_empty());
472    let file: SourceFile = parse.tree();
473    let stmt = file.stmts().next().unwrap();
474    let ast::Stmt::Select(select) = stmt else {
475        unreachable!()
476    };
477    let select_clause = select.select_clause().unwrap();
478    let mut targets = select_clause.target_list().unwrap().targets();
479
480    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
481        unreachable!()
482    };
483    assert_snapshot!(slice.syntax(), @"x[1:2]");
484    assert_eq!(slice.base().unwrap().syntax().text(), "x");
485    assert_eq!(slice.start().unwrap().syntax().text(), "1");
486    assert_eq!(slice.end().unwrap().syntax().text(), "2");
487
488    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
489        unreachable!()
490    };
491    assert_snapshot!(slice.syntax(), @"x[2:]");
492    assert_eq!(slice.base().unwrap().syntax().text(), "x");
493    assert_eq!(slice.start().unwrap().syntax().text(), "2");
494    assert!(slice.end().is_none());
495
496    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
497        unreachable!()
498    };
499    assert_snapshot!(slice.syntax(), @"x[:3]");
500    assert_eq!(slice.base().unwrap().syntax().text(), "x");
501    assert!(slice.start().is_none());
502    assert_eq!(slice.end().unwrap().syntax().text(), "3");
503
504    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
505        unreachable!()
506    };
507    assert_snapshot!(slice.syntax(), @"x[:]");
508    assert_eq!(slice.base().unwrap().syntax().text(), "x");
509    assert!(slice.start().is_none());
510    assert!(slice.end().is_none());
511}
512
513#[test]
514fn field_expr() {
515    let source_code = "
516        select foo.bar;
517    ";
518    let parse = SourceFile::parse(source_code);
519    assert!(parse.errors().is_empty());
520    let file: SourceFile = parse.tree();
521    let stmt = file.stmts().next().unwrap();
522    let ast::Stmt::Select(select) = stmt else {
523        unreachable!()
524    };
525    let select_clause = select.select_clause().unwrap();
526    let target = select_clause
527        .target_list()
528        .unwrap()
529        .targets()
530        .next()
531        .unwrap();
532    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
533        unreachable!()
534    };
535    let base = field_expr.base().unwrap();
536    let field = field_expr.field().unwrap();
537    assert_eq!(base.syntax().text(), "foo");
538    assert_eq!(field.syntax().text(), "bar");
539}
540
541#[test]
542fn between_expr() {
543    let source_code = "
544        select 2 between 1 and 3;
545    ";
546    let parse = SourceFile::parse(source_code);
547    assert!(parse.errors().is_empty());
548    let file: SourceFile = parse.tree();
549    let stmt = file.stmts().next().unwrap();
550    let ast::Stmt::Select(select) = stmt else {
551        unreachable!()
552    };
553    let select_clause = select.select_clause().unwrap();
554    let target = select_clause
555        .target_list()
556        .unwrap()
557        .targets()
558        .next()
559        .unwrap();
560    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
561        unreachable!()
562    };
563    let target = between_expr.target().unwrap();
564    let start = between_expr.start().unwrap();
565    let end = between_expr.end().unwrap();
566    assert_eq!(target.syntax().text(), "2");
567    assert_eq!(start.syntax().text(), "1");
568    assert_eq!(end.syntax().text(), "3");
569}
570
571#[test]
572fn cast_expr() {
573    use insta::assert_snapshot;
574
575    let cast = extract_expr("select cast('123' as int)");
576    assert!(cast.expr().is_some());
577    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
578    assert!(cast.ty().is_some());
579    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
580
581    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
582    assert!(cast.expr().is_some());
583    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
584    assert!(cast.ty().is_some());
585    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
586
587    let cast = extract_expr("select int '123'");
588    assert!(cast.expr().is_some());
589    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
590    assert!(cast.ty().is_some());
591    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
592
593    let cast = extract_expr("select pg_catalog.int4 '123'");
594    assert!(cast.expr().is_some());
595    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
596    assert!(cast.ty().is_some());
597    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
598
599    let cast = extract_expr("select '123'::int");
600    assert!(cast.expr().is_some());
601    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
602    assert!(cast.ty().is_some());
603    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
604
605    let cast = extract_expr("select '123'::int4");
606    assert!(cast.expr().is_some());
607    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
608    assert!(cast.ty().is_some());
609    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
610
611    let cast = extract_expr("select '123'::pg_catalog.int4");
612    assert!(cast.expr().is_some());
613    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
614    assert!(cast.ty().is_some());
615    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
616
617    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
618    assert!(cast.expr().is_some());
619    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
620    assert!(cast.ty().is_some());
621    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
622
623    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
624    assert!(cast.expr().is_some());
625    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
626    assert!(cast.ty().is_some());
627    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
628
629    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
630    assert!(cast.expr().is_some());
631    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
632    assert!(cast.ty().is_some());
633    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
634
635    let cast = extract_expr("select interval '1' month");
636    assert!(cast.expr().is_some());
637    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
638    assert!(cast.ty().is_some());
639    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
640
641    fn extract_expr(sql: &str) -> ast::CastExpr {
642        let parse = SourceFile::parse(sql);
643        assert!(parse.errors().is_empty());
644        let file: SourceFile = parse.tree();
645        let node = file
646            .stmts()
647            .map(|x| match x {
648                ast::Stmt::Select(select) => select
649                    .select_clause()
650                    .unwrap()
651                    .target_list()
652                    .unwrap()
653                    .targets()
654                    .next()
655                    .unwrap()
656                    .expr()
657                    .unwrap()
658                    .clone(),
659                _ => unreachable!(),
660            })
661            .next()
662            .unwrap();
663        match node {
664            ast::Expr::CastExpr(cast) => cast,
665            _ => unreachable!(),
666        }
667    }
668}
669
670#[test]
671fn op_sig() {
672    let source_code = "
673      alter operator p.+ (int4, int8) 
674        owner to u;
675    ";
676    let parse = SourceFile::parse(source_code);
677    assert!(parse.errors().is_empty());
678    let file: SourceFile = parse.tree();
679    let stmt = file.stmts().next().unwrap();
680    let ast::Stmt::AlterOperator(alter_op) = stmt else {
681        unreachable!()
682    };
683    let op_sig = alter_op.op_sig().unwrap();
684    let lhs = op_sig.lhs().unwrap();
685    let rhs = op_sig.rhs().unwrap();
686    assert_snapshot!(lhs.syntax().text(), @"int4");
687    assert_snapshot!(rhs.syntax().text(), @"int8");
688}
689
690#[test]
691fn cast_sig() {
692    let source_code = "
693      drop cast (text as int);
694    ";
695    let parse = SourceFile::parse(source_code);
696    assert!(parse.errors().is_empty());
697    let file: SourceFile = parse.tree();
698    let stmt = file.stmts().next().unwrap();
699    let ast::Stmt::DropCast(alter_op) = stmt else {
700        unreachable!()
701    };
702    let cast_sig = alter_op.cast_sig().unwrap();
703    let lhs = cast_sig.lhs().unwrap();
704    let rhs = cast_sig.rhs().unwrap();
705    assert_snapshot!(lhs.syntax().text(), @"text");
706    assert_snapshot!(rhs.syntax().text(), @"int");
707}