squawk_syntax/ast/
node_ext.rs1use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42 #[inline]
43 pub fn name(&self) -> Option<ast::Name> {
44 support::child(self.syntax())
45 }
46}
47
48impl ast::BinExpr {
49 #[inline]
50 pub fn lhs(&self) -> Option<ast::Expr> {
51 support::children(self.syntax()).next()
52 }
53
54 #[inline]
55 pub fn rhs(&self) -> Option<ast::Expr> {
56 support::children(self.syntax()).nth(1)
57 }
58}
59
60impl ast::FieldExpr {
61 #[inline]
65 pub fn base(&self) -> Option<ast::Expr> {
66 support::children(self.syntax()).next()
67 }
68 #[inline]
69 pub fn field(&self) -> Option<ast::NameRef> {
70 support::children(self.syntax()).last()
71 }
72}
73
74impl ast::IndexExpr {
75 #[inline]
76 pub fn base(&self) -> Option<ast::Expr> {
77 support::children(&self.syntax).next()
78 }
79 #[inline]
80 pub fn index(&self) -> Option<ast::Expr> {
81 support::children(&self.syntax).nth(1)
82 }
83}
84
85impl ast::SliceExpr {
86 #[inline]
87 pub fn base(&self) -> Option<ast::Expr> {
88 support::children(&self.syntax).next()
89 }
90
91 #[inline]
92 pub fn start(&self) -> Option<ast::Expr> {
93 let colon = self.colon_token()?;
98 support::children(&self.syntax)
99 .skip(1)
100 .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101 }
102
103 #[inline]
104 pub fn end(&self) -> Option<ast::Expr> {
105 let colon = self.colon_token()?;
108 support::children(&self.syntax)
109 .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110 }
111}
112
113impl ast::RenameColumn {
114 #[inline]
115 pub fn from(&self) -> Option<ast::NameRef> {
116 support::children(&self.syntax).nth(0)
117 }
118 #[inline]
119 pub fn to(&self) -> Option<ast::NameRef> {
120 support::children(&self.syntax).nth(1)
121 }
122}
123
124impl ast::ForeignKeyConstraint {
125 #[inline]
126 pub fn from_columns(&self) -> Option<ast::ColumnList> {
127 support::children(&self.syntax).nth(0)
128 }
129 #[inline]
130 pub fn to_columns(&self) -> Option<ast::ColumnList> {
131 support::children(&self.syntax).nth(1)
132 }
133}
134
135impl ast::BetweenExpr {
136 #[inline]
137 pub fn target(&self) -> Option<ast::Expr> {
138 support::children(&self.syntax).nth(0)
139 }
140 #[inline]
141 pub fn start(&self) -> Option<ast::Expr> {
142 support::children(&self.syntax).nth(1)
143 }
144 #[inline]
145 pub fn end(&self) -> Option<ast::Expr> {
146 support::children(&self.syntax).nth(2)
147 }
148}
149
150impl ast::WhenClause {
151 #[inline]
152 pub fn condition(&self) -> Option<ast::Expr> {
153 support::children(&self.syntax).next()
154 }
155 #[inline]
156 pub fn then(&self) -> Option<ast::Expr> {
157 support::children(&self.syntax).nth(1)
158 }
159}
160
161impl ast::CompoundSelect {
162 #[inline]
163 pub fn lhs(&self) -> Option<ast::SelectVariant> {
164 support::children(&self.syntax).next()
165 }
166 #[inline]
167 pub fn rhs(&self) -> Option<ast::SelectVariant> {
168 support::children(&self.syntax).nth(1)
169 }
170}
171
172impl ast::NameRef {
173 #[inline]
174 pub fn text(&self) -> TokenText<'_> {
175 text_of_first_token(self.syntax())
176 }
177}
178
179impl ast::Name {
180 #[inline]
181 pub fn text(&self) -> TokenText<'_> {
182 text_of_first_token(self.syntax())
183 }
184}
185
186impl ast::CharType {
187 #[inline]
188 pub fn text(&self) -> TokenText<'_> {
189 text_of_first_token(self.syntax())
190 }
191}
192
193impl ast::OpSig {
194 #[inline]
195 pub fn lhs(&self) -> Option<ast::Type> {
196 support::children(self.syntax()).next()
197 }
198
199 #[inline]
200 pub fn rhs(&self) -> Option<ast::Type> {
201 support::children(self.syntax()).nth(1)
202 }
203}
204
205impl ast::CastSig {
206 #[inline]
207 pub fn lhs(&self) -> Option<ast::Type> {
208 support::children(self.syntax()).next()
209 }
210
211 #[inline]
212 pub fn rhs(&self) -> Option<ast::Type> {
213 support::children(self.syntax()).nth(1)
214 }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218 fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219 green_ref
220 .children()
221 .next()
222 .and_then(NodeOrToken::into_token)
223 .unwrap()
224 }
225
226 match node.green() {
227 Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228 Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229 }
230}
231
232#[test]
233fn index_expr() {
234 let source_code = "
235 select foo[bar];
236 ";
237 let parse = SourceFile::parse(source_code);
238 assert!(parse.errors().is_empty());
239 let file: SourceFile = parse.tree();
240 let stmt = file.stmts().next().unwrap();
241 let ast::Stmt::Select(select) = stmt else {
242 unreachable!()
243 };
244 let select_clause = select.select_clause().unwrap();
245 let target = select_clause
246 .target_list()
247 .unwrap()
248 .targets()
249 .next()
250 .unwrap();
251 let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
252 unreachable!()
253 };
254 let base = index_expr.base().unwrap();
255 let index = index_expr.index().unwrap();
256 assert_eq!(base.syntax().text(), "foo");
257 assert_eq!(index.syntax().text(), "bar");
258}
259
260#[test]
261fn slice_expr() {
262 use insta::assert_snapshot;
263 let source_code = "
264 select x[1:2], x[2:], x[:3], x[:];
265 ";
266 let parse = SourceFile::parse(source_code);
267 assert!(parse.errors().is_empty());
268 let file: SourceFile = parse.tree();
269 let stmt = file.stmts().next().unwrap();
270 let ast::Stmt::Select(select) = stmt else {
271 unreachable!()
272 };
273 let select_clause = select.select_clause().unwrap();
274 let mut targets = select_clause.target_list().unwrap().targets();
275
276 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
277 unreachable!()
278 };
279 assert_snapshot!(slice.syntax(), @"x[1:2]");
280 assert_eq!(slice.base().unwrap().syntax().text(), "x");
281 assert_eq!(slice.start().unwrap().syntax().text(), "1");
282 assert_eq!(slice.end().unwrap().syntax().text(), "2");
283
284 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
285 unreachable!()
286 };
287 assert_snapshot!(slice.syntax(), @"x[2:]");
288 assert_eq!(slice.base().unwrap().syntax().text(), "x");
289 assert_eq!(slice.start().unwrap().syntax().text(), "2");
290 assert!(slice.end().is_none());
291
292 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
293 unreachable!()
294 };
295 assert_snapshot!(slice.syntax(), @"x[:3]");
296 assert_eq!(slice.base().unwrap().syntax().text(), "x");
297 assert!(slice.start().is_none());
298 assert_eq!(slice.end().unwrap().syntax().text(), "3");
299
300 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
301 unreachable!()
302 };
303 assert_snapshot!(slice.syntax(), @"x[:]");
304 assert_eq!(slice.base().unwrap().syntax().text(), "x");
305 assert!(slice.start().is_none());
306 assert!(slice.end().is_none());
307}
308
309#[test]
310fn field_expr() {
311 let source_code = "
312 select foo.bar;
313 ";
314 let parse = SourceFile::parse(source_code);
315 assert!(parse.errors().is_empty());
316 let file: SourceFile = parse.tree();
317 let stmt = file.stmts().next().unwrap();
318 let ast::Stmt::Select(select) = stmt else {
319 unreachable!()
320 };
321 let select_clause = select.select_clause().unwrap();
322 let target = select_clause
323 .target_list()
324 .unwrap()
325 .targets()
326 .next()
327 .unwrap();
328 let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
329 unreachable!()
330 };
331 let base = field_expr.base().unwrap();
332 let field = field_expr.field().unwrap();
333 assert_eq!(base.syntax().text(), "foo");
334 assert_eq!(field.syntax().text(), "bar");
335}
336
337#[test]
338fn between_expr() {
339 let source_code = "
340 select 2 between 1 and 3;
341 ";
342 let parse = SourceFile::parse(source_code);
343 assert!(parse.errors().is_empty());
344 let file: SourceFile = parse.tree();
345 let stmt = file.stmts().next().unwrap();
346 let ast::Stmt::Select(select) = stmt else {
347 unreachable!()
348 };
349 let select_clause = select.select_clause().unwrap();
350 let target = select_clause
351 .target_list()
352 .unwrap()
353 .targets()
354 .next()
355 .unwrap();
356 let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
357 unreachable!()
358 };
359 let target = between_expr.target().unwrap();
360 let start = between_expr.start().unwrap();
361 let end = between_expr.end().unwrap();
362 assert_eq!(target.syntax().text(), "2");
363 assert_eq!(start.syntax().text(), "1");
364 assert_eq!(end.syntax().text(), "3");
365}
366
367#[test]
368fn cast_expr() {
369 use insta::assert_snapshot;
370
371 let cast = extract_expr("select cast('123' as int)");
372 assert!(cast.expr().is_some());
373 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
374 assert!(cast.ty().is_some());
375 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
376
377 let cast = extract_expr("select cast('123' as pg_catalog.int4)");
378 assert!(cast.expr().is_some());
379 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
380 assert!(cast.ty().is_some());
381 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
382
383 let cast = extract_expr("select int '123'");
384 assert!(cast.expr().is_some());
385 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
386 assert!(cast.ty().is_some());
387 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
388
389 let cast = extract_expr("select pg_catalog.int4 '123'");
390 assert!(cast.expr().is_some());
391 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
392 assert!(cast.ty().is_some());
393 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
394
395 let cast = extract_expr("select '123'::int");
396 assert!(cast.expr().is_some());
397 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
398 assert!(cast.ty().is_some());
399 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
400
401 let cast = extract_expr("select '123'::int4");
402 assert!(cast.expr().is_some());
403 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
404 assert!(cast.ty().is_some());
405 assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
406
407 let cast = extract_expr("select '123'::pg_catalog.int4");
408 assert!(cast.expr().is_some());
409 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
410 assert!(cast.ty().is_some());
411 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
412
413 let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
414 assert!(cast.expr().is_some());
415 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
416 assert!(cast.ty().is_some());
417 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
418
419 let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
420 assert!(cast.expr().is_some());
421 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
422 assert!(cast.ty().is_some());
423 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
424
425 let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
426 assert!(cast.expr().is_some());
427 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
428 assert!(cast.ty().is_some());
429 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
430
431 let cast = extract_expr("select interval '1' month");
432 assert!(cast.expr().is_some());
433 assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
434 assert!(cast.ty().is_some());
435 assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
436
437 fn extract_expr(sql: &str) -> ast::CastExpr {
438 let parse = SourceFile::parse(sql);
439 assert!(parse.errors().is_empty());
440 let file: SourceFile = parse.tree();
441 let node = file
442 .stmts()
443 .map(|x| match x {
444 ast::Stmt::Select(select) => select
445 .select_clause()
446 .unwrap()
447 .target_list()
448 .unwrap()
449 .targets()
450 .next()
451 .unwrap()
452 .expr()
453 .unwrap()
454 .clone(),
455 _ => unreachable!(),
456 })
457 .next()
458 .unwrap();
459 match node {
460 ast::Expr::CastExpr(cast) => cast,
461 _ => unreachable!(),
462 }
463 }
464}
465
466#[test]
467fn op_sig() {
468 let source_code = "
469 alter operator p.+ (int4, int8)
470 owner to u;
471 ";
472 let parse = SourceFile::parse(source_code);
473 assert!(parse.errors().is_empty());
474 let file: SourceFile = parse.tree();
475 let stmt = file.stmts().next().unwrap();
476 let ast::Stmt::AlterOperator(alter_op) = stmt else {
477 unreachable!()
478 };
479 let op_sig = alter_op.op_sig().unwrap();
480 let lhs = op_sig.lhs().unwrap();
481 let rhs = op_sig.rhs().unwrap();
482 assert_snapshot!(lhs.syntax().text(), @"int4");
483 assert_snapshot!(rhs.syntax().text(), @"int8");
484}
485
486#[test]
487fn cast_sig() {
488 let source_code = "
489 drop cast (text as int);
490 ";
491 let parse = SourceFile::parse(source_code);
492 assert!(parse.errors().is_empty());
493 let file: SourceFile = parse.tree();
494 let stmt = file.stmts().next().unwrap();
495 let ast::Stmt::DropCast(alter_op) = stmt else {
496 unreachable!()
497 };
498 let cast_sig = alter_op.cast_sig().unwrap();
499 let lhs = cast_sig.lhs().unwrap();
500 let rhs = cast_sig.rhs().unwrap();
501 assert_snapshot!(lhs.syntax().text(), @"text");
502 assert_snapshot!(rhs.syntax().text(), @"int");
503}