1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct HavingWithoutAggregate;
7
8impl Rule for HavingWithoutAggregate {
9 fn name(&self) -> &'static str {
10 "HavingWithoutAggregate"
11 }
12
13 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14 if !ctx.parse_errors.is_empty() {
15 return Vec::new();
16 }
17
18 let mut diags = Vec::new();
19
20 for stmt in &ctx.statements {
21 if let Statement::Query(query) = stmt {
22 check_query(query, self.name(), ctx, &mut diags);
23 }
24 }
25
26 diags
27 }
28}
29
30fn check_query(
33 query: &Query,
34 rule: &'static str,
35 ctx: &FileContext,
36 diags: &mut Vec<Diagnostic>,
37) {
38 if let Some(with) = &query.with {
40 for cte in &with.cte_tables {
41 check_query(&cte.query, rule, ctx, diags);
42 }
43 }
44 check_set_expr(&query.body, rule, ctx, diags);
45}
46
47fn check_set_expr(
48 expr: &SetExpr,
49 rule: &'static str,
50 ctx: &FileContext,
51 diags: &mut Vec<Diagnostic>,
52) {
53 match expr {
54 SetExpr::Select(sel) => {
55 check_select(sel, rule, ctx, diags);
56 }
57 SetExpr::Query(inner) => {
58 check_query(inner, rule, ctx, diags);
59 }
60 SetExpr::SetOperation { left, right, .. } => {
61 check_set_expr(left, rule, ctx, diags);
62 check_set_expr(right, rule, ctx, diags);
63 }
64 _ => {}
65 }
66}
67
68fn check_select(
69 sel: &Select,
70 rule: &'static str,
71 ctx: &FileContext,
72 diags: &mut Vec<Diagnostic>,
73) {
74 if let Some(having) = &sel.having {
75 if !has_aggregate(having) {
76 let (line, col) = find_keyword_pos(&ctx.source, "HAVING");
77 diags.push(Diagnostic {
78 rule,
79 message: "HAVING clause contains no aggregate function; use WHERE instead"
80 .to_string(),
81 line,
82 col,
83 });
84 }
85 }
86
87 for table_with_joins in &sel.from {
89 recurse_table_factor(&table_with_joins.relation, rule, ctx, diags);
90 for join in &table_with_joins.joins {
91 recurse_table_factor(&join.relation, rule, ctx, diags);
92 }
93 }
94}
95
96fn recurse_table_factor(
97 tf: &TableFactor,
98 rule: &'static str,
99 ctx: &FileContext,
100 diags: &mut Vec<Diagnostic>,
101) {
102 if let TableFactor::Derived { subquery, .. } = tf {
103 check_query(subquery, rule, ctx, diags);
104 }
105}
106
107fn has_aggregate(expr: &Expr) -> bool {
110 match expr {
111 Expr::Function(func) => {
112 let name = func
113 .name
114 .0
115 .last()
116 .map(|i| i.value.to_uppercase())
117 .unwrap_or_default();
118 matches!(
119 name.as_str(),
120 "COUNT"
121 | "SUM"
122 | "AVG"
123 | "MIN"
124 | "MAX"
125 | "ARRAY_AGG"
126 | "STRING_AGG"
127 | "GROUP_CONCAT"
128 | "STDDEV"
129 | "VARIANCE"
130 | "MEDIAN"
131 | "LISTAGG"
132 | "FIRST_VALUE"
133 | "LAST_VALUE"
134 )
135 }
136 Expr::BinaryOp { left, right, .. } => has_aggregate(left) || has_aggregate(right),
137 Expr::UnaryOp { expr, .. } => has_aggregate(expr),
138 Expr::Nested(e) => has_aggregate(e),
139 Expr::Between {
140 expr, low, high, ..
141 } => has_aggregate(expr) || has_aggregate(low) || has_aggregate(high),
142 Expr::Case {
143 operand,
144 conditions,
145 results,
146 else_result,
147 } => {
148 operand.as_ref().map_or(false, |e| has_aggregate(e))
149 || conditions.iter().any(|e| has_aggregate(e))
150 || results.iter().any(|e| has_aggregate(e))
151 || else_result.as_ref().map_or(false, |e| has_aggregate(e))
152 }
153 _ => false,
154 }
155}
156
157fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
163 let bytes = source.as_bytes();
164 let len = bytes.len();
165 let skip_map = SkipMap::build(source);
166 let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
167 let kw_len = kw_upper.len();
168
169 let mut i = 0;
170 while i + kw_len <= len {
171 if !skip_map.is_code(i) {
172 i += 1;
173 continue;
174 }
175
176 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
178 if !before_ok {
179 i += 1;
180 continue;
181 }
182
183 let matches = bytes[i..i + kw_len]
185 .iter()
186 .zip(kw_upper.iter())
187 .all(|(a, b)| a.eq_ignore_ascii_case(b));
188
189 if matches {
190 let after = i + kw_len;
192 let after_ok = after >= len || !is_word_char(bytes[after]);
193 let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
194
195 if after_ok && all_code {
196 return line_col(source, i);
197 }
198 }
199
200 i += 1;
201 }
202
203 (1, 1)
204}
205
206fn line_col(source: &str, offset: usize) -> (usize, usize) {
208 let before = &source[..offset];
209 let line = before.chars().filter(|&c| c == '\n').count() + 1;
210 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
211 (line, col)
212}