sqrust_rules/ambiguous/
case_when_same_result.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use crate::capitalisation::SkipMap;
3
4pub struct CaseWhenSameResult;
5
6impl Rule for CaseWhenSameResult {
7 fn name(&self) -> &'static str {
8 "Ambiguous/CaseWhenSameResult"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 let source = &ctx.source;
13 let bytes = source.as_bytes();
14 let len = bytes.len();
15 let skip = SkipMap::build(source);
16
17 let mut diags = Vec::new();
18 let mut i = 0;
19
20 while i < len {
21 if !skip.is_code(i) {
22 i += 1;
23 continue;
24 }
25
26 if let Some(after_case) = match_keyword_ci(bytes, len, &skip, i, b"CASE") {
27 if let Some(violation_pos) = check_case_expr(bytes, len, &skip, source, after_case) {
28 let (line, col) = offset_to_line_col(source, violation_pos);
29 diags.push(Diagnostic {
30 rule: self.name(),
31 message: "All CASE branches return the same value — the CASE expression is redundant".to_string(),
32 line,
33 col,
34 });
35 }
36 }
37
38 i += 1;
39 }
40
41 diags
42 }
43}
44
45fn check_case_expr(bytes: &[u8], len: usize, skip: &SkipMap, source: &str, after_case: usize) -> Option<usize> {
49 let case_kw_start = {
50 after_case.saturating_sub(4)
52 };
53
54 let mut pos = skip_code_whitespace_bytes(bytes, len, after_case);
55 let mut branch_values: Vec<String> = Vec::new();
56 let mut has_else = false;
57
58 loop {
59 pos = skip_code_whitespace_bytes(bytes, len, pos);
60 if pos >= len {
61 break;
62 }
63
64 if let Some(after_when) = match_keyword_ci(bytes, len, skip, pos, b"WHEN") {
65 pos = after_when;
67 loop {
68 pos = skip_code_whitespace_bytes(bytes, len, pos);
69 if pos >= len {
70 return None;
71 }
72 if let Some(after_then) = match_keyword_ci(bytes, len, skip, pos, b"THEN") {
73 pos = after_then;
74 break;
75 }
76 pos += 1;
78 }
79
80 pos = skip_code_whitespace_bytes(bytes, len, pos);
82 match extract_single_token_literal(bytes, len, skip, source, pos) {
83 Some((val, end)) => {
84 branch_values.push(val);
85 pos = end;
86 }
87 None => return None, }
89 } else if let Some(after_else) = match_keyword_ci(bytes, len, skip, pos, b"ELSE") {
90 has_else = true;
91 pos = after_else;
92 pos = skip_code_whitespace_bytes(bytes, len, pos);
93
94 match extract_single_token_literal(bytes, len, skip, source, pos) {
95 Some((val, end)) => {
96 branch_values.push(val);
97 pos = end;
98 }
99 None => return None,
100 }
101 } else if match_keyword_ci(bytes, len, skip, pos, b"END").is_some() {
102 break;
103 } else if skip.is_code(pos) {
104 pos += 1;
105 } else {
106 pos += 1;
107 }
108 }
109
110 let total = branch_values.len();
112 if total < 2 {
113 return None;
114 }
115
116 if !has_else && total < 2 {
118 return None;
119 }
120
121 let first = branch_values[0].to_lowercase();
122 let all_same = branch_values.iter().all(|v| v.to_lowercase() == first);
123
124 if all_same {
125 Some(case_kw_start)
126 } else {
127 None
128 }
129}
130
131fn extract_single_token_literal(bytes: &[u8], len: usize, skip: &SkipMap, _source: &str, pos: usize) -> Option<(String, usize)> {
135 if pos >= len {
136 return None;
137 }
138
139 if bytes[pos] == b'\'' {
144 let start = pos;
146 let mut p = pos + 1;
147 while p < len {
148 if bytes[p] == b'\'' {
149 if p + 1 < len && bytes[p + 1] == b'\'' {
150 p += 2; } else {
152 p += 1;
153 break;
154 }
155 } else {
156 p += 1;
157 }
158 }
159 let raw = std::str::from_utf8(&bytes[start..p]).ok()?;
160 let inner = &raw[1..raw.len().saturating_sub(1)];
162 return Some((inner.to_lowercase(), p));
163 }
164
165 if let Some(after_null) = match_keyword_ci(bytes, len, skip, pos, b"NULL") {
167 return Some(("null".to_string(), after_null));
168 }
169
170 let mut p = pos;
172 let negative = skip.is_code(p) && bytes[p] == b'-';
173 if negative {
174 p += 1;
175 p = skip_code_whitespace_bytes(bytes, len, p);
176 }
177
178 if p < len && skip.is_code(p) && bytes[p].is_ascii_digit() {
179 let num_start = if negative { pos } else { p };
180 while p < len && skip.is_code(p) && bytes[p].is_ascii_digit() {
181 p += 1;
182 }
183 if p < len && skip.is_code(p) && (bytes[p].is_ascii_alphanumeric() || bytes[p] == b'_') {
185 return None;
186 }
187 let raw = std::str::from_utf8(&bytes[num_start..p]).ok()?;
188 return Some((raw.to_string(), p));
189 }
190
191 None
192}
193
194fn match_keyword_ci(bytes: &[u8], len: usize, skip: &SkipMap, pos: usize, keyword: &[u8]) -> Option<usize> {
195 let kw_len = keyword.len();
196 if pos + kw_len > len {
197 return None;
198 }
199 for k in 0..kw_len {
200 let b = pos + k;
201 if !skip.is_code(b) {
202 return None;
203 }
204 if bytes[b].to_ascii_uppercase() != keyword[k] {
205 return None;
206 }
207 }
208 let end = pos + kw_len;
209 if end < len && skip.is_code(end) && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
210 return None;
211 }
212 Some(end)
213}
214
215fn skip_code_whitespace_bytes(bytes: &[u8], len: usize, mut pos: usize) -> usize {
216 while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\t' || bytes[pos] == b'\n' || bytes[pos] == b'\r') {
217 pos += 1;
218 }
219 pos
220}
221
222fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
223 let before = &source[..offset.min(source.len())];
224 let line = before.chars().filter(|&c| c == '\n').count() + 1;
225 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
226 (line, col)
227}