1use std::collections::HashSet;
2
3use sqrust_core::{Diagnostic, FileContext, Rule};
4use sqlparser::ast::{
5 Expr, GroupByExpr, Ident, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
6};
7
8use crate::capitalisation::{is_word_char, SkipMap};
9
10pub struct LateralColumnAlias;
11
12impl Rule for LateralColumnAlias {
13 fn name(&self) -> &'static str {
14 "Structure/LateralColumnAlias"
15 }
16
17 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
18 if !ctx.parse_errors.is_empty() {
19 return Vec::new();
20 }
21
22 let mut diags = Vec::new();
23
24 for stmt in &ctx.statements {
25 if let Statement::Query(query) = stmt {
26 check_query(query, self.name(), ctx, &mut diags);
27 }
28 }
29
30 diags
31 }
32}
33
34fn check_query(
37 query: &Query,
38 rule: &'static str,
39 ctx: &FileContext,
40 diags: &mut Vec<Diagnostic>,
41) {
42 if let Some(with) = &query.with {
43 for cte in &with.cte_tables {
44 check_query(&cte.query, rule, ctx, diags);
45 }
46 }
47
48 check_set_expr(&query.body, rule, ctx, diags);
49}
50
51fn check_set_expr(
52 expr: &SetExpr,
53 rule: &'static str,
54 ctx: &FileContext,
55 diags: &mut Vec<Diagnostic>,
56) {
57 match expr {
58 SetExpr::Select(sel) => {
59 check_select(sel, rule, ctx, diags);
60 }
61 SetExpr::Query(inner) => {
62 check_query(inner, rule, ctx, diags);
63 }
64 SetExpr::SetOperation { left, right, .. } => {
65 check_set_expr(left, rule, ctx, diags);
66 check_set_expr(right, rule, ctx, diags);
67 }
68 _ => {}
69 }
70}
71
72fn check_select(
73 sel: &Select,
74 rule: &'static str,
75 ctx: &FileContext,
76 diags: &mut Vec<Diagnostic>,
77) {
78 let aliases: HashSet<String> = sel
80 .projection
81 .iter()
82 .filter_map(|item| {
83 if let SelectItem::ExprWithAlias { alias, .. } = item {
84 Some(alias.value.to_lowercase())
85 } else {
86 None
87 }
88 })
89 .collect();
90
91 if aliases.is_empty() {
92 recurse_from(sel, rule, ctx, diags);
95 return;
96 }
97
98 if let Some(selection) = &sel.selection {
100 collect_lateral_alias_refs(selection, &aliases, rule, ctx, diags);
101 }
102
103 if let GroupByExpr::Expressions(exprs, _) = &sel.group_by {
105 for expr in exprs {
106 collect_lateral_alias_refs(expr, &aliases, rule, ctx, diags);
107 }
108 }
109
110 if let Some(having) = &sel.having {
112 collect_lateral_alias_refs(having, &aliases, rule, ctx, diags);
113 }
114
115 recurse_from(sel, rule, ctx, diags);
117}
118
119fn recurse_from(
120 sel: &Select,
121 rule: &'static str,
122 ctx: &FileContext,
123 diags: &mut Vec<Diagnostic>,
124) {
125 for twj in &sel.from {
126 recurse_table_factor(&twj.relation, rule, ctx, diags);
127 for join in &twj.joins {
128 recurse_table_factor(&join.relation, rule, ctx, diags);
129 }
130 }
131}
132
133fn recurse_table_factor(
134 tf: &TableFactor,
135 rule: &'static str,
136 ctx: &FileContext,
137 diags: &mut Vec<Diagnostic>,
138) {
139 if let TableFactor::Derived { subquery, .. } = tf {
140 check_query(subquery, rule, ctx, diags);
141 }
142}
143
144fn collect_lateral_alias_refs(
149 expr: &Expr,
150 aliases: &HashSet<String>,
151 rule: &'static str,
152 ctx: &FileContext,
153 diags: &mut Vec<Diagnostic>,
154) {
155 match expr {
156 Expr::Identifier(ident) => {
157 check_ident(ident, aliases, rule, ctx, diags);
158 }
159 Expr::BinaryOp { left, right, .. } => {
160 collect_lateral_alias_refs(left, aliases, rule, ctx, diags);
161 collect_lateral_alias_refs(right, aliases, rule, ctx, diags);
162 }
163 Expr::UnaryOp { expr: inner, .. } => {
164 collect_lateral_alias_refs(inner, aliases, rule, ctx, diags);
165 }
166 Expr::Nested(inner) => {
167 collect_lateral_alias_refs(inner, aliases, rule, ctx, diags);
168 }
169 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
170 collect_lateral_alias_refs(inner, aliases, rule, ctx, diags);
171 }
172 Expr::Between {
173 expr: e, low, high, ..
174 } => {
175 collect_lateral_alias_refs(e, aliases, rule, ctx, diags);
176 collect_lateral_alias_refs(low, aliases, rule, ctx, diags);
177 collect_lateral_alias_refs(high, aliases, rule, ctx, diags);
178 }
179 Expr::InList { expr: e, list, .. } => {
180 collect_lateral_alias_refs(e, aliases, rule, ctx, diags);
181 for item in list {
182 collect_lateral_alias_refs(item, aliases, rule, ctx, diags);
183 }
184 }
185 Expr::Case {
186 operand,
187 conditions,
188 results,
189 else_result,
190 } => {
191 if let Some(op) = operand {
192 collect_lateral_alias_refs(op, aliases, rule, ctx, diags);
193 }
194 for cond in conditions {
195 collect_lateral_alias_refs(cond, aliases, rule, ctx, diags);
196 }
197 for res in results {
198 collect_lateral_alias_refs(res, aliases, rule, ctx, diags);
199 }
200 if let Some(else_e) = else_result {
201 collect_lateral_alias_refs(else_e, aliases, rule, ctx, diags);
202 }
203 }
204 Expr::Function(func) => {
205 if let sqlparser::ast::FunctionArguments::List(list) = &func.args {
206 for arg in &list.args {
207 let fae = match arg {
208 sqlparser::ast::FunctionArg::Named { arg, .. }
209 | sqlparser::ast::FunctionArg::ExprNamed { arg, .. }
210 | sqlparser::ast::FunctionArg::Unnamed(arg) => arg,
211 };
212 if let sqlparser::ast::FunctionArgExpr::Expr(e) = fae {
213 collect_lateral_alias_refs(e, aliases, rule, ctx, diags);
214 }
215 }
216 }
217 }
218 _ => {}
220 }
221}
222
223fn check_ident(
224 ident: &Ident,
225 aliases: &HashSet<String>,
226 rule: &'static str,
227 ctx: &FileContext,
228 diags: &mut Vec<Diagnostic>,
229) {
230 if ident.quote_style.is_some() {
232 return;
233 }
234
235 let name_lower = ident.value.to_lowercase();
236 if !aliases.contains(&name_lower) {
237 return;
238 }
239
240 let offset = find_identifier_offset(&ctx.source, &ident.value);
241 let (line, col) = offset_to_line_col(&ctx.source, offset);
242
243 diags.push(Diagnostic {
244 rule,
245 message: format!(
246 "Column alias '{}' used in WHERE/GROUP BY/HAVING — lateral column aliases are not supported by most databases",
247 ident.value
248 ),
249 line,
250 col,
251 });
252}
253
254fn find_identifier_offset(source: &str, name: &str) -> usize {
260 let bytes = source.as_bytes();
261 let skip_map = SkipMap::build(source);
262 let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_lowercase()).collect();
263 let name_len = name_bytes.len();
264 let src_len = bytes.len();
265
266 let mut i = 0usize;
267
268 while i + name_len <= src_len {
269 if !skip_map.is_code(i) {
270 i += 1;
271 continue;
272 }
273
274 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
275 if !before_ok {
276 i += 1;
277 continue;
278 }
279
280 let matches = bytes[i..i + name_len]
281 .iter()
282 .zip(name_bytes.iter())
283 .all(|(&a, &b)| a.to_ascii_lowercase() == b);
284
285 if matches {
286 let after = i + name_len;
287 let after_ok = after >= src_len || !is_word_char(bytes[after]);
288 let all_code = (i..i + name_len).all(|k| skip_map.is_code(k));
289
290 if after_ok && all_code {
291 return i;
292 }
293 }
294
295 i += 1;
296 }
297
298 0
299}
300
301fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
303 let before = &source[..offset.min(source.len())];
304 let line = before.chars().filter(|&c| c == '\n').count() + 1;
305 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
306 (line, col)
307}