1use squawk_syntax::{
2 SyntaxKind, SyntaxNode,
3 ast::{self, AstNode},
4};
5
6use crate::quote::normalize_identifier;
7
8#[derive(Clone, Debug, PartialEq)]
9pub(crate) enum ColumnName {
10 Column(String),
11 UnknownColumn(Option<String>),
23 Star,
24}
25
26impl ColumnName {
27 pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
29 if let Some(as_name) = target.as_name()
30 && let Some(name_node) = as_name.name()
31 {
32 let text = name_node.text();
33 let normalized = normalize_identifier(&text);
34 return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
35 }
36 Self::inferred_from_target(target)
37 }
38
39 pub(crate) fn inferred_from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
41 if let Some(expr) = target.expr()
42 && let Some(name) = name_from_expr(expr, false)
43 {
44 return Some(name);
45 } else if target.star_token().is_some() {
46 return Some((ColumnName::Star, target.syntax().clone()));
47 }
48 None
49 }
50
51 fn new(name: String, unknown_column: bool) -> ColumnName {
52 if unknown_column {
53 ColumnName::UnknownColumn(Some(name))
54 } else {
55 ColumnName::Column(name)
56 }
57 }
58
59 pub(crate) fn to_string(&self) -> Option<String> {
60 match self {
61 ColumnName::Column(string) => Some(string.to_string()),
62 ColumnName::Star => None,
63 ColumnName::UnknownColumn(c) => {
64 Some(c.clone().unwrap_or_else(|| "?column?".to_string()))
65 }
66 }
67 }
68}
69
70fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
71 match ty {
72 ast::Type::PathType(path_type) => {
73 if let Some(name_ref) = path_type
74 .path()
75 .and_then(|x| x.segment())
76 .and_then(|x| x.name_ref())
77 {
78 return name_from_name_ref(name_ref, true).map(|(column, node)| {
79 let column = match column {
80 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
81 _ => column,
82 };
83 (column, node)
84 });
85 }
86 }
87 ast::Type::BitType(bit_type) => {
88 let name = if bit_type.varying_token().is_some() {
89 "varbit"
90 } else {
91 "bit"
92 };
93 return Some((
94 ColumnName::new(name.to_string(), unknown_column),
95 bit_type.syntax().clone(),
96 ));
97 }
98 ast::Type::CharType(char_type) => {
99 let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
100 {
101 "varchar"
102 } else {
103 "bpchar"
104 };
105 return Some((
106 ColumnName::new(name.to_string(), unknown_column),
107 char_type.syntax().clone(),
108 ));
109 }
110 ast::Type::DoubleType(double_type) => {
111 return Some((
112 ColumnName::new("float8".to_string(), unknown_column),
113 double_type.syntax().clone(),
114 ));
115 }
116 ast::Type::IntervalType(interval_type) => {
117 return Some((
118 ColumnName::new("interval".to_string(), unknown_column),
119 interval_type.syntax().clone(),
120 ));
121 }
122 ast::Type::TimeType(time_type) => {
123 let mut name = if time_type.timestamp_token().is_some() {
124 "timestamp".to_owned()
125 } else {
126 "time".to_owned()
127 };
128 if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
129 name.push_str("tz");
132 };
133 return Some((
134 ColumnName::new(name.to_string(), unknown_column),
135 time_type.syntax().clone(),
136 ));
137 }
138 ast::Type::ArrayType(array_type) => {
139 if let Some(inner_ty) = array_type.ty() {
140 return name_from_type(inner_ty, unknown_column);
141 }
142 }
143 ast::Type::PercentType(_) => return None,
146 ast::Type::ExprType(expr_type) => {
147 if let Some(expr) = expr_type.expr() {
148 return name_from_expr(expr, true).map(|(column, node)| {
149 let column = match column {
150 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
151 _ => column,
152 };
153 (column, node)
154 });
155 }
156 }
157 }
158 None
159}
160
161fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
162 if in_type {
163 for node in name_ref.syntax().children_with_tokens() {
164 match node.kind() {
165 SyntaxKind::BIGINT_KW => {
166 return Some((
167 ColumnName::Column("int8".to_owned()),
168 name_ref.syntax().clone(),
169 ));
170 }
171 SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
172 return Some((
173 ColumnName::Column("int4".to_owned()),
174 name_ref.syntax().clone(),
175 ));
176 }
177 SyntaxKind::SMALLINT_KW => {
178 return Some((
179 ColumnName::Column("int2".to_owned()),
180 name_ref.syntax().clone(),
181 ));
182 }
183 _ => (),
184 }
185 }
186 }
187 let text = name_ref.text();
188 let normalized = normalize_identifier(&text);
189 return Some((ColumnName::Column(normalized), name_ref.syntax().clone()));
190}
191
192fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
208 let node = expr.syntax().clone();
209 match expr {
210 ast::Expr::ArrayExpr(_) => {
211 return Some((ColumnName::Column("array".to_string()), node));
212 }
213 ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
214 return Some((ColumnName::UnknownColumn(None), node));
215 }
216 ast::Expr::CallExpr(call_expr) => {
217 if let Some(func_name) = call_expr.expr() {
218 match func_name {
219 ast::Expr::ArrayExpr(_)
220 | ast::Expr::BetweenExpr(_)
221 | ast::Expr::ParenExpr(_)
222 | ast::Expr::BinExpr(_)
223 | ast::Expr::CallExpr(_)
224 | ast::Expr::CaseExpr(_)
225 | ast::Expr::CastExpr(_)
226 | ast::Expr::Literal(_)
227 | ast::Expr::PostfixExpr(_)
228 | ast::Expr::PrefixExpr(_)
229 | ast::Expr::TupleExpr(_)
230 | ast::Expr::IndexExpr(_)
231 | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
232 ast::Expr::FieldExpr(field_expr) => {
233 if let Some(name_ref) = field_expr.field() {
234 return name_from_name_ref(name_ref, in_type);
235 }
236 }
237 ast::Expr::NameRef(name_ref) => {
238 return name_from_name_ref(name_ref, in_type);
239 }
240 }
241 }
242 }
243 ast::Expr::CaseExpr(case) => {
244 if let Some(else_clause) = case.else_clause()
245 && let Some(expr) = else_clause.expr()
246 && let Some((column, node)) = name_from_expr(expr, in_type)
247 {
248 if !matches!(column, ColumnName::UnknownColumn(_)) {
249 return Some((column, node));
250 }
251 }
252 return Some((ColumnName::Column("case".to_string()), node));
253 }
254 ast::Expr::CastExpr(cast_expr) => {
255 let mut unknown_column = false;
256 if let Some(expr) = cast_expr.expr()
257 && let Some((column, node)) = name_from_expr(expr, in_type)
258 {
259 match column {
260 ColumnName::Column(_) => return Some((column, node)),
261 ColumnName::UnknownColumn(_) => unknown_column = true,
262 ColumnName::Star => (),
263 }
264 }
265 if let Some(ty) = cast_expr.ty() {
266 return name_from_type(ty, unknown_column);
267 }
268 }
269 ast::Expr::FieldExpr(field_expr) => {
270 if let Some(name_ref) = field_expr.field() {
271 return name_from_name_ref(name_ref, in_type);
272 }
273 }
274 ast::Expr::IndexExpr(index_expr) => {
275 if let Some(base) = index_expr.base() {
276 return name_from_expr(base, in_type);
277 }
278 }
279 ast::Expr::SliceExpr(slice_expr) => {
280 if let Some(base) = slice_expr.base() {
281 return name_from_expr(base, in_type);
282 }
283 }
284 ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
285 return Some((ColumnName::UnknownColumn(None), node));
286 }
287 ast::Expr::NameRef(name_ref) => {
288 return name_from_name_ref(name_ref, in_type);
289 }
290 ast::Expr::ParenExpr(paren_expr) => {
291 if let Some(expr) = paren_expr.expr() {
292 return name_from_expr(expr, in_type);
293 } else if let Some(select) = paren_expr.select()
294 && let Some(mut targets) = select
295 .select_clause()
296 .and_then(|x| x.target_list())
297 .map(|x| x.targets())
298 && let Some(target) = targets.next()
299 {
300 return ColumnName::from_target(target);
301 }
302 }
303 ast::Expr::TupleExpr(_) => {
304 return Some((ColumnName::Column("row".to_string()), node));
305 }
306 }
307 None
308}
309
310#[test]
311fn examples() {
312 use insta::assert_snapshot;
313
314 assert_snapshot!(name("array(select 1)"), @"array");
316 assert_snapshot!(name("array[1, 2, 3]"), @"array");
317
318 assert_snapshot!(name("1 between 0 and 10"), @"?column?");
320 assert_snapshot!(name("1 + 2"), @"?column?");
321 assert_snapshot!(name("42"), @"?column?");
322 assert_snapshot!(name("'string'"), @"?column?");
323 assert_snapshot!(name("-42"), @"?column?");
325 assert_snapshot!(name("|/ 42"), @"?column?");
326 assert_snapshot!(name("x is null"), @"?column?");
328 assert_snapshot!(name("x is not null"), @"?column?");
329 assert_snapshot!(name("(1 * 2)"), @"?column?");
331 assert_snapshot!(name("(select 1 as a)"), @"a");
332
333 assert_snapshot!(name("count(*)"), @"count");
335 assert_snapshot!(name("schema.func_name(1)"), @"func_name");
336
337 assert_snapshot!(name("foo[bar]"), @"foo");
339 assert_snapshot!(name("foo[1]"), @"foo");
340
341 assert_snapshot!(name("database.schema.table.column"), @"column");
343 assert_snapshot!(name("t.a"), @"a");
344 assert_snapshot!(name("col_name"), @"col_name");
345 assert_snapshot!(name("(c)"), @"c");
346
347 assert_snapshot!(name("case when true then 'foo' end"), @"case");
349 assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
350 assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
351 assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
352
353 assert_snapshot!(name("now()::text"), @"now");
355 assert_snapshot!(name("cast(col_name as text)"), @"col_name");
356 assert_snapshot!(name("col_name::text"), @"col_name");
357 assert_snapshot!(name("col_name::int::text"), @"col_name");
358 assert_snapshot!(name("'1'::bigint"), @"int8");
359 assert_snapshot!(name("'1'::int"), @"int4");
360 assert_snapshot!(name("'1'::smallint"), @"int2");
361 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
362 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
363 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
364 assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
365 assert_snapshot!(name("'{1}'::integer[];"), @"int4");
366 assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
367 assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
368
369 assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
372 assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
373 assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
375 assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
376
377 assert_snapshot!(name("(1, 2, 3)"), @"row");
379 assert_snapshot!(name("(1, 2, 3)::address"), @"row");
380
381 assert_snapshot!(name("(x).city"), @"city");
383
384 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
386 assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
387
388 assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
390
391 assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
393 assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
394 assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
395
396 assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
398
399 assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
401
402 assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
404
405 assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
407 assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
408 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
409 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
410
411 #[track_caller]
412 fn name(sql: &str) -> String {
413 let sql = "select ".to_string() + sql;
414 let parse = squawk_syntax::SourceFile::parse(&sql);
415 assert_eq!(parse.errors(), vec![]);
416 let file = parse.tree();
417
418 let stmt = file.stmts().next().unwrap();
419 let ast::Stmt::Select(select) = stmt else {
420 unreachable!()
421 };
422
423 let target = select
424 .select_clause()
425 .and_then(|sc| sc.target_list())
426 .and_then(|tl| tl.targets().next())
427 .unwrap();
428
429 ColumnName::from_target(target)
430 .and_then(|x| x.0.to_string())
431 .unwrap()
432 }
433}