1use squawk_syntax::{
2 SyntaxKind, SyntaxNode,
3 ast::{self, AstNode},
4};
5
6use crate::quote::normalize_identifier;
7
8#[derive(Clone, Debug, PartialEq)]
9pub(crate) enum ColumnName {
10 Column(String),
11 UnknownColumn(Option<String>),
23 Star,
24}
25
26impl ColumnName {
27 pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
29 if let Some(as_name) = target.as_name()
30 && let Some(name_node) = as_name.name()
31 {
32 let text = name_node.text();
33 let normalized = normalize_identifier(&text);
34 return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
35 }
36 Self::inferred_from_target(target)
37 }
38
39 pub(crate) fn inferred_from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
41 if let Some(expr) = target.expr()
42 && let Some(name) = name_from_expr(expr, false)
43 {
44 return Some(name);
45 } else if target.star_token().is_some() {
46 return Some((ColumnName::Star, target.syntax().clone()));
47 }
48 None
49 }
50
51 fn new(name: String, unknown_column: bool) -> ColumnName {
52 if unknown_column {
53 ColumnName::UnknownColumn(Some(name))
54 } else {
55 ColumnName::Column(name)
56 }
57 }
58
59 pub(crate) fn to_string(&self) -> Option<String> {
60 match self {
61 ColumnName::Column(string) => Some(string.to_string()),
62 ColumnName::Star => None,
63 ColumnName::UnknownColumn(c) => {
64 Some(c.clone().unwrap_or_else(|| "?column?".to_string()))
65 }
66 }
67 }
68}
69
70fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
71 match ty {
72 ast::Type::PathType(path_type) => {
73 if let Some(name_ref) = path_type
74 .path()
75 .and_then(|x| x.segment())
76 .and_then(|x| x.name_ref())
77 {
78 return name_from_name_ref(name_ref, true).map(|(column, node)| {
79 let column = match column {
80 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
81 _ => column,
82 };
83 (column, node)
84 });
85 }
86 }
87 ast::Type::BitType(bit_type) => {
88 let name = if bit_type.varying_token().is_some() {
89 "varbit"
90 } else {
91 "bit"
92 };
93 return Some((
94 ColumnName::new(name.to_string(), unknown_column),
95 bit_type.syntax().clone(),
96 ));
97 }
98 ast::Type::CharType(char_type) => {
99 let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
100 {
101 "varchar"
102 } else {
103 "bpchar"
104 };
105 return Some((
106 ColumnName::new(name.to_string(), unknown_column),
107 char_type.syntax().clone(),
108 ));
109 }
110 ast::Type::DoubleType(double_type) => {
111 return Some((
112 ColumnName::new("float8".to_string(), unknown_column),
113 double_type.syntax().clone(),
114 ));
115 }
116 ast::Type::IntervalType(interval_type) => {
117 return Some((
118 ColumnName::new("interval".to_string(), unknown_column),
119 interval_type.syntax().clone(),
120 ));
121 }
122 ast::Type::TimeType(time_type) => {
123 let mut name = if time_type.timestamp_token().is_some() {
124 "timestamp".to_owned()
125 } else {
126 "time".to_owned()
127 };
128 if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
129 name.push_str("tz");
132 };
133 return Some((
134 ColumnName::new(name.to_string(), unknown_column),
135 time_type.syntax().clone(),
136 ));
137 }
138 ast::Type::ArrayType(array_type) => {
139 if let Some(inner_ty) = array_type.ty() {
140 return name_from_type(inner_ty, unknown_column);
141 }
142 }
143 ast::Type::PercentType(_) => return None,
146 ast::Type::ExprType(expr_type) => {
147 if let Some(expr) = expr_type.expr() {
148 return name_from_expr(expr, true).map(|(column, node)| {
149 let column = match column {
150 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
151 _ => column,
152 };
153 (column, node)
154 });
155 }
156 }
157 }
158 None
159}
160
161fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
162 if in_type {
163 for node in name_ref.syntax().children_with_tokens() {
164 match node.kind() {
165 SyntaxKind::BIGINT_KW => {
166 return Some((
167 ColumnName::Column("int8".to_owned()),
168 name_ref.syntax().clone(),
169 ));
170 }
171 SyntaxKind::BOOLEAN_KW => {
172 return Some((
173 ColumnName::Column("bool".to_owned()),
174 name_ref.syntax().clone(),
175 ));
176 }
177 SyntaxKind::DECIMAL_KW => {
178 return Some((
179 ColumnName::Column("numeric".to_owned()),
180 name_ref.syntax().clone(),
181 ));
182 }
183 SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
184 return Some((
185 ColumnName::Column("int4".to_owned()),
186 name_ref.syntax().clone(),
187 ));
188 }
189 SyntaxKind::SMALLINT_KW => {
190 return Some((
191 ColumnName::Column("int2".to_owned()),
192 name_ref.syntax().clone(),
193 ));
194 }
195 SyntaxKind::REAL_KW => {
196 return Some((
197 ColumnName::Column("float4".to_owned()),
198 name_ref.syntax().clone(),
199 ));
200 }
201 _ => (),
202 }
203 }
204 }
205 let text = name_ref.text();
206 let normalized = normalize_identifier(&text);
207 return Some((ColumnName::Column(normalized), name_ref.syntax().clone()));
208}
209
210fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
226 let node = expr.syntax().clone();
227 match expr {
228 ast::Expr::ArrayExpr(_) => {
229 return Some((ColumnName::Column("array".to_string()), node));
230 }
231 ast::Expr::BetweenExpr(_) => {
232 return Some((ColumnName::UnknownColumn(None), node));
233 }
234 ast::Expr::BinExpr(bin_expr) => match bin_expr.op() {
235 Some(ast::BinOp::AtTimeZone(_)) => {
236 return Some((ColumnName::Column("timezone".to_string()), node));
237 }
238 Some(ast::BinOp::Overlaps(_)) => {
239 return Some((ColumnName::Column("overlaps".to_string()), node));
240 }
241 _ => return Some((ColumnName::UnknownColumn(None), node)),
242 },
243 ast::Expr::CallExpr(call_expr) => {
244 if let Some(exists_fn) = call_expr.exists_fn() {
245 return Some((
246 ColumnName::Column("exists".to_string()),
247 exists_fn.syntax().clone(),
248 ));
249 }
250 if let Some(extract_fn) = call_expr.extract_fn() {
251 return Some((
252 ColumnName::Column("extract".to_string()),
253 extract_fn.syntax().clone(),
254 ));
255 }
256 if let Some(json_exists_fn) = call_expr.json_exists_fn() {
257 return Some((
258 ColumnName::Column("json_exists".to_string()),
259 json_exists_fn.syntax().clone(),
260 ));
261 }
262 if let Some(json_array_fn) = call_expr.json_array_fn() {
263 return Some((
264 ColumnName::Column("json_array".to_string()),
265 json_array_fn.syntax().clone(),
266 ));
267 }
268 if let Some(json_object_fn) = call_expr.json_object_fn() {
269 return Some((
270 ColumnName::Column("json_object".to_string()),
271 json_object_fn.syntax().clone(),
272 ));
273 }
274 if let Some(json_object_agg_fn) = call_expr.json_object_agg_fn() {
275 return Some((
276 ColumnName::Column("json_objectagg".to_string()),
277 json_object_agg_fn.syntax().clone(),
278 ));
279 }
280 if let Some(json_array_agg_fn) = call_expr.json_array_agg_fn() {
281 return Some((
282 ColumnName::Column("json_arrayagg".to_string()),
283 json_array_agg_fn.syntax().clone(),
284 ));
285 }
286 if let Some(json_query_fn) = call_expr.json_query_fn() {
287 return Some((
288 ColumnName::Column("json_query".to_string()),
289 json_query_fn.syntax().clone(),
290 ));
291 }
292 if let Some(json_scalar_fn) = call_expr.json_scalar_fn() {
293 return Some((
294 ColumnName::Column("json_scalar".to_string()),
295 json_scalar_fn.syntax().clone(),
296 ));
297 }
298 if let Some(json_serialize_fn) = call_expr.json_serialize_fn() {
299 return Some((
300 ColumnName::Column("json_serialize".to_string()),
301 json_serialize_fn.syntax().clone(),
302 ));
303 }
304 if let Some(json_value_fn) = call_expr.json_value_fn() {
305 return Some((
306 ColumnName::Column("json_value".to_string()),
307 json_value_fn.syntax().clone(),
308 ));
309 }
310 if let Some(json_fn) = call_expr.json_fn() {
311 return Some((
312 ColumnName::Column("json".to_string()),
313 json_fn.syntax().clone(),
314 ));
315 }
316 if let Some(substring_fn) = call_expr.substring_fn() {
317 return Some((
318 ColumnName::Column("substring".to_string()),
319 substring_fn.syntax().clone(),
320 ));
321 }
322 if let Some(position_fn) = call_expr.position_fn() {
323 return Some((
324 ColumnName::Column("position".to_string()),
325 position_fn.syntax().clone(),
326 ));
327 }
328 if let Some(overlay_fn) = call_expr.overlay_fn() {
329 return Some((
330 ColumnName::Column("overlay".to_string()),
331 overlay_fn.syntax().clone(),
332 ));
333 }
334 if let Some(trim_fn) = call_expr.trim_fn() {
335 let name = if trim_fn.leading_token().is_some() {
336 "ltrim"
337 } else if trim_fn.trailing_token().is_some() {
338 "rtrim"
339 } else {
340 "btrim"
341 };
342 return Some((
343 ColumnName::Column(name.to_string()),
344 trim_fn.syntax().clone(),
345 ));
346 }
347 if let Some(xml_root_fn) = call_expr.xml_root_fn() {
348 return Some((
349 ColumnName::Column("xml_root".to_string()),
350 xml_root_fn.syntax().clone(),
351 ));
352 }
353 if let Some(xml_serialize_fn) = call_expr.xml_serialize_fn() {
354 return Some((
355 ColumnName::Column("xml_serialize".to_string()),
356 xml_serialize_fn.syntax().clone(),
357 ));
358 }
359 if let Some(xml_element_fn) = call_expr.xml_element_fn() {
360 return Some((
361 ColumnName::Column("xml_element".to_string()),
362 xml_element_fn.syntax().clone(),
363 ));
364 }
365 if let Some(xml_forest_fn) = call_expr.xml_forest_fn() {
366 return Some((
367 ColumnName::Column("xml_forest".to_string()),
368 xml_forest_fn.syntax().clone(),
369 ));
370 }
371 if let Some(xml_exists_fn) = call_expr.xml_exists_fn() {
372 return Some((
373 ColumnName::Column("xml_exists".to_string()),
374 xml_exists_fn.syntax().clone(),
375 ));
376 }
377 if let Some(xml_parse_fn) = call_expr.xml_parse_fn() {
378 return Some((
379 ColumnName::Column("xml_parse".to_string()),
380 xml_parse_fn.syntax().clone(),
381 ));
382 }
383 if let Some(xml_pi_fn) = call_expr.xml_pi_fn() {
384 return Some((
385 ColumnName::Column("xml_pi".to_string()),
386 xml_pi_fn.syntax().clone(),
387 ));
388 }
389 if let Some(collation_for_fn) = call_expr.collation_for_fn() {
390 return Some((
391 ColumnName::Column("pg_collation_for".to_string()),
392 collation_for_fn.syntax().clone(),
393 ));
394 }
395 if let Some(func_name) = call_expr.expr() {
396 match func_name {
397 ast::Expr::ArrayExpr(_)
398 | ast::Expr::BetweenExpr(_)
399 | ast::Expr::ParenExpr(_)
400 | ast::Expr::BinExpr(_)
401 | ast::Expr::CallExpr(_)
402 | ast::Expr::CaseExpr(_)
403 | ast::Expr::CastExpr(_)
404 | ast::Expr::Literal(_)
405 | ast::Expr::PostfixExpr(_)
406 | ast::Expr::PrefixExpr(_)
407 | ast::Expr::TupleExpr(_)
408 | ast::Expr::IndexExpr(_)
409 | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
410 ast::Expr::FieldExpr(field_expr) => {
411 if let Some(name_ref) = field_expr.field() {
412 return name_from_name_ref(name_ref, in_type);
413 }
414 }
415 ast::Expr::NameRef(name_ref) => {
416 return name_from_name_ref(name_ref, in_type);
417 }
418 }
419 }
420 }
421 ast::Expr::CaseExpr(case) => {
422 if let Some(else_clause) = case.else_clause()
423 && let Some(expr) = else_clause.expr()
424 && let Some((column, node)) = name_from_expr(expr, in_type)
425 {
426 if !matches!(column, ColumnName::UnknownColumn(_)) {
427 return Some((column, node));
428 }
429 }
430 return Some((ColumnName::Column("case".to_string()), node));
431 }
432 ast::Expr::CastExpr(cast_expr) => {
433 let mut unknown_column = false;
434 if let Some(expr) = cast_expr.expr()
435 && let Some((column, node)) = name_from_expr(expr, in_type)
436 {
437 match column {
438 ColumnName::Column(_) => return Some((column, node)),
439 ColumnName::UnknownColumn(_) => unknown_column = true,
440 ColumnName::Star => (),
441 }
442 }
443 if let Some(ty) = cast_expr.ty() {
444 return name_from_type(ty, unknown_column);
445 }
446 }
447 ast::Expr::FieldExpr(field_expr) => {
448 if let Some(name_ref) = field_expr.field() {
449 return name_from_name_ref(name_ref, in_type);
450 }
451 }
452 ast::Expr::IndexExpr(index_expr) => {
453 if let Some(base) = index_expr.base() {
454 return name_from_expr(base, in_type);
455 }
456 }
457 ast::Expr::SliceExpr(slice_expr) => {
458 if let Some(base) = slice_expr.base() {
459 return name_from_expr(base, in_type);
460 }
461 }
462 ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) => {
463 return Some((ColumnName::UnknownColumn(None), node));
464 }
465 ast::Expr::PostfixExpr(postfix_expr) => match postfix_expr.op() {
466 Some(ast::PostfixOp::AtLocal(_)) => {
467 return Some((ColumnName::Column("timezone".to_string()), node));
468 }
469 Some(ast::PostfixOp::IsNormalized(_)) => {
470 return Some((ColumnName::Column("is_normalized".to_string()), node));
471 }
472 _ => return Some((ColumnName::UnknownColumn(None), node)),
473 },
474 ast::Expr::NameRef(name_ref) => {
475 return name_from_name_ref(name_ref, in_type);
476 }
477 ast::Expr::ParenExpr(paren_expr) => {
478 if let Some(expr) = paren_expr.expr() {
479 return name_from_expr(expr, in_type);
480 } else if let Some(select) = paren_expr.select()
481 && let Some(mut targets) = select
482 .select_clause()
483 .and_then(|x| x.target_list())
484 .map(|x| x.targets())
485 && let Some(target) = targets.next()
486 {
487 return ColumnName::from_target(target);
488 }
489 }
490 ast::Expr::TupleExpr(_) => {
491 return Some((ColumnName::Column("row".to_string()), node));
492 }
493 }
494 None
495}
496
497#[test]
498fn examples() {
499 use insta::assert_snapshot;
500
501 assert_snapshot!(name("array(select 1)"), @"array");
503 assert_snapshot!(name("array[1, 2, 3]"), @"array");
504
505 assert_snapshot!(name("1 between 0 and 10"), @"?column?");
507 assert_snapshot!(name("1 + 2"), @"?column?");
508 assert_snapshot!(name("42"), @"?column?");
509 assert_snapshot!(name("'string'"), @"?column?");
510 assert_snapshot!(name("-42"), @"?column?");
512 assert_snapshot!(name("|/ 42"), @"?column?");
513 assert_snapshot!(name("x is null"), @"?column?");
515 assert_snapshot!(name("x is not null"), @"?column?");
516 assert_snapshot!(name("'foo' is normalized"), @"is_normalized");
517 assert_snapshot!(name("'foo' is not normalized"), @"?column?");
518 assert_snapshot!(name("now() at local"), @"timezone");
519 assert_snapshot!(name("now() at time zone 'America/Chicago'"), @"timezone");
521 assert_snapshot!(
522 name("(DATE '2001-02-16', DATE '2001-12-21') OVERLAPS (DATE '2001-10-30', DATE '2002-10-30')"),
523 @"overlaps"
524 );
525 assert_snapshot!(name("(1 * 2)"), @"?column?");
527 assert_snapshot!(name("(select 1 as a)"), @"a");
528
529 assert_snapshot!(name("count(*)"), @"count");
531 assert_snapshot!(name("schema.func_name(1)"), @"func_name");
532
533 assert_snapshot!(name("collation for ('bar')"), @"pg_collation_for");
535 assert_snapshot!(name("extract(year from now())"), @"extract");
536 assert_snapshot!(name("exists(select 1)"), @"exists");
537 assert_snapshot!(name(r#"json_exists('{"a":1}', '$.a')"#), @"json_exists");
538 assert_snapshot!(name("json_array(1, 2)"), @"json_array");
539 assert_snapshot!(name("json_object('a': 1)"), @"json_object");
540 assert_snapshot!(name("json_objectagg('a': 1)"), @"json_objectagg");
541 assert_snapshot!(name("json_arrayagg(1)"), @"json_arrayagg");
542 assert_snapshot!(name(r#"json_query('{"a":1}', '$.a')"#), @"json_query");
543 assert_snapshot!(name("json_scalar(1)"), @"json_scalar");
544 assert_snapshot!(name(r#"json_serialize('{"a":1}')"#), @"json_serialize");
545 assert_snapshot!(name(r#"json_value('{"a":1}', '$.a')"#), @"json_value");
546 assert_snapshot!(name(r#"json('{"a":1}')"#), @"json");
547 assert_snapshot!(name("substring('hello' from 2 for 3)"), @"substring");
548 assert_snapshot!(name("position('a' in 'abc')"), @"position");
549 assert_snapshot!(name("overlay('hello' placing 'X' from 2)"), @"overlay");
550 assert_snapshot!(name("trim(' hi ')"), @"btrim");
551 assert_snapshot!(name("trim(leading ' ' from ' hi ')"), @"ltrim");
552 assert_snapshot!(name("trim(trailing ' ' from ' hi ')"), @"rtrim");
553 assert_snapshot!(name("trim(both ' ' from ' hi ')"), @"btrim");
554 assert_snapshot!(name("xmlroot('<a/>', version '1.0')"), @"xml_root");
555 assert_snapshot!(name("xmlserialize(document '<a/>' as text)"), @"xml_serialize");
556 assert_snapshot!(name("xmlelement(name foo, 'bar')"), @"xml_element");
557 assert_snapshot!(name("xmlforest('bar' as foo)"), @"xml_forest");
558 assert_snapshot!(name("xmlexists('//a' passing '<a/>')"), @"xml_exists");
559 assert_snapshot!(name("xmlparse(document '<a/>')"), @"xml_parse");
560 assert_snapshot!(name("xmlpi(name foo, 'bar')"), @"xml_pi");
561
562 assert_snapshot!(name("foo[bar]"), @"foo");
564 assert_snapshot!(name("foo[1]"), @"foo");
565
566 assert_snapshot!(name("database.schema.table.column"), @"column");
568 assert_snapshot!(name("t.a"), @"a");
569 assert_snapshot!(name("col_name"), @"col_name");
570 assert_snapshot!(name("(c)"), @"c");
571
572 assert_snapshot!(name("case when true then 'foo' end"), @"case");
574 assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
575 assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
576 assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
577
578 assert_snapshot!(name("now()::text"), @"now");
580 assert_snapshot!(name("cast(col_name as text)"), @"col_name");
581 assert_snapshot!(name("col_name::text"), @"col_name");
582 assert_snapshot!(name("col_name::int::text"), @"col_name");
583 assert_snapshot!(name("'1'::bigint"), @"int8");
584 assert_snapshot!(name("'1'::decimal"), @"numeric");
585 assert_snapshot!(name("'1'::boolean"), @"bool");
586 assert_snapshot!(name("'1'::int"), @"int4");
587 assert_snapshot!(name("'1'::smallint"), @"int2");
588 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
589 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
590 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
591 assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
592 assert_snapshot!(name("'{1}'::integer[];"), @"int4");
593 assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
594 assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
595
596 assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
599 assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
600 assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
602 assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
603
604 assert_snapshot!(name("(1, 2, 3)"), @"row");
606 assert_snapshot!(name("(1, 2, 3)::address"), @"row");
607
608 assert_snapshot!(name("(x).city"), @"city");
610
611 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
613 assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
614
615 assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
617
618 assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
620 assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
621 assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
622 assert_snapshot!(name("cast('hello' as character)"), @"bpchar");
623 assert_snapshot!(name("cast('hello' as bpchar)"), @"bpchar");
624
625 assert_snapshot!(name(r#"cast('hello' as "char")"#), @"char");
626
627 assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
629 assert_snapshot!(name("cast(1.5 as real)"), @"float4");
631
632 assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
634
635 assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
637
638 assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
640 assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
641 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
642 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
643
644 #[track_caller]
645 fn name(sql: &str) -> String {
646 let sql = "select ".to_string() + sql;
647 let parse = squawk_syntax::SourceFile::parse(&sql);
648 assert_eq!(parse.errors(), vec![]);
649 let file = parse.tree();
650
651 let stmt = file.stmts().next().unwrap();
652 let ast::Stmt::Select(select) = stmt else {
653 unreachable!()
654 };
655
656 let target = select
657 .select_clause()
658 .and_then(|sc| sc.target_list())
659 .and_then(|tl| tl.targets().next())
660 .unwrap();
661
662 ColumnName::from_target(target)
663 .and_then(|x| x.0.to_string())
664 .unwrap()
665 }
666}