Skip to main content

sqrust_rules/convention/
no_decode_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NoDecodeFunction;
4
5const MESSAGE: &str =
6    "DECODE() is an Oracle-specific function; use CASE WHEN ... THEN ... END instead";
7
8impl Rule for NoDecodeFunction {
9    fn name(&self) -> &'static str {
10        "Convention/NoDecodeFunction"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        find_violations(&ctx.source, self.name())
15    }
16}
17
18fn build_skip_set(source: &str) -> std::collections::HashSet<usize> {
19    let mut skip = std::collections::HashSet::new();
20    let bytes = source.as_bytes();
21    let len = bytes.len();
22    let mut i = 0;
23    while i < len {
24        if bytes[i] == b'\'' {
25            i += 1;
26            while i < len {
27                if bytes[i] == b'\'' {
28                    if i + 1 < len && bytes[i + 1] == b'\'' {
29                        skip.insert(i);
30                        i += 2;
31                    } else {
32                        i += 1;
33                        break;
34                    }
35                } else {
36                    skip.insert(i);
37                    i += 1;
38                }
39            }
40        } else if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
41            while i < len && bytes[i] != b'\n' {
42                skip.insert(i);
43                i += 1;
44            }
45        } else {
46            i += 1;
47        }
48    }
49    skip
50}
51
52/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
53fn line_col(source: &str, offset: usize) -> (usize, usize) {
54    let before = &source[..offset];
55    let line = before.chars().filter(|&c| c == '\n').count() + 1;
56    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
57    (line, col)
58}
59
60#[inline]
61fn is_word_char(ch: u8) -> bool {
62    ch.is_ascii_alphanumeric() || ch == b'_'
63}
64
65fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
66    let bytes = source.as_bytes();
67    let len = bytes.len();
68
69    if len == 0 {
70        return Vec::new();
71    }
72
73    let skip = build_skip_set(source);
74    let mut diags = Vec::new();
75
76    // "DECODE" is 6 characters
77    let keyword = b"DECODE";
78    let kw_len = keyword.len();
79
80    let mut i = 0;
81    while i + kw_len <= len {
82        // Skip positions inside string literals or comments
83        if skip.contains(&i) {
84            i += 1;
85            continue;
86        }
87
88        // Check word boundary before
89        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
90        if !before_ok {
91            i += 1;
92            continue;
93        }
94
95        // Case-insensitive match of "DECODE"
96        if !bytes[i..i + kw_len].eq_ignore_ascii_case(keyword) {
97            i += 1;
98            continue;
99        }
100
101        // Ensure none of the keyword bytes are in string/comment
102        let all_code = (0..kw_len).all(|k| !skip.contains(&(i + k)));
103        if !all_code {
104            i += 1;
105            continue;
106        }
107
108        let kw_end = i + kw_len;
109
110        // Must be immediately followed by '(' to be a function call
111        if kw_end >= len || bytes[kw_end] != b'(' {
112            i += 1;
113            continue;
114        }
115
116        let (line, col) = line_col(source, i);
117        diags.push(Diagnostic {
118            rule: rule_name,
119            message: MESSAGE.to_string(),
120            line,
121            col,
122        });
123
124        i = kw_end + 1;
125    }
126
127    diags
128}