squawk_syntax/ast/
node_ext.rs1use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42 #[inline]
43 pub fn constraint_name(&self) -> Option<ast::ConstraintName> {
44 support::child(self.syntax())
45 }
46}
47
48impl ast::BinExpr {
49 #[inline]
50 pub fn lhs(&self) -> Option<ast::Expr> {
51 support::children(self.syntax()).next()
52 }
53
54 #[inline]
55 pub fn rhs(&self) -> Option<ast::Expr> {
56 support::children(self.syntax()).nth(1)
57 }
58}
59
60impl ast::FieldExpr {
61 #[inline]
65 pub fn base(&self) -> Option<ast::Expr> {
66 support::children(self.syntax()).next()
67 }
68 #[inline]
69 pub fn field(&self) -> Option<ast::NameRef> {
70 support::children(self.syntax()).last()
71 }
72}
73
74impl ast::IndexExpr {
75 #[inline]
76 pub fn base(&self) -> Option<ast::Expr> {
77 support::children(&self.syntax).next()
78 }
79 #[inline]
80 pub fn index(&self) -> Option<ast::Expr> {
81 support::children(&self.syntax).nth(1)
82 }
83}
84
85impl ast::SliceExpr {
86 #[inline]
87 pub fn base(&self) -> Option<ast::Expr> {
88 support::children(&self.syntax).next()
89 }
90
91 #[inline]
92 pub fn start(&self) -> Option<ast::Expr> {
93 let colon = self.colon_token()?;
98 support::children(&self.syntax)
99 .skip(1)
100 .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101 }
102
103 #[inline]
104 pub fn end(&self) -> Option<ast::Expr> {
105 let colon = self.colon_token()?;
108 support::children(&self.syntax)
109 .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110 }
111}
112
113impl ast::RenameColumn {
114 #[inline]
115 pub fn from(&self) -> Option<ast::NameRef> {
116 support::children(&self.syntax).nth(0)
117 }
118 #[inline]
119 pub fn to(&self) -> Option<ast::NameRef> {
120 support::children(&self.syntax).nth(1)
121 }
122}
123
124impl ast::ForeignKeyConstraint {
125 #[inline]
126 pub fn from_columns(&self) -> Option<ast::ColumnList> {
127 support::children(&self.syntax).nth(0)
128 }
129 #[inline]
130 pub fn to_columns(&self) -> Option<ast::ColumnList> {
131 support::children(&self.syntax).nth(1)
132 }
133}
134
135impl ast::BetweenExpr {
136 #[inline]
137 pub fn target(&self) -> Option<ast::Expr> {
138 support::children(&self.syntax).nth(0)
139 }
140 #[inline]
141 pub fn start(&self) -> Option<ast::Expr> {
142 support::children(&self.syntax).nth(1)
143 }
144 #[inline]
145 pub fn end(&self) -> Option<ast::Expr> {
146 support::children(&self.syntax).nth(2)
147 }
148}
149
150impl ast::WhenClause {
151 #[inline]
152 pub fn condition(&self) -> Option<ast::Expr> {
153 support::children(&self.syntax).next()
154 }
155 #[inline]
156 pub fn then(&self) -> Option<ast::Expr> {
157 support::children(&self.syntax).nth(1)
158 }
159}
160
161impl ast::CompoundSelect {
162 #[inline]
163 pub fn lhs(&self) -> Option<ast::SelectVariant> {
164 support::children(&self.syntax).next()
165 }
166 #[inline]
167 pub fn rhs(&self) -> Option<ast::SelectVariant> {
168 support::children(&self.syntax).nth(1)
169 }
170}
171
172impl ast::NameRef {
173 #[inline]
174 pub fn text(&self) -> TokenText<'_> {
175 text_of_first_token(self.syntax())
176 }
177}
178
179impl ast::Name {
180 #[inline]
181 pub fn text(&self) -> TokenText<'_> {
182 text_of_first_token(self.syntax())
183 }
184}
185
186impl ast::CharType {
187 #[inline]
188 pub fn text(&self) -> TokenText<'_> {
189 text_of_first_token(self.syntax())
190 }
191}
192
193impl ast::OpSig {
194 #[inline]
195 pub fn lhs(&self) -> Option<ast::Type> {
196 support::children(self.syntax()).next()
197 }
198
199 #[inline]
200 pub fn rhs(&self) -> Option<ast::Type> {
201 support::children(self.syntax()).nth(1)
202 }
203}
204
205impl ast::CastSig {
206 #[inline]
207 pub fn lhs(&self) -> Option<ast::Type> {
208 support::children(self.syntax()).next()
209 }
210
211 #[inline]
212 pub fn rhs(&self) -> Option<ast::Type> {
213 support::children(self.syntax()).nth(1)
214 }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218 fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219 green_ref
220 .children()
221 .next()
222 .and_then(NodeOrToken::into_token)
223 .unwrap()
224 }
225
226 match node.green() {
227 Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228 Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229 }
230}
231
232impl ast::HasParamList for ast::FunctionSig {}
233impl ast::HasParamList for ast::Aggregate {}
234
235impl ast::NameLike for ast::Name {}
236impl ast::NameLike for ast::NameRef {}
237
238impl ast::HasWithClause for ast::Select {}
239impl ast::HasWithClause for ast::SelectInto {}
240impl ast::HasWithClause for ast::Insert {}
241impl ast::HasWithClause for ast::Update {}
242impl ast::HasWithClause for ast::Delete {}
243
244impl ast::HasCreateTable for ast::CreateTable {}
245impl ast::HasCreateTable for ast::CreateForeignTable {}
246impl ast::HasCreateTable for ast::CreateTableLike {}
247
248#[test]
249fn index_expr() {
250 let source_code = "
251 select foo[bar];
252 ";
253 let parse = SourceFile::parse(source_code);
254 assert!(parse.errors().is_empty());
255 let file: SourceFile = parse.tree();
256 let stmt = file.stmts().next().unwrap();
257 let ast::Stmt::Select(select) = stmt else {
258 unreachable!()
259 };
260 let select_clause = select.select_clause().unwrap();
261 let target = select_clause
262 .target_list()
263 .unwrap()
264 .targets()
265 .next()
266 .unwrap();
267 let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
268 unreachable!()
269 };
270 let base = index_expr.base().unwrap();
271 let index = index_expr.index().unwrap();
272 assert_eq!(base.syntax().text(), "foo");
273 assert_eq!(index.syntax().text(), "bar");
274}
275
276#[test]
277fn slice_expr() {
278 use insta::assert_snapshot;
279 let source_code = "
280 select x[1:2], x[2:], x[:3], x[:];
281 ";
282 let parse = SourceFile::parse(source_code);
283 assert!(parse.errors().is_empty());
284 let file: SourceFile = parse.tree();
285 let stmt = file.stmts().next().unwrap();
286 let ast::Stmt::Select(select) = stmt else {
287 unreachable!()
288 };
289 let select_clause = select.select_clause().unwrap();
290 let mut targets = select_clause.target_list().unwrap().targets();
291
292 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
293 unreachable!()
294 };
295 assert_snapshot!(slice.syntax(), @"x[1:2]");
296 assert_eq!(slice.base().unwrap().syntax().text(), "x");
297 assert_eq!(slice.start().unwrap().syntax().text(), "1");
298 assert_eq!(slice.end().unwrap().syntax().text(), "2");
299
300 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
301 unreachable!()
302 };
303 assert_snapshot!(slice.syntax(), @"x[2:]");
304 assert_eq!(slice.base().unwrap().syntax().text(), "x");
305 assert_eq!(slice.start().unwrap().syntax().text(), "2");
306 assert!(slice.end().is_none());
307
308 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
309 unreachable!()
310 };
311 assert_snapshot!(slice.syntax(), @"x[:3]");
312 assert_eq!(slice.base().unwrap().syntax().text(), "x");
313 assert!(slice.start().is_none());
314 assert_eq!(slice.end().unwrap().syntax().text(), "3");
315
316 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
317 unreachable!()
318 };
319 assert_snapshot!(slice.syntax(), @"x[:]");
320 assert_eq!(slice.base().unwrap().syntax().text(), "x");
321 assert!(slice.start().is_none());
322 assert!(slice.end().is_none());
323}
324
325#[test]
326fn field_expr() {
327 let source_code = "
328 select foo.bar;
329 ";
330 let parse = SourceFile::parse(source_code);
331 assert!(parse.errors().is_empty());
332 let file: SourceFile = parse.tree();
333 let stmt = file.stmts().next().unwrap();
334 let ast::Stmt::Select(select) = stmt else {
335 unreachable!()
336 };
337 let select_clause = select.select_clause().unwrap();
338 let target = select_clause
339 .target_list()
340 .unwrap()
341 .targets()
342 .next()
343 .unwrap();
344 let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
345 unreachable!()
346 };
347 let base = field_expr.base().unwrap();
348 let field = field_expr.field().unwrap();
349 assert_eq!(base.syntax().text(), "foo");
350 assert_eq!(field.syntax().text(), "bar");
351}
352
353#[test]
354fn between_expr() {
355 let source_code = "
356 select 2 between 1 and 3;
357 ";
358 let parse = SourceFile::parse(source_code);
359 assert!(parse.errors().is_empty());
360 let file: SourceFile = parse.tree();
361 let stmt = file.stmts().next().unwrap();
362 let ast::Stmt::Select(select) = stmt else {
363 unreachable!()
364 };
365 let select_clause = select.select_clause().unwrap();
366 let target = select_clause
367 .target_list()
368 .unwrap()
369 .targets()
370 .next()
371 .unwrap();
372 let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
373 unreachable!()
374 };
375 let target = between_expr.target().unwrap();
376 let start = between_expr.start().unwrap();
377 let end = between_expr.end().unwrap();
378 assert_eq!(target.syntax().text(), "2");
379 assert_eq!(start.syntax().text(), "1");
380 assert_eq!(end.syntax().text(), "3");
381}
382
383#[test]
384fn cast_expr() {
385 use insta::assert_snapshot;
386
387 let cast = extract_expr("select cast('123' as int)");
388 assert!(cast.expr().is_some());
389 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
390 assert!(cast.ty().is_some());
391 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
392
393 let cast = extract_expr("select cast('123' as pg_catalog.int4)");
394 assert!(cast.expr().is_some());
395 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
396 assert!(cast.ty().is_some());
397 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
398
399 let cast = extract_expr("select int '123'");
400 assert!(cast.expr().is_some());
401 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
402 assert!(cast.ty().is_some());
403 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
404
405 let cast = extract_expr("select pg_catalog.int4 '123'");
406 assert!(cast.expr().is_some());
407 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
408 assert!(cast.ty().is_some());
409 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
410
411 let cast = extract_expr("select '123'::int");
412 assert!(cast.expr().is_some());
413 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
414 assert!(cast.ty().is_some());
415 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
416
417 let cast = extract_expr("select '123'::int4");
418 assert!(cast.expr().is_some());
419 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
420 assert!(cast.ty().is_some());
421 assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
422
423 let cast = extract_expr("select '123'::pg_catalog.int4");
424 assert!(cast.expr().is_some());
425 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
426 assert!(cast.ty().is_some());
427 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
428
429 let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
430 assert!(cast.expr().is_some());
431 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
432 assert!(cast.ty().is_some());
433 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
434
435 let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
436 assert!(cast.expr().is_some());
437 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
438 assert!(cast.ty().is_some());
439 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
440
441 let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
442 assert!(cast.expr().is_some());
443 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
444 assert!(cast.ty().is_some());
445 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
446
447 let cast = extract_expr("select interval '1' month");
448 assert!(cast.expr().is_some());
449 assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
450 assert!(cast.ty().is_some());
451 assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
452
453 fn extract_expr(sql: &str) -> ast::CastExpr {
454 let parse = SourceFile::parse(sql);
455 assert!(parse.errors().is_empty());
456 let file: SourceFile = parse.tree();
457 let node = file
458 .stmts()
459 .map(|x| match x {
460 ast::Stmt::Select(select) => select
461 .select_clause()
462 .unwrap()
463 .target_list()
464 .unwrap()
465 .targets()
466 .next()
467 .unwrap()
468 .expr()
469 .unwrap()
470 .clone(),
471 _ => unreachable!(),
472 })
473 .next()
474 .unwrap();
475 match node {
476 ast::Expr::CastExpr(cast) => cast,
477 _ => unreachable!(),
478 }
479 }
480}
481
482#[test]
483fn op_sig() {
484 let source_code = "
485 alter operator p.+ (int4, int8)
486 owner to u;
487 ";
488 let parse = SourceFile::parse(source_code);
489 assert!(parse.errors().is_empty());
490 let file: SourceFile = parse.tree();
491 let stmt = file.stmts().next().unwrap();
492 let ast::Stmt::AlterOperator(alter_op) = stmt else {
493 unreachable!()
494 };
495 let op_sig = alter_op.op_sig().unwrap();
496 let lhs = op_sig.lhs().unwrap();
497 let rhs = op_sig.rhs().unwrap();
498 assert_snapshot!(lhs.syntax().text(), @"int4");
499 assert_snapshot!(rhs.syntax().text(), @"int8");
500}
501
502#[test]
503fn cast_sig() {
504 let source_code = "
505 drop cast (text as int);
506 ";
507 let parse = SourceFile::parse(source_code);
508 assert!(parse.errors().is_empty());
509 let file: SourceFile = parse.tree();
510 let stmt = file.stmts().next().unwrap();
511 let ast::Stmt::DropCast(alter_op) = stmt else {
512 unreachable!()
513 };
514 let cast_sig = alter_op.cast_sig().unwrap();
515 let lhs = cast_sig.lhs().unwrap();
516 let rhs = cast_sig.rhs().unwrap();
517 assert_snapshot!(lhs.syntax().text(), @"text");
518 assert_snapshot!(rhs.syntax().text(), @"int");
519}