1use squawk_syntax::{
2 SyntaxKind, SyntaxNode,
3 ast::{self, AstNode},
4};
5
6use crate::quote::normalize_identifier;
7
8#[derive(Clone, Debug, PartialEq)]
9pub(crate) enum ColumnName {
10 Column(String),
11 UnknownColumn(Option<String>),
23 Star,
24}
25
26impl ColumnName {
27 pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
29 if let Some(as_name) = target.as_name()
30 && let Some(name_node) = as_name.name()
31 {
32 let text = name_node.text();
33 let normalized = normalize_identifier(&text);
34 return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
35 }
36 Self::inferred_from_target(target)
37 }
38
39 pub(crate) fn inferred_from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
41 if let Some(expr) = target.expr()
42 && let Some(name) = name_from_expr(expr, false)
43 {
44 return Some(name);
45 } else if target.star_token().is_some() {
46 return Some((ColumnName::Star, target.syntax().clone()));
47 }
48 None
49 }
50
51 fn new(name: String, unknown_column: bool) -> ColumnName {
52 if unknown_column {
53 ColumnName::UnknownColumn(Some(name))
54 } else {
55 ColumnName::Column(name)
56 }
57 }
58
59 pub(crate) fn to_string(&self) -> Option<String> {
60 match self {
61 ColumnName::Column(string) => Some(string.to_string()),
62 ColumnName::Star => None,
63 ColumnName::UnknownColumn(c) => {
64 Some(c.clone().unwrap_or_else(|| "?column?".to_string()))
65 }
66 }
67 }
68}
69
70fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
71 match ty {
72 ast::Type::PathType(path_type) => {
73 if let Some(name_ref) = path_type
74 .path()
75 .and_then(|x| x.segment())
76 .and_then(|x| x.name_ref())
77 {
78 return name_from_name_ref(name_ref, true).map(|(column, node)| {
79 let column = match column {
80 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
81 _ => column,
82 };
83 (column, node)
84 });
85 }
86 }
87 ast::Type::BitType(bit_type) => {
88 let name = if bit_type.varying_token().is_some() {
89 "varbit"
90 } else {
91 "bit"
92 };
93 return Some((
94 ColumnName::new(name.to_string(), unknown_column),
95 bit_type.syntax().clone(),
96 ));
97 }
98 ast::Type::CharType(char_type) => {
99 let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
100 {
101 "varchar"
102 } else {
103 "bpchar"
104 };
105 return Some((
106 ColumnName::new(name.to_string(), unknown_column),
107 char_type.syntax().clone(),
108 ));
109 }
110 ast::Type::DoubleType(double_type) => {
111 return Some((
112 ColumnName::new("float8".to_string(), unknown_column),
113 double_type.syntax().clone(),
114 ));
115 }
116 ast::Type::IntervalType(interval_type) => {
117 return Some((
118 ColumnName::new("interval".to_string(), unknown_column),
119 interval_type.syntax().clone(),
120 ));
121 }
122 ast::Type::TimeType(time_type) => {
123 let mut name = if time_type.timestamp_token().is_some() {
124 "timestamp".to_owned()
125 } else {
126 "time".to_owned()
127 };
128 if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
129 name.push_str("tz");
132 };
133 return Some((
134 ColumnName::new(name.to_string(), unknown_column),
135 time_type.syntax().clone(),
136 ));
137 }
138 ast::Type::ArrayType(array_type) => {
139 if let Some(inner_ty) = array_type.ty() {
140 return name_from_type(inner_ty, unknown_column);
141 }
142 }
143 ast::Type::PercentType(_) => return None,
146 ast::Type::ExprType(expr_type) => {
147 if let Some(expr) = expr_type.expr() {
148 return name_from_expr(expr, true).map(|(column, node)| {
149 let column = match column {
150 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
151 _ => column,
152 };
153 (column, node)
154 });
155 }
156 }
157 }
158 None
159}
160
161fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
162 if in_type {
163 for node in name_ref.syntax().children_with_tokens() {
164 match node.kind() {
165 SyntaxKind::BIGINT_KW => {
166 return Some((
167 ColumnName::Column("int8".to_owned()),
168 name_ref.syntax().clone(),
169 ));
170 }
171 SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
172 return Some((
173 ColumnName::Column("int4".to_owned()),
174 name_ref.syntax().clone(),
175 ));
176 }
177 SyntaxKind::SMALLINT_KW => {
178 return Some((
179 ColumnName::Column("int2".to_owned()),
180 name_ref.syntax().clone(),
181 ));
182 }
183 SyntaxKind::REAL_KW => {
184 return Some((
185 ColumnName::Column("float4".to_owned()),
186 name_ref.syntax().clone(),
187 ));
188 }
189 _ => (),
190 }
191 }
192 }
193 let text = name_ref.text();
194 let normalized = normalize_identifier(&text);
195 return Some((ColumnName::Column(normalized), name_ref.syntax().clone()));
196}
197
198fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
214 let node = expr.syntax().clone();
215 match expr {
216 ast::Expr::ArrayExpr(_) => {
217 return Some((ColumnName::Column("array".to_string()), node));
218 }
219 ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
220 return Some((ColumnName::UnknownColumn(None), node));
221 }
222 ast::Expr::CallExpr(call_expr) => {
223 if let Some(exists_fn) = call_expr.exists_fn() {
224 return Some((
225 ColumnName::Column("exists".to_string()),
226 exists_fn.syntax().clone(),
227 ));
228 }
229 if let Some(extract_fn) = call_expr.extract_fn() {
230 return Some((
231 ColumnName::Column("extract".to_string()),
232 extract_fn.syntax().clone(),
233 ));
234 }
235 if let Some(json_exists_fn) = call_expr.json_exists_fn() {
236 return Some((
237 ColumnName::Column("json_exists".to_string()),
238 json_exists_fn.syntax().clone(),
239 ));
240 }
241 if let Some(json_array_fn) = call_expr.json_array_fn() {
242 return Some((
243 ColumnName::Column("json_array".to_string()),
244 json_array_fn.syntax().clone(),
245 ));
246 }
247 if let Some(json_object_fn) = call_expr.json_object_fn() {
248 return Some((
249 ColumnName::Column("json_object".to_string()),
250 json_object_fn.syntax().clone(),
251 ));
252 }
253 if let Some(json_object_agg_fn) = call_expr.json_object_agg_fn() {
254 return Some((
255 ColumnName::Column("json_objectagg".to_string()),
256 json_object_agg_fn.syntax().clone(),
257 ));
258 }
259 if let Some(json_array_agg_fn) = call_expr.json_array_agg_fn() {
260 return Some((
261 ColumnName::Column("json_arrayagg".to_string()),
262 json_array_agg_fn.syntax().clone(),
263 ));
264 }
265 if let Some(json_query_fn) = call_expr.json_query_fn() {
266 return Some((
267 ColumnName::Column("json_query".to_string()),
268 json_query_fn.syntax().clone(),
269 ));
270 }
271 if let Some(json_scalar_fn) = call_expr.json_scalar_fn() {
272 return Some((
273 ColumnName::Column("json_scalar".to_string()),
274 json_scalar_fn.syntax().clone(),
275 ));
276 }
277 if let Some(json_serialize_fn) = call_expr.json_serialize_fn() {
278 return Some((
279 ColumnName::Column("json_serialize".to_string()),
280 json_serialize_fn.syntax().clone(),
281 ));
282 }
283 if let Some(json_value_fn) = call_expr.json_value_fn() {
284 return Some((
285 ColumnName::Column("json_value".to_string()),
286 json_value_fn.syntax().clone(),
287 ));
288 }
289 if let Some(json_fn) = call_expr.json_fn() {
290 return Some((
291 ColumnName::Column("json".to_string()),
292 json_fn.syntax().clone(),
293 ));
294 }
295 if let Some(substring_fn) = call_expr.substring_fn() {
296 return Some((
297 ColumnName::Column("substring".to_string()),
298 substring_fn.syntax().clone(),
299 ));
300 }
301 if let Some(position_fn) = call_expr.position_fn() {
302 return Some((
303 ColumnName::Column("position".to_string()),
304 position_fn.syntax().clone(),
305 ));
306 }
307 if let Some(overlay_fn) = call_expr.overlay_fn() {
308 return Some((
309 ColumnName::Column("overlay".to_string()),
310 overlay_fn.syntax().clone(),
311 ));
312 }
313 if let Some(trim_fn) = call_expr.trim_fn() {
314 return Some((
315 ColumnName::Column("trim".to_string()),
316 trim_fn.syntax().clone(),
317 ));
318 }
319 if let Some(xml_root_fn) = call_expr.xml_root_fn() {
320 return Some((
321 ColumnName::Column("xml_root".to_string()),
322 xml_root_fn.syntax().clone(),
323 ));
324 }
325 if let Some(xml_serialize_fn) = call_expr.xml_serialize_fn() {
326 return Some((
327 ColumnName::Column("xml_serialize".to_string()),
328 xml_serialize_fn.syntax().clone(),
329 ));
330 }
331 if let Some(xml_element_fn) = call_expr.xml_element_fn() {
332 return Some((
333 ColumnName::Column("xml_element".to_string()),
334 xml_element_fn.syntax().clone(),
335 ));
336 }
337 if let Some(xml_forest_fn) = call_expr.xml_forest_fn() {
338 return Some((
339 ColumnName::Column("xml_forest".to_string()),
340 xml_forest_fn.syntax().clone(),
341 ));
342 }
343 if let Some(xml_exists_fn) = call_expr.xml_exists_fn() {
344 return Some((
345 ColumnName::Column("xml_exists".to_string()),
346 xml_exists_fn.syntax().clone(),
347 ));
348 }
349 if let Some(xml_parse_fn) = call_expr.xml_parse_fn() {
350 return Some((
351 ColumnName::Column("xml_parse".to_string()),
352 xml_parse_fn.syntax().clone(),
353 ));
354 }
355 if let Some(xml_pi_fn) = call_expr.xml_pi_fn() {
356 return Some((
357 ColumnName::Column("xml_pi".to_string()),
358 xml_pi_fn.syntax().clone(),
359 ));
360 }
361 if let Some(func_name) = call_expr.expr() {
362 match func_name {
363 ast::Expr::ArrayExpr(_)
364 | ast::Expr::BetweenExpr(_)
365 | ast::Expr::ParenExpr(_)
366 | ast::Expr::BinExpr(_)
367 | ast::Expr::CallExpr(_)
368 | ast::Expr::CaseExpr(_)
369 | ast::Expr::CastExpr(_)
370 | ast::Expr::Literal(_)
371 | ast::Expr::PostfixExpr(_)
372 | ast::Expr::PrefixExpr(_)
373 | ast::Expr::TupleExpr(_)
374 | ast::Expr::IndexExpr(_)
375 | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
376 ast::Expr::FieldExpr(field_expr) => {
377 if let Some(name_ref) = field_expr.field() {
378 return name_from_name_ref(name_ref, in_type);
379 }
380 }
381 ast::Expr::NameRef(name_ref) => {
382 return name_from_name_ref(name_ref, in_type);
383 }
384 }
385 }
386 }
387 ast::Expr::CaseExpr(case) => {
388 if let Some(else_clause) = case.else_clause()
389 && let Some(expr) = else_clause.expr()
390 && let Some((column, node)) = name_from_expr(expr, in_type)
391 {
392 if !matches!(column, ColumnName::UnknownColumn(_)) {
393 return Some((column, node));
394 }
395 }
396 return Some((ColumnName::Column("case".to_string()), node));
397 }
398 ast::Expr::CastExpr(cast_expr) => {
399 let mut unknown_column = false;
400 if let Some(expr) = cast_expr.expr()
401 && let Some((column, node)) = name_from_expr(expr, in_type)
402 {
403 match column {
404 ColumnName::Column(_) => return Some((column, node)),
405 ColumnName::UnknownColumn(_) => unknown_column = true,
406 ColumnName::Star => (),
407 }
408 }
409 if let Some(ty) = cast_expr.ty() {
410 return name_from_type(ty, unknown_column);
411 }
412 }
413 ast::Expr::FieldExpr(field_expr) => {
414 if let Some(name_ref) = field_expr.field() {
415 return name_from_name_ref(name_ref, in_type);
416 }
417 }
418 ast::Expr::IndexExpr(index_expr) => {
419 if let Some(base) = index_expr.base() {
420 return name_from_expr(base, in_type);
421 }
422 }
423 ast::Expr::SliceExpr(slice_expr) => {
424 if let Some(base) = slice_expr.base() {
425 return name_from_expr(base, in_type);
426 }
427 }
428 ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
429 return Some((ColumnName::UnknownColumn(None), node));
430 }
431 ast::Expr::NameRef(name_ref) => {
432 return name_from_name_ref(name_ref, in_type);
433 }
434 ast::Expr::ParenExpr(paren_expr) => {
435 if let Some(expr) = paren_expr.expr() {
436 return name_from_expr(expr, in_type);
437 } else if let Some(select) = paren_expr.select()
438 && let Some(mut targets) = select
439 .select_clause()
440 .and_then(|x| x.target_list())
441 .map(|x| x.targets())
442 && let Some(target) = targets.next()
443 {
444 return ColumnName::from_target(target);
445 }
446 }
447 ast::Expr::TupleExpr(_) => {
448 return Some((ColumnName::Column("row".to_string()), node));
449 }
450 }
451 None
452}
453
454#[test]
455fn examples() {
456 use insta::assert_snapshot;
457
458 assert_snapshot!(name("array(select 1)"), @"array");
460 assert_snapshot!(name("array[1, 2, 3]"), @"array");
461
462 assert_snapshot!(name("1 between 0 and 10"), @"?column?");
464 assert_snapshot!(name("1 + 2"), @"?column?");
465 assert_snapshot!(name("42"), @"?column?");
466 assert_snapshot!(name("'string'"), @"?column?");
467 assert_snapshot!(name("-42"), @"?column?");
469 assert_snapshot!(name("|/ 42"), @"?column?");
470 assert_snapshot!(name("x is null"), @"?column?");
472 assert_snapshot!(name("x is not null"), @"?column?");
473 assert_snapshot!(name("(1 * 2)"), @"?column?");
475 assert_snapshot!(name("(select 1 as a)"), @"a");
476
477 assert_snapshot!(name("count(*)"), @"count");
479 assert_snapshot!(name("schema.func_name(1)"), @"func_name");
480
481 assert_snapshot!(name("extract(year from now())"), @"extract");
483 assert_snapshot!(name("exists(select 1)"), @"exists");
484 assert_snapshot!(name(r#"json_exists('{"a":1}', '$.a')"#), @"json_exists");
485 assert_snapshot!(name("json_array(1, 2)"), @"json_array");
486 assert_snapshot!(name("json_object('a': 1)"), @"json_object");
487 assert_snapshot!(name("json_objectagg('a': 1)"), @"json_objectagg");
488 assert_snapshot!(name("json_arrayagg(1)"), @"json_arrayagg");
489 assert_snapshot!(name(r#"json_query('{"a":1}', '$.a')"#), @"json_query");
490 assert_snapshot!(name("json_scalar(1)"), @"json_scalar");
491 assert_snapshot!(name(r#"json_serialize('{"a":1}')"#), @"json_serialize");
492 assert_snapshot!(name(r#"json_value('{"a":1}', '$.a')"#), @"json_value");
493 assert_snapshot!(name(r#"json('{"a":1}')"#), @"json");
494 assert_snapshot!(name("substring('hello' from 2 for 3)"), @"substring");
495 assert_snapshot!(name("position('a' in 'abc')"), @"position");
496 assert_snapshot!(name("overlay('hello' placing 'X' from 2)"), @"overlay");
497 assert_snapshot!(name("trim(' hi ')"), @"trim");
498 assert_snapshot!(name("xmlroot('<a/>', version '1.0')"), @"xml_root");
499 assert_snapshot!(name("xmlserialize(document '<a/>' as text)"), @"xml_serialize");
500 assert_snapshot!(name("xmlelement(name foo, 'bar')"), @"xml_element");
501 assert_snapshot!(name("xmlforest('bar' as foo)"), @"xml_forest");
502 assert_snapshot!(name("xmlexists('//a' passing '<a/>')"), @"xml_exists");
503 assert_snapshot!(name("xmlparse(document '<a/>')"), @"xml_parse");
504 assert_snapshot!(name("xmlpi(name foo, 'bar')"), @"xml_pi");
505
506 assert_snapshot!(name("foo[bar]"), @"foo");
508 assert_snapshot!(name("foo[1]"), @"foo");
509
510 assert_snapshot!(name("database.schema.table.column"), @"column");
512 assert_snapshot!(name("t.a"), @"a");
513 assert_snapshot!(name("col_name"), @"col_name");
514 assert_snapshot!(name("(c)"), @"c");
515
516 assert_snapshot!(name("case when true then 'foo' end"), @"case");
518 assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
519 assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
520 assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
521
522 assert_snapshot!(name("now()::text"), @"now");
524 assert_snapshot!(name("cast(col_name as text)"), @"col_name");
525 assert_snapshot!(name("col_name::text"), @"col_name");
526 assert_snapshot!(name("col_name::int::text"), @"col_name");
527 assert_snapshot!(name("'1'::bigint"), @"int8");
528 assert_snapshot!(name("'1'::int"), @"int4");
529 assert_snapshot!(name("'1'::smallint"), @"int2");
530 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
531 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
532 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
533 assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
534 assert_snapshot!(name("'{1}'::integer[];"), @"int4");
535 assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
536 assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
537
538 assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
541 assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
542 assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
544 assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
545
546 assert_snapshot!(name("(1, 2, 3)"), @"row");
548 assert_snapshot!(name("(1, 2, 3)::address"), @"row");
549
550 assert_snapshot!(name("(x).city"), @"city");
552
553 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
555 assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
556
557 assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
559
560 assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
562 assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
563 assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
564 assert_snapshot!(name("cast('hello' as character)"), @"bpchar");
565 assert_snapshot!(name("cast('hello' as bpchar)"), @"bpchar");
566
567 assert_snapshot!(name(r#"cast('hello' as "char")"#), @"char");
568
569 assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
571 assert_snapshot!(name("cast(1.5 as real)"), @"float4");
573
574 assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
576
577 assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
579
580 assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
582 assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
583 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
584 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
585
586 #[track_caller]
587 fn name(sql: &str) -> String {
588 let sql = "select ".to_string() + sql;
589 let parse = squawk_syntax::SourceFile::parse(&sql);
590 assert_eq!(parse.errors(), vec![]);
591 let file = parse.tree();
592
593 let stmt = file.stmts().next().unwrap();
594 let ast::Stmt::Select(select) = stmt else {
595 unreachable!()
596 };
597
598 let target = select
599 .select_clause()
600 .and_then(|sc| sc.target_list())
601 .and_then(|tl| tl.targets().next())
602 .unwrap();
603
604 ColumnName::from_target(target)
605 .and_then(|x| x.0.to_string())
606 .unwrap()
607 }
608}