squawk_ide/
column_name.rs

1use squawk_syntax::{
2    SyntaxKind, SyntaxNode,
3    ast::{self, AstNode},
4};
5
6use crate::quote::normalize_identifier;
7
8#[derive(Clone, Debug, PartialEq)]
9pub(crate) enum ColumnName {
10    Column(String),
11    /// There's a fallback mechanism that we need to propagate through the
12    /// expressions/types.
13    //
14    /// We can see this with:
15    /// ```sql
16    /// select case when true then 'a' else now()::text end;
17    /// -- column named `now`, propagating the function name
18    /// -- vs
19    /// select case when true then 'a' else 'b' end;
20    /// -- column named `case`
21    /// ```
22    UnknownColumn(Option<String>),
23    Star,
24}
25
26impl ColumnName {
27    // Get the alias, otherwise infer the column name.
28    pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
29        if let Some(as_name) = target.as_name()
30            && let Some(name_node) = as_name.name()
31        {
32            let text = name_node.text();
33            let normalized = normalize_identifier(&text);
34            return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
35        }
36        Self::inferred_from_target(target)
37    }
38
39    // Ignore any aliases, just infer the what the column name.
40    pub(crate) fn inferred_from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
41        if let Some(expr) = target.expr()
42            && let Some(name) = name_from_expr(expr, false)
43        {
44            return Some(name);
45        } else if target.star_token().is_some() {
46            return Some((ColumnName::Star, target.syntax().clone()));
47        }
48        None
49    }
50
51    fn new(name: String, unknown_column: bool) -> ColumnName {
52        if unknown_column {
53            ColumnName::UnknownColumn(Some(name))
54        } else {
55            ColumnName::Column(name)
56        }
57    }
58
59    pub(crate) fn to_string(&self) -> Option<String> {
60        match self {
61            ColumnName::Column(string) => Some(string.to_string()),
62            ColumnName::Star => None,
63            ColumnName::UnknownColumn(c) => {
64                Some(c.clone().unwrap_or_else(|| "?column?".to_string()))
65            }
66        }
67    }
68}
69
70fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
71    match ty {
72        ast::Type::PathType(path_type) => {
73            if let Some(name_ref) = path_type
74                .path()
75                .and_then(|x| x.segment())
76                .and_then(|x| x.name_ref())
77            {
78                return name_from_name_ref(name_ref, true).map(|(column, node)| {
79                    let column = match column {
80                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
81                        _ => column,
82                    };
83                    (column, node)
84                });
85            }
86        }
87        ast::Type::BitType(bit_type) => {
88            let name = if bit_type.varying_token().is_some() {
89                "varbit"
90            } else {
91                "bit"
92            };
93            return Some((
94                ColumnName::new(name.to_string(), unknown_column),
95                bit_type.syntax().clone(),
96            ));
97        }
98        ast::Type::CharType(char_type) => {
99            let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
100            {
101                "varchar"
102            } else {
103                "bpchar"
104            };
105            return Some((
106                ColumnName::new(name.to_string(), unknown_column),
107                char_type.syntax().clone(),
108            ));
109        }
110        ast::Type::DoubleType(double_type) => {
111            return Some((
112                ColumnName::new("float8".to_string(), unknown_column),
113                double_type.syntax().clone(),
114            ));
115        }
116        ast::Type::IntervalType(interval_type) => {
117            return Some((
118                ColumnName::new("interval".to_string(), unknown_column),
119                interval_type.syntax().clone(),
120            ));
121        }
122        ast::Type::TimeType(time_type) => {
123            let mut name = if time_type.timestamp_token().is_some() {
124                "timestamp".to_owned()
125            } else {
126                "time".to_owned()
127            };
128            if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
129                // time -> timetz
130                // timestamp -> timestamptz
131                name.push_str("tz");
132            };
133            return Some((
134                ColumnName::new(name.to_string(), unknown_column),
135                time_type.syntax().clone(),
136            ));
137        }
138        ast::Type::ArrayType(array_type) => {
139            if let Some(inner_ty) = array_type.ty() {
140                return name_from_type(inner_ty, unknown_column);
141            }
142        }
143        // we shouldn't ever hit this since the following isn't valid syntax:
144        // select cast('foo' as t.a%TYPE);
145        ast::Type::PercentType(_) => return None,
146        ast::Type::ExprType(expr_type) => {
147            if let Some(expr) = expr_type.expr() {
148                return name_from_expr(expr, true).map(|(column, node)| {
149                    let column = match column {
150                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
151                        _ => column,
152                    };
153                    (column, node)
154                });
155            }
156        }
157    }
158    None
159}
160
161fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
162    if in_type {
163        for node in name_ref.syntax().children_with_tokens() {
164            match node.kind() {
165                SyntaxKind::BIGINT_KW => {
166                    return Some((
167                        ColumnName::Column("int8".to_owned()),
168                        name_ref.syntax().clone(),
169                    ));
170                }
171                SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
172                    return Some((
173                        ColumnName::Column("int4".to_owned()),
174                        name_ref.syntax().clone(),
175                    ));
176                }
177                SyntaxKind::SMALLINT_KW => {
178                    return Some((
179                        ColumnName::Column("int2".to_owned()),
180                        name_ref.syntax().clone(),
181                    ));
182                }
183                _ => (),
184            }
185        }
186    }
187    let text = name_ref.text();
188    let normalized = normalize_identifier(&text);
189    return Some((ColumnName::Column(normalized), name_ref.syntax().clone()));
190}
191
192/*
193TODO:
194
195unnest(anyarray, anyarray [, ... ]) → setof anyelement, anyelement [, ... ]
196
197select * from unnest(ARRAY[1,2], ARRAY['foo','bar','baz']) →
198 unnset | unnset
199--------+-----
200      1 | foo
201      2 | bar
202        | baz
203*/
204
205// NOTE: we have to have this in_type param because we parse some casts as exprs
206// instead of types.
207fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
208    let node = expr.syntax().clone();
209    match expr {
210        ast::Expr::ArrayExpr(_) => {
211            return Some((ColumnName::Column("array".to_string()), node));
212        }
213        ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
214            return Some((ColumnName::UnknownColumn(None), node));
215        }
216        ast::Expr::CallExpr(call_expr) => {
217            if let Some(exists_fn) = call_expr.exists_fn() {
218                return Some((
219                    ColumnName::Column("exists".to_string()),
220                    exists_fn.syntax().clone(),
221                ));
222            }
223            if let Some(extract_fn) = call_expr.extract_fn() {
224                return Some((
225                    ColumnName::Column("extract".to_string()),
226                    extract_fn.syntax().clone(),
227                ));
228            }
229            if let Some(json_exists_fn) = call_expr.json_exists_fn() {
230                return Some((
231                    ColumnName::Column("json_exists".to_string()),
232                    json_exists_fn.syntax().clone(),
233                ));
234            }
235            if let Some(json_array_fn) = call_expr.json_array_fn() {
236                return Some((
237                    ColumnName::Column("json_array".to_string()),
238                    json_array_fn.syntax().clone(),
239                ));
240            }
241            if let Some(json_object_fn) = call_expr.json_object_fn() {
242                return Some((
243                    ColumnName::Column("json_object".to_string()),
244                    json_object_fn.syntax().clone(),
245                ));
246            }
247            if let Some(json_object_agg_fn) = call_expr.json_object_agg_fn() {
248                return Some((
249                    ColumnName::Column("json_objectagg".to_string()),
250                    json_object_agg_fn.syntax().clone(),
251                ));
252            }
253            if let Some(json_array_agg_fn) = call_expr.json_array_agg_fn() {
254                return Some((
255                    ColumnName::Column("json_arrayagg".to_string()),
256                    json_array_agg_fn.syntax().clone(),
257                ));
258            }
259            if let Some(json_query_fn) = call_expr.json_query_fn() {
260                return Some((
261                    ColumnName::Column("json_query".to_string()),
262                    json_query_fn.syntax().clone(),
263                ));
264            }
265            if let Some(json_scalar_fn) = call_expr.json_scalar_fn() {
266                return Some((
267                    ColumnName::Column("json_scalar".to_string()),
268                    json_scalar_fn.syntax().clone(),
269                ));
270            }
271            if let Some(json_serialize_fn) = call_expr.json_serialize_fn() {
272                return Some((
273                    ColumnName::Column("json_serialize".to_string()),
274                    json_serialize_fn.syntax().clone(),
275                ));
276            }
277            if let Some(json_value_fn) = call_expr.json_value_fn() {
278                return Some((
279                    ColumnName::Column("json_value".to_string()),
280                    json_value_fn.syntax().clone(),
281                ));
282            }
283            if let Some(json_fn) = call_expr.json_fn() {
284                return Some((
285                    ColumnName::Column("json".to_string()),
286                    json_fn.syntax().clone(),
287                ));
288            }
289            if let Some(substring_fn) = call_expr.substring_fn() {
290                return Some((
291                    ColumnName::Column("substring".to_string()),
292                    substring_fn.syntax().clone(),
293                ));
294            }
295            if let Some(position_fn) = call_expr.position_fn() {
296                return Some((
297                    ColumnName::Column("position".to_string()),
298                    position_fn.syntax().clone(),
299                ));
300            }
301            if let Some(overlay_fn) = call_expr.overlay_fn() {
302                return Some((
303                    ColumnName::Column("overlay".to_string()),
304                    overlay_fn.syntax().clone(),
305                ));
306            }
307            if let Some(trim_fn) = call_expr.trim_fn() {
308                return Some((
309                    ColumnName::Column("trim".to_string()),
310                    trim_fn.syntax().clone(),
311                ));
312            }
313            if let Some(xml_root_fn) = call_expr.xml_root_fn() {
314                return Some((
315                    ColumnName::Column("xml_root".to_string()),
316                    xml_root_fn.syntax().clone(),
317                ));
318            }
319            if let Some(xml_serialize_fn) = call_expr.xml_serialize_fn() {
320                return Some((
321                    ColumnName::Column("xml_serialize".to_string()),
322                    xml_serialize_fn.syntax().clone(),
323                ));
324            }
325            if let Some(xml_element_fn) = call_expr.xml_element_fn() {
326                return Some((
327                    ColumnName::Column("xml_element".to_string()),
328                    xml_element_fn.syntax().clone(),
329                ));
330            }
331            if let Some(xml_forest_fn) = call_expr.xml_forest_fn() {
332                return Some((
333                    ColumnName::Column("xml_forest".to_string()),
334                    xml_forest_fn.syntax().clone(),
335                ));
336            }
337            if let Some(xml_exists_fn) = call_expr.xml_exists_fn() {
338                return Some((
339                    ColumnName::Column("xml_exists".to_string()),
340                    xml_exists_fn.syntax().clone(),
341                ));
342            }
343            if let Some(xml_parse_fn) = call_expr.xml_parse_fn() {
344                return Some((
345                    ColumnName::Column("xml_parse".to_string()),
346                    xml_parse_fn.syntax().clone(),
347                ));
348            }
349            if let Some(xml_pi_fn) = call_expr.xml_pi_fn() {
350                return Some((
351                    ColumnName::Column("xml_pi".to_string()),
352                    xml_pi_fn.syntax().clone(),
353                ));
354            }
355            if let Some(func_name) = call_expr.expr() {
356                match func_name {
357                    ast::Expr::ArrayExpr(_)
358                    | ast::Expr::BetweenExpr(_)
359                    | ast::Expr::ParenExpr(_)
360                    | ast::Expr::BinExpr(_)
361                    | ast::Expr::CallExpr(_)
362                    | ast::Expr::CaseExpr(_)
363                    | ast::Expr::CastExpr(_)
364                    | ast::Expr::Literal(_)
365                    | ast::Expr::PostfixExpr(_)
366                    | ast::Expr::PrefixExpr(_)
367                    | ast::Expr::TupleExpr(_)
368                    | ast::Expr::IndexExpr(_)
369                    | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
370                    ast::Expr::FieldExpr(field_expr) => {
371                        if let Some(name_ref) = field_expr.field() {
372                            return name_from_name_ref(name_ref, in_type);
373                        }
374                    }
375                    ast::Expr::NameRef(name_ref) => {
376                        return name_from_name_ref(name_ref, in_type);
377                    }
378                }
379            }
380        }
381        ast::Expr::CaseExpr(case) => {
382            if let Some(else_clause) = case.else_clause()
383                && let Some(expr) = else_clause.expr()
384                && let Some((column, node)) = name_from_expr(expr, in_type)
385            {
386                if !matches!(column, ColumnName::UnknownColumn(_)) {
387                    return Some((column, node));
388                }
389            }
390            return Some((ColumnName::Column("case".to_string()), node));
391        }
392        ast::Expr::CastExpr(cast_expr) => {
393            let mut unknown_column = false;
394            if let Some(expr) = cast_expr.expr()
395                && let Some((column, node)) = name_from_expr(expr, in_type)
396            {
397                match column {
398                    ColumnName::Column(_) => return Some((column, node)),
399                    ColumnName::UnknownColumn(_) => unknown_column = true,
400                    ColumnName::Star => (),
401                }
402            }
403            if let Some(ty) = cast_expr.ty() {
404                return name_from_type(ty, unknown_column);
405            }
406        }
407        ast::Expr::FieldExpr(field_expr) => {
408            if let Some(name_ref) = field_expr.field() {
409                return name_from_name_ref(name_ref, in_type);
410            }
411        }
412        ast::Expr::IndexExpr(index_expr) => {
413            if let Some(base) = index_expr.base() {
414                return name_from_expr(base, in_type);
415            }
416        }
417        ast::Expr::SliceExpr(slice_expr) => {
418            if let Some(base) = slice_expr.base() {
419                return name_from_expr(base, in_type);
420            }
421        }
422        ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
423            return Some((ColumnName::UnknownColumn(None), node));
424        }
425        ast::Expr::NameRef(name_ref) => {
426            return name_from_name_ref(name_ref, in_type);
427        }
428        ast::Expr::ParenExpr(paren_expr) => {
429            if let Some(expr) = paren_expr.expr() {
430                return name_from_expr(expr, in_type);
431            } else if let Some(select) = paren_expr.select()
432                && let Some(mut targets) = select
433                    .select_clause()
434                    .and_then(|x| x.target_list())
435                    .map(|x| x.targets())
436                && let Some(target) = targets.next()
437            {
438                return ColumnName::from_target(target);
439            }
440        }
441        ast::Expr::TupleExpr(_) => {
442            return Some((ColumnName::Column("row".to_string()), node));
443        }
444    }
445    None
446}
447
448#[test]
449fn examples() {
450    use insta::assert_snapshot;
451
452    // array
453    assert_snapshot!(name("array(select 1)"), @"array");
454    assert_snapshot!(name("array[1, 2, 3]"), @"array");
455
456    // unknown columns
457    assert_snapshot!(name("1 between 0 and 10"), @"?column?");
458    assert_snapshot!(name("1 + 2"), @"?column?");
459    assert_snapshot!(name("42"), @"?column?");
460    assert_snapshot!(name("'string'"), @"?column?");
461    // prefix
462    assert_snapshot!(name("-42"), @"?column?");
463    assert_snapshot!(name("|/ 42"), @"?column?");
464    // postfix
465    assert_snapshot!(name("x is null"), @"?column?");
466    assert_snapshot!(name("x is not null"), @"?column?");
467    // paren expr
468    assert_snapshot!(name("(1 * 2)"), @"?column?");
469    assert_snapshot!(name("(select 1 as a)"), @"a");
470
471    // func
472    assert_snapshot!(name("count(*)"), @"count");
473    assert_snapshot!(name("schema.func_name(1)"), @"func_name");
474
475    // special funcs
476    assert_snapshot!(name("extract(year from now())"), @"extract");
477    assert_snapshot!(name("exists(select 1)"), @"exists");
478    assert_snapshot!(name(r#"json_exists('{"a":1}', '$.a')"#), @"json_exists");
479    assert_snapshot!(name("json_array(1, 2)"), @"json_array");
480    assert_snapshot!(name("json_object('a': 1)"), @"json_object");
481    assert_snapshot!(name("json_objectagg('a': 1)"), @"json_objectagg");
482    assert_snapshot!(name("json_arrayagg(1)"), @"json_arrayagg");
483    assert_snapshot!(name(r#"json_query('{"a":1}', '$.a')"#), @"json_query");
484    assert_snapshot!(name("json_scalar(1)"), @"json_scalar");
485    assert_snapshot!(name(r#"json_serialize('{"a":1}')"#), @"json_serialize");
486    assert_snapshot!(name(r#"json_value('{"a":1}', '$.a')"#), @"json_value");
487    assert_snapshot!(name(r#"json('{"a":1}')"#), @"json");
488    assert_snapshot!(name("substring('hello' from 2 for 3)"), @"substring");
489    assert_snapshot!(name("position('a' in 'abc')"), @"position");
490    assert_snapshot!(name("overlay('hello' placing 'X' from 2)"), @"overlay");
491    assert_snapshot!(name("trim('  hi  ')"), @"trim");
492    assert_snapshot!(name("xmlroot('<a/>', version '1.0')"), @"xml_root");
493    assert_snapshot!(name("xmlserialize(document '<a/>' as text)"), @"xml_serialize");
494    assert_snapshot!(name("xmlelement(name foo, 'bar')"), @"xml_element");
495    assert_snapshot!(name("xmlforest('bar' as foo)"), @"xml_forest");
496    assert_snapshot!(name("xmlexists('//a' passing '<a/>')"), @"xml_exists");
497    assert_snapshot!(name("xmlparse(document '<a/>')"), @"xml_parse");
498    assert_snapshot!(name("xmlpi(name foo, 'bar')"), @"xml_pi");
499
500    // index
501    assert_snapshot!(name("foo[bar]"), @"foo");
502    assert_snapshot!(name("foo[1]"), @"foo");
503
504    // column
505    assert_snapshot!(name("database.schema.table.column"), @"column");
506    assert_snapshot!(name("t.a"), @"a");
507    assert_snapshot!(name("col_name"), @"col_name");
508    assert_snapshot!(name("(c)"), @"c");
509
510    // case
511    assert_snapshot!(name("case when true then 'foo' end"), @"case");
512    assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
513    assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
514    assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
515
516    // casts
517    assert_snapshot!(name("now()::text"), @"now");
518    assert_snapshot!(name("cast(col_name as text)"), @"col_name");
519    assert_snapshot!(name("col_name::text"), @"col_name");
520    assert_snapshot!(name("col_name::int::text"), @"col_name");
521    assert_snapshot!(name("'1'::bigint"), @"int8");
522    assert_snapshot!(name("'1'::int"), @"int4");
523    assert_snapshot!(name("'1'::smallint"), @"int2");
524    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
525    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
526    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
527    assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
528    assert_snapshot!(name("'{1}'::integer[];"), @"int4");
529    assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
530    assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
531
532    // alias
533    // with quoting
534    assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
535    assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
536    // without quoting
537    assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
538    assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
539
540    // tuple
541    assert_snapshot!(name("(1, 2, 3)"), @"row");
542    assert_snapshot!(name("(1, 2, 3)::address"), @"row");
543
544    // composite type
545    assert_snapshot!(name("(x).city"), @"city");
546
547    // array types
548    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
549    assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
550
551    // bit types
552    assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
553
554    // char types
555    assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
556    assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
557    assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
558
559    // double types
560    assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
561
562    // interval types
563    assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
564
565    // percent types
566    assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
567
568    // time types
569    assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
570    assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
571    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
572    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
573
574    #[track_caller]
575    fn name(sql: &str) -> String {
576        let sql = "select ".to_string() + sql;
577        let parse = squawk_syntax::SourceFile::parse(&sql);
578        assert_eq!(parse.errors(), vec![]);
579        let file = parse.tree();
580
581        let stmt = file.stmts().next().unwrap();
582        let ast::Stmt::Select(select) = stmt else {
583            unreachable!()
584        };
585
586        let target = select
587            .select_clause()
588            .and_then(|sc| sc.target_list())
589            .and_then(|tl| tl.targets().next())
590            .unwrap();
591
592        ColumnName::from_target(target)
593            .and_then(|x| x.0.to_string())
594            .unwrap()
595    }
596}