Skip to main content

sqrust_rules/structure/
unqualified_column_in_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArgExpr, GroupByExpr, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor, With,
5};
6
7pub struct UnqualifiedColumnInJoin;
8
9impl Rule for UnqualifiedColumnInJoin {
10    fn name(&self) -> &'static str {
11        "Structure/UnqualifiedColumnInJoin"
12    }
13
14    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15        if !ctx.parse_errors.is_empty() {
16            return Vec::new();
17        }
18        let mut diags = Vec::new();
19        for stmt in &ctx.statements {
20            if let Statement::Query(q) = stmt {
21                check_query(q, &ctx.source, self.name(), &mut diags);
22            }
23        }
24        diags
25    }
26}
27
28fn check_query(q: &Query, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
29    if let Some(With { cte_tables, .. }) = &q.with {
30        for cte in cte_tables {
31            check_query(&cte.query, src, rule, diags);
32        }
33    }
34    check_set_expr(&q.body, src, rule, diags);
35}
36
37fn check_set_expr(body: &SetExpr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
38    match body {
39        SetExpr::Select(s) => check_select(s, src, rule, diags),
40        SetExpr::SetOperation { left, right, .. } => {
41            check_set_expr(left, src, rule, diags);
42            check_set_expr(right, src, rule, diags);
43        }
44        SetExpr::Query(q) => check_query(q, src, rule, diags),
45        _ => {}
46    }
47}
48
49fn check_select(sel: &Select, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
50    // Recurse into subqueries in FROM clause regardless of join presence
51    for twj in &sel.from {
52        recurse_table_factor(&twj.relation, src, rule, diags);
53        for join in &twj.joins {
54            recurse_table_factor(&join.relation, src, rule, diags);
55        }
56    }
57
58    // Only flag unqualified columns when there are explicit JOINs
59    let has_joins = sel.from.iter().any(|twj| !twj.joins.is_empty());
60    if !has_joins {
61        return;
62    }
63
64    // Check SELECT projections
65    for item in &sel.projection {
66        match item {
67            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
68                find_unqualified(e, src, rule, diags);
69            }
70            SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {}
71        }
72    }
73
74    // Check WHERE
75    if let Some(w) = &sel.selection {
76        find_unqualified(w, src, rule, diags);
77    }
78
79    // Check HAVING
80    if let Some(h) = &sel.having {
81        find_unqualified(h, src, rule, diags);
82    }
83
84    // Check GROUP BY
85    if let GroupByExpr::Expressions(exprs, _) = &sel.group_by {
86        for g in exprs {
87            find_unqualified(g, src, rule, diags);
88        }
89    }
90}
91
92fn recurse_table_factor(tf: &TableFactor, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
93    if let TableFactor::Derived { subquery, .. } = tf {
94        check_query(subquery, src, rule, diags);
95    }
96}
97
98fn find_unqualified(expr: &Expr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
99    match expr {
100        Expr::Identifier(i) => {
101            // Unqualified column reference — flag it
102            if let Some(off) = find_word_in_source(src, &i.value, 0) {
103                let (line, col) = offset_to_line_col(src, off);
104                diags.push(Diagnostic {
105                    rule,
106                    message: format!(
107                        "Column '{}' is not qualified with a table name or alias; in a JOIN query, all columns should be table-qualified",
108                        i.value
109                    ),
110                    line,
111                    col,
112                });
113            }
114        }
115        Expr::CompoundIdentifier(_) => {} // Qualified — ok
116        Expr::BinaryOp { left, right, .. } => {
117            find_unqualified(left, src, rule, diags);
118            find_unqualified(right, src, rule, diags);
119        }
120        Expr::UnaryOp { expr, .. } | Expr::Nested(expr) => {
121            find_unqualified(expr, src, rule, diags);
122        }
123        Expr::Function(f) => {
124            if let sqlparser::ast::FunctionArguments::List(arg_list) = &f.args {
125                for arg in &arg_list.args {
126                    if let sqlparser::ast::FunctionArg::Unnamed(arg_expr) = arg {
127                        if let FunctionArgExpr::Expr(e) = arg_expr {
128                            find_unqualified(e, src, rule, diags);
129                        }
130                    }
131                }
132            }
133        }
134        Expr::IsNull(e) | Expr::IsNotNull(e) => find_unqualified(e, src, rule, diags),
135        Expr::Between { expr, low, high, .. } => {
136            find_unqualified(expr, src, rule, diags);
137            find_unqualified(low, src, rule, diags);
138            find_unqualified(high, src, rule, diags);
139        }
140        Expr::InList { expr, list, .. } => {
141            find_unqualified(expr, src, rule, diags);
142            for e in list {
143                find_unqualified(e, src, rule, diags);
144            }
145        }
146        Expr::Case { operand, conditions, results, else_result } => {
147            if let Some(e) = operand {
148                find_unqualified(e, src, rule, diags);
149            }
150            for (c, r) in conditions.iter().zip(results.iter()) {
151                find_unqualified(c, src, rule, diags);
152                find_unqualified(r, src, rule, diags);
153            }
154            if let Some(e) = else_result {
155                find_unqualified(e, src, rule, diags);
156            }
157        }
158        _ => {}
159    }
160}
161
162fn find_word_in_source(src: &str, word: &str, start: usize) -> Option<usize> {
163    let bytes = src.as_bytes();
164    let wbytes = word.as_bytes();
165    let wlen = wbytes.len();
166    if wlen == 0 {
167        return None;
168    }
169    let mut i = start;
170    while i + wlen <= bytes.len() {
171        if bytes[i..i + wlen].eq_ignore_ascii_case(wbytes) {
172            let before_ok = i == 0 || (!is_wc(bytes[i - 1]) && bytes[i - 1] != b'.');
173            let after_ok = i + wlen >= bytes.len() || !is_wc(bytes[i + wlen]);
174            if before_ok && after_ok {
175                return Some(i);
176            }
177        }
178        i += 1;
179    }
180    None
181}
182
183fn is_wc(b: u8) -> bool {
184    b.is_ascii_alphanumeric() || b == b'_'
185}
186
187fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
188    let before = &source[..offset.min(source.len())];
189    let line = before.chars().filter(|&c| c == '\n').count() + 1;
190    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
191    (line, col)
192}