Skip to main content

sqrust_rules/structure/
aggregate_in_where.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct AggregateInWhere;
7
8/// Aggregate function names (uppercased) that are forbidden in a WHERE clause.
9const AGGREGATES: &[&str] = &[
10    "COUNT",
11    "SUM",
12    "AVG",
13    "MIN",
14    "MAX",
15    "ARRAY_AGG",
16    "STRING_AGG",
17    "GROUP_CONCAT",
18    "EVERY",
19    "COUNT_IF",
20    "ANY_VALUE",
21];
22
23impl Rule for AggregateInWhere {
24    fn name(&self) -> &'static str {
25        "Structure/AggregateInWhere"
26    }
27
28    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
29        if !ctx.parse_errors.is_empty() {
30            return Vec::new();
31        }
32
33        let mut diags = Vec::new();
34        // Per-function-name occurrence counter so `find_nth_occurrence` can
35        // locate the correct source position when a name appears multiple times.
36        let mut counters: std::collections::HashMap<String, usize> =
37            std::collections::HashMap::new();
38
39        for stmt in &ctx.statements {
40            if let Statement::Query(query) = stmt {
41                check_query(query, ctx, &mut counters, &mut diags);
42            }
43        }
44
45        diags
46    }
47}
48
49// ── AST walking ───────────────────────────────────────────────────────────────
50
51fn check_query(
52    query: &Query,
53    ctx: &FileContext,
54    counters: &mut std::collections::HashMap<String, usize>,
55    diags: &mut Vec<Diagnostic>,
56) {
57    // Visit CTEs.
58    if let Some(with) = &query.with {
59        for cte in &with.cte_tables {
60            check_query(&cte.query, ctx, counters, diags);
61        }
62    }
63    check_set_expr(&query.body, ctx, counters, diags);
64}
65
66fn check_set_expr(
67    expr: &SetExpr,
68    ctx: &FileContext,
69    counters: &mut std::collections::HashMap<String, usize>,
70    diags: &mut Vec<Diagnostic>,
71) {
72    match expr {
73        SetExpr::Select(sel) => {
74            check_select(sel, ctx, counters, diags);
75        }
76        SetExpr::Query(inner) => {
77            check_query(inner, ctx, counters, diags);
78        }
79        SetExpr::SetOperation { left, right, .. } => {
80            check_set_expr(left, ctx, counters, diags);
81            check_set_expr(right, ctx, counters, diags);
82        }
83        _ => {}
84    }
85}
86
87fn check_select(
88    sel: &Select,
89    ctx: &FileContext,
90    counters: &mut std::collections::HashMap<String, usize>,
91    diags: &mut Vec<Diagnostic>,
92) {
93    // Check the WHERE clause for aggregate functions.
94    if let Some(selection) = &sel.selection {
95        collect_aggregates_in_expr(selection, ctx, counters, diags);
96    }
97
98    // Recurse into subqueries in the FROM clause.
99    for table_with_joins in &sel.from {
100        recurse_table_factor(&table_with_joins.relation, ctx, counters, diags);
101        for join in &table_with_joins.joins {
102            recurse_table_factor(&join.relation, ctx, counters, diags);
103        }
104    }
105}
106
107fn recurse_table_factor(
108    tf: &TableFactor,
109    ctx: &FileContext,
110    counters: &mut std::collections::HashMap<String, usize>,
111    diags: &mut Vec<Diagnostic>,
112) {
113    if let TableFactor::Derived { subquery, .. } = tf {
114        check_query(subquery, ctx, counters, diags);
115    }
116}
117
118// ── Aggregate detection in expressions ───────────────────────────────────────
119
120/// Recursively walks `expr` and emits a Diagnostic for every aggregate function
121/// call found directly inside the expression.
122fn collect_aggregates_in_expr(
123    expr: &Expr,
124    ctx: &FileContext,
125    counters: &mut std::collections::HashMap<String, usize>,
126    diags: &mut Vec<Diagnostic>,
127) {
128    match expr {
129        Expr::Function(func) => {
130            let name_upper = func
131                .name
132                .0
133                .last()
134                .map(|ident| ident.value.to_uppercase())
135                .unwrap_or_default();
136
137            if AGGREGATES.contains(&name_upper.as_str()) {
138                let occ = counters.entry(name_upper.clone()).or_insert(0);
139                let occurrence = *occ;
140                *occ += 1;
141
142                let offset = find_nth_occurrence(&ctx.source, &name_upper, occurrence);
143                let (line, col) = offset_to_line_col(&ctx.source, offset);
144
145                diags.push(Diagnostic {
146                    rule: "Structure/AggregateInWhere",
147                    message: "Aggregate function in WHERE clause; use HAVING instead".to_string(),
148                    line,
149                    col,
150                });
151            }
152            // Do not recurse into function args — nested aggregates inside an
153            // aggregate arg are a different issue and the outer call is what
154            // sits in WHERE.
155        }
156        Expr::BinaryOp { left, right, .. } => {
157            collect_aggregates_in_expr(left, ctx, counters, diags);
158            collect_aggregates_in_expr(right, ctx, counters, diags);
159        }
160        Expr::UnaryOp { expr: inner, .. } => {
161            collect_aggregates_in_expr(inner, ctx, counters, diags);
162        }
163        Expr::Nested(inner) => {
164            collect_aggregates_in_expr(inner, ctx, counters, diags);
165        }
166        Expr::Between {
167            expr: e,
168            low,
169            high,
170            ..
171        } => {
172            collect_aggregates_in_expr(e, ctx, counters, diags);
173            collect_aggregates_in_expr(low, ctx, counters, diags);
174            collect_aggregates_in_expr(high, ctx, counters, diags);
175        }
176        Expr::Case {
177            operand,
178            conditions,
179            results,
180            else_result,
181        } => {
182            if let Some(op) = operand {
183                collect_aggregates_in_expr(op, ctx, counters, diags);
184            }
185            for cond in conditions {
186                collect_aggregates_in_expr(cond, ctx, counters, diags);
187            }
188            for res in results {
189                collect_aggregates_in_expr(res, ctx, counters, diags);
190            }
191            if let Some(else_e) = else_result {
192                collect_aggregates_in_expr(else_e, ctx, counters, diags);
193            }
194        }
195        Expr::InList { expr: inner, list, .. } => {
196            collect_aggregates_in_expr(inner, ctx, counters, diags);
197            for e in list {
198                collect_aggregates_in_expr(e, ctx, counters, diags);
199            }
200        }
201        Expr::InSubquery {
202            expr: inner,
203            subquery,
204            ..
205        } => {
206            collect_aggregates_in_expr(inner, ctx, counters, diags);
207            check_query(subquery, ctx, counters, diags);
208        }
209        Expr::Exists { subquery, .. } => {
210            check_query(subquery, ctx, counters, diags);
211        }
212        Expr::Subquery(q) => {
213            check_query(q, ctx, counters, diags);
214        }
215        _ => {}
216    }
217}
218
219// ── Source-text helpers ───────────────────────────────────────────────────────
220
221/// Finds the byte offset of the `nth` (0-indexed) whole-word,
222/// case-insensitive occurrence of `name` (already uppercased) in `source`,
223/// skipping positions inside strings/comments. Returns 0 if not found.
224fn find_nth_occurrence(source: &str, name: &str, nth: usize) -> usize {
225    let bytes = source.as_bytes();
226    let skip_map = SkipMap::build(source);
227    let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
228    let name_len = name_bytes.len();
229    let src_len = bytes.len();
230
231    let mut count = 0usize;
232    let mut i = 0usize;
233
234    while i + name_len <= src_len {
235        if !skip_map.is_code(i) {
236            i += 1;
237            continue;
238        }
239
240        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
241        if !before_ok {
242            i += 1;
243            continue;
244        }
245
246        let matches = bytes[i..i + name_len]
247            .iter()
248            .zip(name_bytes.iter())
249            .all(|(&a, &b)| a.to_ascii_uppercase() == b);
250
251        if matches {
252            let after = i + name_len;
253            let after_ok = after >= src_len || !is_word_char(bytes[after]);
254            let all_code = (i..i + name_len).all(|k| skip_map.is_code(k));
255
256            if after_ok && all_code {
257                if count == nth {
258                    return i;
259                }
260                count += 1;
261            }
262        }
263
264        i += 1;
265    }
266
267    0
268}
269
270/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
271fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
272    let before = &source[..offset];
273    let line = before.chars().filter(|&c| c == '\n').count() + 1;
274    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
275    (line, col)
276}