Skip to main content

sqrust_rules/convention/
in_null_comparison.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct InNullComparison;
4
5/// Returns `true` if `bytes[offset..]` starts with `pattern` case-insensitively,
6/// with a word boundary before and after.
7fn keyword_at_boundary(bytes: &[u8], offset: usize, pattern: &[u8]) -> bool {
8    let end = offset + pattern.len();
9    if end > bytes.len() {
10        return false;
11    }
12    // Word boundary before: preceding char must not be alphanumeric or underscore
13    if offset > 0 && (bytes[offset - 1].is_ascii_alphanumeric() || bytes[offset - 1] == b'_') {
14        return false;
15    }
16    // Case-insensitive match
17    let matches = bytes[offset..end]
18        .iter()
19        .zip(pattern.iter())
20        .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
21    if !matches {
22        return false;
23    }
24    // Word boundary after
25    if end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
26        return false;
27    }
28    true
29}
30
31/// Converts a byte offset to a 1-indexed (line, col) pair.
32fn line_col(source: &str, offset: usize) -> (usize, usize) {
33    let before = &source[..offset];
34    let line = before.chars().filter(|&c| c == '\n').count() + 1;
35    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
36    (line, col)
37}
38
39/// Builds a skip table: `true` at every byte inside strings, comments, or
40/// quoted identifiers.
41fn build_skip(bytes: &[u8]) -> Vec<bool> {
42    let len = bytes.len();
43    let mut skip = vec![false; len];
44    let mut i = 0;
45
46    while i < len {
47        // Line comment: -- ... newline
48        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
49            skip[i] = true;
50            skip[i + 1] = true;
51            i += 2;
52            while i < len && bytes[i] != b'\n' {
53                skip[i] = true;
54                i += 1;
55            }
56            continue;
57        }
58
59        // Block comment: /* ... */
60        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
61            skip[i] = true;
62            skip[i + 1] = true;
63            i += 2;
64            while i < len {
65                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
66                    skip[i] = true;
67                    skip[i + 1] = true;
68                    i += 2;
69                    break;
70                }
71                skip[i] = true;
72                i += 1;
73            }
74            continue;
75        }
76
77        // Single-quoted string: '...' with '' escape
78        if bytes[i] == b'\'' {
79            skip[i] = true;
80            i += 1;
81            while i < len {
82                if bytes[i] == b'\'' {
83                    skip[i] = true;
84                    i += 1;
85                    if i < len && bytes[i] == b'\'' {
86                        skip[i] = true;
87                        i += 1;
88                        continue;
89                    }
90                    break;
91                }
92                skip[i] = true;
93                i += 1;
94            }
95            continue;
96        }
97
98        // Double-quoted identifier: "..."
99        if bytes[i] == b'"' {
100            skip[i] = true;
101            i += 1;
102            while i < len && bytes[i] != b'"' {
103                skip[i] = true;
104                i += 1;
105            }
106            if i < len {
107                skip[i] = true;
108                i += 1;
109            }
110            continue;
111        }
112
113        // Backtick identifier: `...`
114        if bytes[i] == b'`' {
115            skip[i] = true;
116            i += 1;
117            while i < len && bytes[i] != b'`' {
118                skip[i] = true;
119                i += 1;
120            }
121            if i < len {
122                skip[i] = true;
123                i += 1;
124            }
125            continue;
126        }
127
128        i += 1;
129    }
130
131    skip
132}
133
134/// Describes one `IN (NULL)` or `NOT IN (NULL)` match.
135struct Match {
136    /// Byte offset of the `IN` keyword.
137    in_offset: usize,
138    /// Whether this was `NOT IN`.
139    is_not_in: bool,
140}
141
142/// Scans `source` for `IN (NULL)` and `NOT IN (NULL)` patterns outside
143/// strings/comments.
144fn find_matches(source: &str, skip: &[bool]) -> Vec<Match> {
145    let bytes = source.as_bytes();
146    let len = bytes.len();
147    let mut matches = Vec::new();
148    let mut i = 0;
149
150    while i < len {
151        if skip[i] {
152            i += 1;
153            continue;
154        }
155
156        // Try to match `NOT` keyword, then check for `IN` after it.
157        // Also try plain `IN`.
158        let is_not_in = keyword_at_boundary(bytes, i, b"NOT") && !skip[i];
159        if is_not_in {
160            // Skip NOT + whitespace to find IN
161            let mut j = i + 3; // past "NOT"
162            while j < len && bytes[j].is_ascii_whitespace() {
163                j += 1;
164            }
165            if j < len && !skip[j] && keyword_at_boundary(bytes, j, b"IN") {
166                let in_offset = j;
167                // Skip past IN + whitespace
168                let mut k = j + 2;
169                while k < len && bytes[k].is_ascii_whitespace() {
170                    k += 1;
171                }
172                if k < len && bytes[k] == b'(' && !skip[k] {
173                    if let Some(m) = check_paren_null(bytes, skip, k) {
174                        if m {
175                            matches.push(Match { in_offset, is_not_in: true });
176                            i = k + 1;
177                            continue;
178                        }
179                    }
180                }
181            }
182        }
183
184        // Try plain `IN` (word boundary, outside skip)
185        if !skip[i] && keyword_at_boundary(bytes, i, b"IN") {
186            let in_offset = i;
187            let mut j = i + 2; // past "IN"
188            while j < len && bytes[j].is_ascii_whitespace() {
189                j += 1;
190            }
191            if j < len && bytes[j] == b'(' && !skip[j] {
192                if let Some(m) = check_paren_null(bytes, skip, j) {
193                    if m {
194                        matches.push(Match { in_offset, is_not_in: false });
195                        i = j + 1;
196                        continue;
197                    }
198                }
199            }
200        }
201
202        i += 1;
203    }
204
205    matches
206}
207
208/// Given the position of `(`, checks whether the content between `(` and `)` is
209/// exactly `NULL` (case-insensitive, possibly surrounded by whitespace), with no
210/// other tokens. Returns `Some(true)` if it is, `Some(false)` if it is not, and
211/// `None` if no closing `)` was found.
212fn check_paren_null(bytes: &[u8], skip: &[bool], open_paren: usize) -> Option<bool> {
213    let len = bytes.len();
214    let mut i = open_paren + 1; // step past '('
215
216    // Skip leading whitespace
217    while i < len && bytes[i].is_ascii_whitespace() {
218        i += 1;
219    }
220
221    // Expect exactly `NULL`
222    if i + 4 > len {
223        return Some(false);
224    }
225    let null_start = i;
226    // Case-insensitive NULL check
227    let is_null = bytes[null_start..null_start + 4]
228        .iter()
229        .zip(b"NULL".iter())
230        .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
231    if !is_null {
232        return Some(false);
233    }
234
235    // Any of the NULL bytes in a skipped region means it's inside a string —
236    // bail out.
237    for k in null_start..null_start + 4 {
238        if skip[k] {
239            return Some(false);
240        }
241    }
242
243    i = null_start + 4;
244
245    // Word boundary: after NULL must not be alphanumeric or underscore
246    if i < len && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
247        return Some(false);
248    }
249
250    // Skip trailing whitespace
251    while i < len && bytes[i].is_ascii_whitespace() {
252        i += 1;
253    }
254
255    // Next must be `)`
256    if i < len && bytes[i] == b')' && !skip[i] {
257        Some(true)
258    } else {
259        Some(false)
260    }
261}
262
263impl Rule for InNullComparison {
264    fn name(&self) -> &'static str {
265        "Convention/InNullComparison"
266    }
267
268    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
269        let source = &ctx.source;
270        let skip = build_skip(source.as_bytes());
271        let matches = find_matches(source, &skip);
272
273        matches
274            .into_iter()
275            .map(|m| {
276                let (line, col) = line_col(source, m.in_offset);
277                let message = if m.is_not_in {
278                    "Use IS NOT NULL instead of NOT IN (NULL)".to_string()
279                } else {
280                    "Use IS NULL instead of IN (NULL)".to_string()
281                };
282                Diagnostic {
283                    rule: self.name(),
284                    message,
285                    line,
286                    col,
287                }
288            })
289            .collect()
290    }
291}