Skip to main content

sqrust_rules/structure/
cross_apply.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3/// Flag `CROSS APPLY` and `OUTER APPLY` which are SQL Server / PostgreSQL-specific
4/// table-valued function join syntax not supported in standard SQL or most
5/// analytical databases.
6pub struct CrossApply;
7
8impl Rule for CrossApply {
9    fn name(&self) -> &'static str {
10        "Structure/CrossApply"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        find_violations(&ctx.source, self.name())
15    }
16}
17
18fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
19    let bytes = source.as_bytes();
20    let len = bytes.len();
21
22    if len == 0 {
23        return Vec::new();
24    }
25
26    let skip = build_skip_set(bytes, len);
27    let mut diags = Vec::new();
28    let mut i = 0;
29
30    while i < len {
31        if skip[i] {
32            i += 1;
33            continue;
34        }
35
36        // Word boundary before current position.
37        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
38        if !before_ok {
39            i += 1;
40            continue;
41        }
42
43        // Try to match CROSS APPLY.
44        if let Some(end) = match_two_word_keyword(bytes, len, &skip, i, b"CROSS", b"APPLY") {
45            let (line, col) = offset_to_line_col(source, i);
46            diags.push(Diagnostic {
47                rule: rule_name,
48                message: "CROSS APPLY is SQL Server/PostgreSQL-specific; use a LATERAL JOIN for standard SQL".to_string(),
49                line,
50                col,
51            });
52            i = end;
53            continue;
54        }
55
56        // Try to match OUTER APPLY.
57        if let Some(end) = match_two_word_keyword(bytes, len, &skip, i, b"OUTER", b"APPLY") {
58            let (line, col) = offset_to_line_col(source, i);
59            diags.push(Diagnostic {
60                rule: rule_name,
61                message: "OUTER APPLY is SQL Server/PostgreSQL-specific; use a LEFT JOIN LATERAL for standard SQL".to_string(),
62                line,
63                col,
64            });
65            i = end;
66            continue;
67        }
68
69        i += 1;
70    }
71
72    diags
73}
74
75/// Try to match two keywords (word1 followed by optional whitespace then word2)
76/// starting at `start`. Returns `Some(end_offset)` if matched, `None` otherwise.
77/// `end_offset` is one past the last character of the second keyword.
78fn match_two_word_keyword(
79    bytes: &[u8],
80    len: usize,
81    skip: &[bool],
82    start: usize,
83    word1: &[u8],
84    word2: &[u8],
85) -> Option<usize> {
86    let w1_len = word1.len();
87    let w2_len = word2.len();
88
89    if start + w1_len > len {
90        return None;
91    }
92
93    // Match word1 (case-insensitive).
94    let matches_w1 = bytes[start..start + w1_len]
95        .iter()
96        .zip(word1.iter())
97        .all(|(a, b)| a.eq_ignore_ascii_case(b));
98
99    if !matches_w1 {
100        return None;
101    }
102
103    // Ensure none of word1's bytes are skipped.
104    if (start..start + w1_len).any(|k| skip[k]) {
105        return None;
106    }
107
108    // After word1 must be a word boundary (not a word char).
109    let after_w1 = start + w1_len;
110    if after_w1 < len && is_word_char(bytes[after_w1]) {
111        return None;
112    }
113
114    // Skip whitespace between word1 and word2.
115    let mut j = after_w1;
116    while j < len && is_whitespace(bytes[j]) {
117        j += 1;
118    }
119
120    if j + w2_len > len {
121        return None;
122    }
123
124    // Match word2 (case-insensitive).
125    let matches_w2 = bytes[j..j + w2_len]
126        .iter()
127        .zip(word2.iter())
128        .all(|(a, b)| a.eq_ignore_ascii_case(b));
129
130    if !matches_w2 {
131        return None;
132    }
133
134    // Ensure none of word2's bytes are skipped.
135    if (j..j + w2_len).any(|k| skip[k]) {
136        return None;
137    }
138
139    // After word2 must be a word boundary.
140    let after_w2 = j + w2_len;
141    if after_w2 < len && is_word_char(bytes[after_w2]) {
142        return None;
143    }
144
145    Some(after_w2)
146}
147
148#[inline]
149fn is_word_char(ch: u8) -> bool {
150    ch.is_ascii_alphanumeric() || ch == b'_'
151}
152
153#[inline]
154fn is_whitespace(ch: u8) -> bool {
155    ch == b' ' || ch == b'\t' || ch == b'\n' || ch == b'\r'
156}
157
158/// Converts a byte offset to a 1-indexed (line, col) pair.
159fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
160    let before = &source[..offset];
161    let line = before.chars().filter(|&c| c == '\n').count() + 1;
162    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
163    (line, col)
164}
165
166/// Build a boolean skip-set: `skip[i] == true` means byte `i` is inside a
167/// single-quoted string, double-quoted identifier, block comment, or line
168/// comment.
169fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
170    let mut skip = vec![false; len];
171    let mut i = 0;
172
173    while i < len {
174        // Single-quoted string: '...' with '' escape.
175        if bytes[i] == b'\'' {
176            skip[i] = true;
177            i += 1;
178            while i < len {
179                skip[i] = true;
180                if bytes[i] == b'\'' {
181                    if i + 1 < len && bytes[i + 1] == b'\'' {
182                        i += 1;
183                        skip[i] = true;
184                        i += 1;
185                        continue;
186                    }
187                    i += 1;
188                    break;
189                }
190                i += 1;
191            }
192            continue;
193        }
194
195        // Double-quoted identifier: "..." with "" escape.
196        if bytes[i] == b'"' {
197            skip[i] = true;
198            i += 1;
199            while i < len {
200                skip[i] = true;
201                if bytes[i] == b'"' {
202                    if i + 1 < len && bytes[i + 1] == b'"' {
203                        i += 1;
204                        skip[i] = true;
205                        i += 1;
206                        continue;
207                    }
208                    i += 1;
209                    break;
210                }
211                i += 1;
212            }
213            continue;
214        }
215
216        // Block comment: /* ... */
217        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
218            skip[i] = true;
219            skip[i + 1] = true;
220            i += 2;
221            while i < len {
222                skip[i] = true;
223                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
224                    skip[i + 1] = true;
225                    i += 2;
226                    break;
227                }
228                i += 1;
229            }
230            continue;
231        }
232
233        // Line comment: -- to end of line.
234        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
235            skip[i] = true;
236            skip[i + 1] = true;
237            i += 2;
238            while i < len && bytes[i] != b'\n' {
239                skip[i] = true;
240                i += 1;
241            }
242            continue;
243        }
244
245        i += 1;
246    }
247
248    skip
249}