Skip to main content

sqrust_rules/convention/
not_equal.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct NotEqual;
4
5/// Iterates over characters in `source` and calls `visitor` for each character
6/// that is NOT inside a single-quoted string, double-quoted identifier,
7/// line comment (`--`), or block comment (`/* ... */`).
8///
9/// `visitor` receives (byte_offset, char, 1-indexed line, 1-indexed col).
10fn visit_outside_tokens<F>(source: &str, mut visitor: F)
11where
12    F: FnMut(usize, char, usize, usize),
13{
14    let chars: Vec<char> = source.chars().collect();
15    let len = chars.len();
16    let mut i = 0;
17    let mut line = 1usize;
18    let mut col = 1usize;
19    // Compute byte offsets per char position
20    let byte_offsets: Vec<usize> = {
21        let mut offs = Vec::with_capacity(len);
22        let mut off = 0;
23        for ch in &chars {
24            offs.push(off);
25            off += ch.len_utf8();
26        }
27        offs
28    };
29
30    while i < len {
31        let ch = chars[i];
32
33        // Line comment: -- to end of line
34        if ch == '-' && i + 1 < len && chars[i + 1] == '-' {
35            // skip to end of line
36            while i < len && chars[i] != '\n' {
37                if chars[i] == '\n' {
38                    line += 1;
39                    col = 1;
40                } else {
41                    col += 1;
42                }
43                i += 1;
44            }
45            continue;
46        }
47
48        // Block comment: /* ... */
49        if ch == '/' && i + 1 < len && chars[i + 1] == '*' {
50            i += 2;
51            col += 2;
52            while i < len {
53                if chars[i] == '\n' {
54                    line += 1;
55                    col = 1;
56                    i += 1;
57                } else if chars[i] == '*' && i + 1 < len && chars[i + 1] == '/' {
58                    i += 2;
59                    col += 2;
60                    break;
61                } else {
62                    col += 1;
63                    i += 1;
64                }
65            }
66            continue;
67        }
68
69        // Single-quoted string: '...' ('' is escape)
70        if ch == '\'' {
71            col += 1;
72            i += 1;
73            while i < len {
74                if chars[i] == '\'' {
75                    // check for escaped quote ''
76                    if i + 1 < len && chars[i + 1] == '\'' {
77                        col += 2;
78                        i += 2;
79                    } else {
80                        col += 1;
81                        i += 1;
82                        break;
83                    }
84                } else if chars[i] == '\n' {
85                    line += 1;
86                    col = 1;
87                    i += 1;
88                } else {
89                    col += 1;
90                    i += 1;
91                }
92            }
93            continue;
94        }
95
96        // Double-quoted identifier: "..."
97        if ch == '"' {
98            col += 1;
99            i += 1;
100            while i < len {
101                if chars[i] == '"' {
102                    col += 1;
103                    i += 1;
104                    break;
105                } else if chars[i] == '\n' {
106                    line += 1;
107                    col = 1;
108                    i += 1;
109                } else {
110                    col += 1;
111                    i += 1;
112                }
113            }
114            continue;
115        }
116
117        // Normal character — call visitor
118        let byte_off = byte_offsets[i];
119        visitor(byte_off, ch, line, col);
120
121        if ch == '\n' {
122            line += 1;
123            col = 1;
124        } else {
125            col += 1;
126        }
127        i += 1;
128    }
129}
130
131impl Rule for NotEqual {
132    fn name(&self) -> &'static str {
133        "Convention/NotEqual"
134    }
135
136    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
137        let mut diags = Vec::new();
138        let source = &ctx.source;
139
140        // Collect positions of '!' that are outside strings/comments and followed by '='
141        visit_outside_tokens(source, |byte_off, ch, line, col| {
142            if ch == '!' {
143                // Check the next byte character
144                let rest = &source[byte_off..];
145                if rest.starts_with("!=") {
146                    diags.push(Diagnostic {
147                        rule: "Convention/NotEqual",
148                        message: "Use '<>' instead of '!=' for ANSI SQL compatibility".to_string(),
149                        line,
150                        col,
151                    });
152                }
153            }
154        });
155
156        diags
157    }
158
159    fn fix(&self, ctx: &FileContext) -> Option<String> {
160        let source = &ctx.source;
161        let chars: Vec<char> = source.chars().collect();
162        let len = chars.len();
163
164        // Build a set of byte offsets that are inside strings/comments
165        let mut skip_ranges: Vec<(usize, usize)> = Vec::new();
166
167        let byte_offsets: Vec<usize> = {
168            let mut offs = Vec::with_capacity(len);
169            let mut off = 0;
170            for ch in &chars {
171                offs.push(off);
172                off += ch.len_utf8();
173            }
174            offs
175        };
176
177        let mut i = 0;
178        while i < len {
179            let ch = chars[i];
180
181            if ch == '-' && i + 1 < len && chars[i + 1] == '-' {
182                let start = byte_offsets[i];
183                while i < len && chars[i] != '\n' {
184                    i += 1;
185                }
186                let end = if i < len { byte_offsets[i] } else { source.len() };
187                skip_ranges.push((start, end));
188                continue;
189            }
190
191            if ch == '/' && i + 1 < len && chars[i + 1] == '*' {
192                let start = byte_offsets[i];
193                i += 2;
194                while i < len {
195                    if chars[i] == '*' && i + 1 < len && chars[i + 1] == '/' {
196                        i += 2;
197                        break;
198                    }
199                    i += 1;
200                }
201                let end = if i < len { byte_offsets[i] } else { source.len() };
202                skip_ranges.push((start, end));
203                continue;
204            }
205
206            if ch == '\'' {
207                let start = byte_offsets[i];
208                i += 1;
209                while i < len {
210                    if chars[i] == '\'' {
211                        if i + 1 < len && chars[i + 1] == '\'' {
212                            i += 2;
213                        } else {
214                            i += 1;
215                            break;
216                        }
217                    } else {
218                        i += 1;
219                    }
220                }
221                let end = if i < len { byte_offsets[i] } else { source.len() };
222                skip_ranges.push((start, end));
223                continue;
224            }
225
226            if ch == '"' {
227                let start = byte_offsets[i];
228                i += 1;
229                while i < len {
230                    if chars[i] == '"' {
231                        i += 1;
232                        break;
233                    }
234                    i += 1;
235                }
236                let end = if i < len { byte_offsets[i] } else { source.len() };
237                skip_ranges.push((start, end));
238                continue;
239            }
240
241            i += 1;
242        }
243
244        // Replace `!=` occurrences that are NOT inside any skip range
245        let source_bytes = source.as_bytes();
246        let src_len = source.len();
247        let mut result = String::with_capacity(src_len);
248        let mut pos = 0;
249
250        while pos < src_len {
251            // Check if this position is inside a skip range
252            let in_skip = skip_ranges.iter().any(|&(s, e)| pos >= s && pos < e);
253            if in_skip {
254                // Find the range end and copy verbatim
255                let range_end = skip_ranges
256                    .iter()
257                    .filter(|&&(s, _)| pos >= s)
258                    .map(|&(_, e)| e)
259                    .min()
260                    .unwrap_or(pos + 1);
261                result.push_str(&source[pos..range_end]);
262                pos = range_end;
263                continue;
264            }
265
266            if pos + 1 < src_len && source_bytes[pos] == b'!' && source_bytes[pos + 1] == b'=' {
267                result.push_str("<>");
268                pos += 2;
269            } else {
270                let ch = source[pos..].chars().next().unwrap();
271                result.push(ch);
272                pos += ch.len_utf8();
273            }
274        }
275
276        Some(result)
277    }
278}