squawk_syntax/ast/
node_ext.rs1use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42 #[inline]
43 pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
44 support::child(self.syntax())
45 }
46}
47
48impl ast::BinExpr {
49 #[inline]
50 pub fn lhs(&self) -> Option<ast::Expr> {
51 support::children(self.syntax()).next()
52 }
53
54 #[inline]
55 pub fn rhs(&self) -> Option<ast::Expr> {
56 support::children(self.syntax()).nth(1)
57 }
58}
59
60impl ast::FieldExpr {
61 #[inline]
65 pub fn base(&self) -> Option<ast::Expr> {
66 support::children(self.syntax()).next()
67 }
68 #[inline]
69 pub fn field(&self) -> Option<ast::NameRef> {
70 support::children(self.syntax()).last()
71 }
72}
73
74impl ast::IndexExpr {
75 #[inline]
76 pub fn base(&self) -> Option<ast::Expr> {
77 support::children(&self.syntax).next()
78 }
79 #[inline]
80 pub fn index(&self) -> Option<ast::Expr> {
81 support::children(&self.syntax).nth(1)
82 }
83}
84
85impl ast::SliceExpr {
86 #[inline]
87 pub fn base(&self) -> Option<ast::Expr> {
88 support::children(&self.syntax).next()
89 }
90
91 #[inline]
92 pub fn start(&self) -> Option<ast::Expr> {
93 let colon = self.colon_token()?;
98 support::children(&self.syntax)
99 .skip(1)
100 .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101 }
102
103 #[inline]
104 pub fn end(&self) -> Option<ast::Expr> {
105 let colon = self.colon_token()?;
108 support::children(&self.syntax)
109 .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110 }
111}
112
113impl ast::RenameColumn {
114 #[inline]
115 pub fn from(&self) -> Option<ast::NameRef> {
116 support::children(&self.syntax).nth(0)
117 }
118 #[inline]
119 pub fn to(&self) -> Option<ast::NameRef> {
120 support::children(&self.syntax).nth(1)
121 }
122}
123
124impl ast::ForeignKeyConstraint {
125 #[inline]
126 pub fn from_columns(&self) -> Option<ast::ColumnList> {
127 support::children(&self.syntax).nth(0)
128 }
129 #[inline]
130 pub fn to_columns(&self) -> Option<ast::ColumnList> {
131 support::children(&self.syntax).nth(1)
132 }
133}
134
135impl ast::BetweenExpr {
136 #[inline]
137 pub fn target(&self) -> Option<ast::Expr> {
138 support::children(&self.syntax).nth(0)
139 }
140 #[inline]
141 pub fn start(&self) -> Option<ast::Expr> {
142 support::children(&self.syntax).nth(1)
143 }
144 #[inline]
145 pub fn end(&self) -> Option<ast::Expr> {
146 support::children(&self.syntax).nth(2)
147 }
148}
149
150impl ast::WhenClause {
151 #[inline]
152 pub fn condition(&self) -> Option<ast::Expr> {
153 support::children(&self.syntax).next()
154 }
155 #[inline]
156 pub fn then(&self) -> Option<ast::Expr> {
157 support::children(&self.syntax).nth(1)
158 }
159}
160
161impl ast::CompoundSelect {
162 #[inline]
163 pub fn lhs(&self) -> Option<ast::SelectVariant> {
164 support::children(&self.syntax).next()
165 }
166 #[inline]
167 pub fn rhs(&self) -> Option<ast::SelectVariant> {
168 support::children(&self.syntax).nth(1)
169 }
170}
171
172impl ast::NameRef {
173 #[inline]
174 pub fn text(&self) -> TokenText<'_> {
175 text_of_first_token(self.syntax())
176 }
177}
178
179impl ast::Name {
180 #[inline]
181 pub fn text(&self) -> TokenText<'_> {
182 text_of_first_token(self.syntax())
183 }
184}
185
186impl ast::CharType {
187 #[inline]
188 pub fn text(&self) -> TokenText<'_> {
189 text_of_first_token(self.syntax())
190 }
191}
192
193impl ast::OpSig {
194 #[inline]
195 pub fn lhs(&self) -> Option<ast::Type> {
196 support::children(self.syntax()).next()
197 }
198
199 #[inline]
200 pub fn rhs(&self) -> Option<ast::Type> {
201 support::children(self.syntax()).nth(1)
202 }
203}
204
205impl ast::CastSig {
206 #[inline]
207 pub fn lhs(&self) -> Option<ast::Type> {
208 support::children(self.syntax()).next()
209 }
210
211 #[inline]
212 pub fn rhs(&self) -> Option<ast::Type> {
213 support::children(self.syntax()).nth(1)
214 }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218 fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219 green_ref
220 .children()
221 .next()
222 .and_then(NodeOrToken::into_token)
223 .unwrap()
224 }
225
226 match node.green() {
227 Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228 Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229 }
230}
231
232impl ast::WithQuery {
233 #[inline]
234 pub fn with_clause(&self) -> Option<ast::WithClause> {
235 support::child(self.syntax())
236 }
237}
238
239impl ast::HasParamList for ast::FunctionSig {}
240impl ast::HasParamList for ast::Aggregate {}
241
242impl ast::NameLike for ast::Name {}
243impl ast::NameLike for ast::NameRef {}
244
245impl ast::HasWithClause for ast::Select {}
246impl ast::HasWithClause for ast::SelectInto {}
247impl ast::HasWithClause for ast::Insert {}
248impl ast::HasWithClause for ast::Update {}
249impl ast::HasWithClause for ast::Delete {}
250
251impl ast::HasCreateTable for ast::CreateTable {}
252impl ast::HasCreateTable for ast::CreateForeignTable {}
253impl ast::HasCreateTable for ast::CreateTableLike {}
254
255#[test]
256fn index_expr() {
257 let source_code = "
258 select foo[bar];
259 ";
260 let parse = SourceFile::parse(source_code);
261 assert!(parse.errors().is_empty());
262 let file: SourceFile = parse.tree();
263 let stmt = file.stmts().next().unwrap();
264 let ast::Stmt::Select(select) = stmt else {
265 unreachable!()
266 };
267 let select_clause = select.select_clause().unwrap();
268 let target = select_clause
269 .target_list()
270 .unwrap()
271 .targets()
272 .next()
273 .unwrap();
274 let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
275 unreachable!()
276 };
277 let base = index_expr.base().unwrap();
278 let index = index_expr.index().unwrap();
279 assert_eq!(base.syntax().text(), "foo");
280 assert_eq!(index.syntax().text(), "bar");
281}
282
283#[test]
284fn slice_expr() {
285 use insta::assert_snapshot;
286 let source_code = "
287 select x[1:2], x[2:], x[:3], x[:];
288 ";
289 let parse = SourceFile::parse(source_code);
290 assert!(parse.errors().is_empty());
291 let file: SourceFile = parse.tree();
292 let stmt = file.stmts().next().unwrap();
293 let ast::Stmt::Select(select) = stmt else {
294 unreachable!()
295 };
296 let select_clause = select.select_clause().unwrap();
297 let mut targets = select_clause.target_list().unwrap().targets();
298
299 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
300 unreachable!()
301 };
302 assert_snapshot!(slice.syntax(), @"x[1:2]");
303 assert_eq!(slice.base().unwrap().syntax().text(), "x");
304 assert_eq!(slice.start().unwrap().syntax().text(), "1");
305 assert_eq!(slice.end().unwrap().syntax().text(), "2");
306
307 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
308 unreachable!()
309 };
310 assert_snapshot!(slice.syntax(), @"x[2:]");
311 assert_eq!(slice.base().unwrap().syntax().text(), "x");
312 assert_eq!(slice.start().unwrap().syntax().text(), "2");
313 assert!(slice.end().is_none());
314
315 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
316 unreachable!()
317 };
318 assert_snapshot!(slice.syntax(), @"x[:3]");
319 assert_eq!(slice.base().unwrap().syntax().text(), "x");
320 assert!(slice.start().is_none());
321 assert_eq!(slice.end().unwrap().syntax().text(), "3");
322
323 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
324 unreachable!()
325 };
326 assert_snapshot!(slice.syntax(), @"x[:]");
327 assert_eq!(slice.base().unwrap().syntax().text(), "x");
328 assert!(slice.start().is_none());
329 assert!(slice.end().is_none());
330}
331
332#[test]
333fn field_expr() {
334 let source_code = "
335 select foo.bar;
336 ";
337 let parse = SourceFile::parse(source_code);
338 assert!(parse.errors().is_empty());
339 let file: SourceFile = parse.tree();
340 let stmt = file.stmts().next().unwrap();
341 let ast::Stmt::Select(select) = stmt else {
342 unreachable!()
343 };
344 let select_clause = select.select_clause().unwrap();
345 let target = select_clause
346 .target_list()
347 .unwrap()
348 .targets()
349 .next()
350 .unwrap();
351 let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
352 unreachable!()
353 };
354 let base = field_expr.base().unwrap();
355 let field = field_expr.field().unwrap();
356 assert_eq!(base.syntax().text(), "foo");
357 assert_eq!(field.syntax().text(), "bar");
358}
359
360#[test]
361fn between_expr() {
362 let source_code = "
363 select 2 between 1 and 3;
364 ";
365 let parse = SourceFile::parse(source_code);
366 assert!(parse.errors().is_empty());
367 let file: SourceFile = parse.tree();
368 let stmt = file.stmts().next().unwrap();
369 let ast::Stmt::Select(select) = stmt else {
370 unreachable!()
371 };
372 let select_clause = select.select_clause().unwrap();
373 let target = select_clause
374 .target_list()
375 .unwrap()
376 .targets()
377 .next()
378 .unwrap();
379 let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
380 unreachable!()
381 };
382 let target = between_expr.target().unwrap();
383 let start = between_expr.start().unwrap();
384 let end = between_expr.end().unwrap();
385 assert_eq!(target.syntax().text(), "2");
386 assert_eq!(start.syntax().text(), "1");
387 assert_eq!(end.syntax().text(), "3");
388}
389
390#[test]
391fn cast_expr() {
392 use insta::assert_snapshot;
393
394 let cast = extract_expr("select cast('123' as int)");
395 assert!(cast.expr().is_some());
396 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
397 assert!(cast.ty().is_some());
398 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
399
400 let cast = extract_expr("select cast('123' as pg_catalog.int4)");
401 assert!(cast.expr().is_some());
402 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
403 assert!(cast.ty().is_some());
404 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
405
406 let cast = extract_expr("select int '123'");
407 assert!(cast.expr().is_some());
408 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
409 assert!(cast.ty().is_some());
410 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
411
412 let cast = extract_expr("select pg_catalog.int4 '123'");
413 assert!(cast.expr().is_some());
414 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
415 assert!(cast.ty().is_some());
416 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
417
418 let cast = extract_expr("select '123'::int");
419 assert!(cast.expr().is_some());
420 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
421 assert!(cast.ty().is_some());
422 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
423
424 let cast = extract_expr("select '123'::int4");
425 assert!(cast.expr().is_some());
426 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
427 assert!(cast.ty().is_some());
428 assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
429
430 let cast = extract_expr("select '123'::pg_catalog.int4");
431 assert!(cast.expr().is_some());
432 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
433 assert!(cast.ty().is_some());
434 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
435
436 let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
437 assert!(cast.expr().is_some());
438 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
439 assert!(cast.ty().is_some());
440 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
441
442 let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
443 assert!(cast.expr().is_some());
444 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
445 assert!(cast.ty().is_some());
446 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
447
448 let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
449 assert!(cast.expr().is_some());
450 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
451 assert!(cast.ty().is_some());
452 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
453
454 let cast = extract_expr("select interval '1' month");
455 assert!(cast.expr().is_some());
456 assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
457 assert!(cast.ty().is_some());
458 assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
459
460 fn extract_expr(sql: &str) -> ast::CastExpr {
461 let parse = SourceFile::parse(sql);
462 assert!(parse.errors().is_empty());
463 let file: SourceFile = parse.tree();
464 let node = file
465 .stmts()
466 .map(|x| match x {
467 ast::Stmt::Select(select) => select
468 .select_clause()
469 .unwrap()
470 .target_list()
471 .unwrap()
472 .targets()
473 .next()
474 .unwrap()
475 .expr()
476 .unwrap()
477 .clone(),
478 _ => unreachable!(),
479 })
480 .next()
481 .unwrap();
482 match node {
483 ast::Expr::CastExpr(cast) => cast,
484 _ => unreachable!(),
485 }
486 }
487}
488
489#[test]
490fn op_sig() {
491 let source_code = "
492 alter operator p.+ (int4, int8)
493 owner to u;
494 ";
495 let parse = SourceFile::parse(source_code);
496 assert!(parse.errors().is_empty());
497 let file: SourceFile = parse.tree();
498 let stmt = file.stmts().next().unwrap();
499 let ast::Stmt::AlterOperator(alter_op) = stmt else {
500 unreachable!()
501 };
502 let op_sig = alter_op.op_sig().unwrap();
503 let lhs = op_sig.lhs().unwrap();
504 let rhs = op_sig.rhs().unwrap();
505 assert_snapshot!(lhs.syntax().text(), @"int4");
506 assert_snapshot!(rhs.syntax().text(), @"int8");
507}
508
509#[test]
510fn cast_sig() {
511 let source_code = "
512 drop cast (text as int);
513 ";
514 let parse = SourceFile::parse(source_code);
515 assert!(parse.errors().is_empty());
516 let file: SourceFile = parse.tree();
517 let stmt = file.stmts().next().unwrap();
518 let ast::Stmt::DropCast(alter_op) = stmt else {
519 unreachable!()
520 };
521 let cast_sig = alter_op.cast_sig().unwrap();
522 let lhs = cast_sig.lhs().unwrap();
523 let rhs = cast_sig.rhs().unwrap();
524 assert_snapshot!(lhs.syntax().text(), @"text");
525 assert_snapshot!(rhs.syntax().text(), @"int");
526}