1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct NestedCaseInElse;
7
8impl Rule for NestedCaseInElse {
9 fn name(&self) -> &'static str {
10 "Structure/NestedCaseInElse"
11 }
12
13 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14 if !ctx.parse_errors.is_empty() {
15 return Vec::new();
16 }
17 let mut diags = Vec::new();
18 for stmt in &ctx.statements {
19 let mut else_count = 0usize;
23 if let Statement::Query(q) = stmt {
24 check_query(q, ctx, &mut else_count, &mut diags);
25 }
26 }
27 diags
28 }
29}
30
31fn check_query(
34 q: &Query,
35 ctx: &FileContext,
36 else_count: &mut usize,
37 diags: &mut Vec<Diagnostic>,
38) {
39 if let Some(with) = &q.with {
40 for cte in &with.cte_tables {
41 check_query(&cte.query, ctx, else_count, diags);
42 }
43 }
44 check_set_expr(&q.body, ctx, else_count, diags);
45
46 if let Some(order_by) = &q.order_by {
48 for ob_expr in &order_by.exprs {
49 walk_expr(&ob_expr.expr, ctx, else_count, diags);
50 }
51 }
52}
53
54fn check_set_expr(
55 expr: &SetExpr,
56 ctx: &FileContext,
57 else_count: &mut usize,
58 diags: &mut Vec<Diagnostic>,
59) {
60 match expr {
61 SetExpr::Select(sel) => check_select(sel, ctx, else_count, diags),
62 SetExpr::Query(inner) => check_query(inner, ctx, else_count, diags),
63 SetExpr::SetOperation { left, right, .. } => {
64 check_set_expr(left, ctx, else_count, diags);
65 check_set_expr(right, ctx, else_count, diags);
66 }
67 _ => {}
68 }
69}
70
71fn check_select(
72 sel: &Select,
73 ctx: &FileContext,
74 else_count: &mut usize,
75 diags: &mut Vec<Diagnostic>,
76) {
77 for item in &sel.projection {
79 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
80 walk_expr(e, ctx, else_count, diags);
81 }
82 }
83
84 if let Some(selection) = &sel.selection {
86 walk_expr(selection, ctx, else_count, diags);
87 }
88
89 if let Some(having) = &sel.having {
91 walk_expr(having, ctx, else_count, diags);
92 }
93
94 if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &sel.group_by {
96 for e in exprs {
97 walk_expr(e, ctx, else_count, diags);
98 }
99 }
100
101 for twj in &sel.from {
103 recurse_table_factor(&twj.relation, ctx, else_count, diags);
104 for join in &twj.joins {
105 recurse_table_factor(&join.relation, ctx, else_count, diags);
106 }
107 }
108}
109
110fn recurse_table_factor(
111 tf: &TableFactor,
112 ctx: &FileContext,
113 else_count: &mut usize,
114 diags: &mut Vec<Diagnostic>,
115) {
116 if let TableFactor::Derived { subquery, .. } = tf {
117 check_query(subquery, ctx, else_count, diags);
118 }
119}
120
121fn walk_expr(
124 expr: &Expr,
125 ctx: &FileContext,
126 else_count: &mut usize,
127 diags: &mut Vec<Diagnostic>,
128) {
129 match expr {
130 Expr::Case {
131 operand,
132 conditions,
133 results,
134 else_result,
135 } => {
136 if let Some(op) = operand {
140 walk_expr(op, ctx, else_count, diags);
141 }
142 for cond in conditions {
143 walk_expr(cond, ctx, else_count, diags);
144 }
145 for res in results {
146 walk_expr(res, ctx, else_count, diags);
147 }
148
149 if let Some(else_expr) = else_result {
150 if matches!(else_expr.as_ref(), Expr::Case { .. }) {
152 let nth = *else_count;
155 let offset =
156 find_nth_keyword(&ctx.source, "ELSE", nth).unwrap_or(0);
157 let (line, col) = offset_to_line_col(&ctx.source, offset);
158 diags.push(Diagnostic {
159 rule: "Structure/NestedCaseInElse",
160 message:
161 "CASE expression has a nested CASE in its ELSE clause; \
162 flatten with additional WHEN branches instead"
163 .to_string(),
164 line,
165 col,
166 });
167 }
168 *else_count += 1;
170 walk_expr(else_expr, ctx, else_count, diags);
172 }
173 }
174
175 Expr::BinaryOp { left, right, .. } => {
177 walk_expr(left, ctx, else_count, diags);
178 walk_expr(right, ctx, else_count, diags);
179 }
180 Expr::UnaryOp { expr: inner, .. } => walk_expr(inner, ctx, else_count, diags),
181 Expr::Nested(inner) => walk_expr(inner, ctx, else_count, diags),
182 Expr::Cast { expr: inner, .. } => walk_expr(inner, ctx, else_count, diags),
183 Expr::IsNull(inner) | Expr::IsNotNull(inner) => walk_expr(inner, ctx, else_count, diags),
184 Expr::Between {
185 expr: e,
186 low,
187 high,
188 ..
189 } => {
190 walk_expr(e, ctx, else_count, diags);
191 walk_expr(low, ctx, else_count, diags);
192 walk_expr(high, ctx, else_count, diags);
193 }
194 Expr::InList { expr: inner, list, .. } => {
195 walk_expr(inner, ctx, else_count, diags);
196 for e in list {
197 walk_expr(e, ctx, else_count, diags);
198 }
199 }
200 Expr::Function(f) => {
201 if let sqlparser::ast::FunctionArguments::List(arg_list) = &f.args {
202 for arg in &arg_list.args {
203 if let sqlparser::ast::FunctionArg::Unnamed(
204 sqlparser::ast::FunctionArgExpr::Expr(e),
205 ) = arg
206 {
207 walk_expr(e, ctx, else_count, diags);
208 }
209 }
210 }
211 }
212 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
213 check_query(q, ctx, else_count, diags);
214 }
215 _ => {}
216 }
217}
218
219fn find_nth_keyword(source: &str, keyword: &str, nth: usize) -> Option<usize> {
225 let bytes = source.as_bytes();
226 let kw: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
227 let kw_len = kw.len();
228 let src_len = bytes.len();
229 let skip = SkipMap::build(source);
230
231 let mut count = 0usize;
232 let mut i = 0usize;
233
234 while i + kw_len <= src_len {
235 if !skip.is_code(i) {
236 i += 1;
237 continue;
238 }
239
240 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
241 if !before_ok {
242 i += 1;
243 continue;
244 }
245
246 let matches = bytes[i..i + kw_len]
247 .iter()
248 .zip(kw.iter())
249 .all(|(&a, &b)| a.to_ascii_uppercase() == b);
250
251 if matches {
252 let end = i + kw_len;
253 let after_ok = end >= src_len || !is_word_char(bytes[end]);
254 let all_code = (i..end).all(|k| skip.is_code(k));
255
256 if after_ok && all_code {
257 if count == nth {
258 return Some(i);
259 }
260 count += 1;
261 }
262 }
263
264 i += 1;
265 }
266
267 None
268}
269
270fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
271 let safe = offset.min(source.len());
272 let before = &source[..safe];
273 let line = before.chars().filter(|&c| c == '\n').count() + 1;
274 let col = before.rfind('\n').map(|p| safe - p - 1).unwrap_or(safe) + 1;
275 (line, col)
276}