1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, With};
3use std::collections::HashSet;
4
5pub struct ColumnAliasInWhere;
6
7impl Rule for ColumnAliasInWhere {
8 fn name(&self) -> &'static str {
9 "Lint/ColumnAliasInWhere"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16 let mut diags = Vec::new();
17 for stmt in &ctx.statements {
18 check_stmt(stmt, &ctx.source, "Lint/ColumnAliasInWhere", &mut diags);
19 }
20 diags
21 }
22}
23
24fn check_stmt(stmt: &Statement, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
25 if let Statement::Query(q) = stmt {
26 check_query(q, src, rule, diags);
27 }
28}
29
30fn check_query(q: &Query, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
31 if let Some(With { cte_tables, .. }) = &q.with {
32 for cte in cte_tables {
33 check_query(&cte.query, src, rule, diags);
34 }
35 }
36 check_set_expr(&q.body, src, rule, diags);
37}
38
39fn check_set_expr(body: &SetExpr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
40 match body {
41 SetExpr::Select(s) => check_select(s, src, rule, diags),
42 SetExpr::SetOperation { left, right, .. } => {
43 check_set_expr(left, src, rule, diags);
44 check_set_expr(right, src, rule, diags);
45 }
46 SetExpr::Query(q) => check_query(q, src, rule, diags),
47 _ => {}
48 }
49}
50
51fn check_select(sel: &Select, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
52 let mut aliases: HashSet<String> = HashSet::new();
54 for item in &sel.projection {
55 if let SelectItem::ExprWithAlias { alias, .. } = item {
56 aliases.insert(alias.value.to_lowercase());
57 }
58 }
59
60 if aliases.is_empty() {
61 return;
62 }
63
64 if let Some(where_expr) = &sel.selection {
66 let start_offset = find_where_offset(src);
67 find_alias_refs(where_expr, &aliases, src, rule, diags, start_offset);
68 }
69}
70
71fn find_where_offset(src: &str) -> usize {
72 let bytes = src.as_bytes();
73 let kw = b"WHERE";
74 let mut i = 0;
75 while i + 5 <= bytes.len() {
76 if bytes[i..i + 5].eq_ignore_ascii_case(kw) {
77 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
78 let after_ok = i + 5 >= bytes.len() || !is_word_char(bytes[i + 5]);
79 if before_ok && after_ok {
80 return i;
81 }
82 }
83 i += 1;
84 }
85 0
86}
87
88fn find_alias_refs(
89 expr: &Expr,
90 aliases: &HashSet<String>,
91 src: &str,
92 rule: &'static str,
93 diags: &mut Vec<Diagnostic>,
94 start_offset: usize,
95) {
96 match expr {
97 Expr::Identifier(ident) => {
98 let lower = ident.value.to_lowercase();
99 if aliases.contains(&lower) {
100 if let Some(off) = find_word_in_source(src, &ident.value, start_offset) {
101 let (line, col) = offset_to_line_col(src, off);
102 diags.push(Diagnostic {
103 rule,
104 message: format!(
105 "Column alias '{}' is used in WHERE clause; aliases are not available in WHERE (evaluated before SELECT)",
106 ident.value
107 ),
108 line,
109 col,
110 });
111 }
112 }
113 }
114 Expr::BinaryOp { left, right, .. } => {
115 find_alias_refs(left, aliases, src, rule, diags, start_offset);
116 find_alias_refs(right, aliases, src, rule, diags, start_offset);
117 }
118 Expr::UnaryOp { expr, .. } | Expr::Nested(expr) => {
119 find_alias_refs(expr, aliases, src, rule, diags, start_offset);
120 }
121 Expr::Between { expr, low, high, .. } => {
122 find_alias_refs(expr, aliases, src, rule, diags, start_offset);
123 find_alias_refs(low, aliases, src, rule, diags, start_offset);
124 find_alias_refs(high, aliases, src, rule, diags, start_offset);
125 }
126 Expr::InList { expr, list, .. } => {
127 find_alias_refs(expr, aliases, src, rule, diags, start_offset);
128 for e in list {
129 find_alias_refs(e, aliases, src, rule, diags, start_offset);
130 }
131 }
132 Expr::IsNull(e) | Expr::IsNotNull(e) => {
133 find_alias_refs(e, aliases, src, rule, diags, start_offset);
134 }
135 Expr::Like { expr, pattern, .. } | Expr::ILike { expr, pattern, .. } => {
136 find_alias_refs(expr, aliases, src, rule, diags, start_offset);
137 find_alias_refs(pattern, aliases, src, rule, diags, start_offset);
138 }
139 _ => {}
140 }
141}
142
143fn find_word_in_source(src: &str, word: &str, start: usize) -> Option<usize> {
144 let bytes = src.as_bytes();
145 let wbytes = word.as_bytes();
146 let wlen = wbytes.len();
147 if wlen == 0 {
148 return None;
149 }
150 let mut i = start;
151 while i + wlen <= bytes.len() {
152 if bytes[i..i + wlen].eq_ignore_ascii_case(wbytes) {
153 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
154 let after_ok = i + wlen >= bytes.len() || !is_word_char(bytes[i + wlen]);
155 if before_ok && after_ok {
156 return Some(i);
157 }
158 }
159 i += 1;
160 }
161 None
162}
163
164fn is_word_char(b: u8) -> bool {
165 b.is_ascii_alphanumeric() || b == b'_'
166}
167
168fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
169 let before = &source[..offset.min(source.len())];
170 let line = before.chars().filter(|&c| c == '\n').count() + 1;
171 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
172 (line, col)
173}