sqrust_rules/convention/
in_null_comparison.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct InNullComparison;
4
5fn keyword_at_boundary(bytes: &[u8], offset: usize, pattern: &[u8]) -> bool {
8 let end = offset + pattern.len();
9 if end > bytes.len() {
10 return false;
11 }
12 if offset > 0 && (bytes[offset - 1].is_ascii_alphanumeric() || bytes[offset - 1] == b'_') {
14 return false;
15 }
16 let matches = bytes[offset..end]
18 .iter()
19 .zip(pattern.iter())
20 .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
21 if !matches {
22 return false;
23 }
24 if end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
26 return false;
27 }
28 true
29}
30
31fn line_col(source: &str, offset: usize) -> (usize, usize) {
33 let before = &source[..offset];
34 let line = before.chars().filter(|&c| c == '\n').count() + 1;
35 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
36 (line, col)
37}
38
39fn build_skip(bytes: &[u8]) -> Vec<bool> {
42 let len = bytes.len();
43 let mut skip = vec![false; len];
44 let mut i = 0;
45
46 while i < len {
47 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
49 skip[i] = true;
50 skip[i + 1] = true;
51 i += 2;
52 while i < len && bytes[i] != b'\n' {
53 skip[i] = true;
54 i += 1;
55 }
56 continue;
57 }
58
59 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
61 skip[i] = true;
62 skip[i + 1] = true;
63 i += 2;
64 while i < len {
65 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
66 skip[i] = true;
67 skip[i + 1] = true;
68 i += 2;
69 break;
70 }
71 skip[i] = true;
72 i += 1;
73 }
74 continue;
75 }
76
77 if bytes[i] == b'\'' {
79 skip[i] = true;
80 i += 1;
81 while i < len {
82 if bytes[i] == b'\'' {
83 skip[i] = true;
84 i += 1;
85 if i < len && bytes[i] == b'\'' {
86 skip[i] = true;
87 i += 1;
88 continue;
89 }
90 break;
91 }
92 skip[i] = true;
93 i += 1;
94 }
95 continue;
96 }
97
98 if bytes[i] == b'"' {
100 skip[i] = true;
101 i += 1;
102 while i < len && bytes[i] != b'"' {
103 skip[i] = true;
104 i += 1;
105 }
106 if i < len {
107 skip[i] = true;
108 i += 1;
109 }
110 continue;
111 }
112
113 if bytes[i] == b'`' {
115 skip[i] = true;
116 i += 1;
117 while i < len && bytes[i] != b'`' {
118 skip[i] = true;
119 i += 1;
120 }
121 if i < len {
122 skip[i] = true;
123 i += 1;
124 }
125 continue;
126 }
127
128 i += 1;
129 }
130
131 skip
132}
133
134struct Match {
136 in_offset: usize,
138 is_not_in: bool,
140}
141
142fn find_matches(source: &str, skip: &[bool]) -> Vec<Match> {
145 let bytes = source.as_bytes();
146 let len = bytes.len();
147 let mut matches = Vec::new();
148 let mut i = 0;
149
150 while i < len {
151 if skip[i] {
152 i += 1;
153 continue;
154 }
155
156 let is_not_in = keyword_at_boundary(bytes, i, b"NOT") && !skip[i];
159 if is_not_in {
160 let mut j = i + 3; while j < len && bytes[j].is_ascii_whitespace() {
163 j += 1;
164 }
165 if j < len && !skip[j] && keyword_at_boundary(bytes, j, b"IN") {
166 let in_offset = j;
167 let mut k = j + 2;
169 while k < len && bytes[k].is_ascii_whitespace() {
170 k += 1;
171 }
172 if k < len && bytes[k] == b'(' && !skip[k] {
173 if let Some(m) = check_paren_null(bytes, skip, k) {
174 if m {
175 matches.push(Match { in_offset, is_not_in: true });
176 i = k + 1;
177 continue;
178 }
179 }
180 }
181 }
182 }
183
184 if !skip[i] && keyword_at_boundary(bytes, i, b"IN") {
186 let in_offset = i;
187 let mut j = i + 2; while j < len && bytes[j].is_ascii_whitespace() {
189 j += 1;
190 }
191 if j < len && bytes[j] == b'(' && !skip[j] {
192 if let Some(m) = check_paren_null(bytes, skip, j) {
193 if m {
194 matches.push(Match { in_offset, is_not_in: false });
195 i = j + 1;
196 continue;
197 }
198 }
199 }
200 }
201
202 i += 1;
203 }
204
205 matches
206}
207
208fn check_paren_null(bytes: &[u8], skip: &[bool], open_paren: usize) -> Option<bool> {
213 let len = bytes.len();
214 let mut i = open_paren + 1; while i < len && bytes[i].is_ascii_whitespace() {
218 i += 1;
219 }
220
221 if i + 4 > len {
223 return Some(false);
224 }
225 let null_start = i;
226 let is_null = bytes[null_start..null_start + 4]
228 .iter()
229 .zip(b"NULL".iter())
230 .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
231 if !is_null {
232 return Some(false);
233 }
234
235 for k in null_start..null_start + 4 {
238 if skip[k] {
239 return Some(false);
240 }
241 }
242
243 i = null_start + 4;
244
245 if i < len && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
247 return Some(false);
248 }
249
250 while i < len && bytes[i].is_ascii_whitespace() {
252 i += 1;
253 }
254
255 if i < len && bytes[i] == b')' && !skip[i] {
257 Some(true)
258 } else {
259 Some(false)
260 }
261}
262
263impl Rule for InNullComparison {
264 fn name(&self) -> &'static str {
265 "Convention/InNullComparison"
266 }
267
268 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
269 let source = &ctx.source;
270 let skip = build_skip(source.as_bytes());
271 let matches = find_matches(source, &skip);
272
273 matches
274 .into_iter()
275 .map(|m| {
276 let (line, col) = line_col(source, m.in_offset);
277 let message = if m.is_not_in {
278 "Use IS NOT NULL instead of NOT IN (NULL)".to_string()
279 } else {
280 "Use IS NULL instead of IN (NULL)".to_string()
281 };
282 Diagnostic {
283 rule: self.name(),
284 message,
285 line,
286 col,
287 }
288 })
289 .collect()
290 }
291}