Skip to main content

sqrust_rules/structure/
aggregate_star.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3/// Flag aggregate functions other than COUNT that use `*` as their argument
4/// (e.g., `SUM(*)`, `AVG(*)`, `MIN(*)`, `MAX(*)`). Only `COUNT(*)` is valid
5/// SQL; using `*` with other aggregates is almost always a typo or logic error.
6pub struct AggregateStar;
7
8/// Aggregate function names that must NOT use `*`.
9/// All lowercase for case-insensitive matching.
10const FLAGGED_AGGREGATES: &[&str] = &[
11    "sum", "avg", "min", "max", "stddev", "stddev_pop", "stddev_samp", "variance", "var_pop",
12    "var_samp", "median",
13];
14
15impl Rule for AggregateStar {
16    fn name(&self) -> &'static str {
17        "Structure/AggregateStar"
18    }
19
20    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
21        find_violations(&ctx.source, self.name())
22    }
23}
24
25fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
26    let bytes = source.as_bytes();
27    let len = bytes.len();
28
29    if len == 0 {
30        return Vec::new();
31    }
32
33    let skip = build_skip_set(bytes, len);
34    let mut diags = Vec::new();
35    let mut line: usize = 1;
36    let mut line_start: usize = 0;
37    let mut i = 0;
38
39    while i < len {
40        if bytes[i] == b'\n' {
41            line += 1;
42            line_start = i + 1;
43            i += 1;
44            continue;
45        }
46
47        if skip[i] {
48            i += 1;
49            continue;
50        }
51
52        // Try to match each flagged aggregate starting at position i.
53        let mut matched = false;
54        for &func in FLAGGED_AGGREGATES {
55            let flen = func.len();
56            if i + flen + 2 > len {
57                // Need at least funcname + "(*)"
58                continue;
59            }
60
61            // Word-boundary before the function name.
62            let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
63            if !before_ok {
64                continue;
65            }
66
67            // Match function name (case-insensitive).
68            let name_matches = bytes[i..i + flen]
69                .iter()
70                .zip(func.bytes())
71                .all(|(a, b)| a.eq_ignore_ascii_case(&b));
72            if !name_matches {
73                continue;
74            }
75
76            // Immediately after the name must be '(' (word-boundary + paren).
77            let paren_pos = i + flen;
78            if paren_pos >= len || bytes[paren_pos] != b'(' {
79                continue;
80            }
81
82            // After '(' must be '*'.
83            let star_pos = paren_pos + 1;
84            if star_pos >= len || bytes[star_pos] != b'*' {
85                continue;
86            }
87
88            // After '*' must be ')'.
89            let close_pos = star_pos + 1;
90            if close_pos >= len || bytes[close_pos] != b')' {
91                continue;
92            }
93
94            // None of these positions should be in a skip region.
95            if skip[paren_pos] || skip[star_pos] || skip[close_pos] {
96                continue;
97            }
98
99            // Build the display name (preserve case from source).
100            let display_name: String = bytes[i..i + flen]
101                .iter()
102                .map(|b| b.to_ascii_uppercase() as char)
103                .collect();
104
105            let col = i - line_start + 1;
106            diags.push(Diagnostic {
107                rule: rule_name,
108                message: format!(
109                    "{display_name}(*) is not valid SQL — only COUNT(*) supports wildcard \
110                     argument; use {display_name}(column_name) instead"
111                ),
112                line,
113                col,
114            });
115
116            // Advance past the matched pattern so we don't double-count.
117            i = close_pos + 1;
118            matched = true;
119            break;
120        }
121
122        if !matched {
123            i += 1;
124        }
125    }
126
127    diags
128}
129
130#[inline]
131fn is_word_char(ch: u8) -> bool {
132    ch.is_ascii_alphanumeric() || ch == b'_'
133}
134
135/// Build a boolean skip-set: `skip[i] == true` means byte `i` is inside a
136/// single-quoted string, double-quoted identifier, block comment, or line
137/// comment.
138fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
139    let mut skip = vec![false; len];
140    let mut i = 0;
141
142    while i < len {
143        // Single-quoted string: '...' with '' escape.
144        if bytes[i] == b'\'' {
145            skip[i] = true;
146            i += 1;
147            while i < len {
148                skip[i] = true;
149                if bytes[i] == b'\'' {
150                    if i + 1 < len && bytes[i + 1] == b'\'' {
151                        i += 1;
152                        skip[i] = true;
153                        i += 1;
154                        continue;
155                    }
156                    i += 1;
157                    break;
158                }
159                i += 1;
160            }
161            continue;
162        }
163
164        // Double-quoted identifier: "..." with "" escape.
165        if bytes[i] == b'"' {
166            skip[i] = true;
167            i += 1;
168            while i < len {
169                skip[i] = true;
170                if bytes[i] == b'"' {
171                    if i + 1 < len && bytes[i + 1] == b'"' {
172                        i += 1;
173                        skip[i] = true;
174                        i += 1;
175                        continue;
176                    }
177                    i += 1;
178                    break;
179                }
180                i += 1;
181            }
182            continue;
183        }
184
185        // Block comment: /* ... */
186        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
187            skip[i] = true;
188            skip[i + 1] = true;
189            i += 2;
190            while i < len {
191                skip[i] = true;
192                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
193                    skip[i + 1] = true;
194                    i += 2;
195                    break;
196                }
197                i += 1;
198            }
199            continue;
200        }
201
202        // Line comment: -- to end of line.
203        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
204            skip[i] = true;
205            skip[i + 1] = true;
206            i += 2;
207            while i < len && bytes[i] != b'\n' {
208                skip[i] = true;
209                i += 1;
210            }
211            continue;
212        }
213
214        i += 1;
215    }
216
217    skip
218}