1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{BinaryOperator, Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct HavingConditionsCount {
7 pub max_conditions: usize,
11}
12
13impl Default for HavingConditionsCount {
14 fn default() -> Self {
15 HavingConditionsCount { max_conditions: 5 }
16 }
17}
18
19impl Rule for HavingConditionsCount {
20 fn name(&self) -> &'static str {
21 "Structure/HavingConditionsCount"
22 }
23
24 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
25 if !ctx.parse_errors.is_empty() {
26 return Vec::new();
27 }
28
29 let mut diags = Vec::new();
30
31 for stmt in &ctx.statements {
32 if let Statement::Query(query) = stmt {
33 check_query(query, self.max_conditions, ctx, &mut diags);
34 }
35 }
36
37 diags
38 }
39}
40
41fn check_query(
44 query: &Query,
45 max: usize,
46 ctx: &FileContext,
47 diags: &mut Vec<Diagnostic>,
48) {
49 if let Some(with) = &query.with {
51 for cte in &with.cte_tables {
52 check_query(&cte.query, max, ctx, diags);
53 }
54 }
55
56 check_set_expr(&query.body, max, ctx, diags);
57}
58
59fn check_set_expr(
60 expr: &SetExpr,
61 max: usize,
62 ctx: &FileContext,
63 diags: &mut Vec<Diagnostic>,
64) {
65 match expr {
66 SetExpr::Select(sel) => {
67 check_select(sel, max, ctx, diags);
68 }
69 SetExpr::Query(inner) => {
70 check_query(inner, max, ctx, diags);
71 }
72 SetExpr::SetOperation { left, right, .. } => {
73 check_set_expr(left, max, ctx, diags);
74 check_set_expr(right, max, ctx, diags);
75 }
76 _ => {}
77 }
78}
79
80fn check_select(
81 sel: &Select,
82 max: usize,
83 ctx: &FileContext,
84 diags: &mut Vec<Diagnostic>,
85) {
86 if let Some(having) = &sel.having {
87 let ops = count_and_or_ops(having);
89 let conditions = ops + 1;
90 if conditions > max {
91 let (line, col) = find_keyword_pos(&ctx.source, "HAVING");
92 diags.push(Diagnostic {
93 rule: "Structure/HavingConditionsCount",
94 message: format!(
95 "HAVING clause has {conditions} conditions, exceeding the maximum of {max}",
96 ),
97 line,
98 col,
99 });
100 }
101 }
102
103 for twj in &sel.from {
105 check_table_factor(&twj.relation, max, ctx, diags);
106 for join in &twj.joins {
107 check_table_factor(&join.relation, max, ctx, diags);
108 }
109 }
110}
111
112fn check_table_factor(
113 tf: &TableFactor,
114 max: usize,
115 ctx: &FileContext,
116 diags: &mut Vec<Diagnostic>,
117) {
118 if let TableFactor::Derived { subquery, .. } = tf {
119 check_query(subquery, max, ctx, diags);
120 }
121}
122
123fn count_and_or_ops(expr: &Expr) -> usize {
129 match expr {
130 Expr::BinaryOp {
131 left,
132 op: BinaryOperator::And | BinaryOperator::Or,
133 right,
134 } => 1 + count_and_or_ops(left) + count_and_or_ops(right),
135 Expr::BinaryOp { left, right, .. } => {
136 count_and_or_ops(left) + count_and_or_ops(right)
137 }
138 Expr::UnaryOp { expr: inner, .. } => count_and_or_ops(inner),
139 Expr::Nested(inner) => count_and_or_ops(inner),
140 _ => 0,
141 }
142}
143
144fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
150 let bytes = source.as_bytes();
151 let len = bytes.len();
152 let skip_map = SkipMap::build(source);
153 let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
154 let kw_len = kw_upper.len();
155
156 let mut i = 0;
157 while i + kw_len <= len {
158 if !skip_map.is_code(i) {
159 i += 1;
160 continue;
161 }
162
163 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
165 if !before_ok {
166 i += 1;
167 continue;
168 }
169
170 let matches = bytes[i..i + kw_len]
172 .iter()
173 .zip(kw_upper.iter())
174 .all(|(a, b)| a.eq_ignore_ascii_case(b));
175
176 if matches {
177 let after = i + kw_len;
179 let after_ok = after >= len || !is_word_char(bytes[after]);
180 let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
181
182 if after_ok && all_code {
183 return line_col(source, i);
184 }
185 }
186
187 i += 1;
188 }
189
190 (1, 1)
191}
192
193fn line_col(source: &str, offset: usize) -> (usize, usize) {
195 let before = &source[..offset];
196 let line = before.chars().filter(|&c| c == '\n').count() + 1;
197 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
198 (line, col)
199}