squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42    #[inline]
43    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
44        support::child(self.syntax())
45    }
46}
47
48impl ast::BinExpr {
49    #[inline]
50    pub fn lhs(&self) -> Option<ast::Expr> {
51        support::children(self.syntax()).next()
52    }
53
54    #[inline]
55    pub fn rhs(&self) -> Option<ast::Expr> {
56        support::children(self.syntax()).nth(1)
57    }
58}
59
60impl ast::FieldExpr {
61    // We have NameRef as a variant of Expr which complicates things (and it
62    // might not be worth it).
63    // Rust analyzer doesn't do this so it doesn't have to special case this.
64    #[inline]
65    pub fn base(&self) -> Option<ast::Expr> {
66        support::children(self.syntax()).next()
67    }
68    #[inline]
69    pub fn field(&self) -> Option<ast::NameRef> {
70        support::children(self.syntax()).last()
71    }
72}
73
74impl ast::IndexExpr {
75    #[inline]
76    pub fn base(&self) -> Option<ast::Expr> {
77        support::children(&self.syntax).next()
78    }
79    #[inline]
80    pub fn index(&self) -> Option<ast::Expr> {
81        support::children(&self.syntax).nth(1)
82    }
83}
84
85impl ast::SliceExpr {
86    #[inline]
87    pub fn base(&self) -> Option<ast::Expr> {
88        support::children(&self.syntax).next()
89    }
90
91    #[inline]
92    pub fn start(&self) -> Option<ast::Expr> {
93        // With `select x[1:]`, we have two exprs, `x` and `1`.
94        // We skip over the first one, and then we want the second one, but we
95        // want to make sure we don't choose the end expr if instead we had:
96        // `select x[:1]`
97        let colon = self.colon_token()?;
98        support::children(&self.syntax)
99            .skip(1)
100            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101    }
102
103    #[inline]
104    pub fn end(&self) -> Option<ast::Expr> {
105        // We want to make sure we get the last expr after the `:` which is the
106        // end of the slice, i.e., `2` in: `select x[:2]`
107        let colon = self.colon_token()?;
108        support::children(&self.syntax)
109            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110    }
111}
112
113impl ast::RenameColumn {
114    #[inline]
115    pub fn from(&self) -> Option<ast::NameRef> {
116        support::children(&self.syntax).nth(0)
117    }
118    #[inline]
119    pub fn to(&self) -> Option<ast::NameRef> {
120        support::children(&self.syntax).nth(1)
121    }
122}
123
124impl ast::ForeignKeyConstraint {
125    #[inline]
126    pub fn from_columns(&self) -> Option<ast::ColumnList> {
127        support::children(&self.syntax).nth(0)
128    }
129    #[inline]
130    pub fn to_columns(&self) -> Option<ast::ColumnList> {
131        support::children(&self.syntax).nth(1)
132    }
133}
134
135impl ast::BetweenExpr {
136    #[inline]
137    pub fn target(&self) -> Option<ast::Expr> {
138        support::children(&self.syntax).nth(0)
139    }
140    #[inline]
141    pub fn start(&self) -> Option<ast::Expr> {
142        support::children(&self.syntax).nth(1)
143    }
144    #[inline]
145    pub fn end(&self) -> Option<ast::Expr> {
146        support::children(&self.syntax).nth(2)
147    }
148}
149
150impl ast::WhenClause {
151    #[inline]
152    pub fn condition(&self) -> Option<ast::Expr> {
153        support::children(&self.syntax).next()
154    }
155    #[inline]
156    pub fn then(&self) -> Option<ast::Expr> {
157        support::children(&self.syntax).nth(1)
158    }
159}
160
161impl ast::CompoundSelect {
162    #[inline]
163    pub fn lhs(&self) -> Option<ast::SelectVariant> {
164        support::children(&self.syntax).next()
165    }
166    #[inline]
167    pub fn rhs(&self) -> Option<ast::SelectVariant> {
168        support::children(&self.syntax).nth(1)
169    }
170}
171
172impl ast::NameRef {
173    #[inline]
174    pub fn text(&self) -> TokenText<'_> {
175        text_of_first_token(self.syntax())
176    }
177}
178
179impl ast::Name {
180    #[inline]
181    pub fn text(&self) -> TokenText<'_> {
182        text_of_first_token(self.syntax())
183    }
184}
185
186impl ast::CharType {
187    #[inline]
188    pub fn text(&self) -> TokenText<'_> {
189        text_of_first_token(self.syntax())
190    }
191}
192
193impl ast::OpSig {
194    #[inline]
195    pub fn lhs(&self) -> Option<ast::Type> {
196        support::children(self.syntax()).next()
197    }
198
199    #[inline]
200    pub fn rhs(&self) -> Option<ast::Type> {
201        support::children(self.syntax()).nth(1)
202    }
203}
204
205impl ast::CastSig {
206    #[inline]
207    pub fn lhs(&self) -> Option<ast::Type> {
208        support::children(self.syntax()).next()
209    }
210
211    #[inline]
212    pub fn rhs(&self) -> Option<ast::Type> {
213        support::children(self.syntax()).nth(1)
214    }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219        green_ref
220            .children()
221            .next()
222            .and_then(NodeOrToken::into_token)
223            .unwrap()
224    }
225
226    match node.green() {
227        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229    }
230}
231
232impl ast::HasParamList for ast::FunctionSig {}
233impl ast::HasParamList for ast::Aggregate {}
234
235impl ast::NameLike for ast::Name {}
236impl ast::NameLike for ast::NameRef {}
237
238impl ast::HasWithClause for ast::Select {}
239impl ast::HasWithClause for ast::SelectInto {}
240impl ast::HasWithClause for ast::Insert {}
241impl ast::HasWithClause for ast::Update {}
242impl ast::HasWithClause for ast::Delete {}
243
244impl ast::HasCreateTable for ast::CreateTable {}
245impl ast::HasCreateTable for ast::CreateForeignTable {}
246impl ast::HasCreateTable for ast::CreateTableLike {}
247
248#[test]
249fn index_expr() {
250    let source_code = "
251        select foo[bar];
252    ";
253    let parse = SourceFile::parse(source_code);
254    assert!(parse.errors().is_empty());
255    let file: SourceFile = parse.tree();
256    let stmt = file.stmts().next().unwrap();
257    let ast::Stmt::Select(select) = stmt else {
258        unreachable!()
259    };
260    let select_clause = select.select_clause().unwrap();
261    let target = select_clause
262        .target_list()
263        .unwrap()
264        .targets()
265        .next()
266        .unwrap();
267    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
268        unreachable!()
269    };
270    let base = index_expr.base().unwrap();
271    let index = index_expr.index().unwrap();
272    assert_eq!(base.syntax().text(), "foo");
273    assert_eq!(index.syntax().text(), "bar");
274}
275
276#[test]
277fn slice_expr() {
278    use insta::assert_snapshot;
279    let source_code = "
280        select x[1:2], x[2:], x[:3], x[:];
281    ";
282    let parse = SourceFile::parse(source_code);
283    assert!(parse.errors().is_empty());
284    let file: SourceFile = parse.tree();
285    let stmt = file.stmts().next().unwrap();
286    let ast::Stmt::Select(select) = stmt else {
287        unreachable!()
288    };
289    let select_clause = select.select_clause().unwrap();
290    let mut targets = select_clause.target_list().unwrap().targets();
291
292    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
293        unreachable!()
294    };
295    assert_snapshot!(slice.syntax(), @"x[1:2]");
296    assert_eq!(slice.base().unwrap().syntax().text(), "x");
297    assert_eq!(slice.start().unwrap().syntax().text(), "1");
298    assert_eq!(slice.end().unwrap().syntax().text(), "2");
299
300    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
301        unreachable!()
302    };
303    assert_snapshot!(slice.syntax(), @"x[2:]");
304    assert_eq!(slice.base().unwrap().syntax().text(), "x");
305    assert_eq!(slice.start().unwrap().syntax().text(), "2");
306    assert!(slice.end().is_none());
307
308    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
309        unreachable!()
310    };
311    assert_snapshot!(slice.syntax(), @"x[:3]");
312    assert_eq!(slice.base().unwrap().syntax().text(), "x");
313    assert!(slice.start().is_none());
314    assert_eq!(slice.end().unwrap().syntax().text(), "3");
315
316    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
317        unreachable!()
318    };
319    assert_snapshot!(slice.syntax(), @"x[:]");
320    assert_eq!(slice.base().unwrap().syntax().text(), "x");
321    assert!(slice.start().is_none());
322    assert!(slice.end().is_none());
323}
324
325#[test]
326fn field_expr() {
327    let source_code = "
328        select foo.bar;
329    ";
330    let parse = SourceFile::parse(source_code);
331    assert!(parse.errors().is_empty());
332    let file: SourceFile = parse.tree();
333    let stmt = file.stmts().next().unwrap();
334    let ast::Stmt::Select(select) = stmt else {
335        unreachable!()
336    };
337    let select_clause = select.select_clause().unwrap();
338    let target = select_clause
339        .target_list()
340        .unwrap()
341        .targets()
342        .next()
343        .unwrap();
344    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
345        unreachable!()
346    };
347    let base = field_expr.base().unwrap();
348    let field = field_expr.field().unwrap();
349    assert_eq!(base.syntax().text(), "foo");
350    assert_eq!(field.syntax().text(), "bar");
351}
352
353#[test]
354fn between_expr() {
355    let source_code = "
356        select 2 between 1 and 3;
357    ";
358    let parse = SourceFile::parse(source_code);
359    assert!(parse.errors().is_empty());
360    let file: SourceFile = parse.tree();
361    let stmt = file.stmts().next().unwrap();
362    let ast::Stmt::Select(select) = stmt else {
363        unreachable!()
364    };
365    let select_clause = select.select_clause().unwrap();
366    let target = select_clause
367        .target_list()
368        .unwrap()
369        .targets()
370        .next()
371        .unwrap();
372    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
373        unreachable!()
374    };
375    let target = between_expr.target().unwrap();
376    let start = between_expr.start().unwrap();
377    let end = between_expr.end().unwrap();
378    assert_eq!(target.syntax().text(), "2");
379    assert_eq!(start.syntax().text(), "1");
380    assert_eq!(end.syntax().text(), "3");
381}
382
383#[test]
384fn cast_expr() {
385    use insta::assert_snapshot;
386
387    let cast = extract_expr("select cast('123' as int)");
388    assert!(cast.expr().is_some());
389    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
390    assert!(cast.ty().is_some());
391    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
392
393    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
394    assert!(cast.expr().is_some());
395    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
396    assert!(cast.ty().is_some());
397    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
398
399    let cast = extract_expr("select int '123'");
400    assert!(cast.expr().is_some());
401    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
402    assert!(cast.ty().is_some());
403    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
404
405    let cast = extract_expr("select pg_catalog.int4 '123'");
406    assert!(cast.expr().is_some());
407    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
408    assert!(cast.ty().is_some());
409    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
410
411    let cast = extract_expr("select '123'::int");
412    assert!(cast.expr().is_some());
413    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
414    assert!(cast.ty().is_some());
415    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
416
417    let cast = extract_expr("select '123'::int4");
418    assert!(cast.expr().is_some());
419    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
420    assert!(cast.ty().is_some());
421    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
422
423    let cast = extract_expr("select '123'::pg_catalog.int4");
424    assert!(cast.expr().is_some());
425    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
426    assert!(cast.ty().is_some());
427    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
428
429    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
430    assert!(cast.expr().is_some());
431    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
432    assert!(cast.ty().is_some());
433    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
434
435    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
436    assert!(cast.expr().is_some());
437    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
438    assert!(cast.ty().is_some());
439    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
440
441    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
442    assert!(cast.expr().is_some());
443    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
444    assert!(cast.ty().is_some());
445    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
446
447    let cast = extract_expr("select interval '1' month");
448    assert!(cast.expr().is_some());
449    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
450    assert!(cast.ty().is_some());
451    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
452
453    fn extract_expr(sql: &str) -> ast::CastExpr {
454        let parse = SourceFile::parse(sql);
455        assert!(parse.errors().is_empty());
456        let file: SourceFile = parse.tree();
457        let node = file
458            .stmts()
459            .map(|x| match x {
460                ast::Stmt::Select(select) => select
461                    .select_clause()
462                    .unwrap()
463                    .target_list()
464                    .unwrap()
465                    .targets()
466                    .next()
467                    .unwrap()
468                    .expr()
469                    .unwrap()
470                    .clone(),
471                _ => unreachable!(),
472            })
473            .next()
474            .unwrap();
475        match node {
476            ast::Expr::CastExpr(cast) => cast,
477            _ => unreachable!(),
478        }
479    }
480}
481
482#[test]
483fn op_sig() {
484    let source_code = "
485      alter operator p.+ (int4, int8) 
486        owner to u;
487    ";
488    let parse = SourceFile::parse(source_code);
489    assert!(parse.errors().is_empty());
490    let file: SourceFile = parse.tree();
491    let stmt = file.stmts().next().unwrap();
492    let ast::Stmt::AlterOperator(alter_op) = stmt else {
493        unreachable!()
494    };
495    let op_sig = alter_op.op_sig().unwrap();
496    let lhs = op_sig.lhs().unwrap();
497    let rhs = op_sig.rhs().unwrap();
498    assert_snapshot!(lhs.syntax().text(), @"int4");
499    assert_snapshot!(rhs.syntax().text(), @"int8");
500}
501
502#[test]
503fn cast_sig() {
504    let source_code = "
505      drop cast (text as int);
506    ";
507    let parse = SourceFile::parse(source_code);
508    assert!(parse.errors().is_empty());
509    let file: SourceFile = parse.tree();
510    let stmt = file.stmts().next().unwrap();
511    let ast::Stmt::DropCast(alter_op) = stmt else {
512        unreachable!()
513    };
514    let cast_sig = alter_op.cast_sig().unwrap();
515    let lhs = cast_sig.lhs().unwrap();
516    let rhs = cast_sig.rhs().unwrap();
517    assert_snapshot!(lhs.syntax().text(), @"text");
518    assert_snapshot!(rhs.syntax().text(), @"int");
519}