Skip to main content

sqrust_rules/ambiguous/
group_by_position.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3use crate::capitalisation::SkipMap;
4
5pub struct GroupByPosition;
6
7impl Rule for GroupByPosition {
8    fn name(&self) -> &'static str {
9        "Ambiguous/GroupByPosition"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        let source = &ctx.source;
14        let bytes = source.as_bytes();
15        let len = bytes.len();
16        let skip_map = SkipMap::build(source);
17
18        let mut diags = Vec::new();
19        let mut i = 0;
20
21        while i < len {
22            // Skip non-code positions (strings, comments).
23            if !skip_map.is_code(i) {
24                i += 1;
25                continue;
26            }
27
28            // Try to match the keyword GROUP at a word boundary.
29            if let Some(after_group) = match_keyword(bytes, &skip_map, i, b"GROUP") {
30                // Skip whitespace/newlines between GROUP and BY.
31                let after_ws = skip_whitespace(bytes, after_group);
32
33                // Try to match BY.
34                if let Some(after_by) = match_keyword(bytes, &skip_map, after_ws, b"BY") {
35                    // We found GROUP BY — scan the comma-separated list.
36                    scan_positional_list(
37                        bytes,
38                        &skip_map,
39                        source,
40                        after_by,
41                        self.name(),
42                        "Avoid positional GROUP BY references; use column names",
43                        GROUP_BY_STOP_KEYWORDS,
44                        &mut diags,
45                    );
46                    i = after_by;
47                    continue;
48                }
49            }
50
51            i += 1;
52        }
53
54        diags
55    }
56}
57
58/// Keywords that terminate a GROUP BY item list.
59const GROUP_BY_STOP_KEYWORDS: &[&[u8]] = &[
60    b"HAVING", b"ORDER", b"LIMIT", b"UNION", b"INTERSECT", b"EXCEPT",
61];
62
63/// Keywords that terminate an ORDER BY item list.
64pub(super) const ORDER_BY_STOP_KEYWORDS: &[&[u8]] = &[
65    b"LIMIT", b"UNION", b"INTERSECT", b"EXCEPT",
66];
67
68/// Attempts to match `keyword` (case-insensitive, ASCII) at position `i` in `bytes`,
69/// requiring:
70/// - `i` is at a code position
71/// - the character before `i` is not a word character (word boundary start)
72/// - the character after the keyword is not a word character (word boundary end)
73///
74/// Returns the byte offset just after the keyword if matched, or None.
75pub(super) fn match_keyword(
76    bytes: &[u8],
77    skip_map: &SkipMap,
78    i: usize,
79    keyword: &[u8],
80) -> Option<usize> {
81    let len = bytes.len();
82    let klen = keyword.len();
83
84    if i + klen > len {
85        return None;
86    }
87
88    // Must start at a code position.
89    if !skip_map.is_code(i) {
90        return None;
91    }
92
93    // Word boundary before: not preceded by a word character.
94    if i > 0 && is_word_char(bytes[i - 1]) {
95        return None;
96    }
97
98    // Case-insensitive match; every keyword byte must be code.
99    for k in 0..klen {
100        if !bytes[i + k].eq_ignore_ascii_case(&keyword[k]) {
101            return None;
102        }
103        if !skip_map.is_code(i + k) {
104            return None;
105        }
106    }
107
108    // Word boundary after: not followed by a word character.
109    let end = i + klen;
110    if end < len && is_word_char(bytes[end]) {
111        return None;
112    }
113
114    Some(end)
115}
116
117/// Skips ASCII whitespace (space, tab, newline, carriage return).
118pub(super) fn skip_whitespace(bytes: &[u8], mut i: usize) -> usize {
119    while i < bytes.len()
120        && (bytes[i] == b' '
121            || bytes[i] == b'\t'
122            || bytes[i] == b'\n'
123            || bytes[i] == b'\r')
124    {
125        i += 1;
126    }
127    i
128}
129
130/// Returns true if `ch` is a word character (letter, digit, underscore).
131#[inline]
132fn is_word_char(ch: u8) -> bool {
133    ch.is_ascii_alphanumeric() || ch == b'_'
134}
135
136/// Converts a byte offset to 1-indexed (line, col).
137fn line_col(source: &str, offset: usize) -> (usize, usize) {
138    let before = &source[..offset];
139    let line = before.chars().filter(|&c| c == '\n').count() + 1;
140    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
141    (line, col)
142}
143
144/// Returns `true` if the slice is a bare integer literal (all ASCII digits, non-empty).
145fn is_bare_integer(slice: &[u8]) -> bool {
146    !slice.is_empty() && slice.iter().all(|b| b.is_ascii_digit())
147}
148
149/// Trims leading ASCII whitespace from a byte slice.
150fn trim_leading(bytes: &[u8]) -> &[u8] {
151    let start = bytes
152        .iter()
153        .take_while(|&&b| b == b' ' || b == b'\t' || b == b'\n' || b == b'\r')
154        .count();
155    &bytes[start..]
156}
157
158/// Trims trailing ASCII whitespace from a byte slice.
159fn trim_trailing(bytes: &[u8]) -> &[u8] {
160    let end = bytes
161        .iter()
162        .rposition(|&b| b != b' ' && b != b'\t' && b != b'\n' && b != b'\r')
163        .map(|p| p + 1)
164        .unwrap_or(0);
165    &bytes[..end]
166}
167
168/// Strips a trailing `ASC` or `DESC` token (case-insensitive) from item bytes.
169/// Only strips if the last word is exactly ASC or DESC.
170fn strip_trailing_direction(bytes: &[u8]) -> &[u8] {
171    let trimmed = trim_trailing(bytes);
172    let word_end = trimmed.len();
173    if word_end == 0 {
174        return trimmed;
175    }
176    // Walk backwards over word characters.
177    let mut word_start = word_end;
178    while word_start > 0 && is_word_char(trimmed[word_start - 1]) {
179        word_start -= 1;
180    }
181    let last_word = &trimmed[word_start..word_end];
182    if (last_word.eq_ignore_ascii_case(b"ASC") || last_word.eq_ignore_ascii_case(b"DESC"))
183        && word_start > 0
184    {
185        trim_trailing(&trimmed[..word_start])
186    } else {
187        trimmed
188    }
189}
190
191/// Locates the first occurrence of `first_byte` in `bytes[search_start..]` and
192/// returns its absolute offset. Used to find the exact start of an integer token.
193fn find_first_byte_offset(bytes: &[u8], search_start: usize, first_byte: u8) -> usize {
194    let mut i = search_start;
195    while i < bytes.len() {
196        if bytes[i] == first_byte {
197            return i;
198        }
199        i += 1;
200    }
201    search_start
202}
203
204/// Scans the comma-separated expression list that follows GROUP BY or ORDER BY.
205///
206/// For each item that reduces to a bare integer at a code position, emits a Diagnostic.
207///
208/// - `start`: byte offset immediately after the `BY` keyword
209/// - `rule_name`: the `&'static str` rule name for diagnostics
210/// - `message`: the violation message to embed in each Diagnostic
211/// - `stop_keywords`: keywords that terminate the clause
212/// - `diags`: output vector
213pub(super) fn scan_positional_list(
214    bytes: &[u8],
215    skip_map: &SkipMap,
216    source: &str,
217    start: usize,
218    rule_name: &'static str,
219    message: &'static str,
220    stop_keywords: &[&[u8]],
221    diags: &mut Vec<Diagnostic>,
222) {
223    let len = bytes.len();
224    let mut i = start;
225
226    'outer: loop {
227        // Skip leading whitespace before the next item.
228        while i < len
229            && (bytes[i] == b' '
230                || bytes[i] == b'\t'
231                || bytes[i] == b'\n'
232                || bytes[i] == b'\r')
233        {
234            i += 1;
235        }
236
237        if i >= len {
238            break;
239        }
240
241        // Semicolon always terminates.
242        if skip_map.is_code(i) && bytes[i] == b';' {
243            break;
244        }
245
246        // Check if we've hit a stop keyword (end of clause).
247        for &stop in stop_keywords {
248            if match_keyword(bytes, skip_map, i, stop).is_some() {
249                break 'outer;
250            }
251        }
252
253        // Record where this item begins (in absolute source offsets).
254        let item_start_abs = i;
255
256        // Collect item content until ',' or stop.
257        let mut item_end_abs = i;
258        while item_end_abs < len {
259            if !skip_map.is_code(item_end_abs) {
260                item_end_abs += 1;
261                continue;
262            }
263
264            if bytes[item_end_abs] == b',' || bytes[item_end_abs] == b';' {
265                break;
266            }
267
268            let mut stopped = false;
269            for &stop in stop_keywords {
270                if match_keyword(bytes, skip_map, item_end_abs, stop).is_some() {
271                    stopped = true;
272                    break;
273                }
274            }
275            if stopped {
276                break;
277            }
278
279            item_end_abs += 1;
280        }
281
282        // Slice the item from the source bytes.
283        let item_slice = &bytes[item_start_abs..item_end_abs];
284
285        // Trim, strip direction keyword, trim again.
286        let trimmed = trim_leading(item_slice);
287        let trimmed = strip_trailing_direction(trimmed);
288        let trimmed = trim_trailing(trimmed);
289
290        if !trimmed.is_empty() && is_bare_integer(trimmed) {
291            // The trimmed content starts with a digit.  Find the exact offset
292            // in the original source of that first digit.
293            let first_digit = trimmed[0];
294            let int_abs = find_first_byte_offset(bytes, item_start_abs, first_digit);
295            let (line, col) = line_col(source, int_abs);
296            diags.push(Diagnostic {
297                rule: rule_name,
298                message: message.to_string(),
299                line,
300                col,
301            });
302        }
303
304        // Advance past comma (if any) to next item.
305        if item_end_abs < len && bytes[item_end_abs] == b',' {
306            i = item_end_abs + 1;
307        } else {
308            break;
309        }
310    }
311}