sqrust_rules/structure/
aggregate_star.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct AggregateStar;
7
8const FLAGGED_AGGREGATES: &[&str] = &[
11 "sum", "avg", "min", "max", "stddev", "stddev_pop", "stddev_samp", "variance", "var_pop",
12 "var_samp", "median",
13];
14
15impl Rule for AggregateStar {
16 fn name(&self) -> &'static str {
17 "Structure/AggregateStar"
18 }
19
20 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
21 find_violations(&ctx.source, self.name())
22 }
23}
24
25fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
26 let bytes = source.as_bytes();
27 let len = bytes.len();
28
29 if len == 0 {
30 return Vec::new();
31 }
32
33 let skip = build_skip_set(bytes, len);
34 let mut diags = Vec::new();
35 let mut line: usize = 1;
36 let mut line_start: usize = 0;
37 let mut i = 0;
38
39 while i < len {
40 if bytes[i] == b'\n' {
41 line += 1;
42 line_start = i + 1;
43 i += 1;
44 continue;
45 }
46
47 if skip[i] {
48 i += 1;
49 continue;
50 }
51
52 let mut matched = false;
54 for &func in FLAGGED_AGGREGATES {
55 let flen = func.len();
56 if i + flen + 2 > len {
57 continue;
59 }
60
61 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
63 if !before_ok {
64 continue;
65 }
66
67 let name_matches = bytes[i..i + flen]
69 .iter()
70 .zip(func.bytes())
71 .all(|(a, b)| a.eq_ignore_ascii_case(&b));
72 if !name_matches {
73 continue;
74 }
75
76 let paren_pos = i + flen;
78 if paren_pos >= len || bytes[paren_pos] != b'(' {
79 continue;
80 }
81
82 let star_pos = paren_pos + 1;
84 if star_pos >= len || bytes[star_pos] != b'*' {
85 continue;
86 }
87
88 let close_pos = star_pos + 1;
90 if close_pos >= len || bytes[close_pos] != b')' {
91 continue;
92 }
93
94 if skip[paren_pos] || skip[star_pos] || skip[close_pos] {
96 continue;
97 }
98
99 let display_name: String = bytes[i..i + flen]
101 .iter()
102 .map(|b| b.to_ascii_uppercase() as char)
103 .collect();
104
105 let col = i - line_start + 1;
106 diags.push(Diagnostic {
107 rule: rule_name,
108 message: format!(
109 "{display_name}(*) is not valid SQL — only COUNT(*) supports wildcard \
110 argument; use {display_name}(column_name) instead"
111 ),
112 line,
113 col,
114 });
115
116 i = close_pos + 1;
118 matched = true;
119 break;
120 }
121
122 if !matched {
123 i += 1;
124 }
125 }
126
127 diags
128}
129
130#[inline]
131fn is_word_char(ch: u8) -> bool {
132 ch.is_ascii_alphanumeric() || ch == b'_'
133}
134
135fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
139 let mut skip = vec![false; len];
140 let mut i = 0;
141
142 while i < len {
143 if bytes[i] == b'\'' {
145 skip[i] = true;
146 i += 1;
147 while i < len {
148 skip[i] = true;
149 if bytes[i] == b'\'' {
150 if i + 1 < len && bytes[i + 1] == b'\'' {
151 i += 1;
152 skip[i] = true;
153 i += 1;
154 continue;
155 }
156 i += 1;
157 break;
158 }
159 i += 1;
160 }
161 continue;
162 }
163
164 if bytes[i] == b'"' {
166 skip[i] = true;
167 i += 1;
168 while i < len {
169 skip[i] = true;
170 if bytes[i] == b'"' {
171 if i + 1 < len && bytes[i + 1] == b'"' {
172 i += 1;
173 skip[i] = true;
174 i += 1;
175 continue;
176 }
177 i += 1;
178 break;
179 }
180 i += 1;
181 }
182 continue;
183 }
184
185 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
187 skip[i] = true;
188 skip[i + 1] = true;
189 i += 2;
190 while i < len {
191 skip[i] = true;
192 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
193 skip[i + 1] = true;
194 i += 2;
195 break;
196 }
197 i += 1;
198 }
199 continue;
200 }
201
202 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
204 skip[i] = true;
205 skip[i + 1] = true;
206 i += 2;
207 while i < len && bytes[i] != b'\n' {
208 skip[i] = true;
209 i += 1;
210 }
211 continue;
212 }
213
214 i += 1;
215 }
216
217 skip
218}