Skip to main content

sqrust_rules/structure/
nested_case_in_else.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct NestedCaseInElse;
7
8impl Rule for NestedCaseInElse {
9    fn name(&self) -> &'static str {
10        "Structure/NestedCaseInElse"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        if !ctx.parse_errors.is_empty() {
15            return Vec::new();
16        }
17        let mut diags = Vec::new();
18        for stmt in &ctx.statements {
19            // Counter tracks how many ELSE keywords we've already used so we can
20            // pinpoint the correct ELSE that introduces the nested CASE.
21            // Reset per statement so the offset search stays within the current statement.
22            let mut else_count = 0usize;
23            if let Statement::Query(q) = stmt {
24                check_query(q, ctx, &mut else_count, &mut diags);
25            }
26        }
27        diags
28    }
29}
30
31// ── AST walking ───────────────────────────────────────────────────────────────
32
33fn check_query(
34    q: &Query,
35    ctx: &FileContext,
36    else_count: &mut usize,
37    diags: &mut Vec<Diagnostic>,
38) {
39    if let Some(with) = &q.with {
40        for cte in &with.cte_tables {
41            check_query(&cte.query, ctx, else_count, diags);
42        }
43    }
44    check_set_expr(&q.body, ctx, else_count, diags);
45
46    // ORDER BY expressions.
47    if let Some(order_by) = &q.order_by {
48        for ob_expr in &order_by.exprs {
49            walk_expr(&ob_expr.expr, ctx, else_count, diags);
50        }
51    }
52}
53
54fn check_set_expr(
55    expr: &SetExpr,
56    ctx: &FileContext,
57    else_count: &mut usize,
58    diags: &mut Vec<Diagnostic>,
59) {
60    match expr {
61        SetExpr::Select(sel) => check_select(sel, ctx, else_count, diags),
62        SetExpr::Query(inner) => check_query(inner, ctx, else_count, diags),
63        SetExpr::SetOperation { left, right, .. } => {
64            check_set_expr(left, ctx, else_count, diags);
65            check_set_expr(right, ctx, else_count, diags);
66        }
67        _ => {}
68    }
69}
70
71fn check_select(
72    sel: &Select,
73    ctx: &FileContext,
74    else_count: &mut usize,
75    diags: &mut Vec<Diagnostic>,
76) {
77    // Projection.
78    for item in &sel.projection {
79        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
80            walk_expr(e, ctx, else_count, diags);
81        }
82    }
83
84    // WHERE.
85    if let Some(selection) = &sel.selection {
86        walk_expr(selection, ctx, else_count, diags);
87    }
88
89    // HAVING.
90    if let Some(having) = &sel.having {
91        walk_expr(having, ctx, else_count, diags);
92    }
93
94    // GROUP BY.
95    if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &sel.group_by {
96        for e in exprs {
97            walk_expr(e, ctx, else_count, diags);
98        }
99    }
100
101    // FROM — recurse into subqueries.
102    for twj in &sel.from {
103        recurse_table_factor(&twj.relation, ctx, else_count, diags);
104        for join in &twj.joins {
105            recurse_table_factor(&join.relation, ctx, else_count, diags);
106        }
107    }
108}
109
110fn recurse_table_factor(
111    tf: &TableFactor,
112    ctx: &FileContext,
113    else_count: &mut usize,
114    diags: &mut Vec<Diagnostic>,
115) {
116    if let TableFactor::Derived { subquery, .. } = tf {
117        check_query(subquery, ctx, else_count, diags);
118    }
119}
120
121/// Walk an expression. When we encounter a CASE whose ELSE is itself a CASE,
122/// emit a diagnostic, then continue recursing into both branches.
123fn walk_expr(
124    expr: &Expr,
125    ctx: &FileContext,
126    else_count: &mut usize,
127    diags: &mut Vec<Diagnostic>,
128) {
129    match expr {
130        Expr::Case {
131            operand,
132            conditions,
133            results,
134            else_result,
135        } => {
136            // Recurse into operand, WHEN/THEN branches first (they appear
137            // before ELSE in source text, so their ELSE keywords are counted
138            // before the outer ELSE).
139            if let Some(op) = operand {
140                walk_expr(op, ctx, else_count, diags);
141            }
142            for cond in conditions {
143                walk_expr(cond, ctx, else_count, diags);
144            }
145            for res in results {
146                walk_expr(res, ctx, else_count, diags);
147            }
148
149            if let Some(else_expr) = else_result {
150                // Check if the ELSE value is itself a CASE.
151                if matches!(else_expr.as_ref(), Expr::Case { .. }) {
152                    // Find the Nth ELSE keyword (the one that introduces this
153                    // nested CASE) and emit a diagnostic at it.
154                    let nth = *else_count;
155                    let offset =
156                        find_nth_keyword(&ctx.source, "ELSE", nth).unwrap_or(0);
157                    let (line, col) = offset_to_line_col(&ctx.source, offset);
158                    diags.push(Diagnostic {
159                        rule: "Structure/NestedCaseInElse",
160                        message:
161                            "CASE expression has a nested CASE in its ELSE clause; \
162                             flatten with additional WHEN branches instead"
163                                .to_string(),
164                        line,
165                        col,
166                    });
167                }
168                // Count the ELSE keyword we just processed.
169                *else_count += 1;
170                // Recurse into the ELSE expression (catches deeper nesting).
171                walk_expr(else_expr, ctx, else_count, diags);
172            }
173        }
174
175        // Pass-through recursion for other expression types.
176        Expr::BinaryOp { left, right, .. } => {
177            walk_expr(left, ctx, else_count, diags);
178            walk_expr(right, ctx, else_count, diags);
179        }
180        Expr::UnaryOp { expr: inner, .. } => walk_expr(inner, ctx, else_count, diags),
181        Expr::Nested(inner) => walk_expr(inner, ctx, else_count, diags),
182        Expr::Cast { expr: inner, .. } => walk_expr(inner, ctx, else_count, diags),
183        Expr::IsNull(inner) | Expr::IsNotNull(inner) => walk_expr(inner, ctx, else_count, diags),
184        Expr::Between {
185            expr: e,
186            low,
187            high,
188            ..
189        } => {
190            walk_expr(e, ctx, else_count, diags);
191            walk_expr(low, ctx, else_count, diags);
192            walk_expr(high, ctx, else_count, diags);
193        }
194        Expr::InList { expr: inner, list, .. } => {
195            walk_expr(inner, ctx, else_count, diags);
196            for e in list {
197                walk_expr(e, ctx, else_count, diags);
198            }
199        }
200        Expr::Function(f) => {
201            if let sqlparser::ast::FunctionArguments::List(arg_list) = &f.args {
202                for arg in &arg_list.args {
203                    if let sqlparser::ast::FunctionArg::Unnamed(
204                        sqlparser::ast::FunctionArgExpr::Expr(e),
205                    ) = arg
206                    {
207                        walk_expr(e, ctx, else_count, diags);
208                    }
209                }
210            }
211        }
212        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
213            check_query(q, ctx, else_count, diags);
214        }
215        _ => {}
216    }
217}
218
219// ── Source-text helpers ───────────────────────────────────────────────────────
220
221/// Find the `nth` (0-indexed) whole-word, case-insensitive occurrence of
222/// `keyword` in `source`, skipping inside strings/comments.
223/// Returns `Some(byte_offset)` or `None`.
224fn find_nth_keyword(source: &str, keyword: &str, nth: usize) -> Option<usize> {
225    let bytes = source.as_bytes();
226    let kw: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
227    let kw_len = kw.len();
228    let src_len = bytes.len();
229    let skip = SkipMap::build(source);
230
231    let mut count = 0usize;
232    let mut i = 0usize;
233
234    while i + kw_len <= src_len {
235        if !skip.is_code(i) {
236            i += 1;
237            continue;
238        }
239
240        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
241        if !before_ok {
242            i += 1;
243            continue;
244        }
245
246        let matches = bytes[i..i + kw_len]
247            .iter()
248            .zip(kw.iter())
249            .all(|(&a, &b)| a.to_ascii_uppercase() == b);
250
251        if matches {
252            let end = i + kw_len;
253            let after_ok = end >= src_len || !is_word_char(bytes[end]);
254            let all_code = (i..end).all(|k| skip.is_code(k));
255
256            if after_ok && all_code {
257                if count == nth {
258                    return Some(i);
259                }
260                count += 1;
261            }
262        }
263
264        i += 1;
265    }
266
267    None
268}
269
270fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
271    let safe = offset.min(source.len());
272    let before = &source[..safe];
273    let line = before.chars().filter(|&c| c == '\n').count() + 1;
274    let col = before.rfind('\n').map(|p| safe - p - 1).unwrap_or(safe) + 1;
275    (line, col)
276}