Skip to main content

sqrust_rules/ambiguous/
add_months_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct AddMonthsFunction;
4
5/// Function specs: (uppercase name, message).
6const FUNCTIONS: &[(&str, &str)] = &[
7    (
8        "ADD_MONTHS",
9        "ADD_MONTHS() is Oracle-specific; use standard interval arithmetic (date + INTERVAL n MONTH) for portable SQL",
10    ),
11    (
12        "MONTHS_BETWEEN",
13        "MONTHS_BETWEEN() is Oracle-specific; use DATEDIFF or interval subtraction depending on your dialect",
14    ),
15];
16
17impl Rule for AddMonthsFunction {
18    fn name(&self) -> &'static str {
19        "Ambiguous/AddMonthsFunction"
20    }
21
22    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
23        find_violations(&ctx.source, self.name())
24    }
25}
26
27fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
28    let bytes = source.as_bytes();
29    let len = bytes.len();
30
31    if len == 0 {
32        return Vec::new();
33    }
34
35    let skip = build_skip_set(bytes, len);
36    let mut diags = Vec::new();
37
38    for (func_name, message) in FUNCTIONS {
39        scan_for_function(source, bytes, len, &skip, func_name, message, rule_name, &mut diags);
40    }
41
42    diags.sort_by(|a, b| a.line.cmp(&b.line).then(a.col.cmp(&b.col)));
43    diags
44}
45
46/// Scan for `func_name(` (case-insensitive) with word boundaries on both sides.
47fn scan_for_function(
48    source: &str,
49    bytes: &[u8],
50    len: usize,
51    skip: &[bool],
52    func_name: &str,
53    message: &str,
54    rule_name: &'static str,
55    diags: &mut Vec<Diagnostic>,
56) {
57    let kw = func_name.as_bytes();
58    let kw_len = kw.len();
59    let mut i = 0;
60
61    while i + kw_len <= len {
62        if skip[i] {
63            i += 1;
64            continue;
65        }
66
67        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
68        if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(kw) {
69            let after = i + kw_len;
70            // Word boundary after: next char must not be a word char
71            let after_ok = after >= len || !is_word_char(bytes[after]);
72            if after_ok {
73                // Skip optional whitespace then check for '('
74                let mut j = after;
75                while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
76                    j += 1;
77                }
78                if j < len && bytes[j] == b'(' {
79                    let (line, col) = line_col(source, i);
80                    diags.push(Diagnostic {
81                        rule: rule_name,
82                        message: message.to_string(),
83                        line,
84                        col,
85                    });
86                    i += kw_len;
87                    continue;
88                }
89            }
90        }
91
92        i += 1;
93    }
94}
95
96#[inline]
97fn is_word_char(ch: u8) -> bool {
98    ch.is_ascii_alphanumeric() || ch == b'_'
99}
100
101fn line_col(source: &str, offset: usize) -> (usize, usize) {
102    let before = &source[..offset.min(source.len())];
103    let line = before.chars().filter(|&c| c == '\n').count() + 1;
104    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
105    (line, col)
106}
107
108/// Build a boolean skip-set: `skip[i] == true` means byte `i` is inside a
109/// single-quoted string, double-quoted identifier, block comment, or line comment.
110fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
111    let mut skip = vec![false; len];
112    let mut i = 0;
113
114    while i < len {
115        // Single-quoted string: '...' with '' escape.
116        if bytes[i] == b'\'' {
117            skip[i] = true;
118            i += 1;
119            while i < len {
120                skip[i] = true;
121                if bytes[i] == b'\'' {
122                    if i + 1 < len && bytes[i + 1] == b'\'' {
123                        i += 1;
124                        skip[i] = true;
125                        i += 1;
126                        continue;
127                    }
128                    i += 1;
129                    break;
130                }
131                i += 1;
132            }
133            continue;
134        }
135
136        // Double-quoted identifier: "..." with "" escape.
137        if bytes[i] == b'"' {
138            skip[i] = true;
139            i += 1;
140            while i < len {
141                skip[i] = true;
142                if bytes[i] == b'"' {
143                    if i + 1 < len && bytes[i + 1] == b'"' {
144                        i += 1;
145                        skip[i] = true;
146                        i += 1;
147                        continue;
148                    }
149                    i += 1;
150                    break;
151                }
152                i += 1;
153            }
154            continue;
155        }
156
157        // Block comment: /* ... */
158        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
159            skip[i] = true;
160            skip[i + 1] = true;
161            i += 2;
162            while i < len {
163                skip[i] = true;
164                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
165                    skip[i + 1] = true;
166                    i += 2;
167                    break;
168                }
169                i += 1;
170            }
171            continue;
172        }
173
174        // Line comment: -- to end of line.
175        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
176            skip[i] = true;
177            skip[i + 1] = true;
178            i += 2;
179            while i < len && bytes[i] != b'\n' {
180                skip[i] = true;
181                i += 1;
182            }
183            continue;
184        }
185
186        i += 1;
187    }
188
189    skip
190}