1use squawk_syntax::{
2 SyntaxKind, SyntaxNode,
3 ast::{self, AstNode},
4};
5
6fn normalize_identifier(text: &str) -> String {
7 if text.starts_with('"') && text.ends_with('"') {
8 text[1..text.len() - 1].to_string()
9 } else {
10 text.to_lowercase()
11 }
12}
13
14#[derive(Clone, Debug, PartialEq)]
15pub(crate) enum ColumnName {
16 Column(String),
17 UnknownColumn(Option<String>),
29 Star,
30}
31
32impl ColumnName {
33 #[allow(dead_code)]
34 pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
35 if let Some(as_name) = target.as_name()
36 && let Some(name_node) = as_name.name()
37 {
38 let text = name_node.text();
39 let normalized = normalize_identifier(&text);
40 return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
41 } else if let Some(expr) = target.expr()
42 && let Some(name) = name_from_expr(expr, false)
43 {
44 return Some(name);
45 } else if target.star_token().is_some() {
46 return Some((ColumnName::Star, target.syntax().clone()));
47 }
48 None
49 }
50
51 fn new(name: String, unknown_column: bool) -> ColumnName {
52 if unknown_column {
53 ColumnName::UnknownColumn(Some(name))
54 } else {
55 ColumnName::Column(name)
56 }
57 }
58}
59
60fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
61 match ty {
62 ast::Type::PathType(path_type) => {
63 if let Some(name_ref) = path_type
64 .path()
65 .and_then(|x| x.segment())
66 .and_then(|x| x.name_ref())
67 {
68 return name_from_name_ref(name_ref, true).map(|(column, node)| {
69 let column = match column {
70 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
71 _ => column,
72 };
73 (column, node)
74 });
75 }
76 }
77 ast::Type::BitType(bit_type) => {
78 let name = if bit_type.varying_token().is_some() {
79 "varbit"
80 } else {
81 "bit"
82 };
83 return Some((
84 ColumnName::new(name.to_string(), unknown_column),
85 bit_type.syntax().clone(),
86 ));
87 }
88 ast::Type::CharType(char_type) => {
89 let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
90 {
91 "varchar"
92 } else {
93 "bpchar"
94 };
95 return Some((
96 ColumnName::new(name.to_string(), unknown_column),
97 char_type.syntax().clone(),
98 ));
99 }
100 ast::Type::DoubleType(double_type) => {
101 return Some((
102 ColumnName::new("float8".to_string(), unknown_column),
103 double_type.syntax().clone(),
104 ));
105 }
106 ast::Type::IntervalType(interval_type) => {
107 return Some((
108 ColumnName::new("interval".to_string(), unknown_column),
109 interval_type.syntax().clone(),
110 ));
111 }
112 ast::Type::TimeType(time_type) => {
113 let mut name = if time_type.timestamp_token().is_some() {
114 "timestamp".to_owned()
115 } else {
116 "time".to_owned()
117 };
118 if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
119 name.push_str("tz");
122 };
123 return Some((
124 ColumnName::new(name.to_string(), unknown_column),
125 time_type.syntax().clone(),
126 ));
127 }
128 ast::Type::ArrayType(array_type) => {
129 if let Some(inner_ty) = array_type.ty() {
130 return name_from_type(inner_ty, unknown_column);
131 }
132 }
133 ast::Type::PercentType(_) => return None,
136 ast::Type::ExprType(expr_type) => {
137 if let Some(expr) = expr_type.expr() {
138 return name_from_expr(expr, true).map(|(column, node)| {
139 let column = match column {
140 ColumnName::Column(c) => ColumnName::new(c, unknown_column),
141 _ => column,
142 };
143 (column, node)
144 });
145 }
146 }
147 }
148 None
149}
150
151fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
152 if in_type {
153 for node in name_ref.syntax().children_with_tokens() {
154 match node.kind() {
155 SyntaxKind::BIGINT_KW => {
156 return Some((
157 ColumnName::Column("int8".to_owned()),
158 name_ref.syntax().clone(),
159 ));
160 }
161 SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
162 return Some((
163 ColumnName::Column("int4".to_owned()),
164 name_ref.syntax().clone(),
165 ));
166 }
167 SyntaxKind::SMALLINT_KW => {
168 return Some((
169 ColumnName::Column("int2".to_owned()),
170 name_ref.syntax().clone(),
171 ));
172 }
173 _ => (),
174 }
175 }
176 }
177 return Some((
178 ColumnName::Column(name_ref.text().to_string()),
179 name_ref.syntax().clone(),
180 ));
181}
182
183fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
199 let node = expr.syntax().clone();
200 match expr {
201 ast::Expr::ArrayExpr(_) => {
202 return Some((ColumnName::Column("array".to_string()), node));
203 }
204 ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
205 return Some((ColumnName::UnknownColumn(None), node));
206 }
207 ast::Expr::CallExpr(call_expr) => {
208 if let Some(func_name) = call_expr.expr() {
209 match func_name {
210 ast::Expr::ArrayExpr(_)
211 | ast::Expr::BetweenExpr(_)
212 | ast::Expr::ParenExpr(_)
213 | ast::Expr::BinExpr(_)
214 | ast::Expr::CallExpr(_)
215 | ast::Expr::CaseExpr(_)
216 | ast::Expr::CastExpr(_)
217 | ast::Expr::Literal(_)
218 | ast::Expr::PostfixExpr(_)
219 | ast::Expr::PrefixExpr(_)
220 | ast::Expr::TupleExpr(_)
221 | ast::Expr::IndexExpr(_)
222 | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
223 ast::Expr::FieldExpr(field_expr) => {
224 if let Some(name_ref) = field_expr.field() {
225 return name_from_name_ref(name_ref, in_type);
226 }
227 }
228 ast::Expr::NameRef(name_ref) => {
229 return name_from_name_ref(name_ref, in_type);
230 }
231 }
232 }
233 }
234 ast::Expr::CaseExpr(case) => {
235 if let Some(else_clause) = case.else_clause()
236 && let Some(expr) = else_clause.expr()
237 && let Some((column, node)) = name_from_expr(expr, in_type)
238 {
239 if !matches!(column, ColumnName::UnknownColumn(_)) {
240 return Some((column, node));
241 }
242 }
243 return Some((ColumnName::Column("case".to_string()), node));
244 }
245 ast::Expr::CastExpr(cast_expr) => {
246 let mut unknown_column = false;
247 if let Some(expr) = cast_expr.expr()
248 && let Some((column, node)) = name_from_expr(expr, in_type)
249 {
250 match column {
251 ColumnName::Column(_) => return Some((column, node)),
252 ColumnName::UnknownColumn(_) => unknown_column = true,
253 ColumnName::Star => (),
254 }
255 }
256 if let Some(ty) = cast_expr.ty() {
257 return name_from_type(ty, unknown_column);
258 }
259 }
260 ast::Expr::FieldExpr(field_expr) => {
261 if let Some(name_ref) = field_expr.field() {
262 return name_from_name_ref(name_ref, in_type);
263 }
264 }
265 ast::Expr::IndexExpr(index_expr) => {
266 if let Some(base) = index_expr.base() {
267 return name_from_expr(base, in_type);
268 }
269 }
270 ast::Expr::SliceExpr(slice_expr) => {
271 if let Some(base) = slice_expr.base() {
272 return name_from_expr(base, in_type);
273 }
274 }
275 ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
276 return Some((ColumnName::UnknownColumn(None), node));
277 }
278 ast::Expr::NameRef(name_ref) => {
279 return name_from_name_ref(name_ref, in_type);
280 }
281 ast::Expr::ParenExpr(paren_expr) => {
282 if let Some(expr) = paren_expr.expr() {
283 return name_from_expr(expr, in_type);
284 } else if let Some(select) = paren_expr.select()
285 && let Some(mut targets) = select
286 .select_clause()
287 .and_then(|x| x.target_list())
288 .map(|x| x.targets())
289 && let Some(target) = targets.next()
290 {
291 return ColumnName::from_target(target);
292 }
293 }
294 ast::Expr::TupleExpr(_) => {
295 return Some((ColumnName::Column("row".to_string()), node));
296 }
297 }
298 None
299}
300
301#[test]
302fn examples() {
303 use insta::assert_snapshot;
304
305 assert_snapshot!(name("array(select 1)"), @"array");
307 assert_snapshot!(name("array[1, 2, 3]"), @"array");
308
309 assert_snapshot!(name("1 between 0 and 10"), @"?column?");
311 assert_snapshot!(name("1 + 2"), @"?column?");
312 assert_snapshot!(name("42"), @"?column?");
313 assert_snapshot!(name("'string'"), @"?column?");
314 assert_snapshot!(name("-42"), @"?column?");
316 assert_snapshot!(name("|/ 42"), @"?column?");
317 assert_snapshot!(name("x is null"), @"?column?");
319 assert_snapshot!(name("x is not null"), @"?column?");
320 assert_snapshot!(name("(1 * 2)"), @"?column?");
322 assert_snapshot!(name("(select 1 as a)"), @"a");
323
324 assert_snapshot!(name("count(*)"), @"count");
326 assert_snapshot!(name("schema.func_name(1)"), @"func_name");
327
328 assert_snapshot!(name("foo[bar]"), @"foo");
330 assert_snapshot!(name("foo[1]"), @"foo");
331
332 assert_snapshot!(name("database.schema.table.column"), @"column");
334 assert_snapshot!(name("t.a"), @"a");
335 assert_snapshot!(name("col_name"), @"col_name");
336 assert_snapshot!(name("(c)"), @"c");
337
338 assert_snapshot!(name("case when true then 'foo' end"), @"case");
340 assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
341 assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
342 assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
343
344 assert_snapshot!(name("now()::text"), @"now");
346 assert_snapshot!(name("cast(col_name as text)"), @"col_name");
347 assert_snapshot!(name("col_name::text"), @"col_name");
348 assert_snapshot!(name("col_name::int::text"), @"col_name");
349 assert_snapshot!(name("'1'::bigint"), @"int8");
350 assert_snapshot!(name("'1'::int"), @"int4");
351 assert_snapshot!(name("'1'::smallint"), @"int2");
352 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
353 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
354 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
355 assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
356 assert_snapshot!(name("'{1}'::integer[];"), @"int4");
357 assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
358 assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
359
360 assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
363 assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
364 assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
366 assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
367
368 assert_snapshot!(name("(1, 2, 3)"), @"row");
370 assert_snapshot!(name("(1, 2, 3)::address"), @"row");
371
372 assert_snapshot!(name("(x).city"), @"city");
374
375 assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
377 assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
378
379 assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
381
382 assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
384 assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
385 assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
386
387 assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
389
390 assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
392
393 assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
395
396 assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
398 assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
399 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
400 assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
401
402 #[track_caller]
403 fn name(sql: &str) -> String {
404 let sql = "select ".to_string() + sql;
405 let parse = squawk_syntax::SourceFile::parse(&sql);
406 assert_eq!(parse.errors(), vec![]);
407 let file = parse.tree();
408
409 let stmt = file.stmts().next().unwrap();
410 let ast::Stmt::Select(select) = stmt else {
411 unreachable!()
412 };
413
414 let target = select
415 .select_clause()
416 .and_then(|sc| sc.target_list())
417 .and_then(|tl| tl.targets().next())
418 .unwrap();
419
420 ColumnName::from_target(target)
421 .map(|x| match x.0 {
422 ColumnName::Column(string) => string,
423 ColumnName::Star => unreachable!(),
424 ColumnName::UnknownColumn(c) => c.unwrap_or_else(|| "?column?".to_string()),
425 })
426 .unwrap()
427 }
428}