Skip to main content

sqrust_rules/convention/
trailing_comma.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct TrailingComma;
4
5/// Returns `true` if `source[offset..]` starts with `pattern` case-insensitively,
6/// and the character after the pattern is a word boundary (not `[a-zA-Z0-9_]`).
7fn keyword_at(bytes: &[u8], offset: usize, pattern: &[u8]) -> bool {
8    let end = offset + pattern.len();
9    if end > bytes.len() {
10        return false;
11    }
12    let matches = bytes[offset..end]
13        .iter()
14        .zip(pattern.iter())
15        .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
16    if !matches {
17        return false;
18    }
19    // word boundary: next char must not be alphanumeric or underscore
20    if end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
21        return false;
22    }
23    true
24}
25
26/// The SQL keywords that terminate a SELECT list. Checked in order; longest
27/// first so `INTERSECT` is tried before `IN` (though `IN` is not in this list).
28const TERMINATORS: &[&[u8]] = &[
29    b"INTERSECT",
30    b"EXCEPT",
31    b"HAVING",
32    b"UNION",
33    b"GROUP",
34    b"ORDER",
35    b"WHERE",
36    b"LIMIT",
37    b"FROM",
38];
39
40/// Converts a byte offset to a 1-indexed (line, col) pair.
41fn line_col(source: &str, offset: usize) -> (usize, usize) {
42    let before = &source[..offset];
43    let line = before.chars().filter(|&c| c == '\n').count() + 1;
44    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
45    (line, col)
46}
47
48/// Builds a skip table: `true` at every byte that is inside a string literal,
49/// line comment, block comment, or quoted identifier.
50fn build_skip(bytes: &[u8]) -> Vec<bool> {
51    let len = bytes.len();
52    let mut skip = vec![false; len];
53    let mut i = 0;
54
55    while i < len {
56        // Line comment: -- ... newline
57        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
58            skip[i] = true;
59            skip[i + 1] = true;
60            i += 2;
61            while i < len && bytes[i] != b'\n' {
62                skip[i] = true;
63                i += 1;
64            }
65            continue;
66        }
67
68        // Block comment: /* ... */
69        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
70            skip[i] = true;
71            skip[i + 1] = true;
72            i += 2;
73            while i < len {
74                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
75                    skip[i] = true;
76                    skip[i + 1] = true;
77                    i += 2;
78                    break;
79                }
80                skip[i] = true;
81                i += 1;
82            }
83            continue;
84        }
85
86        // Single-quoted string: '...' with '' escape
87        if bytes[i] == b'\'' {
88            skip[i] = true;
89            i += 1;
90            while i < len {
91                if bytes[i] == b'\'' {
92                    skip[i] = true;
93                    i += 1;
94                    if i < len && bytes[i] == b'\'' {
95                        skip[i] = true;
96                        i += 1;
97                        continue;
98                    }
99                    break;
100                }
101                skip[i] = true;
102                i += 1;
103            }
104            continue;
105        }
106
107        // Double-quoted identifier: "..."
108        if bytes[i] == b'"' {
109            skip[i] = true;
110            i += 1;
111            while i < len && bytes[i] != b'"' {
112                skip[i] = true;
113                i += 1;
114            }
115            if i < len {
116                skip[i] = true;
117                i += 1;
118            }
119            continue;
120        }
121
122        // Backtick identifier: `...`
123        if bytes[i] == b'`' {
124            skip[i] = true;
125            i += 1;
126            while i < len && bytes[i] != b'`' {
127                skip[i] = true;
128                i += 1;
129            }
130            if i < len {
131                skip[i] = true;
132                i += 1;
133            }
134            continue;
135        }
136
137        i += 1;
138    }
139
140    skip
141}
142
143/// Finds all trailing-comma byte offsets: a `,` (outside strings/comments) whose
144/// next non-whitespace token starts with a terminator keyword.
145fn find_trailing_commas(source: &str, skip: &[bool]) -> Vec<usize> {
146    let bytes = source.as_bytes();
147    let len = bytes.len();
148    let mut positions = Vec::new();
149    let mut i = 0;
150
151    while i < len {
152        if skip[i] {
153            i += 1;
154            continue;
155        }
156
157        if bytes[i] == b',' {
158            // Look ahead past whitespace
159            let comma_offset = i;
160            let mut j = i + 1;
161            while j < len && bytes[j].is_ascii_whitespace() {
162                j += 1;
163            }
164
165            // Check if next non-whitespace is a terminator keyword
166            for &kw in TERMINATORS {
167                if keyword_at(bytes, j, kw) {
168                    positions.push(comma_offset);
169                    break;
170                }
171            }
172        }
173
174        i += 1;
175    }
176
177    positions
178}
179
180impl Rule for TrailingComma {
181    fn name(&self) -> &'static str {
182        "Convention/TrailingComma"
183    }
184
185    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
186        let source = &ctx.source;
187        let skip = build_skip(source.as_bytes());
188        let positions = find_trailing_commas(source, &skip);
189
190        positions
191            .into_iter()
192            .map(|offset| {
193                let (line, col) = line_col(source, offset);
194                Diagnostic {
195                    rule: self.name(),
196                    message: "Trailing comma before SQL keyword".to_string(),
197                    line,
198                    col,
199                }
200            })
201            .collect()
202    }
203
204    fn fix(&self, ctx: &FileContext) -> Option<String> {
205        let source = &ctx.source;
206        let skip = build_skip(source.as_bytes());
207        let positions = find_trailing_commas(source, &skip);
208
209        if positions.is_empty() {
210            return None;
211        }
212
213        // Remove commas in reverse order so earlier offsets stay valid
214        let mut result = source.clone();
215        for offset in positions.into_iter().rev() {
216            result.remove(offset);
217        }
218
219        Some(result)
220    }
221}