Skip to main content

sqrust_rules/structure/
lateral_column_alias.rs

1use std::collections::HashSet;
2
3use sqrust_core::{Diagnostic, FileContext, Rule};
4use sqlparser::ast::{
5    Expr, GroupByExpr, Ident, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
6};
7
8use crate::capitalisation::{is_word_char, SkipMap};
9
10pub struct LateralColumnAlias;
11
12impl Rule for LateralColumnAlias {
13    fn name(&self) -> &'static str {
14        "Structure/LateralColumnAlias"
15    }
16
17    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
18        if !ctx.parse_errors.is_empty() {
19            return Vec::new();
20        }
21
22        let mut diags = Vec::new();
23
24        for stmt in &ctx.statements {
25            if let Statement::Query(query) = stmt {
26                check_query(query, self.name(), ctx, &mut diags);
27            }
28        }
29
30        diags
31    }
32}
33
34// ── AST walking ───────────────────────────────────────────────────────────────
35
36fn check_query(
37    query: &Query,
38    rule: &'static str,
39    ctx: &FileContext,
40    diags: &mut Vec<Diagnostic>,
41) {
42    if let Some(with) = &query.with {
43        for cte in &with.cte_tables {
44            check_query(&cte.query, rule, ctx, diags);
45        }
46    }
47
48    check_set_expr(&query.body, rule, ctx, diags);
49}
50
51fn check_set_expr(
52    expr: &SetExpr,
53    rule: &'static str,
54    ctx: &FileContext,
55    diags: &mut Vec<Diagnostic>,
56) {
57    match expr {
58        SetExpr::Select(sel) => {
59            check_select(sel, rule, ctx, diags);
60        }
61        SetExpr::Query(inner) => {
62            check_query(inner, rule, ctx, diags);
63        }
64        SetExpr::SetOperation { left, right, .. } => {
65            check_set_expr(left, rule, ctx, diags);
66            check_set_expr(right, rule, ctx, diags);
67        }
68        _ => {}
69    }
70}
71
72fn check_select(
73    sel: &Select,
74    rule: &'static str,
75    ctx: &FileContext,
76    diags: &mut Vec<Diagnostic>,
77) {
78    // Collect all aliases defined in the SELECT projection.
79    let aliases: HashSet<String> = sel
80        .projection
81        .iter()
82        .filter_map(|item| {
83            if let SelectItem::ExprWithAlias { alias, .. } = item {
84                Some(alias.value.to_lowercase())
85            } else {
86                None
87            }
88        })
89        .collect();
90
91    if aliases.is_empty() {
92        // No aliases — nothing can be a lateral alias reference.
93        // Still recurse into subqueries in FROM.
94        recurse_from(sel, rule, ctx, diags);
95        return;
96    }
97
98    // Check WHERE clause.
99    if let Some(selection) = &sel.selection {
100        collect_lateral_alias_refs(selection, &aliases, rule, ctx, diags);
101    }
102
103    // Check GROUP BY expressions.
104    if let GroupByExpr::Expressions(exprs, _) = &sel.group_by {
105        for expr in exprs {
106            collect_lateral_alias_refs(expr, &aliases, rule, ctx, diags);
107        }
108    }
109
110    // Check HAVING clause.
111    if let Some(having) = &sel.having {
112        collect_lateral_alias_refs(having, &aliases, rule, ctx, diags);
113    }
114
115    // Recurse into subqueries in the FROM clause.
116    recurse_from(sel, rule, ctx, diags);
117}
118
119fn recurse_from(
120    sel: &Select,
121    rule: &'static str,
122    ctx: &FileContext,
123    diags: &mut Vec<Diagnostic>,
124) {
125    for twj in &sel.from {
126        recurse_table_factor(&twj.relation, rule, ctx, diags);
127        for join in &twj.joins {
128            recurse_table_factor(&join.relation, rule, ctx, diags);
129        }
130    }
131}
132
133fn recurse_table_factor(
134    tf: &TableFactor,
135    rule: &'static str,
136    ctx: &FileContext,
137    diags: &mut Vec<Diagnostic>,
138) {
139    if let TableFactor::Derived { subquery, .. } = tf {
140        check_query(subquery, rule, ctx, diags);
141    }
142}
143
144// ── Lateral alias detection ───────────────────────────────────────────────────
145
146/// Recursively walks `expr` and emits a Diagnostic for every unquoted
147/// identifier that matches one of the SELECT-list aliases.
148fn collect_lateral_alias_refs(
149    expr: &Expr,
150    aliases: &HashSet<String>,
151    rule: &'static str,
152    ctx: &FileContext,
153    diags: &mut Vec<Diagnostic>,
154) {
155    match expr {
156        Expr::Identifier(ident) => {
157            check_ident(ident, aliases, rule, ctx, diags);
158        }
159        Expr::BinaryOp { left, right, .. } => {
160            collect_lateral_alias_refs(left, aliases, rule, ctx, diags);
161            collect_lateral_alias_refs(right, aliases, rule, ctx, diags);
162        }
163        Expr::UnaryOp { expr: inner, .. } => {
164            collect_lateral_alias_refs(inner, aliases, rule, ctx, diags);
165        }
166        Expr::Nested(inner) => {
167            collect_lateral_alias_refs(inner, aliases, rule, ctx, diags);
168        }
169        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
170            collect_lateral_alias_refs(inner, aliases, rule, ctx, diags);
171        }
172        Expr::Between {
173            expr: e, low, high, ..
174        } => {
175            collect_lateral_alias_refs(e, aliases, rule, ctx, diags);
176            collect_lateral_alias_refs(low, aliases, rule, ctx, diags);
177            collect_lateral_alias_refs(high, aliases, rule, ctx, diags);
178        }
179        Expr::InList { expr: e, list, .. } => {
180            collect_lateral_alias_refs(e, aliases, rule, ctx, diags);
181            for item in list {
182                collect_lateral_alias_refs(item, aliases, rule, ctx, diags);
183            }
184        }
185        Expr::Case {
186            operand,
187            conditions,
188            results,
189            else_result,
190        } => {
191            if let Some(op) = operand {
192                collect_lateral_alias_refs(op, aliases, rule, ctx, diags);
193            }
194            for cond in conditions {
195                collect_lateral_alias_refs(cond, aliases, rule, ctx, diags);
196            }
197            for res in results {
198                collect_lateral_alias_refs(res, aliases, rule, ctx, diags);
199            }
200            if let Some(else_e) = else_result {
201                collect_lateral_alias_refs(else_e, aliases, rule, ctx, diags);
202            }
203        }
204        Expr::Function(func) => {
205            if let sqlparser::ast::FunctionArguments::List(list) = &func.args {
206                for arg in &list.args {
207                    let fae = match arg {
208                        sqlparser::ast::FunctionArg::Named { arg, .. }
209                        | sqlparser::ast::FunctionArg::ExprNamed { arg, .. }
210                        | sqlparser::ast::FunctionArg::Unnamed(arg) => arg,
211                    };
212                    if let sqlparser::ast::FunctionArgExpr::Expr(e) = fae {
213                        collect_lateral_alias_refs(e, aliases, rule, ctx, diags);
214                    }
215                }
216            }
217        }
218        // Do not descend into subqueries here — they have their own scope.
219        _ => {}
220    }
221}
222
223fn check_ident(
224    ident: &Ident,
225    aliases: &HashSet<String>,
226    rule: &'static str,
227    ctx: &FileContext,
228    diags: &mut Vec<Diagnostic>,
229) {
230    // Only flag unquoted identifiers (quote_style is None).
231    if ident.quote_style.is_some() {
232        return;
233    }
234
235    let name_lower = ident.value.to_lowercase();
236    if !aliases.contains(&name_lower) {
237        return;
238    }
239
240    let offset = find_identifier_offset(&ctx.source, &ident.value);
241    let (line, col) = offset_to_line_col(&ctx.source, offset);
242
243    diags.push(Diagnostic {
244        rule,
245        message: format!(
246            "Column alias '{}' used in WHERE/GROUP BY/HAVING — lateral column aliases are not supported by most databases",
247            ident.value
248        ),
249        line,
250        col,
251    });
252}
253
254// ── Source-text helpers ───────────────────────────────────────────────────────
255
256/// Finds the byte offset of the first whole-word, case-insensitive occurrence
257/// of `name` in `source`, skipping positions inside strings/comments.
258/// Returns 0 if not found.
259fn find_identifier_offset(source: &str, name: &str) -> usize {
260    let bytes = source.as_bytes();
261    let skip_map = SkipMap::build(source);
262    let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_lowercase()).collect();
263    let name_len = name_bytes.len();
264    let src_len = bytes.len();
265
266    let mut i = 0usize;
267
268    while i + name_len <= src_len {
269        if !skip_map.is_code(i) {
270            i += 1;
271            continue;
272        }
273
274        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
275        if !before_ok {
276            i += 1;
277            continue;
278        }
279
280        let matches = bytes[i..i + name_len]
281            .iter()
282            .zip(name_bytes.iter())
283            .all(|(&a, &b)| a.to_ascii_lowercase() == b);
284
285        if matches {
286            let after = i + name_len;
287            let after_ok = after >= src_len || !is_word_char(bytes[after]);
288            let all_code = (i..i + name_len).all(|k| skip_map.is_code(k));
289
290            if after_ok && all_code {
291                return i;
292            }
293        }
294
295        i += 1;
296    }
297
298    0
299}
300
301/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
302fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
303    let before = &source[..offset.min(source.len())];
304    let line = before.chars().filter(|&c| c == '\n').count() + 1;
305    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
306    (line, col)
307}