Skip to main content

sqrust_rules/structure/
union_all.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3use crate::capitalisation::{is_word_char, SkipMap};
4
5pub struct UnionAll;
6
7impl Rule for UnionAll {
8    fn name(&self) -> &'static str {
9        "Structure/UnionAll"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        let source = &ctx.source;
14        let bytes = source.as_bytes();
15        let len = bytes.len();
16        let skip_map = SkipMap::build(source);
17
18        let mut diags = Vec::new();
19        let mut i = 0;
20
21        while i < len {
22            // Skip non-code bytes (strings, comments, quoted identifiers).
23            if !skip_map.is_code(i) {
24                i += 1;
25                continue;
26            }
27
28            // Look for a word boundary start.
29            if !is_word_char(bytes[i]) || (i > 0 && is_word_char(bytes[i - 1])) {
30                i += 1;
31                continue;
32            }
33
34            // Identify the end of this word token.
35            let word_start = i;
36            let mut j = i;
37            while j < len && is_word_char(bytes[j]) {
38                j += 1;
39            }
40            let word_end = j;
41
42            // Ensure the entire word is in code (no skip inside the word).
43            let all_code = (word_start..word_end).all(|k| skip_map.is_code(k));
44
45            if all_code {
46                let word_bytes = &bytes[word_start..word_end];
47
48                // Case-insensitive match for "UNION".
49                let is_union = word_bytes.len() == 5
50                    && b"UNION"
51                        .iter()
52                        .zip(word_bytes.iter())
53                        .all(|(a, b)| a.eq_ignore_ascii_case(b));
54
55                if is_union {
56                    // Skip whitespace (including newlines) after UNION.
57                    let mut k = word_end;
58                    while k < len && (bytes[k] == b' ' || bytes[k] == b'\t' || bytes[k] == b'\n' || bytes[k] == b'\r') {
59                        k += 1;
60                    }
61
62                    // Read the next word.
63                    let next_word_start = k;
64                    while k < len && is_word_char(bytes[k]) {
65                        k += 1;
66                    }
67                    let next_word_end = k;
68
69                    let next_word = &bytes[next_word_start..next_word_end];
70
71                    let is_all = next_word.len() == 3
72                        && b"ALL"
73                            .iter()
74                            .zip(next_word.iter())
75                            .all(|(a, b)| a.eq_ignore_ascii_case(b));
76
77                    let is_distinct = next_word.len() == 8
78                        && b"DISTINCT"
79                            .iter()
80                            .zip(next_word.iter())
81                            .all(|(a, b)| a.eq_ignore_ascii_case(b));
82
83                    if !is_all && !is_distinct {
84                        let (line, col) = line_col(source, word_start);
85                        diags.push(Diagnostic {
86                            rule: self.name(),
87                            message: "Prefer UNION ALL or UNION DISTINCT over bare UNION to make intent explicit".to_string(),
88                            line,
89                            col,
90                        });
91                    }
92                }
93            }
94
95            i = word_end;
96        }
97
98        diags
99    }
100}
101
102/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
103fn line_col(source: &str, offset: usize) -> (usize, usize) {
104    let before = &source[..offset];
105    let line = before.chars().filter(|&c| c == '\n').count() + 1;
106    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
107    (line, col)
108}