squawk_syntax/ast/
node_ext.rs1use std::borrow::Cow;
28
29#[cfg(test)]
30use insta::assert_snapshot;
31use rowan::{GreenNodeData, GreenTokenData, NodeOrToken};
32
33#[cfg(test)]
34use crate::SourceFile;
35use crate::ast;
36use crate::ast::AstNode;
37use crate::{SyntaxNode, TokenText};
38
39use super::support;
40
41impl ast::Constraint {
42 #[inline]
43 pub fn name(&self) -> Option<ast::Name> {
44 support::child(self.syntax())
45 }
46}
47
48impl ast::BinExpr {
49 #[inline]
50 pub fn lhs(&self) -> Option<ast::Expr> {
51 support::children(self.syntax()).next()
52 }
53
54 #[inline]
55 pub fn rhs(&self) -> Option<ast::Expr> {
56 support::children(self.syntax()).nth(1)
57 }
58}
59
60impl ast::FieldExpr {
61 #[inline]
65 pub fn base(&self) -> Option<ast::Expr> {
66 support::children(self.syntax()).next()
67 }
68 #[inline]
69 pub fn field(&self) -> Option<ast::NameRef> {
70 support::children(self.syntax()).last()
71 }
72}
73
74impl ast::IndexExpr {
75 #[inline]
76 pub fn base(&self) -> Option<ast::Expr> {
77 support::children(&self.syntax).next()
78 }
79 #[inline]
80 pub fn index(&self) -> Option<ast::Expr> {
81 support::children(&self.syntax).nth(1)
82 }
83}
84
85impl ast::SliceExpr {
86 #[inline]
87 pub fn base(&self) -> Option<ast::Expr> {
88 support::children(&self.syntax).next()
89 }
90
91 #[inline]
92 pub fn start(&self) -> Option<ast::Expr> {
93 let colon = self.colon_token()?;
98 support::children(&self.syntax)
99 .skip(1)
100 .find(|expr: &ast::Expr| expr.syntax().text_range().end() <= colon.text_range().start())
101 }
102
103 #[inline]
104 pub fn end(&self) -> Option<ast::Expr> {
105 let colon = self.colon_token()?;
108 support::children(&self.syntax)
109 .find(|expr: &ast::Expr| expr.syntax().text_range().start() >= colon.text_range().end())
110 }
111}
112
113impl ast::RenameColumn {
114 #[inline]
115 pub fn from(&self) -> Option<ast::NameRef> {
116 support::children(&self.syntax).nth(0)
117 }
118 #[inline]
119 pub fn to(&self) -> Option<ast::NameRef> {
120 support::children(&self.syntax).nth(1)
121 }
122}
123
124impl ast::ForeignKeyConstraint {
125 #[inline]
126 pub fn from_columns(&self) -> Option<ast::ColumnList> {
127 support::children(&self.syntax).nth(0)
128 }
129 #[inline]
130 pub fn to_columns(&self) -> Option<ast::ColumnList> {
131 support::children(&self.syntax).nth(1)
132 }
133}
134
135impl ast::BetweenExpr {
136 #[inline]
137 pub fn target(&self) -> Option<ast::Expr> {
138 support::children(&self.syntax).nth(0)
139 }
140 #[inline]
141 pub fn start(&self) -> Option<ast::Expr> {
142 support::children(&self.syntax).nth(1)
143 }
144 #[inline]
145 pub fn end(&self) -> Option<ast::Expr> {
146 support::children(&self.syntax).nth(2)
147 }
148}
149
150impl ast::WhenClause {
151 #[inline]
152 pub fn condition(&self) -> Option<ast::Expr> {
153 support::children(&self.syntax).next()
154 }
155 #[inline]
156 pub fn then(&self) -> Option<ast::Expr> {
157 support::children(&self.syntax).nth(1)
158 }
159}
160
161impl ast::CompoundSelect {
162 #[inline]
163 pub fn lhs(&self) -> Option<ast::SelectVariant> {
164 support::children(&self.syntax).next()
165 }
166 #[inline]
167 pub fn rhs(&self) -> Option<ast::SelectVariant> {
168 support::children(&self.syntax).nth(1)
169 }
170}
171
172impl ast::NameRef {
173 #[inline]
174 pub fn text(&self) -> TokenText<'_> {
175 text_of_first_token(self.syntax())
176 }
177}
178
179impl ast::Name {
180 #[inline]
181 pub fn text(&self) -> TokenText<'_> {
182 text_of_first_token(self.syntax())
183 }
184}
185
186impl ast::CharType {
187 #[inline]
188 pub fn text(&self) -> TokenText<'_> {
189 text_of_first_token(self.syntax())
190 }
191}
192
193impl ast::OpSig {
194 #[inline]
195 pub fn lhs(&self) -> Option<ast::Type> {
196 support::children(self.syntax()).next()
197 }
198
199 #[inline]
200 pub fn rhs(&self) -> Option<ast::Type> {
201 support::children(self.syntax()).nth(1)
202 }
203}
204
205impl ast::CastSig {
206 #[inline]
207 pub fn lhs(&self) -> Option<ast::Type> {
208 support::children(self.syntax()).next()
209 }
210
211 #[inline]
212 pub fn rhs(&self) -> Option<ast::Type> {
213 support::children(self.syntax()).nth(1)
214 }
215}
216
217pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> {
218 fn first_token(green_ref: &GreenNodeData) -> &GreenTokenData {
219 green_ref
220 .children()
221 .next()
222 .and_then(NodeOrToken::into_token)
223 .unwrap()
224 }
225
226 match node.green() {
227 Cow::Borrowed(green_ref) => TokenText::borrowed(first_token(green_ref).text()),
228 Cow::Owned(green) => TokenText::owned(first_token(&green).to_owned()),
229 }
230}
231
232impl ast::HasParamList for ast::FunctionSig {}
233impl ast::HasParamList for ast::Aggregate {}
234impl ast::NameLike for ast::Name {}
235impl ast::NameLike for ast::NameRef {}
236
237#[test]
238fn index_expr() {
239 let source_code = "
240 select foo[bar];
241 ";
242 let parse = SourceFile::parse(source_code);
243 assert!(parse.errors().is_empty());
244 let file: SourceFile = parse.tree();
245 let stmt = file.stmts().next().unwrap();
246 let ast::Stmt::Select(select) = stmt else {
247 unreachable!()
248 };
249 let select_clause = select.select_clause().unwrap();
250 let target = select_clause
251 .target_list()
252 .unwrap()
253 .targets()
254 .next()
255 .unwrap();
256 let ast::Expr::IndexExpr(index_expr) = target.expr().unwrap() else {
257 unreachable!()
258 };
259 let base = index_expr.base().unwrap();
260 let index = index_expr.index().unwrap();
261 assert_eq!(base.syntax().text(), "foo");
262 assert_eq!(index.syntax().text(), "bar");
263}
264
265#[test]
266fn slice_expr() {
267 use insta::assert_snapshot;
268 let source_code = "
269 select x[1:2], x[2:], x[:3], x[:];
270 ";
271 let parse = SourceFile::parse(source_code);
272 assert!(parse.errors().is_empty());
273 let file: SourceFile = parse.tree();
274 let stmt = file.stmts().next().unwrap();
275 let ast::Stmt::Select(select) = stmt else {
276 unreachable!()
277 };
278 let select_clause = select.select_clause().unwrap();
279 let mut targets = select_clause.target_list().unwrap().targets();
280
281 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
282 unreachable!()
283 };
284 assert_snapshot!(slice.syntax(), @"x[1:2]");
285 assert_eq!(slice.base().unwrap().syntax().text(), "x");
286 assert_eq!(slice.start().unwrap().syntax().text(), "1");
287 assert_eq!(slice.end().unwrap().syntax().text(), "2");
288
289 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
290 unreachable!()
291 };
292 assert_snapshot!(slice.syntax(), @"x[2:]");
293 assert_eq!(slice.base().unwrap().syntax().text(), "x");
294 assert_eq!(slice.start().unwrap().syntax().text(), "2");
295 assert!(slice.end().is_none());
296
297 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
298 unreachable!()
299 };
300 assert_snapshot!(slice.syntax(), @"x[:3]");
301 assert_eq!(slice.base().unwrap().syntax().text(), "x");
302 assert!(slice.start().is_none());
303 assert_eq!(slice.end().unwrap().syntax().text(), "3");
304
305 let ast::Expr::SliceExpr(slice) = targets.next().unwrap().expr().unwrap() else {
306 unreachable!()
307 };
308 assert_snapshot!(slice.syntax(), @"x[:]");
309 assert_eq!(slice.base().unwrap().syntax().text(), "x");
310 assert!(slice.start().is_none());
311 assert!(slice.end().is_none());
312}
313
314#[test]
315fn field_expr() {
316 let source_code = "
317 select foo.bar;
318 ";
319 let parse = SourceFile::parse(source_code);
320 assert!(parse.errors().is_empty());
321 let file: SourceFile = parse.tree();
322 let stmt = file.stmts().next().unwrap();
323 let ast::Stmt::Select(select) = stmt else {
324 unreachable!()
325 };
326 let select_clause = select.select_clause().unwrap();
327 let target = select_clause
328 .target_list()
329 .unwrap()
330 .targets()
331 .next()
332 .unwrap();
333 let ast::Expr::FieldExpr(field_expr) = target.expr().unwrap() else {
334 unreachable!()
335 };
336 let base = field_expr.base().unwrap();
337 let field = field_expr.field().unwrap();
338 assert_eq!(base.syntax().text(), "foo");
339 assert_eq!(field.syntax().text(), "bar");
340}
341
342#[test]
343fn between_expr() {
344 let source_code = "
345 select 2 between 1 and 3;
346 ";
347 let parse = SourceFile::parse(source_code);
348 assert!(parse.errors().is_empty());
349 let file: SourceFile = parse.tree();
350 let stmt = file.stmts().next().unwrap();
351 let ast::Stmt::Select(select) = stmt else {
352 unreachable!()
353 };
354 let select_clause = select.select_clause().unwrap();
355 let target = select_clause
356 .target_list()
357 .unwrap()
358 .targets()
359 .next()
360 .unwrap();
361 let ast::Expr::BetweenExpr(between_expr) = target.expr().unwrap() else {
362 unreachable!()
363 };
364 let target = between_expr.target().unwrap();
365 let start = between_expr.start().unwrap();
366 let end = between_expr.end().unwrap();
367 assert_eq!(target.syntax().text(), "2");
368 assert_eq!(start.syntax().text(), "1");
369 assert_eq!(end.syntax().text(), "3");
370}
371
372#[test]
373fn cast_expr() {
374 use insta::assert_snapshot;
375
376 let cast = extract_expr("select cast('123' as int)");
377 assert!(cast.expr().is_some());
378 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
379 assert!(cast.ty().is_some());
380 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
381
382 let cast = extract_expr("select cast('123' as pg_catalog.int4)");
383 assert!(cast.expr().is_some());
384 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
385 assert!(cast.ty().is_some());
386 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
387
388 let cast = extract_expr("select int '123'");
389 assert!(cast.expr().is_some());
390 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
391 assert!(cast.ty().is_some());
392 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
393
394 let cast = extract_expr("select pg_catalog.int4 '123'");
395 assert!(cast.expr().is_some());
396 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
397 assert!(cast.ty().is_some());
398 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
399
400 let cast = extract_expr("select '123'::int");
401 assert!(cast.expr().is_some());
402 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
403 assert!(cast.ty().is_some());
404 assert_snapshot!(cast.ty().unwrap().syntax(), @"int");
405
406 let cast = extract_expr("select '123'::int4");
407 assert!(cast.expr().is_some());
408 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
409 assert!(cast.ty().is_some());
410 assert_snapshot!(cast.ty().unwrap().syntax(), @"int4");
411
412 let cast = extract_expr("select '123'::pg_catalog.int4");
413 assert!(cast.expr().is_some());
414 assert_snapshot!(cast.expr().unwrap().syntax(), @"'123'");
415 assert!(cast.ty().is_some());
416 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.int4");
417
418 let cast = extract_expr("select '{123}'::pg_catalog.varchar(10)[]");
419 assert!(cast.expr().is_some());
420 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
421 assert!(cast.ty().is_some());
422 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
423
424 let cast = extract_expr("select cast('{123}' as pg_catalog.varchar(10)[])");
425 assert!(cast.expr().is_some());
426 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
427 assert!(cast.ty().is_some());
428 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)[]");
429
430 let cast = extract_expr("select pg_catalog.varchar(10) '{123}'");
431 assert!(cast.expr().is_some());
432 assert_snapshot!(cast.expr().unwrap().syntax(), @"'{123}'");
433 assert!(cast.ty().is_some());
434 assert_snapshot!(cast.ty().unwrap().syntax(), @"pg_catalog.varchar(10)");
435
436 let cast = extract_expr("select interval '1' month");
437 assert!(cast.expr().is_some());
438 assert_snapshot!(cast.expr().unwrap().syntax(), @"'1'");
439 assert!(cast.ty().is_some());
440 assert_snapshot!(cast.ty().unwrap().syntax(), @"interval");
441
442 fn extract_expr(sql: &str) -> ast::CastExpr {
443 let parse = SourceFile::parse(sql);
444 assert!(parse.errors().is_empty());
445 let file: SourceFile = parse.tree();
446 let node = file
447 .stmts()
448 .map(|x| match x {
449 ast::Stmt::Select(select) => select
450 .select_clause()
451 .unwrap()
452 .target_list()
453 .unwrap()
454 .targets()
455 .next()
456 .unwrap()
457 .expr()
458 .unwrap()
459 .clone(),
460 _ => unreachable!(),
461 })
462 .next()
463 .unwrap();
464 match node {
465 ast::Expr::CastExpr(cast) => cast,
466 _ => unreachable!(),
467 }
468 }
469}
470
471#[test]
472fn op_sig() {
473 let source_code = "
474 alter operator p.+ (int4, int8)
475 owner to u;
476 ";
477 let parse = SourceFile::parse(source_code);
478 assert!(parse.errors().is_empty());
479 let file: SourceFile = parse.tree();
480 let stmt = file.stmts().next().unwrap();
481 let ast::Stmt::AlterOperator(alter_op) = stmt else {
482 unreachable!()
483 };
484 let op_sig = alter_op.op_sig().unwrap();
485 let lhs = op_sig.lhs().unwrap();
486 let rhs = op_sig.rhs().unwrap();
487 assert_snapshot!(lhs.syntax().text(), @"int4");
488 assert_snapshot!(rhs.syntax().text(), @"int8");
489}
490
491#[test]
492fn cast_sig() {
493 let source_code = "
494 drop cast (text as int);
495 ";
496 let parse = SourceFile::parse(source_code);
497 assert!(parse.errors().is_empty());
498 let file: SourceFile = parse.tree();
499 let stmt = file.stmts().next().unwrap();
500 let ast::Stmt::DropCast(alter_op) = stmt else {
501 unreachable!()
502 };
503 let cast_sig = alter_op.cast_sig().unwrap();
504 let lhs = cast_sig.lhs().unwrap();
505 let rhs = cast_sig.rhs().unwrap();
506 assert_snapshot!(lhs.syntax().text(), @"text");
507 assert_snapshot!(rhs.syntax().text(), @"int");
508}