1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct AggregateInWhere;
7
8const AGGREGATES: &[&str] = &[
10 "COUNT",
11 "SUM",
12 "AVG",
13 "MIN",
14 "MAX",
15 "ARRAY_AGG",
16 "STRING_AGG",
17 "GROUP_CONCAT",
18 "EVERY",
19 "COUNT_IF",
20 "ANY_VALUE",
21];
22
23impl Rule for AggregateInWhere {
24 fn name(&self) -> &'static str {
25 "Structure/AggregateInWhere"
26 }
27
28 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
29 if !ctx.parse_errors.is_empty() {
30 return Vec::new();
31 }
32
33 let mut diags = Vec::new();
34 let mut counters: std::collections::HashMap<String, usize> =
37 std::collections::HashMap::new();
38
39 for stmt in &ctx.statements {
40 if let Statement::Query(query) = stmt {
41 check_query(query, ctx, &mut counters, &mut diags);
42 }
43 }
44
45 diags
46 }
47}
48
49fn check_query(
52 query: &Query,
53 ctx: &FileContext,
54 counters: &mut std::collections::HashMap<String, usize>,
55 diags: &mut Vec<Diagnostic>,
56) {
57 if let Some(with) = &query.with {
59 for cte in &with.cte_tables {
60 check_query(&cte.query, ctx, counters, diags);
61 }
62 }
63 check_set_expr(&query.body, ctx, counters, diags);
64}
65
66fn check_set_expr(
67 expr: &SetExpr,
68 ctx: &FileContext,
69 counters: &mut std::collections::HashMap<String, usize>,
70 diags: &mut Vec<Diagnostic>,
71) {
72 match expr {
73 SetExpr::Select(sel) => {
74 check_select(sel, ctx, counters, diags);
75 }
76 SetExpr::Query(inner) => {
77 check_query(inner, ctx, counters, diags);
78 }
79 SetExpr::SetOperation { left, right, .. } => {
80 check_set_expr(left, ctx, counters, diags);
81 check_set_expr(right, ctx, counters, diags);
82 }
83 _ => {}
84 }
85}
86
87fn check_select(
88 sel: &Select,
89 ctx: &FileContext,
90 counters: &mut std::collections::HashMap<String, usize>,
91 diags: &mut Vec<Diagnostic>,
92) {
93 if let Some(selection) = &sel.selection {
95 collect_aggregates_in_expr(selection, ctx, counters, diags);
96 }
97
98 for table_with_joins in &sel.from {
100 recurse_table_factor(&table_with_joins.relation, ctx, counters, diags);
101 for join in &table_with_joins.joins {
102 recurse_table_factor(&join.relation, ctx, counters, diags);
103 }
104 }
105}
106
107fn recurse_table_factor(
108 tf: &TableFactor,
109 ctx: &FileContext,
110 counters: &mut std::collections::HashMap<String, usize>,
111 diags: &mut Vec<Diagnostic>,
112) {
113 if let TableFactor::Derived { subquery, .. } = tf {
114 check_query(subquery, ctx, counters, diags);
115 }
116}
117
118fn collect_aggregates_in_expr(
123 expr: &Expr,
124 ctx: &FileContext,
125 counters: &mut std::collections::HashMap<String, usize>,
126 diags: &mut Vec<Diagnostic>,
127) {
128 match expr {
129 Expr::Function(func) => {
130 let name_upper = func
131 .name
132 .0
133 .last()
134 .map(|ident| ident.value.to_uppercase())
135 .unwrap_or_default();
136
137 if AGGREGATES.contains(&name_upper.as_str()) {
138 let occ = counters.entry(name_upper.clone()).or_insert(0);
139 let occurrence = *occ;
140 *occ += 1;
141
142 let offset = find_nth_occurrence(&ctx.source, &name_upper, occurrence);
143 let (line, col) = offset_to_line_col(&ctx.source, offset);
144
145 diags.push(Diagnostic {
146 rule: "Structure/AggregateInWhere",
147 message: "Aggregate function in WHERE clause; use HAVING instead".to_string(),
148 line,
149 col,
150 });
151 }
152 }
156 Expr::BinaryOp { left, right, .. } => {
157 collect_aggregates_in_expr(left, ctx, counters, diags);
158 collect_aggregates_in_expr(right, ctx, counters, diags);
159 }
160 Expr::UnaryOp { expr: inner, .. } => {
161 collect_aggregates_in_expr(inner, ctx, counters, diags);
162 }
163 Expr::Nested(inner) => {
164 collect_aggregates_in_expr(inner, ctx, counters, diags);
165 }
166 Expr::Between {
167 expr: e,
168 low,
169 high,
170 ..
171 } => {
172 collect_aggregates_in_expr(e, ctx, counters, diags);
173 collect_aggregates_in_expr(low, ctx, counters, diags);
174 collect_aggregates_in_expr(high, ctx, counters, diags);
175 }
176 Expr::Case {
177 operand,
178 conditions,
179 results,
180 else_result,
181 } => {
182 if let Some(op) = operand {
183 collect_aggregates_in_expr(op, ctx, counters, diags);
184 }
185 for cond in conditions {
186 collect_aggregates_in_expr(cond, ctx, counters, diags);
187 }
188 for res in results {
189 collect_aggregates_in_expr(res, ctx, counters, diags);
190 }
191 if let Some(else_e) = else_result {
192 collect_aggregates_in_expr(else_e, ctx, counters, diags);
193 }
194 }
195 Expr::InList { expr: inner, list, .. } => {
196 collect_aggregates_in_expr(inner, ctx, counters, diags);
197 for e in list {
198 collect_aggregates_in_expr(e, ctx, counters, diags);
199 }
200 }
201 Expr::InSubquery {
202 expr: inner,
203 subquery,
204 ..
205 } => {
206 collect_aggregates_in_expr(inner, ctx, counters, diags);
207 check_query(subquery, ctx, counters, diags);
208 }
209 Expr::Exists { subquery, .. } => {
210 check_query(subquery, ctx, counters, diags);
211 }
212 Expr::Subquery(q) => {
213 check_query(q, ctx, counters, diags);
214 }
215 _ => {}
216 }
217}
218
219fn find_nth_occurrence(source: &str, name: &str, nth: usize) -> usize {
225 let bytes = source.as_bytes();
226 let skip_map = SkipMap::build(source);
227 let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
228 let name_len = name_bytes.len();
229 let src_len = bytes.len();
230
231 let mut count = 0usize;
232 let mut i = 0usize;
233
234 while i + name_len <= src_len {
235 if !skip_map.is_code(i) {
236 i += 1;
237 continue;
238 }
239
240 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
241 if !before_ok {
242 i += 1;
243 continue;
244 }
245
246 let matches = bytes[i..i + name_len]
247 .iter()
248 .zip(name_bytes.iter())
249 .all(|(&a, &b)| a.to_ascii_uppercase() == b);
250
251 if matches {
252 let after = i + name_len;
253 let after_ok = after >= src_len || !is_word_char(bytes[after]);
254 let all_code = (i..i + name_len).all(|k| skip_map.is_code(k));
255
256 if after_ok && all_code {
257 if count == nth {
258 return i;
259 }
260 count += 1;
261 }
262 }
263
264 i += 1;
265 }
266
267 0
268}
269
270fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
272 let before = &source[..offset];
273 let line = before.chars().filter(|&c| c == '\n').count() + 1;
274 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
275 (line, col)
276}