Skip to main content

sqrust_rules/convention/
boolean_comparison.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct BooleanComparison;
4
5/// Converts a byte offset to a 1-indexed (line, col) pair.
6fn line_col(source: &str, offset: usize) -> (usize, usize) {
7    let before = &source[..offset];
8    let line = before.chars().filter(|&c| c == '\n').count() + 1;
9    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
10    (line, col)
11}
12
13/// Builds a skip table: `true` at every byte inside strings, comments, or
14/// quoted identifiers.
15fn build_skip(bytes: &[u8]) -> Vec<bool> {
16    let len = bytes.len();
17    let mut skip = vec![false; len];
18    let mut i = 0;
19
20    while i < len {
21        // Line comment: -- ... newline
22        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
23            skip[i] = true;
24            skip[i + 1] = true;
25            i += 2;
26            while i < len && bytes[i] != b'\n' {
27                skip[i] = true;
28                i += 1;
29            }
30            continue;
31        }
32
33        // Block comment: /* ... */
34        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
35            skip[i] = true;
36            skip[i + 1] = true;
37            i += 2;
38            while i < len {
39                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
40                    skip[i] = true;
41                    skip[i + 1] = true;
42                    i += 2;
43                    break;
44                }
45                skip[i] = true;
46                i += 1;
47            }
48            continue;
49        }
50
51        // Single-quoted string: '...' with '' escape
52        if bytes[i] == b'\'' {
53            skip[i] = true;
54            i += 1;
55            while i < len {
56                if bytes[i] == b'\'' {
57                    skip[i] = true;
58                    i += 1;
59                    if i < len && bytes[i] == b'\'' {
60                        skip[i] = true;
61                        i += 1;
62                        continue;
63                    }
64                    break;
65                }
66                skip[i] = true;
67                i += 1;
68            }
69            continue;
70        }
71
72        // Double-quoted identifier: "..."
73        if bytes[i] == b'"' {
74            skip[i] = true;
75            i += 1;
76            while i < len && bytes[i] != b'"' {
77                skip[i] = true;
78                i += 1;
79            }
80            if i < len {
81                skip[i] = true;
82                i += 1;
83            }
84            continue;
85        }
86
87        // Backtick identifier: `...`
88        if bytes[i] == b'`' {
89            skip[i] = true;
90            i += 1;
91            while i < len && bytes[i] != b'`' {
92                skip[i] = true;
93                i += 1;
94            }
95            if i < len {
96                skip[i] = true;
97                i += 1;
98            }
99            continue;
100        }
101
102        i += 1;
103    }
104
105    skip
106}
107
108/// Checks whether `bytes[offset..]` starts with `pattern` case-insensitively,
109/// followed by a word boundary (end of input or non-alphanumeric/non-underscore).
110fn bool_keyword_at(bytes: &[u8], offset: usize, pattern: &[u8]) -> bool {
111    let end = offset + pattern.len();
112    if end > bytes.len() {
113        return false;
114    }
115    let matches = bytes[offset..end]
116        .iter()
117        .zip(pattern.iter())
118        .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
119    if !matches {
120        return false;
121    }
122    // Word boundary after the keyword
123    if end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
124        return false;
125    }
126    true
127}
128
129/// Scans `source` for `= TRUE`, `= FALSE`, `<> TRUE`, `<> FALSE`,
130/// `!= TRUE`, `!= FALSE` patterns outside strings/comments.
131/// Returns the byte offset of the operator (`=`, `<>`, `!=`).
132fn find_boolean_comparisons(source: &str, skip: &[bool]) -> Vec<usize> {
133    let bytes = source.as_bytes();
134    let len = bytes.len();
135    let mut results = Vec::new();
136    let mut i = 0;
137
138    while i < len {
139        if skip[i] {
140            i += 1;
141            continue;
142        }
143
144        // Try to match `!=` or `<>` or `=` operators
145        let (op_len, is_op) = if i + 1 < len && bytes[i] == b'!' && bytes[i + 1] == b'=' {
146            (2, true)
147        } else if i + 1 < len && bytes[i] == b'<' && bytes[i + 1] == b'>' {
148            (2, true)
149        } else if bytes[i] == b'=' {
150            (1, true)
151        } else {
152            (0, false)
153        };
154
155        if is_op {
156            let op_offset = i;
157            // Advance past the operator
158            let mut j = i + op_len;
159            // Skip whitespace after operator
160            while j < len && (bytes[j] == b' ' || bytes[j] == b'\t' || bytes[j] == b'\n' || bytes[j] == b'\r') {
161                j += 1;
162            }
163            // Check for TRUE or FALSE (case-insensitive, word boundary)
164            if j < len && !skip[j] {
165                if bool_keyword_at(bytes, j, b"TRUE") || bool_keyword_at(bytes, j, b"FALSE") {
166                    results.push(op_offset);
167                }
168            }
169            i += op_len;
170            continue;
171        }
172
173        i += 1;
174    }
175
176    results
177}
178
179impl Rule for BooleanComparison {
180    fn name(&self) -> &'static str {
181        "Convention/BooleanComparison"
182    }
183
184    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
185        if !ctx.parse_errors.is_empty() {
186            return Vec::new();
187        }
188
189        let source = &ctx.source;
190        let bytes = source.as_bytes();
191        let skip = build_skip(bytes);
192        let offsets = find_boolean_comparisons(source, &skip);
193
194        offsets
195            .into_iter()
196            .map(|op_offset| {
197                let (line, col) = line_col(source, op_offset);
198                Diagnostic {
199                    rule: self.name(),
200                    message: "Explicit comparison with boolean literal; use the expression directly"
201                        .to_string(),
202                    line,
203                    col,
204                }
205            })
206            .collect()
207    }
208}