Skip to main content

sqrust_rules/structure/
case_when_count.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct CaseWhenCount {
7    /// Maximum number of WHEN clauses allowed per CASE expression.
8    /// CASE expressions with more branches than this are flagged.
9    pub max_when_clauses: usize,
10}
11
12impl Default for CaseWhenCount {
13    fn default() -> Self {
14        CaseWhenCount {
15            max_when_clauses: 5,
16        }
17    }
18}
19
20impl Rule for CaseWhenCount {
21    fn name(&self) -> &'static str {
22        "CaseWhenCount"
23    }
24
25    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
26        if !ctx.parse_errors.is_empty() {
27            return Vec::new();
28        }
29
30        let mut diags = Vec::new();
31        // Track how many CASE keywords we've consumed so we can map each
32        // violation to the correct source position.
33        let mut case_occurrence: usize = 0;
34
35        for stmt in &ctx.statements {
36            if let Statement::Query(query) = stmt {
37                check_query(query, self.max_when_clauses, ctx, &mut case_occurrence, &mut diags);
38            }
39        }
40
41        diags
42    }
43}
44
45// ── AST walking ───────────────────────────────────────────────────────────────
46
47fn check_query(
48    query: &Query,
49    max: usize,
50    ctx: &FileContext,
51    occurrence: &mut usize,
52    diags: &mut Vec<Diagnostic>,
53) {
54    if let Some(with) = &query.with {
55        for cte in &with.cte_tables {
56            check_query(&cte.query, max, ctx, occurrence, diags);
57        }
58    }
59
60    check_set_expr(&query.body, max, ctx, occurrence, diags);
61}
62
63fn check_set_expr(
64    expr: &SetExpr,
65    max: usize,
66    ctx: &FileContext,
67    occurrence: &mut usize,
68    diags: &mut Vec<Diagnostic>,
69) {
70    match expr {
71        SetExpr::Select(sel) => check_select(sel, max, ctx, occurrence, diags),
72        SetExpr::Query(inner) => check_query(inner, max, ctx, occurrence, diags),
73        SetExpr::SetOperation { left, right, .. } => {
74            check_set_expr(left, max, ctx, occurrence, diags);
75            check_set_expr(right, max, ctx, occurrence, diags);
76        }
77        _ => {}
78    }
79}
80
81fn check_select(
82    sel: &Select,
83    max: usize,
84    ctx: &FileContext,
85    occurrence: &mut usize,
86    diags: &mut Vec<Diagnostic>,
87) {
88    // SELECT projection.
89    for item in &sel.projection {
90        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
91            check_expr(e, max, ctx, occurrence, diags);
92        }
93    }
94
95    // WHERE clause.
96    if let Some(selection) = &sel.selection {
97        check_expr(selection, max, ctx, occurrence, diags);
98    }
99
100    // FROM / JOIN — recurse into derived tables.
101    for twj in &sel.from {
102        check_table_factor(&twj.relation, max, ctx, occurrence, diags);
103        for join in &twj.joins {
104            check_table_factor(&join.relation, max, ctx, occurrence, diags);
105        }
106    }
107}
108
109fn check_table_factor(
110    tf: &TableFactor,
111    max: usize,
112    ctx: &FileContext,
113    occurrence: &mut usize,
114    diags: &mut Vec<Diagnostic>,
115) {
116    if let TableFactor::Derived { subquery, .. } = tf {
117        check_query(subquery, max, ctx, occurrence, diags);
118    }
119}
120
121fn check_expr(
122    expr: &Expr,
123    max: usize,
124    ctx: &FileContext,
125    occurrence: &mut usize,
126    diags: &mut Vec<Diagnostic>,
127) {
128    match expr {
129        Expr::Case {
130            operand,
131            conditions,
132            results,
133            else_result,
134        } => {
135            let n = conditions.len();
136            // Consume one CASE occurrence.
137            let occ = *occurrence;
138            *occurrence += 1;
139
140            if n > max {
141                let (line, col) = find_nth_keyword_pos(&ctx.source, "CASE", occ);
142                diags.push(Diagnostic {
143                    rule: "CaseWhenCount",
144                    message: format!(
145                        "CASE expression has {n} WHEN clauses, exceeding the maximum of {max}"
146                    ),
147                    line,
148                    col,
149                });
150            }
151
152            // Recurse into operand.
153            if let Some(op) = operand {
154                check_expr(op, max, ctx, occurrence, diags);
155            }
156
157            // Recurse into each condition and result expression.
158            for cond in conditions {
159                check_expr(cond, max, ctx, occurrence, diags);
160            }
161            for res in results {
162                check_expr(res, max, ctx, occurrence, diags);
163            }
164            if let Some(els) = else_result {
165                check_expr(els, max, ctx, occurrence, diags);
166            }
167        }
168
169        Expr::BinaryOp { left, right, .. } => {
170            check_expr(left, max, ctx, occurrence, diags);
171            check_expr(right, max, ctx, occurrence, diags);
172        }
173        Expr::UnaryOp { expr: inner, .. } => {
174            check_expr(inner, max, ctx, occurrence, diags);
175        }
176        Expr::Subquery(q) => check_query(q, max, ctx, occurrence, diags),
177        Expr::InSubquery { subquery, expr: e, .. } => {
178            check_expr(e, max, ctx, occurrence, diags);
179            check_query(subquery, max, ctx, occurrence, diags);
180        }
181        Expr::Exists { subquery, .. } => check_query(subquery, max, ctx, occurrence, diags),
182        Expr::Nested(inner) => check_expr(inner, max, ctx, occurrence, diags),
183        Expr::Function(f) => {
184            use sqlparser::ast::{FunctionArg, FunctionArgExpr, FunctionArguments};
185            if let FunctionArguments::List(list) = &f.args {
186                for arg in &list.args {
187                    let arg_expr = match arg {
188                        FunctionArg::Unnamed(e) => Some(e),
189                        FunctionArg::Named { arg: e, .. } => Some(e),
190                        FunctionArg::ExprNamed { arg: e, .. } => Some(e),
191                    };
192                    if let Some(FunctionArgExpr::Expr(inner)) = arg_expr {
193                        check_expr(inner, max, ctx, occurrence, diags);
194                    }
195                }
196            }
197        }
198        _ => {}
199    }
200}
201
202// ── keyword position helper ───────────────────────────────────────────────────
203
204/// Find the `nth` occurrence (0-indexed) of a keyword (case-insensitive,
205/// word-boundary, outside strings/comments) in `source`.
206/// Returns a 1-indexed (line, col) pair. Falls back to (1, 1) if not found.
207fn find_nth_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
208    let bytes = source.as_bytes();
209    let len = bytes.len();
210    let skip_map = SkipMap::build(source);
211    let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
212    let kw_len = kw_upper.len();
213
214    let mut count = 0usize;
215    let mut i = 0;
216    while i + kw_len <= len {
217        if !skip_map.is_code(i) {
218            i += 1;
219            continue;
220        }
221
222        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
223        if !before_ok {
224            i += 1;
225            continue;
226        }
227
228        let matches = bytes[i..i + kw_len]
229            .iter()
230            .zip(kw_upper.iter())
231            .all(|(a, b)| a.eq_ignore_ascii_case(b));
232
233        if matches {
234            let after = i + kw_len;
235            let after_ok = after >= len || !is_word_char(bytes[after]);
236            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
237
238            if after_ok && all_code {
239                if count == nth {
240                    return line_col(source, i);
241                }
242                count += 1;
243            }
244        }
245
246        i += 1;
247    }
248
249    (1, 1)
250}
251
252/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
253fn line_col(source: &str, offset: usize) -> (usize, usize) {
254    let before = &source[..offset];
255    let line = before.chars().filter(|&c| c == '\n').count() + 1;
256    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
257    (line, col)
258}