Skip to main content

sqrust_rules/structure/
having_without_aggregate.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct HavingWithoutAggregate;
7
8impl Rule for HavingWithoutAggregate {
9    fn name(&self) -> &'static str {
10        "HavingWithoutAggregate"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        if !ctx.parse_errors.is_empty() {
15            return Vec::new();
16        }
17
18        let mut diags = Vec::new();
19
20        for stmt in &ctx.statements {
21            if let Statement::Query(query) = stmt {
22                check_query(query, self.name(), ctx, &mut diags);
23            }
24        }
25
26        diags
27    }
28}
29
30// ── AST walking ───────────────────────────────────────────────────────────────
31
32fn check_query(
33    query: &Query,
34    rule: &'static str,
35    ctx: &FileContext,
36    diags: &mut Vec<Diagnostic>,
37) {
38    // Visit CTEs.
39    if let Some(with) = &query.with {
40        for cte in &with.cte_tables {
41            check_query(&cte.query, rule, ctx, diags);
42        }
43    }
44    check_set_expr(&query.body, rule, ctx, diags);
45}
46
47fn check_set_expr(
48    expr: &SetExpr,
49    rule: &'static str,
50    ctx: &FileContext,
51    diags: &mut Vec<Diagnostic>,
52) {
53    match expr {
54        SetExpr::Select(sel) => {
55            check_select(sel, rule, ctx, diags);
56        }
57        SetExpr::Query(inner) => {
58            check_query(inner, rule, ctx, diags);
59        }
60        SetExpr::SetOperation { left, right, .. } => {
61            check_set_expr(left, rule, ctx, diags);
62            check_set_expr(right, rule, ctx, diags);
63        }
64        _ => {}
65    }
66}
67
68fn check_select(
69    sel: &Select,
70    rule: &'static str,
71    ctx: &FileContext,
72    diags: &mut Vec<Diagnostic>,
73) {
74    if let Some(having) = &sel.having {
75        if !has_aggregate(having) {
76            let (line, col) = find_keyword_pos(&ctx.source, "HAVING");
77            diags.push(Diagnostic {
78                rule,
79                message: "HAVING clause contains no aggregate function; use WHERE instead"
80                    .to_string(),
81                line,
82                col,
83            });
84        }
85    }
86
87    // Recurse into subqueries in the FROM clause.
88    for table_with_joins in &sel.from {
89        recurse_table_factor(&table_with_joins.relation, rule, ctx, diags);
90        for join in &table_with_joins.joins {
91            recurse_table_factor(&join.relation, rule, ctx, diags);
92        }
93    }
94}
95
96fn recurse_table_factor(
97    tf: &TableFactor,
98    rule: &'static str,
99    ctx: &FileContext,
100    diags: &mut Vec<Diagnostic>,
101) {
102    if let TableFactor::Derived { subquery, .. } = tf {
103        check_query(subquery, rule, ctx, diags);
104    }
105}
106
107// ── aggregate detection ───────────────────────────────────────────────────────
108
109fn has_aggregate(expr: &Expr) -> bool {
110    match expr {
111        Expr::Function(func) => {
112            let name = func
113                .name
114                .0
115                .last()
116                .map(|i| i.value.to_uppercase())
117                .unwrap_or_default();
118            matches!(
119                name.as_str(),
120                "COUNT"
121                    | "SUM"
122                    | "AVG"
123                    | "MIN"
124                    | "MAX"
125                    | "ARRAY_AGG"
126                    | "STRING_AGG"
127                    | "GROUP_CONCAT"
128                    | "STDDEV"
129                    | "VARIANCE"
130                    | "MEDIAN"
131                    | "LISTAGG"
132                    | "FIRST_VALUE"
133                    | "LAST_VALUE"
134            )
135        }
136        Expr::BinaryOp { left, right, .. } => has_aggregate(left) || has_aggregate(right),
137        Expr::UnaryOp { expr, .. } => has_aggregate(expr),
138        Expr::Nested(e) => has_aggregate(e),
139        Expr::Between {
140            expr, low, high, ..
141        } => has_aggregate(expr) || has_aggregate(low) || has_aggregate(high),
142        Expr::Case {
143            operand,
144            conditions,
145            results,
146            else_result,
147        } => {
148            operand.as_ref().map_or(false, |e| has_aggregate(e))
149                || conditions.iter().any(|e| has_aggregate(e))
150                || results.iter().any(|e| has_aggregate(e))
151                || else_result.as_ref().map_or(false, |e| has_aggregate(e))
152        }
153        _ => false,
154    }
155}
156
157// ── keyword position helper ───────────────────────────────────────────────────
158
159/// Find the first occurrence of a keyword (case-insensitive, word-boundary,
160/// outside strings/comments) in `source`. Returns a 1-indexed (line, col)
161/// pair. Falls back to (1, 1) if not found.
162fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
163    let bytes = source.as_bytes();
164    let len = bytes.len();
165    let skip_map = SkipMap::build(source);
166    let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
167    let kw_len = kw_upper.len();
168
169    let mut i = 0;
170    while i + kw_len <= len {
171        if !skip_map.is_code(i) {
172            i += 1;
173            continue;
174        }
175
176        // Word boundary before.
177        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
178        if !before_ok {
179            i += 1;
180            continue;
181        }
182
183        // Case-insensitive match.
184        let matches = bytes[i..i + kw_len]
185            .iter()
186            .zip(kw_upper.iter())
187            .all(|(a, b)| a.eq_ignore_ascii_case(b));
188
189        if matches {
190            // Word boundary after.
191            let after = i + kw_len;
192            let after_ok = after >= len || !is_word_char(bytes[after]);
193            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
194
195            if after_ok && all_code {
196                return line_col(source, i);
197            }
198        }
199
200        i += 1;
201    }
202
203    (1, 1)
204}
205
206/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
207fn line_col(source: &str, offset: usize) -> (usize, usize) {
208    let before = &source[..offset];
209    let line = before.chars().filter(|&c| c == '\n').count() + 1;
210    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
211    (line, col)
212}