Skip to main content

sqrust_rules/capitalisation/
functions.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3use super::{is_word_char, SkipMap};
4
5/// SQL built-in function names that must be written in UPPERCASE when used
6/// as function calls (i.e. immediately followed by `(`).
7const FUNCTIONS: &[&str] = &[
8    "COUNT",
9    "SUM",
10    "MAX",
11    "MIN",
12    "AVG",
13    "COALESCE",
14    "NULLIF",
15    "CAST",
16    "CONVERT",
17    "UPPER",
18    "LOWER",
19    "LENGTH",
20    "TRIM",
21    "LTRIM",
22    "RTRIM",
23    "SUBSTR",
24    "SUBSTRING",
25    "REPLACE",
26    "CONCAT",
27    "NOW",
28    "CURRENT_DATE",
29    "CURRENT_TIMESTAMP",
30    "DATE_TRUNC",
31    "EXTRACT",
32    "ROUND",
33    "FLOOR",
34    "CEIL",
35    "ABS",
36    "MOD",
37    "POWER",
38    "SQRT",
39    "ROW_NUMBER",
40    "RANK",
41    "DENSE_RANK",
42    "LAG",
43    "LEAD",
44    "FIRST_VALUE",
45    "LAST_VALUE",
46    "NTH_VALUE",
47    "NTILE",
48    "PERCENT_RANK",
49    "CUME_DIST",
50    "ARRAY_AGG",
51    "STRING_AGG",
52    "BOOL_AND",
53    "BOOL_OR",
54    "VARIANCE",
55    "STDDEV",
56    "UNNEST",
57    "GREATEST",
58    "LEAST",
59];
60
61pub struct Functions;
62
63impl Rule for Functions {
64    fn name(&self) -> &'static str {
65        "Capitalisation/Functions"
66    }
67
68    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
69        let source = &ctx.source;
70        let bytes = source.as_bytes();
71        let len = bytes.len();
72        let skip_map = SkipMap::build(source);
73
74        let mut diags = Vec::new();
75
76        let mut i = 0;
77        while i < len {
78            // Enter token detection on a word-start in code that is not
79            // preceded by a word character.
80            if skip_map.is_code(i) && is_word_char(bytes[i]) {
81                let preceded_by_word = i > 0 && is_word_char(bytes[i - 1]);
82                if !preceded_by_word {
83                    // Find end of this word token
84                    let word_start = i;
85                    let mut j = i;
86                    while j < len && is_word_char(bytes[j]) {
87                        j += 1;
88                    }
89                    let word_end = j; // exclusive
90
91                    // The whole word must be in code
92                    let all_code = (word_start..word_end).all(|k| skip_map.is_code(k));
93
94                    if all_code {
95                        // A function call requires an immediate `(` right after the word
96                        // (word_end must point to `(` in code).
97                        let followed_by_paren = word_end < len
98                            && bytes[word_end] == b'('
99                            && skip_map.is_code(word_end);
100
101                        if followed_by_paren {
102                            let word_bytes = &bytes[word_start..word_end];
103
104                            for func in FUNCTIONS {
105                                if func.len() == word_bytes.len()
106                                    && func
107                                        .bytes()
108                                        .zip(word_bytes.iter())
109                                        .all(|(a, &b)| a.eq_ignore_ascii_case(&b))
110                                {
111                                    // Matched — is it already uppercase?
112                                    let already_upper = word_bytes.iter().all(|b| {
113                                        b.is_ascii_uppercase() || !b.is_ascii_alphabetic()
114                                    });
115                                    if !already_upper {
116                                        let (line, col) = line_col(source, word_start);
117                                        let found = std::str::from_utf8(word_bytes)
118                                            .unwrap_or("?")
119                                            .to_string();
120                                        let upper = found.to_uppercase();
121                                        diags.push(Diagnostic {
122                                            rule: self.name(),
123                                            message: format!(
124                                                "Function '{}' should be UPPERCASE (use '{}')",
125                                                found, upper
126                                            ),
127                                            line,
128                                            col,
129                                        });
130                                    }
131                                    break;
132                                }
133                            }
134                        }
135                    }
136
137                    i = word_end;
138                    continue;
139                }
140            }
141            i += 1;
142        }
143
144        diags
145    }
146}
147
148/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
149fn line_col(source: &str, offset: usize) -> (usize, usize) {
150    let before = &source[..offset];
151    let line = before.chars().filter(|&c| c == '\n').count() + 1;
152    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
153    (line, col)
154}