1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct CaseElse;
5
6impl Rule for CaseElse {
7 fn name(&self) -> &'static str {
8 "Convention/CaseElse"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
13 return Vec::new();
14 }
15
16 let mut diags = Vec::new();
17
18 for stmt in &ctx.statements {
19 collect_from_statement(stmt, ctx, &mut diags);
20 }
21
22 diags
23 }
24}
25
26fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
27 if let Statement::Query(query) = stmt {
28 collect_from_query(query, ctx, diags);
29 }
30}
31
32fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
33 if let Some(with) = &query.with {
35 for cte in &with.cte_tables {
36 collect_from_query(&cte.query, ctx, diags);
37 }
38 }
39
40 if let Some(order_by) = &query.order_by {
42 for ob_expr in &order_by.exprs {
43 check_expr(&ob_expr.expr, ctx, diags);
44 }
45 }
46
47 collect_from_set_expr(&query.body, ctx, diags);
49}
50
51fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
52 match expr {
53 SetExpr::Select(select) => {
54 collect_from_select(select, ctx, diags);
55 }
56 SetExpr::Query(inner) => {
57 collect_from_query(inner, ctx, diags);
58 }
59 SetExpr::SetOperation { left, right, .. } => {
60 collect_from_set_expr(left, ctx, diags);
61 collect_from_set_expr(right, ctx, diags);
62 }
63 _ => {}
64 }
65}
66
67fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
68 for item in &select.projection {
70 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
71 check_expr(e, ctx, diags);
72 }
73 }
74
75 for table_with_joins in &select.from {
77 collect_from_table_factor(&table_with_joins.relation, ctx, diags);
78 for join in &table_with_joins.joins {
79 collect_from_table_factor(&join.relation, ctx, diags);
80 }
81 }
82
83 if let Some(selection) = &select.selection {
85 check_expr(selection, ctx, diags);
86 }
87
88 if let Some(having) = &select.having {
90 check_expr(having, ctx, diags);
91 }
92}
93
94fn collect_from_table_factor(
95 factor: &TableFactor,
96 ctx: &FileContext,
97 diags: &mut Vec<Diagnostic>,
98) {
99 if let TableFactor::Derived { subquery, .. } = factor {
100 collect_from_query(subquery, ctx, diags);
101 }
102}
103
104fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
106 match expr {
107 Expr::Case {
108 operand,
109 conditions,
110 results,
111 else_result,
112 } => {
113 if else_result.is_none() {
115 let (line, col) = find_keyword_pos(&ctx.source, "CASE");
116 diags.push(Diagnostic {
117 rule: "Convention/CaseElse",
118 message: "CASE expression has no ELSE clause; unmatched conditions will return NULL"
119 .to_string(),
120 line,
121 col,
122 });
123 }
124
125 if let Some(op) = operand {
127 check_expr(op, ctx, diags);
128 }
129
130 for cond in conditions {
132 check_expr(cond, ctx, diags);
133 }
134
135 for result in results {
137 check_expr(result, ctx, diags);
138 }
139
140 if let Some(else_e) = else_result {
142 check_expr(else_e, ctx, diags);
143 }
144 }
145
146 Expr::BinaryOp { left, right, .. } => {
148 check_expr(left, ctx, diags);
149 check_expr(right, ctx, diags);
150 }
151 Expr::UnaryOp { expr: inner, .. } => {
152 check_expr(inner, ctx, diags);
153 }
154 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
155 check_expr(inner, ctx, diags);
156 }
157 Expr::IsDistinctFrom(left, right) | Expr::IsNotDistinctFrom(left, right) => {
158 check_expr(left, ctx, diags);
159 check_expr(right, ctx, diags);
160 }
161 Expr::InList { expr: inner, list, .. } => {
162 check_expr(inner, ctx, diags);
163 for e in list {
164 check_expr(e, ctx, diags);
165 }
166 }
167 Expr::Between {
168 expr: inner,
169 low,
170 high,
171 ..
172 } => {
173 check_expr(inner, ctx, diags);
174 check_expr(low, ctx, diags);
175 check_expr(high, ctx, diags);
176 }
177 Expr::Function(f) => {
178 if let sqlparser::ast::FunctionArguments::List(arg_list) = &f.args {
179 for arg in &arg_list.args {
180 if let sqlparser::ast::FunctionArg::Unnamed(
181 sqlparser::ast::FunctionArgExpr::Expr(e),
182 ) = arg
183 {
184 check_expr(e, ctx, diags);
185 }
186 }
187 }
188 }
189 Expr::Cast { expr: inner, .. } => {
190 check_expr(inner, ctx, diags);
191 }
192 Expr::Nested(inner) => {
193 check_expr(inner, ctx, diags);
194 }
195 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
196 collect_from_query(q, ctx, diags);
197 }
198 _ => {}
199 }
200}
201
202fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
206 let upper = source.to_uppercase();
207 let kw_upper = keyword.to_uppercase();
208 let kw_len = kw_upper.len();
209 let bytes = upper.as_bytes();
210 let len = bytes.len();
211
212 let mut pos = 0;
213 while pos + kw_len <= len {
214 if let Some(rel) = upper[pos..].find(kw_upper.as_str()) {
215 let abs = pos + rel;
216
217 let before_ok = abs == 0 || {
218 let b = bytes[abs - 1];
219 !b.is_ascii_alphanumeric() && b != b'_'
220 };
221 let after = abs + kw_len;
222 let after_ok = after >= len || {
223 let b = bytes[after];
224 !b.is_ascii_alphanumeric() && b != b'_'
225 };
226
227 if before_ok && after_ok {
228 return line_col(source, abs);
229 }
230
231 pos = abs + 1;
232 } else {
233 break;
234 }
235 }
236
237 (1, 1)
238}
239
240fn line_col(source: &str, offset: usize) -> (usize, usize) {
242 let before = &source[..offset];
243 let line = before.chars().filter(|&c| c == '\n').count() + 1;
244 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
245 (line, col)
246}