Skip to main content

sqrust_rules/ambiguous/
case_when_same_result.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use crate::capitalisation::SkipMap;
3
4pub struct CaseWhenSameResult;
5
6impl Rule for CaseWhenSameResult {
7    fn name(&self) -> &'static str {
8        "Ambiguous/CaseWhenSameResult"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        let source = &ctx.source;
13        let bytes = source.as_bytes();
14        let len = bytes.len();
15        let skip = SkipMap::build(source);
16
17        let mut diags = Vec::new();
18        let mut i = 0;
19
20        while i < len {
21            if !skip.is_code(i) {
22                i += 1;
23                continue;
24            }
25
26            if let Some(after_case) = match_keyword_ci(bytes, len, &skip, i, b"CASE") {
27                if let Some(violation_pos) = check_case_expr(bytes, len, &skip, source, after_case) {
28                    let (line, col) = offset_to_line_col(source, violation_pos);
29                    diags.push(Diagnostic {
30                        rule: self.name(),
31                        message: "All CASE branches return the same value — the CASE expression is redundant".to_string(),
32                        line,
33                        col,
34                    });
35                }
36            }
37
38            i += 1;
39        }
40
41        diags
42    }
43}
44
45/// Checks a CASE expression starting right after the CASE keyword.
46/// Returns the offset of CASE (for error reporting) if all branches have the same
47/// single-token literal result, else None.
48fn check_case_expr(bytes: &[u8], len: usize, skip: &SkipMap, source: &str, after_case: usize) -> Option<usize> {
49    let case_kw_start = {
50        // Find the actual CASE position (4 bytes before after_case).
51        after_case.saturating_sub(4)
52    };
53
54    let mut pos = skip_code_whitespace_bytes(bytes, len, after_case);
55    let mut branch_values: Vec<String> = Vec::new();
56    let mut has_else = false;
57
58    loop {
59        pos = skip_code_whitespace_bytes(bytes, len, pos);
60        if pos >= len {
61            break;
62        }
63
64        if let Some(after_when) = match_keyword_ci(bytes, len, skip, pos, b"WHEN") {
65            // Skip past the condition until we reach THEN.
66            pos = after_when;
67            loop {
68                pos = skip_code_whitespace_bytes(bytes, len, pos);
69                if pos >= len {
70                    return None;
71                }
72                if let Some(after_then) = match_keyword_ci(bytes, len, skip, pos, b"THEN") {
73                    pos = after_then;
74                    break;
75                }
76                // Skip one character of the condition.
77                pos += 1;
78            }
79
80            // Parse the result literal after THEN.
81            pos = skip_code_whitespace_bytes(bytes, len, pos);
82            match extract_single_token_literal(bytes, len, skip, source, pos) {
83                Some((val, end)) => {
84                    branch_values.push(val);
85                    pos = end;
86                }
87                None => return None, // complex expression — skip this CASE
88            }
89        } else if let Some(after_else) = match_keyword_ci(bytes, len, skip, pos, b"ELSE") {
90            has_else = true;
91            pos = after_else;
92            pos = skip_code_whitespace_bytes(bytes, len, pos);
93
94            match extract_single_token_literal(bytes, len, skip, source, pos) {
95                Some((val, end)) => {
96                    branch_values.push(val);
97                    pos = end;
98                }
99                None => return None,
100            }
101        } else if match_keyword_ci(bytes, len, skip, pos, b"END").is_some() {
102            break;
103        } else if skip.is_code(pos) {
104            pos += 1;
105        } else {
106            pos += 1;
107        }
108    }
109
110    // Need at least 2 branches and they must all be the same value.
111    let total = branch_values.len();
112    if total < 2 {
113        return None;
114    }
115
116    // If no ELSE, we need at least 2 WHEN branches.
117    if !has_else && total < 2 {
118        return None;
119    }
120
121    let first = branch_values[0].to_lowercase();
122    let all_same = branch_values.iter().all(|v| v.to_lowercase() == first);
123
124    if all_same {
125        Some(case_kw_start)
126    } else {
127        None
128    }
129}
130
131/// Extracts a single-token literal starting at `pos`.
132/// Accepted literals: single-quoted strings, integers, NULL.
133/// Returns `(normalized_value, end_offset)` or None if complex.
134fn extract_single_token_literal(bytes: &[u8], len: usize, skip: &SkipMap, _source: &str, pos: usize) -> Option<(String, usize)> {
135    if pos >= len {
136        return None;
137    }
138
139    // Single-quoted string — the opening quote is marked as non-code in SkipMap,
140    // but we detect it by checking the raw byte before skip classification.
141    // Actually the opening quote byte is marked skip=true by SkipMap.build().
142    // We need to look at the raw bytes here.
143    if bytes[pos] == b'\'' {
144        // Find end of string (scan raw bytes).
145        let start = pos;
146        let mut p = pos + 1;
147        while p < len {
148            if bytes[p] == b'\'' {
149                if p + 1 < len && bytes[p + 1] == b'\'' {
150                    p += 2; // escaped quote
151                } else {
152                    p += 1;
153                    break;
154                }
155            } else {
156                p += 1;
157            }
158        }
159        let raw = std::str::from_utf8(&bytes[start..p]).ok()?;
160        // Normalize: lowercase the content (strip outer quotes for comparison).
161        let inner = &raw[1..raw.len().saturating_sub(1)];
162        return Some((inner.to_lowercase(), p));
163    }
164
165    // NULL keyword.
166    if let Some(after_null) = match_keyword_ci(bytes, len, skip, pos, b"NULL") {
167        return Some(("null".to_string(), after_null));
168    }
169
170    // Integer (possibly negative).
171    let mut p = pos;
172    let negative = skip.is_code(p) && bytes[p] == b'-';
173    if negative {
174        p += 1;
175        p = skip_code_whitespace_bytes(bytes, len, p);
176    }
177
178    if p < len && skip.is_code(p) && bytes[p].is_ascii_digit() {
179        let num_start = if negative { pos } else { p };
180        while p < len && skip.is_code(p) && bytes[p].is_ascii_digit() {
181            p += 1;
182        }
183        // Word boundary.
184        if p < len && skip.is_code(p) && (bytes[p].is_ascii_alphanumeric() || bytes[p] == b'_') {
185            return None;
186        }
187        let raw = std::str::from_utf8(&bytes[num_start..p]).ok()?;
188        return Some((raw.to_string(), p));
189    }
190
191    None
192}
193
194fn match_keyword_ci(bytes: &[u8], len: usize, skip: &SkipMap, pos: usize, keyword: &[u8]) -> Option<usize> {
195    let kw_len = keyword.len();
196    if pos + kw_len > len {
197        return None;
198    }
199    for k in 0..kw_len {
200        let b = pos + k;
201        if !skip.is_code(b) {
202            return None;
203        }
204        if bytes[b].to_ascii_uppercase() != keyword[k] {
205            return None;
206        }
207    }
208    let end = pos + kw_len;
209    if end < len && skip.is_code(end) && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
210        return None;
211    }
212    Some(end)
213}
214
215fn skip_code_whitespace_bytes(bytes: &[u8], len: usize, mut pos: usize) -> usize {
216    while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\t' || bytes[pos] == b'\n' || bytes[pos] == b'\r') {
217        pos += 1;
218    }
219    pos
220}
221
222fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
223    let before = &source[..offset.min(source.len())];
224    let line = before.chars().filter(|&c| c == '\n').count() + 1;
225    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
226    (line, col)
227}