Skip to main content

sqrust_rules/convention/
case_else.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct CaseElse;
5
6impl Rule for CaseElse {
7    fn name(&self) -> &'static str {
8        "Convention/CaseElse"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17
18        for stmt in &ctx.statements {
19            collect_from_statement(stmt, ctx, &mut diags);
20        }
21
22        diags
23    }
24}
25
26fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
27    if let Statement::Query(query) = stmt {
28        collect_from_query(query, ctx, diags);
29    }
30}
31
32fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
33    // Check CTEs.
34    if let Some(with) = &query.with {
35        for cte in &with.cte_tables {
36            collect_from_query(&cte.query, ctx, diags);
37        }
38    }
39
40    // Check ORDER BY expressions.
41    if let Some(order_by) = &query.order_by {
42        for ob_expr in &order_by.exprs {
43            check_expr(&ob_expr.expr, ctx, diags);
44        }
45    }
46
47    // Recurse into the body.
48    collect_from_set_expr(&query.body, ctx, diags);
49}
50
51fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
52    match expr {
53        SetExpr::Select(select) => {
54            collect_from_select(select, ctx, diags);
55        }
56        SetExpr::Query(inner) => {
57            collect_from_query(inner, ctx, diags);
58        }
59        SetExpr::SetOperation { left, right, .. } => {
60            collect_from_set_expr(left, ctx, diags);
61            collect_from_set_expr(right, ctx, diags);
62        }
63        _ => {}
64    }
65}
66
67fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
68    // SELECT projection.
69    for item in &select.projection {
70        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
71            check_expr(e, ctx, diags);
72        }
73    }
74
75    // FROM clause — derived subqueries.
76    for table_with_joins in &select.from {
77        collect_from_table_factor(&table_with_joins.relation, ctx, diags);
78        for join in &table_with_joins.joins {
79            collect_from_table_factor(&join.relation, ctx, diags);
80        }
81    }
82
83    // WHERE clause.
84    if let Some(selection) = &select.selection {
85        check_expr(selection, ctx, diags);
86    }
87
88    // HAVING clause.
89    if let Some(having) = &select.having {
90        check_expr(having, ctx, diags);
91    }
92}
93
94fn collect_from_table_factor(
95    factor: &TableFactor,
96    ctx: &FileContext,
97    diags: &mut Vec<Diagnostic>,
98) {
99    if let TableFactor::Derived { subquery, .. } = factor {
100        collect_from_query(subquery, ctx, diags);
101    }
102}
103
104/// Recursively checks an expression tree for CASE expressions missing ELSE.
105fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
106    match expr {
107        Expr::Case {
108            operand,
109            conditions,
110            results,
111            else_result,
112        } => {
113            // Flag if no ELSE clause.
114            if else_result.is_none() {
115                let (line, col) = find_keyword_pos(&ctx.source, "CASE");
116                diags.push(Diagnostic {
117                    rule: "Convention/CaseElse",
118                    message: "CASE expression has no ELSE clause; unmatched conditions will return NULL"
119                        .to_string(),
120                    line,
121                    col,
122                });
123            }
124
125            // Recurse into operand.
126            if let Some(op) = operand {
127                check_expr(op, ctx, diags);
128            }
129
130            // Recurse into WHEN conditions.
131            for cond in conditions {
132                check_expr(cond, ctx, diags);
133            }
134
135            // Recurse into THEN results.
136            for result in results {
137                check_expr(result, ctx, diags);
138            }
139
140            // Recurse into ELSE result.
141            if let Some(else_e) = else_result {
142                check_expr(else_e, ctx, diags);
143            }
144        }
145
146        // Recurse through other expression types.
147        Expr::BinaryOp { left, right, .. } => {
148            check_expr(left, ctx, diags);
149            check_expr(right, ctx, diags);
150        }
151        Expr::UnaryOp { expr: inner, .. } => {
152            check_expr(inner, ctx, diags);
153        }
154        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
155            check_expr(inner, ctx, diags);
156        }
157        Expr::IsDistinctFrom(left, right) | Expr::IsNotDistinctFrom(left, right) => {
158            check_expr(left, ctx, diags);
159            check_expr(right, ctx, diags);
160        }
161        Expr::InList { expr: inner, list, .. } => {
162            check_expr(inner, ctx, diags);
163            for e in list {
164                check_expr(e, ctx, diags);
165            }
166        }
167        Expr::Between {
168            expr: inner,
169            low,
170            high,
171            ..
172        } => {
173            check_expr(inner, ctx, diags);
174            check_expr(low, ctx, diags);
175            check_expr(high, ctx, diags);
176        }
177        Expr::Function(f) => {
178            if let sqlparser::ast::FunctionArguments::List(arg_list) = &f.args {
179                for arg in &arg_list.args {
180                    if let sqlparser::ast::FunctionArg::Unnamed(
181                        sqlparser::ast::FunctionArgExpr::Expr(e),
182                    ) = arg
183                    {
184                        check_expr(e, ctx, diags);
185                    }
186                }
187            }
188        }
189        Expr::Cast { expr: inner, .. } => {
190            check_expr(inner, ctx, diags);
191        }
192        Expr::Nested(inner) => {
193            check_expr(inner, ctx, diags);
194        }
195        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
196            collect_from_query(q, ctx, diags);
197        }
198        _ => {}
199    }
200}
201
202/// Finds the first occurrence of `keyword` (case-insensitive, word-boundary)
203/// in `source` and returns a 1-indexed (line, col) pair.
204/// Falls back to (1, 1) if not found.
205fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
206    let upper = source.to_uppercase();
207    let kw_upper = keyword.to_uppercase();
208    let kw_len = kw_upper.len();
209    let bytes = upper.as_bytes();
210    let len = bytes.len();
211
212    let mut pos = 0;
213    while pos + kw_len <= len {
214        if let Some(rel) = upper[pos..].find(kw_upper.as_str()) {
215            let abs = pos + rel;
216
217            let before_ok = abs == 0 || {
218                let b = bytes[abs - 1];
219                !b.is_ascii_alphanumeric() && b != b'_'
220            };
221            let after = abs + kw_len;
222            let after_ok = after >= len || {
223                let b = bytes[after];
224                !b.is_ascii_alphanumeric() && b != b'_'
225            };
226
227            if before_ok && after_ok {
228                return line_col(source, abs);
229            }
230
231            pos = abs + 1;
232        } else {
233            break;
234        }
235    }
236
237    (1, 1)
238}
239
240/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
241fn line_col(source: &str, offset: usize) -> (usize, usize) {
242    let before = &source[..offset];
243    let line = before.chars().filter(|&c| c == '\n').count() + 1;
244    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
245    (line, col)
246}