Skip to main content

sqrust_rules/capitalisation/
keywords.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3use super::{is_word_char, SkipMap};
4
5/// All SQL reserved keywords that must be written in UPPERCASE.
6/// Stored as uppercase for comparison purposes.
7const KEYWORDS: &[&str] = &[
8    "SELECT", "FROM", "WHERE", "JOIN", "LEFT", "RIGHT", "INNER", "OUTER", "FULL", "CROSS", "ON",
9    "AND", "OR", "NOT", "IN", "LIKE", "IS", "NULL", "AS", "BY", "HAVING", "UNION", "ALL",
10    "DISTINCT", "LIMIT", "OFFSET", "WITH", "CASE", "WHEN", "THEN", "ELSE", "END", "GROUP",
11    "ORDER", "ASC", "DESC", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TABLE",
12    "INDEX", "VIEW", "SET", "INTO", "VALUES", "EXISTS", "BETWEEN", "OVER", "PARTITION", "USING",
13    "NATURAL", "LATERAL", "RECURSIVE", "RETURNING", "EXCEPT", "INTERSECT", "FILTER",
14];
15
16pub struct Keywords;
17
18impl Rule for Keywords {
19    fn name(&self) -> &'static str {
20        "Capitalisation/Keywords"
21    }
22
23    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
24        let source = &ctx.source;
25        let bytes = source.as_bytes();
26        let len = bytes.len();
27        let skip_map = SkipMap::build(source);
28
29        let mut diags = Vec::new();
30
31        // Walk every byte; when we find the start of a word that is code, try
32        // to match it against the keyword list.
33        let mut i = 0;
34        while i < len {
35            // Only enter keyword detection on a word-start that is code and is
36            // not preceded by a word character.
37            if skip_map.is_code(i) && is_word_char(bytes[i]) {
38                let preceded_by_word = i > 0 && is_word_char(bytes[i - 1]);
39                if !preceded_by_word {
40                    // Find end of this word token
41                    let word_start = i;
42                    let mut j = i;
43                    while j < len && is_word_char(bytes[j]) {
44                        j += 1;
45                    }
46                    let word_end = j; // exclusive
47
48                    // The whole word must be in code (no skip bytes inside it)
49                    let all_code = (word_start..word_end).all(|k| skip_map.is_code(k));
50
51                    if all_code {
52                        let word_bytes = &bytes[word_start..word_end];
53
54                        // Check against keyword list (case-insensitive)
55                        for kw in KEYWORDS {
56                            if kw.len() == word_bytes.len()
57                                && kw
58                                    .bytes()
59                                    .zip(word_bytes.iter())
60                                    .all(|(a, &b)| a.eq_ignore_ascii_case(&b))
61                            {
62                                // It matches a keyword — is it already uppercase?
63                                let already_upper = word_bytes
64                                    .iter()
65                                    .all(|b| b.is_ascii_uppercase() || !b.is_ascii_alphabetic());
66                                if !already_upper {
67                                    // Compute line + col (1-indexed)
68                                    let (line, col) = line_col(source, word_start);
69                                    let found =
70                                        std::str::from_utf8(word_bytes).unwrap_or("?").to_string();
71                                    let upper = found.to_uppercase();
72                                    diags.push(Diagnostic {
73                                        rule: self.name(),
74                                        message: format!(
75                                            "Keyword '{}' should be UPPERCASE (use '{}')",
76                                            found, upper
77                                        ),
78                                        line,
79                                        col,
80                                    });
81                                }
82                                break;
83                            }
84                        }
85                    }
86
87                    i = word_end;
88                    continue;
89                }
90            }
91            i += 1;
92        }
93
94        diags
95    }
96}
97
98/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
99fn line_col(source: &str, offset: usize) -> (usize, usize) {
100    let before = &source[..offset];
101    let line = before.chars().filter(|&c| c == '\n').count() + 1;
102    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
103    (line, col)
104}