squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42    #[inline]
43    pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
44        support::child(self.syntax())
45    }
46}
47
48impl ast::BinExpr {
49    #[inline]
50    pub fn lhs(&self) -> Option<ast::Expr> {
51        support::children(self.syntax()).next()
52    }
53
54    #[inline]
55    pub fn rhs(&self) -> Option<ast::Expr> {
56        support::children(self.syntax()).nth(1)
57    }
58}
59
60impl ast::FieldExpr {
61    // We have NameRef as a variant of Expr which complicates things (and it
62    // might not be worth it).
63    // Rust analyzer doesn't do this so it doesn't have to special case this.
64    #[inline]
65    pub fn base(&self) -> Option<ast::Expr> {
66        support::children(self.syntax()).next()
67    }
68    #[inline]
69    pub fn field(&self) -> Option<ast::NameRef> {
70        support::children(self.syntax()).last()
71    }
72}
73
74impl ast::IndexExpr {
75    #[inline]
76    pub fn base(&self) -> Option<ast::Expr> {
77        support::children(&self.syntax).next()
78    }
79    #[inline]
80    pub fn index(&self) -> Option<ast::Expr> {
81        support::children(&self.syntax).nth(1)
82    }
83}
84
85impl ast::SliceExpr {
86    #[inline]
87    pub fn base(&self) -> Option<ast::Expr> {
88        support::children(&self.syntax).next()
89    }
90
91    #[inline]
92    pub fn start(&self) -> Option<ast::Expr> {
93        // With `select x[1:]`, we have two exprs, `x` and `1`.
94        // We skip over the first one, and then we want the second one, but we
95        // want to make sure we don't choose the end expr if instead we had:
96        // `select x[:1]`
97        let colon = self.colon_token()?;
98        support::children(&self.syntax)
99            .skip(1)
100            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101    }
102
103    #[inline]
104    pub fn end(&self) -> Option<ast::Expr> {
105        // We want to make sure we get the last expr after the `:` which is the
106        // end of the slice, i.e., `2` in: `select x[:2]`
107        let colon = self.colon_token()?;
108        support::children(&self.syntax)
109            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110    }
111}
112
113impl ast::RenameColumn {
114    #[inline]
115    pub fn from(&self) -> Option<ast::NameRef> {
116        support::children(&self.syntax).nth(0)
117    }
118    #[inline]
119    pub fn to(&self) -> Option<ast::NameRef> {
120        support::children(&self.syntax).nth(1)
121    }
122}
123
124impl ast::ForeignKeyConstraint {
125    #[inline]
126    pub fn from_columns(&self) -> Option<ast::ColumnList> {
127        support::children(&self.syntax).nth(0)
128    }
129    #[inline]
130    pub fn to_columns(&self) -> Option<ast::ColumnList> {
131        support::children(&self.syntax).nth(1)
132    }
133}
134
135impl ast::BetweenExpr {
136    #[inline]
137    pub fn target(&self) -> Option<ast::Expr> {
138        support::children(&self.syntax).nth(0)
139    }
140    #[inline]
141    pub fn start(&self) -> Option<ast::Expr> {
142        support::children(&self.syntax).nth(1)
143    }
144    #[inline]
145    pub fn end(&self) -> Option<ast::Expr> {
146        support::children(&self.syntax).nth(2)
147    }
148}
149
150impl ast::WhenClause {
151    #[inline]
152    pub fn condition(&self) -> Option<ast::Expr> {
153        support::children(&self.syntax).next()
154    }
155    #[inline]
156    pub fn then(&self) -> Option<ast::Expr> {
157        support::children(&self.syntax).nth(1)
158    }
159}
160
161impl ast::CompoundSelect {
162    #[inline]
163    pub fn lhs(&self) -> Option<ast::SelectVariant> {
164        support::children(&self.syntax).next()
165    }
166    #[inline]
167    pub fn rhs(&self) -> Option<ast::SelectVariant> {
168        support::children(&self.syntax).nth(1)
169    }
170}
171
172impl ast::NameRef {
173    #[inline]
174    pub fn text(&self) -> TokenText<'_> {
175        text_of_first_token(self.syntax())
176    }
177}
178
179impl ast::Name {
180    #[inline]
181    pub fn text(&self) -> TokenText<'_> {
182        text_of_first_token(self.syntax())
183    }
184}
185
186impl ast::CharType {
187    #[inline]
188    pub fn text(&self) -> TokenText<'_> {
189        text_of_first_token(self.syntax())
190    }
191}
192
193impl ast::OpSig {
194    #[inline]
195    pub fn lhs(&self) -> Option<ast::Type> {
196        support::children(self.syntax()).next()
197    }
198
199    #[inline]
200    pub fn rhs(&self) -> Option<ast::Type> {
201        support::children(self.syntax()).nth(1)
202    }
203}
204
205impl ast::CastSig {
206    #[inline]
207    pub fn lhs(&self) -> Option<ast::Type> {
208        support::children(self.syntax()).next()
209    }
210
211    #[inline]
212    pub fn rhs(&self) -> Option<ast::Type> {
213        support::children(self.syntax()).nth(1)
214    }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219        green_ref
220            .children()
221            .next()
222            .and_then(NodeOrToken::into_token)
223            .unwrap()
224    }
225
226    match node.green() {
227        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229    }
230}
231
232impl ast::WithQuery {
233    #[inline]
234    pub fn with_clause(&self) -> Option<ast::WithClause> {
235        support::child(self.syntax())
236    }
237}
238
239impl ast::HasParamList for ast::FunctionSig {}
240impl ast::HasParamList for ast::Aggregate {}
241
242impl ast::NameLike for ast::Name {}
243impl ast::NameLike for ast::NameRef {}
244
245impl ast::HasWithClause for ast::Select {}
246impl ast::HasWithClause for ast::SelectInto {}
247impl ast::HasWithClause for ast::Insert {}
248impl ast::HasWithClause for ast::Update {}
249impl ast::HasWithClause for ast::Delete {}
250
251impl ast::HasCreateTable for ast::CreateTable {}
252impl ast::HasCreateTable for ast::CreateForeignTable {}
253impl ast::HasCreateTable for ast::CreateTableLike {}
254
255#[test]
256fn index_expr() {
257    let source_code = "
258        select foo[bar];
259    ";
260    let parse = SourceFile::parse(source_code);
261    assert!(parse.errors().is_empty());
262    let file: SourceFile = parse.tree();
263    let stmt = file.stmts().next().unwrap();
264    let ast::Stmt::Select(select) = stmt else {
265        unreachable!()
266    };
267    let select_clause = select.select_clause().unwrap();
268    let target = select_clause
269        .target_list()
270        .unwrap()
271        .targets()
272        .next()
273        .unwrap();
274    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
275        unreachable!()
276    };
277    let base = index_expr.base().unwrap();
278    let index = index_expr.index().unwrap();
279    assert_eq!(base.syntax().text(), "foo");
280    assert_eq!(index.syntax().text(), "bar");
281}
282
283#[test]
284fn slice_expr() {
285    use insta::assert_snapshot;
286    let source_code = "
287        select x[1:2], x[2:], x[:3], x[:];
288    ";
289    let parse = SourceFile::parse(source_code);
290    assert!(parse.errors().is_empty());
291    let file: SourceFile = parse.tree();
292    let stmt = file.stmts().next().unwrap();
293    let ast::Stmt::Select(select) = stmt else {
294        unreachable!()
295    };
296    let select_clause = select.select_clause().unwrap();
297    let mut targets = select_clause.target_list().unwrap().targets();
298
299    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
300        unreachable!()
301    };
302    assert_snapshot!(slice.syntax(), @"x[1:2]");
303    assert_eq!(slice.base().unwrap().syntax().text(), "x");
304    assert_eq!(slice.start().unwrap().syntax().text(), "1");
305    assert_eq!(slice.end().unwrap().syntax().text(), "2");
306
307    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
308        unreachable!()
309    };
310    assert_snapshot!(slice.syntax(), @"x[2:]");
311    assert_eq!(slice.base().unwrap().syntax().text(), "x");
312    assert_eq!(slice.start().unwrap().syntax().text(), "2");
313    assert!(slice.end().is_none());
314
315    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
316        unreachable!()
317    };
318    assert_snapshot!(slice.syntax(), @"x[:3]");
319    assert_eq!(slice.base().unwrap().syntax().text(), "x");
320    assert!(slice.start().is_none());
321    assert_eq!(slice.end().unwrap().syntax().text(), "3");
322
323    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
324        unreachable!()
325    };
326    assert_snapshot!(slice.syntax(), @"x[:]");
327    assert_eq!(slice.base().unwrap().syntax().text(), "x");
328    assert!(slice.start().is_none());
329    assert!(slice.end().is_none());
330}
331
332#[test]
333fn field_expr() {
334    let source_code = "
335        select foo.bar;
336    ";
337    let parse = SourceFile::parse(source_code);
338    assert!(parse.errors().is_empty());
339    let file: SourceFile = parse.tree();
340    let stmt = file.stmts().next().unwrap();
341    let ast::Stmt::Select(select) = stmt else {
342        unreachable!()
343    };
344    let select_clause = select.select_clause().unwrap();
345    let target = select_clause
346        .target_list()
347        .unwrap()
348        .targets()
349        .next()
350        .unwrap();
351    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
352        unreachable!()
353    };
354    let base = field_expr.base().unwrap();
355    let field = field_expr.field().unwrap();
356    assert_eq!(base.syntax().text(), "foo");
357    assert_eq!(field.syntax().text(), "bar");
358}
359
360#[test]
361fn between_expr() {
362    let source_code = "
363        select 2 between 1 and 3;
364    ";
365    let parse = SourceFile::parse(source_code);
366    assert!(parse.errors().is_empty());
367    let file: SourceFile = parse.tree();
368    let stmt = file.stmts().next().unwrap();
369    let ast::Stmt::Select(select) = stmt else {
370        unreachable!()
371    };
372    let select_clause = select.select_clause().unwrap();
373    let target = select_clause
374        .target_list()
375        .unwrap()
376        .targets()
377        .next()
378        .unwrap();
379    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
380        unreachable!()
381    };
382    let target = between_expr.target().unwrap();
383    let start = between_expr.start().unwrap();
384    let end = between_expr.end().unwrap();
385    assert_eq!(target.syntax().text(), "2");
386    assert_eq!(start.syntax().text(), "1");
387    assert_eq!(end.syntax().text(), "3");
388}
389
390#[test]
391fn cast_expr() {
392    use insta::assert_snapshot;
393
394    let cast = extract_expr("select cast('123' as int)");
395    assert!(cast.expr().is_some());
396    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
397    assert!(cast.ty().is_some());
398    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
399
400    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
401    assert!(cast.expr().is_some());
402    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
403    assert!(cast.ty().is_some());
404    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
405
406    let cast = extract_expr("select int '123'");
407    assert!(cast.expr().is_some());
408    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
409    assert!(cast.ty().is_some());
410    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
411
412    let cast = extract_expr("select pg_catalog.int4 '123'");
413    assert!(cast.expr().is_some());
414    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
415    assert!(cast.ty().is_some());
416    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
417
418    let cast = extract_expr("select '123'::int");
419    assert!(cast.expr().is_some());
420    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
421    assert!(cast.ty().is_some());
422    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
423
424    let cast = extract_expr("select '123'::int4");
425    assert!(cast.expr().is_some());
426    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
427    assert!(cast.ty().is_some());
428    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
429
430    let cast = extract_expr("select '123'::pg_catalog.int4");
431    assert!(cast.expr().is_some());
432    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
433    assert!(cast.ty().is_some());
434    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
435
436    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
437    assert!(cast.expr().is_some());
438    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
439    assert!(cast.ty().is_some());
440    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
441
442    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
443    assert!(cast.expr().is_some());
444    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
445    assert!(cast.ty().is_some());
446    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
447
448    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
449    assert!(cast.expr().is_some());
450    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
451    assert!(cast.ty().is_some());
452    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
453
454    let cast = extract_expr("select interval '1' month");
455    assert!(cast.expr().is_some());
456    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
457    assert!(cast.ty().is_some());
458    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
459
460    fn extract_expr(sql: &str) -> ast::CastExpr {
461        let parse = SourceFile::parse(sql);
462        assert!(parse.errors().is_empty());
463        let file: SourceFile = parse.tree();
464        let node = file
465            .stmts()
466            .map(|x| match x {
467                ast::Stmt::Select(select) => select
468                    .select_clause()
469                    .unwrap()
470                    .target_list()
471                    .unwrap()
472                    .targets()
473                    .next()
474                    .unwrap()
475                    .expr()
476                    .unwrap()
477                    .clone(),
478                _ => unreachable!(),
479            })
480            .next()
481            .unwrap();
482        match node {
483            ast::Expr::CastExpr(cast) => cast,
484            _ => unreachable!(),
485        }
486    }
487}
488
489#[test]
490fn op_sig() {
491    let source_code = "
492      alter operator p.+ (int4, int8) 
493        owner to u;
494    ";
495    let parse = SourceFile::parse(source_code);
496    assert!(parse.errors().is_empty());
497    let file: SourceFile = parse.tree();
498    let stmt = file.stmts().next().unwrap();
499    let ast::Stmt::AlterOperator(alter_op) = stmt else {
500        unreachable!()
501    };
502    let op_sig = alter_op.op_sig().unwrap();
503    let lhs = op_sig.lhs().unwrap();
504    let rhs = op_sig.rhs().unwrap();
505    assert_snapshot!(lhs.syntax().text(), @"int4");
506    assert_snapshot!(rhs.syntax().text(), @"int8");
507}
508
509#[test]
510fn cast_sig() {
511    let source_code = "
512      drop cast (text as int);
513    ";
514    let parse = SourceFile::parse(source_code);
515    assert!(parse.errors().is_empty());
516    let file: SourceFile = parse.tree();
517    let stmt = file.stmts().next().unwrap();
518    let ast::Stmt::DropCast(alter_op) = stmt else {
519        unreachable!()
520    };
521    let cast_sig = alter_op.cast_sig().unwrap();
522    let lhs = cast_sig.lhs().unwrap();
523    let rhs = cast_sig.rhs().unwrap();
524    assert_snapshot!(lhs.syntax().text(), @"text");
525    assert_snapshot!(rhs.syntax().text(), @"int");
526}