squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42    #[inline]
43    pub fn name(&self) -> Option<ast::Name> {
44        support::child(self.syntax())
45    }
46}
47
48impl ast::BinExpr {
49    #[inline]
50    pub fn lhs(&self) -> Option<ast::Expr> {
51        support::children(self.syntax()).next()
52    }
53
54    #[inline]
55    pub fn rhs(&self) -> Option<ast::Expr> {
56        support::children(self.syntax()).nth(1)
57    }
58}
59
60impl ast::FieldExpr {
61    // We have NameRef as a variant of Expr which complicates things (and it
62    // might not be worth it).
63    // Rust analyzer doesn't do this so it doesn't have to special case this.
64    #[inline]
65    pub fn base(&self) -> Option<ast::Expr> {
66        support::children(self.syntax()).next()
67    }
68    #[inline]
69    pub fn field(&self) -> Option<ast::NameRef> {
70        support::children(self.syntax()).last()
71    }
72}
73
74impl ast::IndexExpr {
75    #[inline]
76    pub fn base(&self) -> Option<ast::Expr> {
77        support::children(&self.syntax).next()
78    }
79    #[inline]
80    pub fn index(&self) -> Option<ast::Expr> {
81        support::children(&self.syntax).nth(1)
82    }
83}
84
85impl ast::SliceExpr {
86    #[inline]
87    pub fn base(&self) -> Option<ast::Expr> {
88        support::children(&self.syntax).next()
89    }
90
91    #[inline]
92    pub fn start(&self) -> Option<ast::Expr> {
93        // With `select x[1:]`, we have two exprs, `x` and `1`.
94        // We skip over the first one, and then we want the second one, but we
95        // want to make sure we don't choose the end expr if instead we had:
96        // `select x[:1]`
97        let colon = self.colon_token()?;
98        support::children(&self.syntax)
99            .skip(1)
100            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101    }
102
103    #[inline]
104    pub fn end(&self) -> Option<ast::Expr> {
105        // We want to make sure we get the last expr after the `:` which is the
106        // end of the slice, i.e., `2` in: `select x[:2]`
107        let colon = self.colon_token()?;
108        support::children(&self.syntax)
109            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110    }
111}
112
113impl ast::RenameColumn {
114    #[inline]
115    pub fn from(&self) -> Option<ast::NameRef> {
116        support::children(&self.syntax).nth(0)
117    }
118    #[inline]
119    pub fn to(&self) -> Option<ast::NameRef> {
120        support::children(&self.syntax).nth(1)
121    }
122}
123
124impl ast::ForeignKeyConstraint {
125    #[inline]
126    pub fn from_columns(&self) -> Option<ast::ColumnList> {
127        support::children(&self.syntax).nth(0)
128    }
129    #[inline]
130    pub fn to_columns(&self) -> Option<ast::ColumnList> {
131        support::children(&self.syntax).nth(1)
132    }
133}
134
135impl ast::BetweenExpr {
136    #[inline]
137    pub fn target(&self) -> Option<ast::Expr> {
138        support::children(&self.syntax).nth(0)
139    }
140    #[inline]
141    pub fn start(&self) -> Option<ast::Expr> {
142        support::children(&self.syntax).nth(1)
143    }
144    #[inline]
145    pub fn end(&self) -> Option<ast::Expr> {
146        support::children(&self.syntax).nth(2)
147    }
148}
149
150impl ast::WhenClause {
151    #[inline]
152    pub fn condition(&self) -> Option<ast::Expr> {
153        support::children(&self.syntax).next()
154    }
155    #[inline]
156    pub fn then(&self) -> Option<ast::Expr> {
157        support::children(&self.syntax).nth(1)
158    }
159}
160
161impl ast::CompoundSelect {
162    #[inline]
163    pub fn lhs(&self) -> Option<ast::SelectVariant> {
164        support::children(&self.syntax).next()
165    }
166    #[inline]
167    pub fn rhs(&self) -> Option<ast::SelectVariant> {
168        support::children(&self.syntax).nth(1)
169    }
170}
171
172impl ast::NameRef {
173    #[inline]
174    pub fn text(&self) -> TokenText<'_> {
175        text_of_first_token(self.syntax())
176    }
177}
178
179impl ast::Name {
180    #[inline]
181    pub fn text(&self) -> TokenText<'_> {
182        text_of_first_token(self.syntax())
183    }
184}
185
186impl ast::CharType {
187    #[inline]
188    pub fn text(&self) -> TokenText<'_> {
189        text_of_first_token(self.syntax())
190    }
191}
192
193impl ast::OpSig {
194    #[inline]
195    pub fn lhs(&self) -> Option<ast::Type> {
196        support::children(self.syntax()).next()
197    }
198
199    #[inline]
200    pub fn rhs(&self) -> Option<ast::Type> {
201        support::children(self.syntax()).nth(1)
202    }
203}
204
205impl ast::CastSig {
206    #[inline]
207    pub fn lhs(&self) -> Option<ast::Type> {
208        support::children(self.syntax()).next()
209    }
210
211    #[inline]
212    pub fn rhs(&self) -> Option<ast::Type> {
213        support::children(self.syntax()).nth(1)
214    }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219        green_ref
220            .children()
221            .next()
222            .and_then(NodeOrToken::into_token)
223            .unwrap()
224    }
225
226    match node.green() {
227        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229    }
230}
231
232impl ast::HasParamList for ast::FunctionSig {}
233impl ast::HasParamList for ast::Aggregate {}
234impl ast::NameLike for ast::Name {}
235impl ast::NameLike for ast::NameRef {}
236
237#[test]
238fn index_expr() {
239    let source_code = "
240        select foo[bar];
241    ";
242    let parse = SourceFile::parse(source_code);
243    assert!(parse.errors().is_empty());
244    let file: SourceFile = parse.tree();
245    let stmt = file.stmts().next().unwrap();
246    let ast::Stmt::Select(select) = stmt else {
247        unreachable!()
248    };
249    let select_clause = select.select_clause().unwrap();
250    let target = select_clause
251        .target_list()
252        .unwrap()
253        .targets()
254        .next()
255        .unwrap();
256    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
257        unreachable!()
258    };
259    let base = index_expr.base().unwrap();
260    let index = index_expr.index().unwrap();
261    assert_eq!(base.syntax().text(), "foo");
262    assert_eq!(index.syntax().text(), "bar");
263}
264
265#[test]
266fn slice_expr() {
267    use insta::assert_snapshot;
268    let source_code = "
269        select x[1:2], x[2:], x[:3], x[:];
270    ";
271    let parse = SourceFile::parse(source_code);
272    assert!(parse.errors().is_empty());
273    let file: SourceFile = parse.tree();
274    let stmt = file.stmts().next().unwrap();
275    let ast::Stmt::Select(select) = stmt else {
276        unreachable!()
277    };
278    let select_clause = select.select_clause().unwrap();
279    let mut targets = select_clause.target_list().unwrap().targets();
280
281    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
282        unreachable!()
283    };
284    assert_snapshot!(slice.syntax(), @"x[1:2]");
285    assert_eq!(slice.base().unwrap().syntax().text(), "x");
286    assert_eq!(slice.start().unwrap().syntax().text(), "1");
287    assert_eq!(slice.end().unwrap().syntax().text(), "2");
288
289    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
290        unreachable!()
291    };
292    assert_snapshot!(slice.syntax(), @"x[2:]");
293    assert_eq!(slice.base().unwrap().syntax().text(), "x");
294    assert_eq!(slice.start().unwrap().syntax().text(), "2");
295    assert!(slice.end().is_none());
296
297    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
298        unreachable!()
299    };
300    assert_snapshot!(slice.syntax(), @"x[:3]");
301    assert_eq!(slice.base().unwrap().syntax().text(), "x");
302    assert!(slice.start().is_none());
303    assert_eq!(slice.end().unwrap().syntax().text(), "3");
304
305    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
306        unreachable!()
307    };
308    assert_snapshot!(slice.syntax(), @"x[:]");
309    assert_eq!(slice.base().unwrap().syntax().text(), "x");
310    assert!(slice.start().is_none());
311    assert!(slice.end().is_none());
312}
313
314#[test]
315fn field_expr() {
316    let source_code = "
317        select foo.bar;
318    ";
319    let parse = SourceFile::parse(source_code);
320    assert!(parse.errors().is_empty());
321    let file: SourceFile = parse.tree();
322    let stmt = file.stmts().next().unwrap();
323    let ast::Stmt::Select(select) = stmt else {
324        unreachable!()
325    };
326    let select_clause = select.select_clause().unwrap();
327    let target = select_clause
328        .target_list()
329        .unwrap()
330        .targets()
331        .next()
332        .unwrap();
333    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
334        unreachable!()
335    };
336    let base = field_expr.base().unwrap();
337    let field = field_expr.field().unwrap();
338    assert_eq!(base.syntax().text(), "foo");
339    assert_eq!(field.syntax().text(), "bar");
340}
341
342#[test]
343fn between_expr() {
344    let source_code = "
345        select 2 between 1 and 3;
346    ";
347    let parse = SourceFile::parse(source_code);
348    assert!(parse.errors().is_empty());
349    let file: SourceFile = parse.tree();
350    let stmt = file.stmts().next().unwrap();
351    let ast::Stmt::Select(select) = stmt else {
352        unreachable!()
353    };
354    let select_clause = select.select_clause().unwrap();
355    let target = select_clause
356        .target_list()
357        .unwrap()
358        .targets()
359        .next()
360        .unwrap();
361    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
362        unreachable!()
363    };
364    let target = between_expr.target().unwrap();
365    let start = between_expr.start().unwrap();
366    let end = between_expr.end().unwrap();
367    assert_eq!(target.syntax().text(), "2");
368    assert_eq!(start.syntax().text(), "1");
369    assert_eq!(end.syntax().text(), "3");
370}
371
372#[test]
373fn cast_expr() {
374    use insta::assert_snapshot;
375
376    let cast = extract_expr("select cast('123' as int)");
377    assert!(cast.expr().is_some());
378    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
379    assert!(cast.ty().is_some());
380    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
381
382    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
383    assert!(cast.expr().is_some());
384    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
385    assert!(cast.ty().is_some());
386    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
387
388    let cast = extract_expr("select int '123'");
389    assert!(cast.expr().is_some());
390    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
391    assert!(cast.ty().is_some());
392    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
393
394    let cast = extract_expr("select pg_catalog.int4 '123'");
395    assert!(cast.expr().is_some());
396    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
397    assert!(cast.ty().is_some());
398    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
399
400    let cast = extract_expr("select '123'::int");
401    assert!(cast.expr().is_some());
402    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
403    assert!(cast.ty().is_some());
404    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
405
406    let cast = extract_expr("select '123'::int4");
407    assert!(cast.expr().is_some());
408    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
409    assert!(cast.ty().is_some());
410    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
411
412    let cast = extract_expr("select '123'::pg_catalog.int4");
413    assert!(cast.expr().is_some());
414    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
415    assert!(cast.ty().is_some());
416    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
417
418    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
419    assert!(cast.expr().is_some());
420    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
421    assert!(cast.ty().is_some());
422    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
423
424    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
425    assert!(cast.expr().is_some());
426    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
427    assert!(cast.ty().is_some());
428    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
429
430    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
431    assert!(cast.expr().is_some());
432    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
433    assert!(cast.ty().is_some());
434    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
435
436    let cast = extract_expr("select interval '1' month");
437    assert!(cast.expr().is_some());
438    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
439    assert!(cast.ty().is_some());
440    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
441
442    fn extract_expr(sql: &str) -> ast::CastExpr {
443        let parse = SourceFile::parse(sql);
444        assert!(parse.errors().is_empty());
445        let file: SourceFile = parse.tree();
446        let node = file
447            .stmts()
448            .map(|x| match x {
449                ast::Stmt::Select(select) => select
450                    .select_clause()
451                    .unwrap()
452                    .target_list()
453                    .unwrap()
454                    .targets()
455                    .next()
456                    .unwrap()
457                    .expr()
458                    .unwrap()
459                    .clone(),
460                _ => unreachable!(),
461            })
462            .next()
463            .unwrap();
464        match node {
465            ast::Expr::CastExpr(cast) => cast,
466            _ => unreachable!(),
467        }
468    }
469}
470
471#[test]
472fn op_sig() {
473    let source_code = "
474      alter operator p.+ (int4, int8) 
475        owner to u;
476    ";
477    let parse = SourceFile::parse(source_code);
478    assert!(parse.errors().is_empty());
479    let file: SourceFile = parse.tree();
480    let stmt = file.stmts().next().unwrap();
481    let ast::Stmt::AlterOperator(alter_op) = stmt else {
482        unreachable!()
483    };
484    let op_sig = alter_op.op_sig().unwrap();
485    let lhs = op_sig.lhs().unwrap();
486    let rhs = op_sig.rhs().unwrap();
487    assert_snapshot!(lhs.syntax().text(), @"int4");
488    assert_snapshot!(rhs.syntax().text(), @"int8");
489}
490
491#[test]
492fn cast_sig() {
493    let source_code = "
494      drop cast (text as int);
495    ";
496    let parse = SourceFile::parse(source_code);
497    assert!(parse.errors().is_empty());
498    let file: SourceFile = parse.tree();
499    let stmt = file.stmts().next().unwrap();
500    let ast::Stmt::DropCast(alter_op) = stmt else {
501        unreachable!()
502    };
503    let cast_sig = alter_op.cast_sig().unwrap();
504    let lhs = cast_sig.lhs().unwrap();
505    let rhs = cast_sig.rhs().unwrap();
506    assert_snapshot!(lhs.syntax().text(), @"text");
507    assert_snapshot!(rhs.syntax().text(), @"int");
508}