Skip to main content

sqrust_rules/convention/
is_null.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct IsNull;
4
5/// Returns `true` if `source[offset..]` starts with `pattern`,
6/// compared case-insensitively for ASCII characters.
7fn starts_with_ci(source: &[u8], offset: usize, pattern: &[u8]) -> bool {
8    let end = offset + pattern.len();
9    if end > source.len() {
10        return false;
11    }
12    source[offset..end]
13        .iter()
14        .zip(pattern.iter())
15        .all(|(&a, &b)| a.eq_ignore_ascii_case(&b))
16}
17
18/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
19fn line_col(source: &str, offset: usize) -> (usize, usize) {
20    let before = &source[..offset];
21    let line = before.chars().filter(|&c| c == '\n').count() + 1;
22    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
23    (line, col)
24}
25
26/// Builds a skip table: each entry is `true` if the byte at that offset is
27/// inside a string literal, line comment, block comment, or quoted identifier.
28fn build_skip(source: &[u8]) -> Vec<bool> {
29    let len = source.len();
30    let mut skip = vec![false; len];
31    let mut i = 0;
32
33    while i < len {
34        // Line comment: -- ... end-of-line
35        if i + 1 < len && source[i] == b'-' && source[i + 1] == b'-' {
36            skip[i] = true;
37            skip[i + 1] = true;
38            i += 2;
39            while i < len && source[i] != b'\n' {
40                skip[i] = true;
41                i += 1;
42            }
43            continue;
44        }
45
46        // Block comment: /* ... */
47        if i + 1 < len && source[i] == b'/' && source[i + 1] == b'*' {
48            skip[i] = true;
49            skip[i + 1] = true;
50            i += 2;
51            while i < len {
52                if i + 1 < len && source[i] == b'*' && source[i + 1] == b'/' {
53                    skip[i] = true;
54                    skip[i + 1] = true;
55                    i += 2;
56                    break;
57                }
58                skip[i] = true;
59                i += 1;
60            }
61            continue;
62        }
63
64        // Single-quoted string: '...' with '' as escaped quote
65        if source[i] == b'\'' {
66            skip[i] = true;
67            i += 1;
68            while i < len {
69                if source[i] == b'\'' {
70                    skip[i] = true;
71                    i += 1;
72                    if i < len && source[i] == b'\'' {
73                        skip[i] = true;
74                        i += 1;
75                        continue;
76                    }
77                    break;
78                }
79                skip[i] = true;
80                i += 1;
81            }
82            continue;
83        }
84
85        // Double-quoted identifier: "..."
86        if source[i] == b'"' {
87            skip[i] = true;
88            i += 1;
89            while i < len && source[i] != b'"' {
90                skip[i] = true;
91                i += 1;
92            }
93            if i < len {
94                skip[i] = true;
95                i += 1;
96            }
97            continue;
98        }
99
100        // Backtick identifier: `...`
101        if source[i] == b'`' {
102            skip[i] = true;
103            i += 1;
104            while i < len && source[i] != b'`' {
105                skip[i] = true;
106                i += 1;
107            }
108            if i < len {
109                skip[i] = true;
110                i += 1;
111            }
112            continue;
113        }
114
115        i += 1;
116    }
117
118    skip
119}
120
121/// A detected null-comparison pattern with metadata for diagnostics and fix.
122struct NullMatch {
123    /// Byte offset of the operator (`=`, `<>`, or `!=`)
124    op_offset: usize,
125    /// Length of the full matched span (operator + spaces + NULL)
126    full_len: usize,
127    /// Replacement text (e.g. `IS NULL` or `IS NOT NULL`)
128    replacement: &'static str,
129    /// Diagnostic message
130    message: &'static str,
131}
132
133/// Tries to match `NULL` (case-insensitive) after skipping whitespace from
134/// byte index `after_op`. Returns `Some(NullMatch)` on success.
135fn try_match_null(
136    bytes: &[u8],
137    skip: &[bool],
138    op_start: usize,
139    op_len: usize,
140    replacement: &'static str,
141    message: &'static str,
142) -> Option<NullMatch> {
143    let len = bytes.len();
144    let mut j = op_start + op_len;
145
146    // Require at least one whitespace between operator and NULL
147    if j >= len || !bytes[j].is_ascii_whitespace() {
148        return None;
149    }
150
151    // Skip whitespace — all must be code (outside strings/comments)
152    while j < len && bytes[j].is_ascii_whitespace() {
153        if skip[j] {
154            return None;
155        }
156        j += 1;
157    }
158
159    // Check for NULL (case-insensitive) as a word boundary
160    if j + 4 > len {
161        return None;
162    }
163    if !starts_with_ci(bytes, j, b"NULL") {
164        return None;
165    }
166
167    // Make sure the 4 NULL bytes are all outside strings/comments
168    for k in j..j + 4 {
169        if skip[k] {
170            return None;
171        }
172    }
173
174    // Word boundary after NULL: must not be followed by `[a-zA-Z0-9_]`
175    let end = j + 4;
176    if end < len && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
177        return None;
178    }
179
180    Some(NullMatch {
181        op_offset: op_start,
182        full_len: end - op_start,
183        replacement,
184        message,
185    })
186}
187
188/// Scans `source` for `= NULL`, `<> NULL`, `!= NULL` outside strings/comments.
189fn find_null_matches(source: &str, skip: &[bool]) -> Vec<NullMatch> {
190    let bytes = source.as_bytes();
191    let len = bytes.len();
192    let mut matches = Vec::new();
193    let mut i = 0;
194
195    while i < len {
196        if skip[i] {
197            i += 1;
198            continue;
199        }
200
201        // `= NULL` — but not when preceded by `!` (!=) or `<` (<>)
202        if bytes[i] == b'=' {
203            let preceded_by_bang = i > 0 && bytes[i - 1] == b'!';
204            let preceded_by_lt = i > 0 && bytes[i - 1] == b'<';
205            if !preceded_by_bang && !preceded_by_lt {
206                if let Some(m) = try_match_null(
207                    bytes,
208                    skip,
209                    i,
210                    1,
211                    "IS NULL",
212                    "Use IS NULL instead of = NULL",
213                ) {
214                    matches.push(m);
215                    i += 1;
216                    continue;
217                }
218            }
219        }
220
221        // `<> NULL`
222        if bytes[i] == b'<' && i + 1 < len && bytes[i + 1] == b'>' {
223            if let Some(m) = try_match_null(
224                bytes,
225                skip,
226                i,
227                2,
228                "IS NOT NULL",
229                "Use IS NOT NULL instead of <> NULL",
230            ) {
231                matches.push(m);
232                i += 2;
233                continue;
234            }
235        }
236
237        // `!= NULL`
238        if bytes[i] == b'!' && i + 1 < len && bytes[i + 1] == b'=' {
239            if let Some(m) = try_match_null(
240                bytes,
241                skip,
242                i,
243                2,
244                "IS NOT NULL",
245                "Use IS NOT NULL instead of != NULL",
246            ) {
247                matches.push(m);
248                i += 2;
249                continue;
250            }
251        }
252
253        i += 1;
254    }
255
256    matches
257}
258
259impl Rule for IsNull {
260    fn name(&self) -> &'static str {
261        "Convention/IsNull"
262    }
263
264    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
265        let source = &ctx.source;
266        let skip = build_skip(source.as_bytes());
267        let matches = find_null_matches(source, &skip);
268
269        matches
270            .into_iter()
271            .map(|m| {
272                let (line, col) = line_col(source, m.op_offset);
273                Diagnostic {
274                    rule: self.name(),
275                    message: m.message.to_string(),
276                    line,
277                    col,
278                }
279            })
280            .collect()
281    }
282
283    fn fix(&self, ctx: &FileContext) -> Option<String> {
284        let source = &ctx.source;
285        let skip = build_skip(source.as_bytes());
286        let matches = find_null_matches(source, &skip);
287
288        if matches.is_empty() {
289            return None;
290        }
291
292        // Apply replacements in reverse order to keep earlier byte offsets valid
293        let mut result = source.clone();
294        for m in matches.into_iter().rev() {
295            result.replace_range(m.op_offset..m.op_offset + m.full_len, m.replacement);
296        }
297
298        Some(result)
299    }
300}