sqrust_rules/convention/
trailing_comma.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct TrailingComma;
4
5fn keyword_at(bytes: &[u8], offset: usize, pattern: &[u8]) -> bool {
8 let end = offset + pattern.len();
9 if end > bytes.len() {
10 return false;
11 }
12 let matches = bytes[offset..end]
13 .iter()
14 .zip(pattern.iter())
15 .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
16 if !matches {
17 return false;
18 }
19 if end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') {
21 return false;
22 }
23 true
24}
25
26const TERMINATORS: &[&[u8]] = &[
29 b"INTERSECT",
30 b"EXCEPT",
31 b"HAVING",
32 b"UNION",
33 b"GROUP",
34 b"ORDER",
35 b"WHERE",
36 b"LIMIT",
37 b"FROM",
38];
39
40fn line_col(source: &str, offset: usize) -> (usize, usize) {
42 let before = &source[..offset];
43 let line = before.chars().filter(|&c| c == '\n').count() + 1;
44 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
45 (line, col)
46}
47
48fn build_skip(bytes: &[u8]) -> Vec<bool> {
51 let len = bytes.len();
52 let mut skip = vec![false; len];
53 let mut i = 0;
54
55 while i < len {
56 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
58 skip[i] = true;
59 skip[i + 1] = true;
60 i += 2;
61 while i < len && bytes[i] != b'\n' {
62 skip[i] = true;
63 i += 1;
64 }
65 continue;
66 }
67
68 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
70 skip[i] = true;
71 skip[i + 1] = true;
72 i += 2;
73 while i < len {
74 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
75 skip[i] = true;
76 skip[i + 1] = true;
77 i += 2;
78 break;
79 }
80 skip[i] = true;
81 i += 1;
82 }
83 continue;
84 }
85
86 if bytes[i] == b'\'' {
88 skip[i] = true;
89 i += 1;
90 while i < len {
91 if bytes[i] == b'\'' {
92 skip[i] = true;
93 i += 1;
94 if i < len && bytes[i] == b'\'' {
95 skip[i] = true;
96 i += 1;
97 continue;
98 }
99 break;
100 }
101 skip[i] = true;
102 i += 1;
103 }
104 continue;
105 }
106
107 if bytes[i] == b'"' {
109 skip[i] = true;
110 i += 1;
111 while i < len && bytes[i] != b'"' {
112 skip[i] = true;
113 i += 1;
114 }
115 if i < len {
116 skip[i] = true;
117 i += 1;
118 }
119 continue;
120 }
121
122 if bytes[i] == b'`' {
124 skip[i] = true;
125 i += 1;
126 while i < len && bytes[i] != b'`' {
127 skip[i] = true;
128 i += 1;
129 }
130 if i < len {
131 skip[i] = true;
132 i += 1;
133 }
134 continue;
135 }
136
137 i += 1;
138 }
139
140 skip
141}
142
143fn find_trailing_commas(source: &str, skip: &[bool]) -> Vec<usize> {
146 let bytes = source.as_bytes();
147 let len = bytes.len();
148 let mut positions = Vec::new();
149 let mut i = 0;
150
151 while i < len {
152 if skip[i] {
153 i += 1;
154 continue;
155 }
156
157 if bytes[i] == b',' {
158 let comma_offset = i;
160 let mut j = i + 1;
161 while j < len && bytes[j].is_ascii_whitespace() {
162 j += 1;
163 }
164
165 for &kw in TERMINATORS {
167 if keyword_at(bytes, j, kw) {
168 positions.push(comma_offset);
169 break;
170 }
171 }
172 }
173
174 i += 1;
175 }
176
177 positions
178}
179
180impl Rule for TrailingComma {
181 fn name(&self) -> &'static str {
182 "Convention/TrailingComma"
183 }
184
185 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
186 let source = &ctx.source;
187 let skip = build_skip(source.as_bytes());
188 let positions = find_trailing_commas(source, &skip);
189
190 positions
191 .into_iter()
192 .map(|offset| {
193 let (line, col) = line_col(source, offset);
194 Diagnostic {
195 rule: self.name(),
196 message: "Trailing comma before SQL keyword".to_string(),
197 line,
198 col,
199 }
200 })
201 .collect()
202 }
203
204 fn fix(&self, ctx: &FileContext) -> Option<String> {
205 let source = &ctx.source;
206 let skip = build_skip(source.as_bytes());
207 let positions = find_trailing_commas(source, &skip);
208
209 if positions.is_empty() {
210 return None;
211 }
212
213 let mut result = source.clone();
215 for offset in positions.into_iter().rev() {
216 result.remove(offset);
217 }
218
219 Some(result)
220 }
221}