sqrust_rules/ambiguous/
add_months_function.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct AddMonthsFunction;
4
5const FUNCTIONS: &[(&str, &str)] = &[
7 (
8 "ADD_MONTHS",
9 "ADD_MONTHS() is Oracle-specific; use standard interval arithmetic (date + INTERVAL n MONTH) for portable SQL",
10 ),
11 (
12 "MONTHS_BETWEEN",
13 "MONTHS_BETWEEN() is Oracle-specific; use DATEDIFF or interval subtraction depending on your dialect",
14 ),
15];
16
17impl Rule for AddMonthsFunction {
18 fn name(&self) -> &'static str {
19 "Ambiguous/AddMonthsFunction"
20 }
21
22 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
23 find_violations(&ctx.source, self.name())
24 }
25}
26
27fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
28 let bytes = source.as_bytes();
29 let len = bytes.len();
30
31 if len == 0 {
32 return Vec::new();
33 }
34
35 let skip = build_skip_set(bytes, len);
36 let mut diags = Vec::new();
37
38 for (func_name, message) in FUNCTIONS {
39 scan_for_function(source, bytes, len, &skip, func_name, message, rule_name, &mut diags);
40 }
41
42 diags.sort_by(|a, b| a.line.cmp(&b.line).then(a.col.cmp(&b.col)));
43 diags
44}
45
46fn scan_for_function(
48 source: &str,
49 bytes: &[u8],
50 len: usize,
51 skip: &[bool],
52 func_name: &str,
53 message: &str,
54 rule_name: &'static str,
55 diags: &mut Vec<Diagnostic>,
56) {
57 let kw = func_name.as_bytes();
58 let kw_len = kw.len();
59 let mut i = 0;
60
61 while i + kw_len <= len {
62 if skip[i] {
63 i += 1;
64 continue;
65 }
66
67 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
68 if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(kw) {
69 let after = i + kw_len;
70 let after_ok = after >= len || !is_word_char(bytes[after]);
72 if after_ok {
73 let mut j = after;
75 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
76 j += 1;
77 }
78 if j < len && bytes[j] == b'(' {
79 let (line, col) = line_col(source, i);
80 diags.push(Diagnostic {
81 rule: rule_name,
82 message: message.to_string(),
83 line,
84 col,
85 });
86 i += kw_len;
87 continue;
88 }
89 }
90 }
91
92 i += 1;
93 }
94}
95
96#[inline]
97fn is_word_char(ch: u8) -> bool {
98 ch.is_ascii_alphanumeric() || ch == b'_'
99}
100
101fn line_col(source: &str, offset: usize) -> (usize, usize) {
102 let before = &source[..offset.min(source.len())];
103 let line = before.chars().filter(|&c| c == '\n').count() + 1;
104 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
105 (line, col)
106}
107
108fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
111 let mut skip = vec![false; len];
112 let mut i = 0;
113
114 while i < len {
115 if bytes[i] == b'\'' {
117 skip[i] = true;
118 i += 1;
119 while i < len {
120 skip[i] = true;
121 if bytes[i] == b'\'' {
122 if i + 1 < len && bytes[i + 1] == b'\'' {
123 i += 1;
124 skip[i] = true;
125 i += 1;
126 continue;
127 }
128 i += 1;
129 break;
130 }
131 i += 1;
132 }
133 continue;
134 }
135
136 if bytes[i] == b'"' {
138 skip[i] = true;
139 i += 1;
140 while i < len {
141 skip[i] = true;
142 if bytes[i] == b'"' {
143 if i + 1 < len && bytes[i + 1] == b'"' {
144 i += 1;
145 skip[i] = true;
146 i += 1;
147 continue;
148 }
149 i += 1;
150 break;
151 }
152 i += 1;
153 }
154 continue;
155 }
156
157 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
159 skip[i] = true;
160 skip[i + 1] = true;
161 i += 2;
162 while i < len {
163 skip[i] = true;
164 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
165 skip[i + 1] = true;
166 i += 2;
167 break;
168 }
169 i += 1;
170 }
171 continue;
172 }
173
174 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
176 skip[i] = true;
177 skip[i + 1] = true;
178 i += 2;
179 while i < len && bytes[i] != b'\n' {
180 skip[i] = true;
181 i += 1;
182 }
183 continue;
184 }
185
186 i += 1;
187 }
188
189 skip
190}