1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct CaseWhenCount {
7 pub max_when_clauses: usize,
10}
11
12impl Default for CaseWhenCount {
13 fn default() -> Self {
14 CaseWhenCount {
15 max_when_clauses: 5,
16 }
17 }
18}
19
20impl Rule for CaseWhenCount {
21 fn name(&self) -> &'static str {
22 "CaseWhenCount"
23 }
24
25 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
26 if !ctx.parse_errors.is_empty() {
27 return Vec::new();
28 }
29
30 let mut diags = Vec::new();
31 let mut case_occurrence: usize = 0;
34
35 for stmt in &ctx.statements {
36 if let Statement::Query(query) = stmt {
37 check_query(query, self.max_when_clauses, ctx, &mut case_occurrence, &mut diags);
38 }
39 }
40
41 diags
42 }
43}
44
45fn check_query(
48 query: &Query,
49 max: usize,
50 ctx: &FileContext,
51 occurrence: &mut usize,
52 diags: &mut Vec<Diagnostic>,
53) {
54 if let Some(with) = &query.with {
55 for cte in &with.cte_tables {
56 check_query(&cte.query, max, ctx, occurrence, diags);
57 }
58 }
59
60 check_set_expr(&query.body, max, ctx, occurrence, diags);
61}
62
63fn check_set_expr(
64 expr: &SetExpr,
65 max: usize,
66 ctx: &FileContext,
67 occurrence: &mut usize,
68 diags: &mut Vec<Diagnostic>,
69) {
70 match expr {
71 SetExpr::Select(sel) => check_select(sel, max, ctx, occurrence, diags),
72 SetExpr::Query(inner) => check_query(inner, max, ctx, occurrence, diags),
73 SetExpr::SetOperation { left, right, .. } => {
74 check_set_expr(left, max, ctx, occurrence, diags);
75 check_set_expr(right, max, ctx, occurrence, diags);
76 }
77 _ => {}
78 }
79}
80
81fn check_select(
82 sel: &Select,
83 max: usize,
84 ctx: &FileContext,
85 occurrence: &mut usize,
86 diags: &mut Vec<Diagnostic>,
87) {
88 for item in &sel.projection {
90 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
91 check_expr(e, max, ctx, occurrence, diags);
92 }
93 }
94
95 if let Some(selection) = &sel.selection {
97 check_expr(selection, max, ctx, occurrence, diags);
98 }
99
100 for twj in &sel.from {
102 check_table_factor(&twj.relation, max, ctx, occurrence, diags);
103 for join in &twj.joins {
104 check_table_factor(&join.relation, max, ctx, occurrence, diags);
105 }
106 }
107}
108
109fn check_table_factor(
110 tf: &TableFactor,
111 max: usize,
112 ctx: &FileContext,
113 occurrence: &mut usize,
114 diags: &mut Vec<Diagnostic>,
115) {
116 if let TableFactor::Derived { subquery, .. } = tf {
117 check_query(subquery, max, ctx, occurrence, diags);
118 }
119}
120
121fn check_expr(
122 expr: &Expr,
123 max: usize,
124 ctx: &FileContext,
125 occurrence: &mut usize,
126 diags: &mut Vec<Diagnostic>,
127) {
128 match expr {
129 Expr::Case {
130 operand,
131 conditions,
132 results,
133 else_result,
134 } => {
135 let n = conditions.len();
136 let occ = *occurrence;
138 *occurrence += 1;
139
140 if n > max {
141 let (line, col) = find_nth_keyword_pos(&ctx.source, "CASE", occ);
142 diags.push(Diagnostic {
143 rule: "CaseWhenCount",
144 message: format!(
145 "CASE expression has {n} WHEN clauses, exceeding the maximum of {max}"
146 ),
147 line,
148 col,
149 });
150 }
151
152 if let Some(op) = operand {
154 check_expr(op, max, ctx, occurrence, diags);
155 }
156
157 for cond in conditions {
159 check_expr(cond, max, ctx, occurrence, diags);
160 }
161 for res in results {
162 check_expr(res, max, ctx, occurrence, diags);
163 }
164 if let Some(els) = else_result {
165 check_expr(els, max, ctx, occurrence, diags);
166 }
167 }
168
169 Expr::BinaryOp { left, right, .. } => {
170 check_expr(left, max, ctx, occurrence, diags);
171 check_expr(right, max, ctx, occurrence, diags);
172 }
173 Expr::UnaryOp { expr: inner, .. } => {
174 check_expr(inner, max, ctx, occurrence, diags);
175 }
176 Expr::Subquery(q) => check_query(q, max, ctx, occurrence, diags),
177 Expr::InSubquery { subquery, expr: e, .. } => {
178 check_expr(e, max, ctx, occurrence, diags);
179 check_query(subquery, max, ctx, occurrence, diags);
180 }
181 Expr::Exists { subquery, .. } => check_query(subquery, max, ctx, occurrence, diags),
182 Expr::Nested(inner) => check_expr(inner, max, ctx, occurrence, diags),
183 Expr::Function(f) => {
184 use sqlparser::ast::{FunctionArg, FunctionArgExpr, FunctionArguments};
185 if let FunctionArguments::List(list) = &f.args {
186 for arg in &list.args {
187 let arg_expr = match arg {
188 FunctionArg::Unnamed(e) => Some(e),
189 FunctionArg::Named { arg: e, .. } => Some(e),
190 FunctionArg::ExprNamed { arg: e, .. } => Some(e),
191 };
192 if let Some(FunctionArgExpr::Expr(inner)) = arg_expr {
193 check_expr(inner, max, ctx, occurrence, diags);
194 }
195 }
196 }
197 }
198 _ => {}
199 }
200}
201
202fn find_nth_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
208 let bytes = source.as_bytes();
209 let len = bytes.len();
210 let skip_map = SkipMap::build(source);
211 let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
212 let kw_len = kw_upper.len();
213
214 let mut count = 0usize;
215 let mut i = 0;
216 while i + kw_len <= len {
217 if !skip_map.is_code(i) {
218 i += 1;
219 continue;
220 }
221
222 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
223 if !before_ok {
224 i += 1;
225 continue;
226 }
227
228 let matches = bytes[i..i + kw_len]
229 .iter()
230 .zip(kw_upper.iter())
231 .all(|(a, b)| a.eq_ignore_ascii_case(b));
232
233 if matches {
234 let after = i + kw_len;
235 let after_ok = after >= len || !is_word_char(bytes[after]);
236 let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
237
238 if after_ok && all_code {
239 if count == nth {
240 return line_col(source, i);
241 }
242 count += 1;
243 }
244 }
245
246 i += 1;
247 }
248
249 (1, 1)
250}
251
252fn line_col(source: &str, offset: usize) -> (usize, usize) {
254 let before = &source[..offset];
255 let line = before.chars().filter(|&c| c == '\n').count() + 1;
256 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
257 (line, col)
258}