Skip to main content

sqlglot_rust/optimizer/
qualify_columns.rs

1//! Qualify columns optimizer pass.
2//!
3//! Resolves column references by:
4//! - Expanding `SELECT *` to explicit column lists from schema
5//! - Expanding `SELECT t.*` to columns from table `t`
6//! - Adding table qualifiers to unqualified columns (e.g., `col` → `t.col`)
7//! - Validating column existence against the schema
8//! - Resolving columns through CTEs and derived tables
9
10use std::collections::HashMap;
11
12use crate::ast::*;
13use crate::dialects::Dialect;
14use crate::schema::{Schema, normalize_identifier};
15
16/// Qualify columns in a statement using the provided schema.
17///
18/// This adds table qualifiers to unqualified column references and expands
19/// wildcard selects (`*`, `t.*`) into explicit column lists.
20pub fn qualify_columns<S: Schema>(statement: Statement, schema: &S) -> Statement {
21    let dialect = schema.dialect();
22    match statement {
23        Statement::Select(sel) => {
24            let qualified = qualify_select(sel, schema, dialect, &HashMap::new());
25            Statement::Select(qualified)
26        }
27        Statement::SetOperation(mut set_op) => {
28            set_op.left = Box::new(qualify_columns(*set_op.left, schema));
29            set_op.right = Box::new(qualify_columns(*set_op.right, schema));
30            Statement::SetOperation(set_op)
31        }
32        other => other,
33    }
34}
35
36/// Metadata about columns available from a source (table, derived table, CTE).
37#[derive(Debug, Clone)]
38struct SourceColumns {
39    /// Column names in definition order.
40    columns: Vec<String>,
41}
42
43/// Build a mapping of source name/alias → available columns for a SELECT scope.
44fn resolve_source_columns<S: Schema>(
45    sel: &SelectStatement,
46    schema: &S,
47    dialect: Dialect,
48    cte_columns: &HashMap<String, Vec<String>>,
49) -> HashMap<String, SourceColumns> {
50    let mut source_map: HashMap<String, SourceColumns> = HashMap::new();
51
52    // Process FROM source
53    if let Some(from) = &sel.from {
54        collect_source_columns(&from.source, schema, dialect, cte_columns, &mut source_map);
55    }
56
57    // Process JOINs
58    for join in &sel.joins {
59        collect_source_columns(&join.table, schema, dialect, cte_columns, &mut source_map);
60    }
61
62    source_map
63}
64
65/// Collect columns from a single table source.
66fn collect_source_columns<S: Schema>(
67    source: &TableSource,
68    schema: &S,
69    dialect: Dialect,
70    cte_columns: &HashMap<String, Vec<String>>,
71    source_map: &mut HashMap<String, SourceColumns>,
72) {
73    match source {
74        TableSource::Table(table_ref) => {
75            let key = table_ref
76                .alias
77                .as_deref()
78                .unwrap_or(&table_ref.name)
79                .to_string();
80            let norm_key = normalize_identifier(&key, dialect);
81
82            // Check if this is a CTE reference
83            let norm_name = normalize_identifier(&table_ref.name, dialect);
84            if let Some(cols) = cte_columns.get(&norm_name) {
85                source_map.insert(
86                    norm_key,
87                    SourceColumns {
88                        columns: cols.clone(),
89                    },
90                );
91                return;
92            }
93
94            // Build the table path for schema lookup
95            let path = build_table_path(table_ref, dialect);
96            let path_refs: Vec<&str> = path.iter().map(|s| s.as_str()).collect();
97
98            if let Ok(cols) = schema.column_names(&path_refs) {
99                source_map.insert(norm_key, SourceColumns { columns: cols });
100            }
101        }
102        TableSource::Subquery { query, alias } => {
103            if let Some(alias) = alias {
104                let norm_alias = normalize_identifier(alias, dialect);
105                let cols = extract_output_columns(query, schema, dialect, cte_columns);
106                source_map.insert(norm_alias, SourceColumns { columns: cols });
107            }
108        }
109        TableSource::Lateral { source: inner } => {
110            collect_source_columns(inner, schema, dialect, cte_columns, source_map);
111        }
112        TableSource::Unnest { alias, .. } => {
113            if let Some(alias) = alias {
114                let norm_alias = normalize_identifier(alias, dialect);
115                // Unnest typically produces unnamed columns; skip
116                source_map.insert(norm_alias, SourceColumns { columns: vec![] });
117            }
118        }
119        TableSource::TableFunction { alias, .. } => {
120            if let Some(alias) = alias {
121                let norm_alias = normalize_identifier(alias, dialect);
122                source_map.insert(norm_alias, SourceColumns { columns: vec![] });
123            }
124        }
125    }
126}
127
128/// Build a normalized table path for schema lookup.
129fn build_table_path(table_ref: &TableRef, dialect: Dialect) -> Vec<String> {
130    let mut path = Vec::new();
131    if let Some(cat) = &table_ref.catalog {
132        path.push(normalize_identifier(cat, dialect));
133    }
134    if let Some(sch) = &table_ref.schema {
135        path.push(normalize_identifier(sch, dialect));
136    }
137    path.push(normalize_identifier(&table_ref.name, dialect));
138    path
139}
140
141/// Extract output column names from a subquery statement.
142fn extract_output_columns<S: Schema>(
143    stmt: &Statement,
144    schema: &S,
145    dialect: Dialect,
146    cte_columns: &HashMap<String, Vec<String>>,
147) -> Vec<String> {
148    match stmt {
149        Statement::Select(sel) => {
150            let inner_sources = resolve_source_columns(sel, schema, dialect, cte_columns);
151            let mut cols = Vec::new();
152            for item in &sel.columns {
153                match item {
154                    SelectItem::Wildcard => {
155                        // Expand * from all sources (in definition order)
156                        for_each_source_ordered(sel, dialect, &inner_sources, |sc| {
157                            cols.extend(sc.columns.iter().cloned());
158                        });
159                    }
160                    SelectItem::QualifiedWildcard { table } => {
161                        let norm_table = normalize_identifier(table, dialect);
162                        if let Some(sc) = inner_sources.get(&norm_table) {
163                            cols.extend(sc.columns.iter().cloned());
164                        }
165                    }
166                    SelectItem::Expr { alias, expr } => {
167                        if let Some(alias) = alias {
168                            cols.push(alias.clone());
169                        } else {
170                            cols.push(expr_output_name(expr));
171                        }
172                    }
173                }
174            }
175            cols
176        }
177        Statement::SetOperation(set_op) => {
178            // Output columns come from the left branch
179            extract_output_columns(&set_op.left, schema, dialect, cte_columns)
180        }
181        _ => vec![],
182    }
183}
184
185/// Get the output name of an expression (column name, function name, or a placeholder).
186fn expr_output_name(expr: &Expr) -> String {
187    match expr {
188        Expr::Column { name, .. } => name.clone(),
189        Expr::Function { name, .. } => name.clone(),
190        Expr::TypedFunction { .. } => "_col".to_string(),
191        _ => "_col".to_string(),
192    }
193}
194
195/// Iterate source columns in FROM/JOIN order for deterministic wildcard expansion.
196fn for_each_source_ordered<F>(
197    sel: &SelectStatement,
198    dialect: Dialect,
199    source_map: &HashMap<String, SourceColumns>,
200    mut callback: F,
201) where
202    F: FnMut(&SourceColumns),
203{
204    // FROM source first
205    if let Some(from) = &sel.from {
206        let key = source_key_for(&from.source, dialect);
207        if let Some(sc) = source_map.get(&key) {
208            callback(sc);
209        }
210    }
211    // Then JOINs in order
212    for join in &sel.joins {
213        let key = source_key_for(&join.table, dialect);
214        if let Some(sc) = source_map.get(&key) {
215            callback(sc);
216        }
217    }
218}
219
220/// Get the source key (alias or name) for a table source.
221fn source_key_for(source: &TableSource, dialect: Dialect) -> String {
222    match source {
223        TableSource::Table(tr) => {
224            let name = tr.alias.as_deref().unwrap_or(&tr.name);
225            normalize_identifier(name, dialect)
226        }
227        TableSource::Subquery { alias, .. } => alias
228            .as_deref()
229            .map(|a| normalize_identifier(a, dialect))
230            .unwrap_or_default(),
231        TableSource::Lateral { source } => source_key_for(source, dialect),
232        TableSource::Unnest { alias, .. } | TableSource::TableFunction { alias, .. } => alias
233            .as_deref()
234            .map(|a| normalize_identifier(a, dialect))
235            .unwrap_or_default(),
236    }
237}
238
239/// Qualify a SELECT statement: expand wildcards and qualify column references.
240fn qualify_select<S: Schema>(
241    mut sel: SelectStatement,
242    schema: &S,
243    dialect: Dialect,
244    outer_cte_columns: &HashMap<String, Vec<String>>,
245) -> SelectStatement {
246    // ── 1. Build CTE column map (CTEs defined in this SELECT + inherited) ──
247    let mut cte_columns = outer_cte_columns.clone();
248    for cte in &sel.ctes {
249        let cols = if !cte.columns.is_empty() {
250            // Explicit CTE column list: WITH cte(a, b) AS (...)
251            cte.columns.clone()
252        } else {
253            extract_output_columns(&cte.query, schema, dialect, &cte_columns)
254        };
255        let norm_name = normalize_identifier(&cte.name, dialect);
256        cte_columns.insert(norm_name, cols);
257    }
258
259    // ── 2. Recursively qualify CTE bodies ────────────────────────────
260    sel.ctes = sel
261        .ctes
262        .into_iter()
263        .map(|mut cte| {
264            cte.query = Box::new(qualify_columns(*cte.query, schema));
265            cte
266        })
267        .collect();
268
269    // ── 3. Recursively qualify derived tables and join subqueries ─────
270    if let Some(ref mut from) = sel.from {
271        qualify_table_source(&mut from.source, schema, dialect, &cte_columns);
272    }
273    for join in &mut sel.joins {
274        qualify_table_source(&mut join.table, schema, dialect, &cte_columns);
275    }
276
277    // ── 4. Resolve source columns for this scope ─────────────────────
278    let source_map = resolve_source_columns(&sel, schema, dialect, &cte_columns);
279
280    // ── 5. Expand wildcards in SELECT list ────────────────────────────
281    let mut new_columns = Vec::new();
282    let old_columns = std::mem::take(&mut sel.columns);
283    for item in old_columns {
284        match item {
285            SelectItem::Wildcard => {
286                // Expand to all columns from all sources in order
287                for_each_source_ordered(&sel, dialect, &source_map, |sc| {
288                    for col_name in &sc.columns {
289                        new_columns.push(SelectItem::Expr {
290                            expr: Expr::Column {
291                                table: None,
292                                name: col_name.clone(),
293                                quote_style: QuoteStyle::None,
294                                table_quote_style: QuoteStyle::None,
295                            },
296                            alias: None,
297                        });
298                    }
299                });
300            }
301            SelectItem::QualifiedWildcard { table } => {
302                let norm_table = normalize_identifier(&table, dialect);
303                if let Some(sc) = source_map.get(&norm_table) {
304                    for col_name in &sc.columns {
305                        new_columns.push(SelectItem::Expr {
306                            expr: Expr::Column {
307                                table: Some(table.clone()),
308                                name: col_name.clone(),
309                                quote_style: QuoteStyle::None,
310                                table_quote_style: QuoteStyle::None,
311                            },
312                            alias: None,
313                        });
314                    }
315                } else {
316                    // Unknown table — preserve as-is
317                    new_columns.push(SelectItem::QualifiedWildcard { table });
318                }
319            }
320            SelectItem::Expr { expr, alias } => {
321                let qualified_expr = qualify_expr(expr, &source_map, schema, dialect, &cte_columns);
322                new_columns.push(SelectItem::Expr {
323                    expr: qualified_expr,
324                    alias,
325                });
326            }
327        }
328    }
329    sel.columns = new_columns;
330
331    // ── 6. Qualify expressions in WHERE, GROUP BY, HAVING, ORDER BY ──
332    if let Some(wh) = sel.where_clause {
333        sel.where_clause = Some(qualify_expr(wh, &source_map, schema, dialect, &cte_columns));
334    }
335    sel.group_by = sel
336        .group_by
337        .into_iter()
338        .map(|e| qualify_expr(e, &source_map, schema, dialect, &cte_columns))
339        .collect();
340    if let Some(having) = sel.having {
341        sel.having = Some(qualify_expr(
342            having,
343            &source_map,
344            schema,
345            dialect,
346            &cte_columns,
347        ));
348    }
349    sel.order_by = sel
350        .order_by
351        .into_iter()
352        .map(|mut item| {
353            item.expr = qualify_expr(item.expr, &source_map, schema, dialect, &cte_columns);
354            item
355        })
356        .collect();
357    if let Some(qualify) = sel.qualify {
358        sel.qualify = Some(qualify_expr(
359            qualify,
360            &source_map,
361            schema,
362            dialect,
363            &cte_columns,
364        ));
365    }
366
367    // ── 7. Qualify JOIN ON expressions ────────────────────────────────
368    for join in &mut sel.joins {
369        if let Some(on) = join.on.take() {
370            join.on = Some(qualify_expr(on, &source_map, schema, dialect, &cte_columns));
371        }
372    }
373
374    sel
375}
376
377/// Recursively qualify columns inside subquery table sources.
378fn qualify_table_source<S: Schema>(
379    source: &mut TableSource,
380    schema: &S,
381    dialect: Dialect,
382    cte_columns: &HashMap<String, Vec<String>>,
383) {
384    match source {
385        TableSource::Subquery { query, .. } => {
386            *query = Box::new(qualify_columns_inner(
387                *query.clone(),
388                schema,
389                dialect,
390                cte_columns,
391            ));
392        }
393        TableSource::Lateral { source: inner } => {
394            qualify_table_source(inner, schema, dialect, cte_columns);
395        }
396        _ => {}
397    }
398}
399
400/// Inner qualify entry point that passes CTE context.
401fn qualify_columns_inner<S: Schema>(
402    statement: Statement,
403    schema: &S,
404    dialect: Dialect,
405    cte_columns: &HashMap<String, Vec<String>>,
406) -> Statement {
407    match statement {
408        Statement::Select(sel) => {
409            Statement::Select(qualify_select(sel, schema, dialect, cte_columns))
410        }
411        Statement::SetOperation(mut set_op) => {
412            set_op.left = Box::new(qualify_columns_inner(
413                *set_op.left,
414                schema,
415                dialect,
416                cte_columns,
417            ));
418            set_op.right = Box::new(qualify_columns_inner(
419                *set_op.right,
420                schema,
421                dialect,
422                cte_columns,
423            ));
424            Statement::SetOperation(set_op)
425        }
426        other => other,
427    }
428}
429
430/// Qualify column references in an expression by adding table qualifiers.
431/// Also recursively qualifies any subqueries found inside the expression.
432fn qualify_expr<S: Schema>(
433    expr: Expr,
434    source_map: &HashMap<String, SourceColumns>,
435    schema: &S,
436    dialect: Dialect,
437    cte_columns: &HashMap<String, Vec<String>>,
438) -> Expr {
439    expr.transform(&|e| match e {
440        Expr::Column {
441            table: None,
442            name,
443            quote_style,
444            table_quote_style,
445        } => {
446            let norm_name = normalize_identifier(&name, dialect);
447            // Find which source contains this column
448            let resolved_source = resolve_column(&norm_name, source_map);
449            if let Some(source_name) = resolved_source {
450                Expr::Column {
451                    table: Some(source_name),
452                    name,
453                    quote_style,
454                    table_quote_style,
455                }
456            } else {
457                // Column not found in any source — leave unqualified
458                // (could be an alias reference, positional, etc.)
459                Expr::Column {
460                    table: None,
461                    name,
462                    quote_style,
463                    table_quote_style,
464                }
465            }
466        }
467        // Recursively qualify subqueries inside expressions
468        Expr::InSubquery {
469            expr,
470            subquery,
471            negated,
472        } => Expr::InSubquery {
473            expr,
474            subquery: Box::new(qualify_columns_inner(
475                *subquery,
476                schema,
477                dialect,
478                cte_columns,
479            )),
480            negated,
481        },
482        Expr::Subquery(stmt) => Expr::Subquery(Box::new(qualify_columns_inner(
483            *stmt,
484            schema,
485            dialect,
486            cte_columns,
487        ))),
488        Expr::Exists { subquery, negated } => Expr::Exists {
489            subquery: Box::new(qualify_columns_inner(
490                *subquery,
491                schema,
492                dialect,
493                cte_columns,
494            )),
495            negated,
496        },
497        other => other,
498    })
499}
500
501/// Find which source owns a column name.
502/// If exactly one source has it, return that source's name.
503/// If multiple sources have it or none do, return None.
504fn resolve_column(
505    norm_col_name: &str,
506    source_map: &HashMap<String, SourceColumns>,
507) -> Option<String> {
508    let mut matches: Vec<&str> = Vec::new();
509    for (source_name, sc) in source_map {
510        if sc
511            .columns
512            .iter()
513            .any(|c| c.eq_ignore_ascii_case(norm_col_name))
514        {
515            matches.push(source_name);
516        }
517    }
518    if matches.len() == 1 {
519        Some(matches[0].to_string())
520    } else {
521        None
522    }
523}
524
525// ═══════════════════════════════════════════════════════════════════════
526// Tests
527// ═══════════════════════════════════════════════════════════════════════
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use crate::generator::generate;
533    use crate::parser::parse;
534    use crate::schema::MappingSchema;
535
536    fn make_schema() -> MappingSchema {
537        let mut schema = MappingSchema::new(Dialect::Ansi);
538        schema
539            .add_table(
540                &["users"],
541                vec![
542                    ("id".to_string(), DataType::Int),
543                    ("name".to_string(), DataType::Varchar(Some(255))),
544                    ("email".to_string(), DataType::Text),
545                ],
546            )
547            .unwrap();
548        schema
549            .add_table(
550                &["orders"],
551                vec![
552                    ("id".to_string(), DataType::Int),
553                    ("user_id".to_string(), DataType::Int),
554                    (
555                        "amount".to_string(),
556                        DataType::Decimal {
557                            precision: Some(10),
558                            scale: Some(2),
559                        },
560                    ),
561                    ("status".to_string(), DataType::Varchar(Some(50))),
562                ],
563            )
564            .unwrap();
565        schema
566            .add_table(
567                &["products"],
568                vec![
569                    ("id".to_string(), DataType::Int),
570                    ("name".to_string(), DataType::Varchar(Some(255))),
571                    (
572                        "price".to_string(),
573                        DataType::Decimal {
574                            precision: Some(10),
575                            scale: Some(2),
576                        },
577                    ),
578                ],
579            )
580            .unwrap();
581        schema
582    }
583
584    fn qualify(sql: &str, schema: &MappingSchema) -> String {
585        let stmt = parse(sql, Dialect::Ansi).unwrap();
586        let qualified = qualify_columns(stmt, schema);
587        generate(&qualified, Dialect::Ansi)
588    }
589
590    #[test]
591    fn test_expand_star() {
592        let schema = make_schema();
593        assert_eq!(
594            qualify("SELECT * FROM users", &schema),
595            "SELECT id, name, email FROM users"
596        );
597    }
598
599    #[test]
600    fn test_expand_qualified_wildcard() {
601        let schema = make_schema();
602        assert_eq!(
603            qualify("SELECT users.* FROM users", &schema),
604            "SELECT users.id, users.name, users.email FROM users"
605        );
606    }
607
608    #[test]
609    fn test_expand_star_with_alias() {
610        let schema = make_schema();
611        assert_eq!(
612            qualify("SELECT * FROM users AS u", &schema),
613            "SELECT id, name, email FROM users AS u"
614        );
615    }
616
617    #[test]
618    fn test_expand_qualified_wildcard_alias() {
619        let schema = make_schema();
620        assert_eq!(
621            qualify("SELECT u.* FROM users AS u", &schema),
622            "SELECT u.id, u.name, u.email FROM users AS u"
623        );
624    }
625
626    #[test]
627    fn test_qualify_unqualified_single_table() {
628        let schema = make_schema();
629        assert_eq!(
630            qualify("SELECT id, name FROM users", &schema),
631            "SELECT users.id, users.name FROM users"
632        );
633    }
634
635    #[test]
636    fn test_qualify_unqualified_single_table_alias() {
637        let schema = make_schema();
638        assert_eq!(
639            qualify("SELECT id, name FROM users AS u", &schema),
640            "SELECT u.id, u.name FROM users AS u"
641        );
642    }
643
644    #[test]
645    fn test_qualify_already_qualified() {
646        let schema = make_schema();
647        assert_eq!(
648            qualify("SELECT users.id, users.name FROM users", &schema),
649            "SELECT users.id, users.name FROM users"
650        );
651    }
652
653    #[test]
654    fn test_qualify_join_unambiguous() {
655        let schema = make_schema();
656        assert_eq!(
657            qualify(
658                "SELECT name, amount FROM users JOIN orders ON users.id = orders.user_id",
659                &schema
660            ),
661            "SELECT users.name, orders.amount FROM users INNER JOIN orders ON users.id = orders.user_id"
662        );
663    }
664
665    #[test]
666    fn test_qualify_join_ambiguous_left_unqualified() {
667        // 'id' exists in both users and orders — should remain unqualified
668        let schema = make_schema();
669        let result = qualify(
670            "SELECT id FROM users JOIN orders ON users.id = orders.user_id",
671            &schema,
672        );
673        // Ambiguous — stays unqualified
674        assert_eq!(
675            result,
676            "SELECT id FROM users INNER JOIN orders ON users.id = orders.user_id"
677        );
678    }
679
680    #[test]
681    fn test_qualify_where_clause() {
682        let schema = make_schema();
683        assert_eq!(
684            qualify(
685                "SELECT name FROM users WHERE email = 'test@test.com'",
686                &schema
687            ),
688            "SELECT users.name FROM users WHERE users.email = 'test@test.com'"
689        );
690    }
691
692    #[test]
693    fn test_qualify_order_by() {
694        let schema = make_schema();
695        assert_eq!(
696            qualify("SELECT name FROM users ORDER BY email", &schema),
697            "SELECT users.name FROM users ORDER BY users.email"
698        );
699    }
700
701    #[test]
702    fn test_qualify_group_by_having() {
703        let schema = make_schema();
704        assert_eq!(
705            qualify(
706                "SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
707                &schema
708            ),
709            "SELECT orders.status, COUNT(*) FROM orders GROUP BY orders.status HAVING COUNT(*) > 1"
710        );
711    }
712
713    #[test]
714    fn test_expand_star_join() {
715        let schema = make_schema();
716        let result = qualify(
717            "SELECT * FROM users JOIN orders ON users.id = orders.user_id",
718            &schema,
719        );
720        assert_eq!(
721            result,
722            "SELECT id, name, email, id, user_id, amount, status FROM users INNER JOIN orders ON users.id = orders.user_id"
723        );
724    }
725
726    #[test]
727    fn test_cte_column_resolution() {
728        let schema = make_schema();
729        let result = qualify(
730            "WITH active AS (SELECT id, name FROM users) SELECT id, name FROM active",
731            &schema,
732        );
733        assert_eq!(
734            result,
735            "WITH active AS (SELECT users.id, users.name FROM users) SELECT active.id, active.name FROM active"
736        );
737    }
738
739    #[test]
740    fn test_derived_table_column_resolution() {
741        let schema = make_schema();
742        let result = qualify(
743            "SELECT id FROM (SELECT id, name FROM users) AS sub",
744            &schema,
745        );
746        assert_eq!(
747            result,
748            "SELECT sub.id FROM (SELECT users.id, users.name FROM users) AS sub"
749        );
750    }
751
752    #[test]
753    fn test_preserve_expression_aliases() {
754        let schema = make_schema();
755        assert_eq!(
756            qualify("SELECT name AS user_name FROM users", &schema),
757            "SELECT users.name AS user_name FROM users"
758        );
759    }
760
761    #[test]
762    fn test_qualify_join_on() {
763        let schema = make_schema();
764        // 'id' is ambiguous (in both users and orders) so stays unqualified
765        // 'user_id' is unique to orders so gets qualified
766        assert_eq!(
767            qualify(
768                "SELECT name FROM users JOIN orders ON id = user_id",
769                &schema
770            ),
771            "SELECT users.name FROM users INNER JOIN orders ON id = orders.user_id"
772        );
773    }
774
775    #[test]
776    fn test_no_schema_columns_passthrough() {
777        // Table not in schema — columns pass through unchanged
778        let schema = make_schema();
779        assert_eq!(
780            qualify("SELECT x, y FROM unknown_table", &schema),
781            "SELECT x, y FROM unknown_table"
782        );
783    }
784}