Skip to main content

sqrust_rules/ambiguous/
inconsistent_column_reference.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3use crate::capitalisation::SkipMap;
4use super::group_by_position::{match_keyword, skip_whitespace};
5
6pub struct InconsistentColumnReference;
7
8impl Rule for InconsistentColumnReference {
9    fn name(&self) -> &'static str {
10        "Ambiguous/InconsistentColumnReference"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        let source = &ctx.source;
15        let bytes = source.as_bytes();
16        let len = bytes.len();
17        let skip_map = SkipMap::build(source);
18        let mut diags = Vec::new();
19        let mut i = 0;
20
21        while i < len {
22            if !skip_map.is_code(i) {
23                i += 1;
24                continue;
25            }
26
27            // Check ORDER BY.
28            if let Some(after_order) = match_keyword(bytes, &skip_map, i, b"ORDER") {
29                let after_ws = skip_whitespace(bytes, after_order);
30                if let Some(after_by) = match_keyword(bytes, &skip_map, after_ws, b"BY") {
31                    if has_mixed_refs(bytes, &skip_map, after_by, ORDER_BY_STOP) {
32                        let (line, col) = offset_to_line_col(source, i);
33                        diags.push(Diagnostic {
34                            rule: "Ambiguous/InconsistentColumnReference",
35                            message: "ORDER BY mixes positional column references (e.g. 1) with named references; use one style consistently".to_string(),
36                            line,
37                            col,
38                        });
39                    }
40                    i = after_by;
41                    continue;
42                }
43            }
44
45            // Check GROUP BY.
46            if let Some(after_group) = match_keyword(bytes, &skip_map, i, b"GROUP") {
47                let after_ws = skip_whitespace(bytes, after_group);
48                if let Some(after_by) = match_keyword(bytes, &skip_map, after_ws, b"BY") {
49                    if has_mixed_refs(bytes, &skip_map, after_by, GROUP_BY_STOP) {
50                        let (line, col) = offset_to_line_col(source, i);
51                        diags.push(Diagnostic {
52                            rule: "Ambiguous/InconsistentColumnReference",
53                            message: "GROUP BY mixes positional column references (e.g. 1) with named references; use one style consistently".to_string(),
54                            line,
55                            col,
56                        });
57                    }
58                    i = after_by;
59                    continue;
60                }
61            }
62
63            i += 1;
64        }
65
66        diags
67    }
68}
69
70/// Stop keywords that terminate an ORDER BY item list.
71const ORDER_BY_STOP: &[&[u8]] = &[
72    b"LIMIT", b"UNION", b"INTERSECT", b"EXCEPT", b"FETCH", b"OFFSET", b"FOR",
73];
74
75/// Stop keywords that terminate a GROUP BY item list.
76const GROUP_BY_STOP: &[&[u8]] = &[
77    b"HAVING", b"ORDER", b"LIMIT", b"UNION", b"INTERSECT", b"EXCEPT",
78];
79
80/// Converts a byte offset to 1-indexed (line, col).
81fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
82    let before = &source[..offset];
83    let line = before.chars().filter(|&c| c == '\n').count() + 1;
84    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
85    (line, col)
86}
87
88/// Returns `true` if `ch` is an ASCII digit.
89#[inline]
90fn is_digit(ch: u8) -> bool {
91    ch.is_ascii_digit()
92}
93
94/// Returns `true` if `ch` is a word character (`[a-zA-Z0-9_]`).
95#[inline]
96fn is_word_start(ch: u8) -> bool {
97    ch.is_ascii_alphabetic() || ch == b'_' || ch == b'"' || ch == b'`'
98}
99
100/// Scans the comma-separated expression list that follows GROUP BY or ORDER BY.
101/// Returns true if the clause mixes positional (integer) and named references.
102fn has_mixed_refs(
103    bytes: &[u8],
104    skip_map: &SkipMap,
105    start: usize,
106    stop_keywords: &[&[u8]],
107) -> bool {
108    let len = bytes.len();
109    let mut i = start;
110    let mut has_positional = false;
111    let mut has_named = false;
112
113    'outer: loop {
114        // Skip leading whitespace.
115        while i < len
116            && (bytes[i] == b' '
117                || bytes[i] == b'\t'
118                || bytes[i] == b'\n'
119                || bytes[i] == b'\r')
120        {
121            i += 1;
122        }
123
124        if i >= len {
125            break;
126        }
127
128        // Semicolon or closing paren terminates.
129        if skip_map.is_code(i) && (bytes[i] == b';' || bytes[i] == b')') {
130            break;
131        }
132
133        // Check stop keywords.
134        for &stop in stop_keywords {
135            if match_keyword(bytes, skip_map, i, stop).is_some() {
136                break 'outer;
137            }
138        }
139
140        // Find the first significant code token in this item.
141        // Collect item until next comma at depth 0 or stop.
142        let item_start = i;
143        let mut item_end = i;
144        let mut depth = 0usize;
145
146        while item_end < len {
147            if !skip_map.is_code(item_end) {
148                item_end += 1;
149                continue;
150            }
151
152            let b = bytes[item_end];
153
154            if b == b'(' {
155                depth += 1;
156                item_end += 1;
157                continue;
158            }
159            if b == b')' {
160                if depth == 0 {
161                    break;
162                }
163                depth -= 1;
164                item_end += 1;
165                continue;
166            }
167            if depth == 0 {
168                if b == b',' || b == b';' {
169                    break;
170                }
171                let mut stopped = false;
172                for &stop in stop_keywords {
173                    if match_keyword(bytes, skip_map, item_end, stop).is_some() {
174                        stopped = true;
175                        break;
176                    }
177                }
178                if stopped {
179                    break;
180                }
181            }
182
183            item_end += 1;
184        }
185
186        // Inspect the first significant code token in this item.
187        let mut j = item_start;
188        // Skip leading whitespace inside item.
189        while j < item_end
190            && (bytes[j] == b' '
191                || bytes[j] == b'\t'
192                || bytes[j] == b'\n'
193                || bytes[j] == b'\r')
194        {
195            j += 1;
196        }
197
198        if j < item_end && skip_map.is_code(j) {
199            let ch = bytes[j];
200            if is_digit(ch) {
201                // Check it's a pure integer token (all digits, not e.g. 1+expr).
202                let mut k = j;
203                while k < item_end && skip_map.is_code(k) && bytes[k].is_ascii_digit() {
204                    k += 1;
205                }
206                // After digits there should be whitespace, comma or end-of-item for a positional ref.
207                let next_code = {
208                    let mut n = k;
209                    while n < item_end
210                        && (bytes[n] == b' '
211                            || bytes[n] == b'\t'
212                            || bytes[n] == b'\n'
213                            || bytes[n] == b'\r')
214                    {
215                        n += 1;
216                    }
217                    n
218                };
219                // If after the integer there's nothing meaningful in this item (possibly ASC/DESC), it's positional.
220                let after_int_word: &[u8] = if next_code < item_end {
221                    let word_start = next_code;
222                    let mut word_end = next_code;
223                    while word_end < item_end && skip_map.is_code(word_end) && (bytes[word_end].is_ascii_alphanumeric() || bytes[word_end] == b'_') {
224                        word_end += 1;
225                    }
226                    &bytes[word_start..word_end]
227                } else {
228                    &[]
229                };
230                // Only count as positional if the token after the digits is ASC, DESC, NULLS, or end.
231                if after_int_word.is_empty()
232                    || after_int_word.eq_ignore_ascii_case(b"ASC")
233                    || after_int_word.eq_ignore_ascii_case(b"DESC")
234                    || after_int_word.eq_ignore_ascii_case(b"NULLS")
235                {
236                    has_positional = true;
237                } else {
238                    // e.g. `1 + col` — treat as named/expression
239                    has_named = true;
240                }
241            } else if is_word_start(ch) {
242                has_named = true;
243            }
244        }
245
246        // Advance past comma.
247        if item_end < len && bytes[item_end] == b',' {
248            i = item_end + 1;
249        } else {
250            break;
251        }
252    }
253
254    has_positional && has_named
255}