Skip to main content

sqrust_rules/ambiguous/
table_alias_conflict.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Query, SetExpr, Statement, TableFactor, TableWithJoins};
3use std::collections::HashMap;
4
5pub struct TableAliasConflict;
6
7impl Rule for TableAliasConflict {
8    fn name(&self) -> &'static str {
9        "Ambiguous/TableAliasConflict"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16
17        let mut diags = Vec::new();
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, self.name(), &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(query: &Query, rule: &'static str, diags: &mut Vec<Diagnostic>) {
28    // Recurse into CTEs.
29    if let Some(with) = &query.with {
30        for cte in &with.cte_tables {
31            check_query(&cte.query, rule, diags);
32        }
33    }
34    check_set_expr(&query.body, rule, diags);
35}
36
37fn check_set_expr(expr: &SetExpr, rule: &'static str, diags: &mut Vec<Diagnostic>) {
38    match expr {
39        SetExpr::Select(sel) => {
40            // Collect effective names for all table references in this SELECT's FROM clause.
41            // Key: lowercased effective name. Value: first occurrence line/col.
42            let mut seen: HashMap<String, (usize, usize)> = HashMap::new();
43            let mut reported: std::collections::HashSet<String> = std::collections::HashSet::new();
44
45            for twj in &sel.from {
46                collect_from_item(twj, rule, &mut seen, &mut reported, diags);
47            }
48
49            // Recurse into subqueries in FROM / JOIN — but as separate scopes.
50            for twj in &sel.from {
51                recurse_subqueries_in_from(twj, rule, diags);
52            }
53        }
54        SetExpr::SetOperation { left, right, .. } => {
55            check_set_expr(left, rule, diags);
56            check_set_expr(right, rule, diags);
57        }
58        SetExpr::Query(inner) => {
59            check_query(inner, rule, diags);
60        }
61        _ => {}
62    }
63}
64
65/// Collects all table reference effective names from a `TableWithJoins` and checks for conflicts.
66/// Does NOT recurse into subqueries — they are handled separately as independent scopes.
67fn collect_from_item(
68    twj: &TableWithJoins,
69    rule: &'static str,
70    seen: &mut HashMap<String, (usize, usize)>,
71    reported: &mut std::collections::HashSet<String>,
72    diags: &mut Vec<Diagnostic>,
73) {
74    collect_table_factor_name(&twj.relation, rule, seen, reported, diags);
75    for join in &twj.joins {
76        collect_table_factor_name(&join.relation, rule, seen, reported, diags);
77    }
78}
79
80/// Extracts the effective name from a `TableFactor` and checks for conflicts.
81/// For `TableFactor::Table`: uses alias if present, else last part of table name.
82/// For other variants (Derived, etc.): skips (handled as subquery scope).
83fn collect_table_factor_name(
84    tf: &TableFactor,
85    rule: &'static str,
86    seen: &mut HashMap<String, (usize, usize)>,
87    reported: &mut std::collections::HashSet<String>,
88    diags: &mut Vec<Diagnostic>,
89) {
90    if let TableFactor::Table { name, alias, .. } = tf {
91        let effective = if let Some(table_alias) = alias {
92            table_alias.name.value.to_lowercase()
93        } else {
94            // Use the last part of the qualified name as the effective name.
95            name.0
96                .last()
97                .map(|ident| ident.value.to_lowercase())
98                .unwrap_or_default()
99        };
100
101        if effective.is_empty() {
102            return;
103        }
104
105        if seen.contains_key(&effective) {
106            // Only report once per conflicting alias per SELECT scope.
107            if reported.insert(effective.clone()) {
108                diags.push(Diagnostic {
109                    rule,
110                    message: format!(
111                        "Table alias '{}' is used more than once in this FROM clause",
112                        effective
113                    ),
114                    line: 1,
115                    col: 1,
116                });
117            }
118        } else {
119            seen.insert(effective, (1, 1));
120        }
121    }
122}
123
124/// Recurses into subqueries that appear as `Derived` table factors (independent scopes).
125fn recurse_subqueries_in_from(twj: &TableWithJoins, rule: &'static str, diags: &mut Vec<Diagnostic>) {
126    recurse_table_factor_subquery(&twj.relation, rule, diags);
127    for join in &twj.joins {
128        recurse_table_factor_subquery(&join.relation, rule, diags);
129    }
130}
131
132fn recurse_table_factor_subquery(
133    tf: &TableFactor,
134    rule: &'static str,
135    diags: &mut Vec<Diagnostic>,
136) {
137    if let TableFactor::Derived { subquery, .. } = tf {
138        check_query(subquery, rule, diags);
139    }
140}