squawk_syntax/ast/
node_ext.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/ast/node_ext.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42    #[inline]
43    pub fn name(&self) -> Option<ast::Name> {
44        support::child(self.syntax())
45    }
46}
47
48impl ast::BinExpr {
49    #[inline]
50    pub fn lhs(&self) -> Option<ast::Expr> {
51        support::children(self.syntax()).next()
52    }
53
54    #[inline]
55    pub fn rhs(&self) -> Option<ast::Expr> {
56        support::children(self.syntax()).nth(1)
57    }
58}
59
60impl ast::FieldExpr {
61    // We have NameRef as a variant of Expr which complicates things (and it
62    // might not be worth it).
63    // Rust analyzer doesn't do this so it doesn't have to special case this.
64    #[inline]
65    pub fn base(&self) -> Option<ast::Expr> {
66        support::children(self.syntax()).next()
67    }
68    #[inline]
69    pub fn field(&self) -> Option<ast::NameRef> {
70        support::children(self.syntax()).last()
71    }
72}
73
74impl ast::IndexExpr {
75    #[inline]
76    pub fn base(&self) -> Option<ast::Expr> {
77        support::children(&self.syntax).next()
78    }
79    #[inline]
80    pub fn index(&self) -> Option<ast::Expr> {
81        support::children(&self.syntax).nth(1)
82    }
83}
84
85impl ast::SliceExpr {
86    #[inline]
87    pub fn base(&self) -> Option<ast::Expr> {
88        support::children(&self.syntax).next()
89    }
90
91    #[inline]
92    pub fn start(&self) -> Option<ast::Expr> {
93        // With `select x[1:]`, we have two exprs, `x` and `1`.
94        // We skip over the first one, and then we want the second one, but we
95        // want to make sure we don't choose the end expr if instead we had:
96        // `select x[:1]`
97        let colon = self.colon_token()?;
98        support::children(&self.syntax)
99            .skip(1)
100            .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101    }
102
103    #[inline]
104    pub fn end(&self) -> Option<ast::Expr> {
105        // We want to make sure we get the last expr after the `:` which is the
106        // end of the slice, i.e., `2` in: `select x[:2]`
107        let colon = self.colon_token()?;
108        support::children(&self.syntax)
109            .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110    }
111}
112
113impl ast::RenameColumn {
114    #[inline]
115    pub fn from(&self) -> Option<ast::NameRef> {
116        support::children(&self.syntax).nth(0)
117    }
118    #[inline]
119    pub fn to(&self) -> Option<ast::NameRef> {
120        support::children(&self.syntax).nth(1)
121    }
122}
123
124impl ast::ForeignKeyConstraint {
125    #[inline]
126    pub fn from_columns(&self) -> Option<ast::ColumnList> {
127        support::children(&self.syntax).nth(0)
128    }
129    #[inline]
130    pub fn to_columns(&self) -> Option<ast::ColumnList> {
131        support::children(&self.syntax).nth(1)
132    }
133}
134
135impl ast::BetweenExpr {
136    #[inline]
137    pub fn target(&self) -> Option<ast::Expr> {
138        support::children(&self.syntax).nth(0)
139    }
140    #[inline]
141    pub fn start(&self) -> Option<ast::Expr> {
142        support::children(&self.syntax).nth(1)
143    }
144    #[inline]
145    pub fn end(&self) -> Option<ast::Expr> {
146        support::children(&self.syntax).nth(2)
147    }
148}
149
150impl ast::WhenClause {
151    #[inline]
152    pub fn condition(&self) -> Option<ast::Expr> {
153        support::children(&self.syntax).next()
154    }
155    #[inline]
156    pub fn then(&self) -> Option<ast::Expr> {
157        support::children(&self.syntax).nth(1)
158    }
159}
160
161impl ast::CompoundSelect {
162    #[inline]
163    pub fn lhs(&self) -> Option<ast::SelectVariant> {
164        support::children(&self.syntax).next()
165    }
166    #[inline]
167    pub fn rhs(&self) -> Option<ast::SelectVariant> {
168        support::children(&self.syntax).nth(1)
169    }
170}
171
172impl ast::NameRef {
173    #[inline]
174    pub fn text(&self) -> TokenText<'_> {
175        text_of_first_token(self.syntax())
176    }
177}
178
179impl ast::Name {
180    #[inline]
181    pub fn text(&self) -> TokenText<'_> {
182        text_of_first_token(self.syntax())
183    }
184}
185
186impl ast::CharType {
187    #[inline]
188    pub fn text(&self) -> TokenText<'_> {
189        text_of_first_token(self.syntax())
190    }
191}
192
193impl ast::OpSig {
194    #[inline]
195    pub fn lhs(&self) -> Option<ast::Type> {
196        support::children(self.syntax()).next()
197    }
198
199    #[inline]
200    pub fn rhs(&self) -> Option<ast::Type> {
201        support::children(self.syntax()).nth(1)
202    }
203}
204
205impl ast::CastSig {
206    #[inline]
207    pub fn lhs(&self) -> Option<ast::Type> {
208        support::children(self.syntax()).next()
209    }
210
211    #[inline]
212    pub fn rhs(&self) -> Option<ast::Type> {
213        support::children(self.syntax()).nth(1)
214    }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218    fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219        green_ref
220            .children()
221            .next()
222            .and_then(NodeOrToken::into_token)
223            .unwrap()
224    }
225
226    match node.green() {
227        Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228        Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229    }
230}
231
232#[test]
233fn index_expr() {
234    let source_code = "
235        select foo[bar];
236    ";
237    let parse = SourceFile::parse(source_code);
238    assert!(parse.errors().is_empty());
239    let file: SourceFile = parse.tree();
240    let stmt = file.stmts().next().unwrap();
241    let ast::Stmt::Select(select) = stmt else {
242        unreachable!()
243    };
244    let select_clause = select.select_clause().unwrap();
245    let target = select_clause
246        .target_list()
247        .unwrap()
248        .targets()
249        .next()
250        .unwrap();
251    let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
252        unreachable!()
253    };
254    let base = index_expr.base().unwrap();
255    let index = index_expr.index().unwrap();
256    assert_eq!(base.syntax().text(), "foo");
257    assert_eq!(index.syntax().text(), "bar");
258}
259
260#[test]
261fn slice_expr() {
262    use insta::assert_snapshot;
263    let source_code = "
264        select x[1:2], x[2:], x[:3], x[:];
265    ";
266    let parse = SourceFile::parse(source_code);
267    assert!(parse.errors().is_empty());
268    let file: SourceFile = parse.tree();
269    let stmt = file.stmts().next().unwrap();
270    let ast::Stmt::Select(select) = stmt else {
271        unreachable!()
272    };
273    let select_clause = select.select_clause().unwrap();
274    let mut targets = select_clause.target_list().unwrap().targets();
275
276    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
277        unreachable!()
278    };
279    assert_snapshot!(slice.syntax(), @"x[1:2]");
280    assert_eq!(slice.base().unwrap().syntax().text(), "x");
281    assert_eq!(slice.start().unwrap().syntax().text(), "1");
282    assert_eq!(slice.end().unwrap().syntax().text(), "2");
283
284    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
285        unreachable!()
286    };
287    assert_snapshot!(slice.syntax(), @"x[2:]");
288    assert_eq!(slice.base().unwrap().syntax().text(), "x");
289    assert_eq!(slice.start().unwrap().syntax().text(), "2");
290    assert!(slice.end().is_none());
291
292    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
293        unreachable!()
294    };
295    assert_snapshot!(slice.syntax(), @"x[:3]");
296    assert_eq!(slice.base().unwrap().syntax().text(), "x");
297    assert!(slice.start().is_none());
298    assert_eq!(slice.end().unwrap().syntax().text(), "3");
299
300    let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
301        unreachable!()
302    };
303    assert_snapshot!(slice.syntax(), @"x[:]");
304    assert_eq!(slice.base().unwrap().syntax().text(), "x");
305    assert!(slice.start().is_none());
306    assert!(slice.end().is_none());
307}
308
309#[test]
310fn field_expr() {
311    let source_code = "
312        select foo.bar;
313    ";
314    let parse = SourceFile::parse(source_code);
315    assert!(parse.errors().is_empty());
316    let file: SourceFile = parse.tree();
317    let stmt = file.stmts().next().unwrap();
318    let ast::Stmt::Select(select) = stmt else {
319        unreachable!()
320    };
321    let select_clause = select.select_clause().unwrap();
322    let target = select_clause
323        .target_list()
324        .unwrap()
325        .targets()
326        .next()
327        .unwrap();
328    let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
329        unreachable!()
330    };
331    let base = field_expr.base().unwrap();
332    let field = field_expr.field().unwrap();
333    assert_eq!(base.syntax().text(), "foo");
334    assert_eq!(field.syntax().text(), "bar");
335}
336
337#[test]
338fn between_expr() {
339    let source_code = "
340        select 2 between 1 and 3;
341    ";
342    let parse = SourceFile::parse(source_code);
343    assert!(parse.errors().is_empty());
344    let file: SourceFile = parse.tree();
345    let stmt = file.stmts().next().unwrap();
346    let ast::Stmt::Select(select) = stmt else {
347        unreachable!()
348    };
349    let select_clause = select.select_clause().unwrap();
350    let target = select_clause
351        .target_list()
352        .unwrap()
353        .targets()
354        .next()
355        .unwrap();
356    let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
357        unreachable!()
358    };
359    let target = between_expr.target().unwrap();
360    let start = between_expr.start().unwrap();
361    let end = between_expr.end().unwrap();
362    assert_eq!(target.syntax().text(), "2");
363    assert_eq!(start.syntax().text(), "1");
364    assert_eq!(end.syntax().text(), "3");
365}
366
367#[test]
368fn cast_expr() {
369    use insta::assert_snapshot;
370
371    let cast = extract_expr("select cast('123' as int)");
372    assert!(cast.expr().is_some());
373    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
374    assert!(cast.ty().is_some());
375    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
376
377    let cast = extract_expr("select cast('123' as pg_catalog.int4)");
378    assert!(cast.expr().is_some());
379    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
380    assert!(cast.ty().is_some());
381    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
382
383    let cast = extract_expr("select int '123'");
384    assert!(cast.expr().is_some());
385    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
386    assert!(cast.ty().is_some());
387    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
388
389    let cast = extract_expr("select pg_catalog.int4 '123'");
390    assert!(cast.expr().is_some());
391    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
392    assert!(cast.ty().is_some());
393    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
394
395    let cast = extract_expr("select '123'::int");
396    assert!(cast.expr().is_some());
397    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
398    assert!(cast.ty().is_some());
399    assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
400
401    let cast = extract_expr("select '123'::int4");
402    assert!(cast.expr().is_some());
403    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
404    assert!(cast.ty().is_some());
405    assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
406
407    let cast = extract_expr("select '123'::pg_catalog.int4");
408    assert!(cast.expr().is_some());
409    assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
410    assert!(cast.ty().is_some());
411    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
412
413    let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
414    assert!(cast.expr().is_some());
415    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
416    assert!(cast.ty().is_some());
417    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
418
419    let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
420    assert!(cast.expr().is_some());
421    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
422    assert!(cast.ty().is_some());
423    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
424
425    let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
426    assert!(cast.expr().is_some());
427    assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
428    assert!(cast.ty().is_some());
429    assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
430
431    let cast = extract_expr("select interval '1' month");
432    assert!(cast.expr().is_some());
433    assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
434    assert!(cast.ty().is_some());
435    assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
436
437    fn extract_expr(sql: &str) -> ast::CastExpr {
438        let parse = SourceFile::parse(sql);
439        assert!(parse.errors().is_empty());
440        let file: SourceFile = parse.tree();
441        let node = file
442            .stmts()
443            .map(|x| match x {
444                ast::Stmt::Select(select) => select
445                    .select_clause()
446                    .unwrap()
447                    .target_list()
448                    .unwrap()
449                    .targets()
450                    .next()
451                    .unwrap()
452                    .expr()
453                    .unwrap()
454                    .clone(),
455                _ => unreachable!(),
456            })
457            .next()
458            .unwrap();
459        match node {
460            ast::Expr::CastExpr(cast) => cast,
461            _ => unreachable!(),
462        }
463    }
464}
465
466#[test]
467fn op_sig() {
468    let source_code = "
469      alter operator p.+ (int4, int8) 
470        owner to u;
471    ";
472    let parse = SourceFile::parse(source_code);
473    assert!(parse.errors().is_empty());
474    let file: SourceFile = parse.tree();
475    let stmt = file.stmts().next().unwrap();
476    let ast::Stmt::AlterOperator(alter_op) = stmt else {
477        unreachable!()
478    };
479    let op_sig = alter_op.op_sig().unwrap();
480    let lhs = op_sig.lhs().unwrap();
481    let rhs = op_sig.rhs().unwrap();
482    assert_snapshot!(lhs.syntax().text(), @"int4");
483    assert_snapshot!(rhs.syntax().text(), @"int8");
484}
485
486#[test]
487fn cast_sig() {
488    let source_code = "
489      drop cast (text as int);
490    ";
491    let parse = SourceFile::parse(source_code);
492    assert!(parse.errors().is_empty());
493    let file: SourceFile = parse.tree();
494    let stmt = file.stmts().next().unwrap();
495    let ast::Stmt::DropCast(alter_op) = stmt else {
496        unreachable!()
497    };
498    let cast_sig = alter_op.cast_sig().unwrap();
499    let lhs = cast_sig.lhs().unwrap();
500    let rhs = cast_sig.rhs().unwrap();
501    assert_snapshot!(lhs.syntax().text(), @"text");
502    assert_snapshot!(rhs.syntax().text(), @"int");
503}