squawk_ide/
column_name.rs

1use squawk_syntax::{
2    SyntaxKind, SyntaxNode,
3    ast::{self, AstNode},
4};
5
6fn normalize_identifier(text: &str) -> String {
7    if text.starts_with('"') && text.ends_with('"') {
8        text[1..text.len() - 1].to_string()
9    } else {
10        text.to_lowercase()
11    }
12}
13
14#[derive(Clone, Debug, PartialEq)]
15pub(crate) enum ColumnName {
16    Column(String),
17    /// There's a fallback mechanism that we need to propagate through the
18    /// expressions/types.
19    //
20    /// We can see this with:
21    /// ```sql
22    /// select case when true then 'a' else now()::text end;
23    /// -- column named `now`, propagating the function name
24    /// -- vs
25    /// select case when true then 'a' else 'b' end;
26    /// -- column named `case`
27    /// ```
28    UnknownColumn(Option<String>),
29    Star,
30}
31
32impl ColumnName {
33    #[allow(dead_code)]
34    pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> {
35        if let Some(as_name) = target.as_name()
36            && let Some(name_node) = as_name.name()
37        {
38            let text = name_node.text();
39            let normalized = normalize_identifier(&text);
40            return Some((ColumnName::Column(normalized), name_node.syntax().clone()));
41        } else if let Some(expr) = target.expr()
42            && let Some(name) = name_from_expr(expr, false)
43        {
44            return Some(name);
45        } else if target.star_token().is_some() {
46            return Some((ColumnName::Star, target.syntax().clone()));
47        }
48        None
49    }
50
51    fn new(name: String, unknown_column: bool) -> ColumnName {
52        if unknown_column {
53            ColumnName::UnknownColumn(Some(name))
54        } else {
55            ColumnName::Column(name)
56        }
57    }
58}
59
60fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> {
61    match ty {
62        ast::Type::PathType(path_type) => {
63            if let Some(name_ref) = path_type
64                .path()
65                .and_then(|x| x.segment())
66                .and_then(|x| x.name_ref())
67            {
68                return name_from_name_ref(name_ref, true).map(|(column, node)| {
69                    let column = match column {
70                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
71                        _ => column,
72                    };
73                    (column, node)
74                });
75            }
76        }
77        ast::Type::BitType(bit_type) => {
78            let name = if bit_type.varying_token().is_some() {
79                "varbit"
80            } else {
81                "bit"
82            };
83            return Some((
84                ColumnName::new(name.to_string(), unknown_column),
85                bit_type.syntax().clone(),
86            ));
87        }
88        ast::Type::CharType(char_type) => {
89            let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
90            {
91                "varchar"
92            } else {
93                "bpchar"
94            };
95            return Some((
96                ColumnName::new(name.to_string(), unknown_column),
97                char_type.syntax().clone(),
98            ));
99        }
100        ast::Type::DoubleType(double_type) => {
101            return Some((
102                ColumnName::new("float8".to_string(), unknown_column),
103                double_type.syntax().clone(),
104            ));
105        }
106        ast::Type::IntervalType(interval_type) => {
107            return Some((
108                ColumnName::new("interval".to_string(), unknown_column),
109                interval_type.syntax().clone(),
110            ));
111        }
112        ast::Type::TimeType(time_type) => {
113            let mut name = if time_type.timestamp_token().is_some() {
114                "timestamp".to_owned()
115            } else {
116                "time".to_owned()
117            };
118            if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
119                // time -> timetz
120                // timestamp -> timestamptz
121                name.push_str("tz");
122            };
123            return Some((
124                ColumnName::new(name.to_string(), unknown_column),
125                time_type.syntax().clone(),
126            ));
127        }
128        ast::Type::ArrayType(array_type) => {
129            if let Some(inner_ty) = array_type.ty() {
130                return name_from_type(inner_ty, unknown_column);
131            }
132        }
133        // we shouldn't ever hit this since the following isn't valid syntax:
134        // select cast('foo' as t.a%TYPE);
135        ast::Type::PercentType(_) => return None,
136        ast::Type::ExprType(expr_type) => {
137            if let Some(expr) = expr_type.expr() {
138                return name_from_expr(expr, true).map(|(column, node)| {
139                    let column = match column {
140                        ColumnName::Column(c) => ColumnName::new(c, unknown_column),
141                        _ => column,
142                    };
143                    (column, node)
144                });
145            }
146        }
147    }
148    None
149}
150
151fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
152    if in_type {
153        for node in name_ref.syntax().children_with_tokens() {
154            match node.kind() {
155                SyntaxKind::BIGINT_KW => {
156                    return Some((
157                        ColumnName::Column("int8".to_owned()),
158                        name_ref.syntax().clone(),
159                    ));
160                }
161                SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => {
162                    return Some((
163                        ColumnName::Column("int4".to_owned()),
164                        name_ref.syntax().clone(),
165                    ));
166                }
167                SyntaxKind::SMALLINT_KW => {
168                    return Some((
169                        ColumnName::Column("int2".to_owned()),
170                        name_ref.syntax().clone(),
171                    ));
172                }
173                _ => (),
174            }
175        }
176    }
177    return Some((
178        ColumnName::Column(name_ref.text().to_string()),
179        name_ref.syntax().clone(),
180    ));
181}
182
183/*
184TODO:
185
186unnest(anyarray, anyarray [, ... ]) → setof anyelement, anyelement [, ... ]
187
188select * from unnest(ARRAY[1,2], ARRAY['foo','bar','baz']) →
189 unnset | unnset
190--------+-----
191      1 | foo
192      2 | bar
193        | baz
194*/
195
196// NOTE: we have to have this in_type param because we parse some casts as exprs
197// instead of types.
198fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> {
199    let node = expr.syntax().clone();
200    match expr {
201        ast::Expr::ArrayExpr(_) => {
202            return Some((ColumnName::Column("array".to_string()), node));
203        }
204        ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => {
205            return Some((ColumnName::UnknownColumn(None), node));
206        }
207        ast::Expr::CallExpr(call_expr) => {
208            if let Some(func_name) = call_expr.expr() {
209                match func_name {
210                    ast::Expr::ArrayExpr(_)
211                    | ast::Expr::BetweenExpr(_)
212                    | ast::Expr::ParenExpr(_)
213                    | ast::Expr::BinExpr(_)
214                    | ast::Expr::CallExpr(_)
215                    | ast::Expr::CaseExpr(_)
216                    | ast::Expr::CastExpr(_)
217                    | ast::Expr::Literal(_)
218                    | ast::Expr::PostfixExpr(_)
219                    | ast::Expr::PrefixExpr(_)
220                    | ast::Expr::TupleExpr(_)
221                    | ast::Expr::IndexExpr(_)
222                    | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"),
223                    ast::Expr::FieldExpr(field_expr) => {
224                        if let Some(name_ref) = field_expr.field() {
225                            return name_from_name_ref(name_ref, in_type);
226                        }
227                    }
228                    ast::Expr::NameRef(name_ref) => {
229                        return name_from_name_ref(name_ref, in_type);
230                    }
231                }
232            }
233        }
234        ast::Expr::CaseExpr(case) => {
235            if let Some(else_clause) = case.else_clause()
236                && let Some(expr) = else_clause.expr()
237                && let Some((column, node)) = name_from_expr(expr, in_type)
238            {
239                if !matches!(column, ColumnName::UnknownColumn(_)) {
240                    return Some((column, node));
241                }
242            }
243            return Some((ColumnName::Column("case".to_string()), node));
244        }
245        ast::Expr::CastExpr(cast_expr) => {
246            let mut unknown_column = false;
247            if let Some(expr) = cast_expr.expr()
248                && let Some((column, node)) = name_from_expr(expr, in_type)
249            {
250                match column {
251                    ColumnName::Column(_) => return Some((column, node)),
252                    ColumnName::UnknownColumn(_) => unknown_column = true,
253                    ColumnName::Star => (),
254                }
255            }
256            if let Some(ty) = cast_expr.ty() {
257                return name_from_type(ty, unknown_column);
258            }
259        }
260        ast::Expr::FieldExpr(field_expr) => {
261            if let Some(name_ref) = field_expr.field() {
262                return name_from_name_ref(name_ref, in_type);
263            }
264        }
265        ast::Expr::IndexExpr(index_expr) => {
266            if let Some(base) = index_expr.base() {
267                return name_from_expr(base, in_type);
268            }
269        }
270        ast::Expr::SliceExpr(slice_expr) => {
271            if let Some(base) = slice_expr.base() {
272                return name_from_expr(base, in_type);
273            }
274        }
275        ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => {
276            return Some((ColumnName::UnknownColumn(None), node));
277        }
278        ast::Expr::NameRef(name_ref) => {
279            return name_from_name_ref(name_ref, in_type);
280        }
281        ast::Expr::ParenExpr(paren_expr) => {
282            if let Some(expr) = paren_expr.expr() {
283                return name_from_expr(expr, in_type);
284            } else if let Some(select) = paren_expr.select()
285                && let Some(mut targets) = select
286                    .select_clause()
287                    .and_then(|x| x.target_list())
288                    .map(|x| x.targets())
289                && let Some(target) = targets.next()
290            {
291                return ColumnName::from_target(target);
292            }
293        }
294        ast::Expr::TupleExpr(_) => {
295            return Some((ColumnName::Column("row".to_string()), node));
296        }
297    }
298    None
299}
300
301#[test]
302fn examples() {
303    use insta::assert_snapshot;
304
305    // array
306    assert_snapshot!(name("array(select 1)"), @"array");
307    assert_snapshot!(name("array[1, 2, 3]"), @"array");
308
309    // unknown columns
310    assert_snapshot!(name("1 between 0 and 10"), @"?column?");
311    assert_snapshot!(name("1 + 2"), @"?column?");
312    assert_snapshot!(name("42"), @"?column?");
313    assert_snapshot!(name("'string'"), @"?column?");
314    // prefix
315    assert_snapshot!(name("-42"), @"?column?");
316    assert_snapshot!(name("|/ 42"), @"?column?");
317    // postfix
318    assert_snapshot!(name("x is null"), @"?column?");
319    assert_snapshot!(name("x is not null"), @"?column?");
320    // paren expr
321    assert_snapshot!(name("(1 * 2)"), @"?column?");
322    assert_snapshot!(name("(select 1 as a)"), @"a");
323
324    // func
325    assert_snapshot!(name("count(*)"), @"count");
326    assert_snapshot!(name("schema.func_name(1)"), @"func_name");
327
328    // index
329    assert_snapshot!(name("foo[bar]"), @"foo");
330    assert_snapshot!(name("foo[1]"), @"foo");
331
332    // column
333    assert_snapshot!(name("database.schema.table.column"), @"column");
334    assert_snapshot!(name("t.a"), @"a");
335    assert_snapshot!(name("col_name"), @"col_name");
336    assert_snapshot!(name("(c)"), @"c");
337
338    // case
339    assert_snapshot!(name("case when true then 'foo' end"), @"case");
340    assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now");
341    assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case");
342    assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case");
343
344    // casts
345    assert_snapshot!(name("now()::text"), @"now");
346    assert_snapshot!(name("cast(col_name as text)"), @"col_name");
347    assert_snapshot!(name("col_name::text"), @"col_name");
348    assert_snapshot!(name("col_name::int::text"), @"col_name");
349    assert_snapshot!(name("'1'::bigint"), @"int8");
350    assert_snapshot!(name("'1'::int"), @"int4");
351    assert_snapshot!(name("'1'::smallint"), @"int2");
352    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8");
353    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4");
354    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2");
355    assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar");
356    assert_snapshot!(name("'{1}'::integer[];"), @"int4");
357    assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4");
358    assert_snapshot!(name("'1'::bigint::smallint"), @"int2");
359
360    // alias
361    // with quoting
362    assert_snapshot!(name(r#"'foo' as "FOO""#), @"FOO");
363    assert_snapshot!(name(r#"'foo' as "foo""#), @"foo");
364    // without quoting
365    assert_snapshot!(name(r#"'foo' as FOO"#), @"foo");
366    assert_snapshot!(name(r#"'foo' as foo"#), @"foo");
367
368    // tuple
369    assert_snapshot!(name("(1, 2, 3)"), @"row");
370    assert_snapshot!(name("(1, 2, 3)::address"), @"row");
371
372    // composite type
373    assert_snapshot!(name("(x).city"), @"city");
374
375    // array types
376    assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4");
377    assert_snapshot!(name("cast('{foo}' as text[])"), @"text");
378
379    // bit types
380    assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit");
381
382    // char types
383    assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar");
384    assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar");
385    assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar");
386
387    // double types
388    assert_snapshot!(name("cast(1.5 as double precision)"), @"float8");
389
390    // interval types
391    assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval");
392
393    // percent types
394    assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo");
395
396    // time types
397    assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time");
398    assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz");
399    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz");
400    assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp");
401
402    #[track_caller]
403    fn name(sql: &str) -> String {
404        let sql = "select ".to_string() + sql;
405        let parse = squawk_syntax::SourceFile::parse(&sql);
406        assert_eq!(parse.errors(), vec![]);
407        let file = parse.tree();
408
409        let stmt = file.stmts().next().unwrap();
410        let ast::Stmt::Select(select) = stmt else {
411            unreachable!()
412        };
413
414        let target = select
415            .select_clause()
416            .and_then(|sc| sc.target_list())
417            .and_then(|tl| tl.targets().next())
418            .unwrap();
419
420        ColumnName::from_target(target)
421            .map(|x| match x.0 {
422                ColumnName::Column(string) => string,
423                ColumnName::Star => unreachable!(),
424                ColumnName::UnknownColumn(c) => c.unwrap_or_else(|| "?column?".to_string()),
425            })
426            .unwrap()
427    }
428}