Skip to main content

sqrust_rules/ambiguous/
date_trunc_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct DateTruncFunction;
4
5impl Rule for DateTruncFunction {
6    fn name(&self) -> &'static str {
7        "Ambiguous/DateTruncFunction"
8    }
9
10    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11        find_violations(&ctx.source, self.name())
12    }
13}
14
15/// Function specs: (uppercase name, message).
16const FUNCTIONS: &[(&str, &str)] = &[
17    (
18        "DATE_TRUNC",
19        "DATE_TRUNC() is PostgreSQL/DuckDB-specific; syntax and supported units vary across databases",
20    ),
21    (
22        "DATE_FORMAT",
23        "DATE_FORMAT() is MySQL-specific; use TO_CHAR() (Oracle/PostgreSQL) or FORMAT() (SQL Server) for formatting dates",
24    ),
25    (
26        "TRUNC",
27        "TRUNC() for date truncation is Oracle/PostgreSQL-specific; use DATE_TRUNC() or DATETRUNC() depending on dialect",
28    ),
29];
30
31fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
32    let bytes = source.as_bytes();
33    let len = bytes.len();
34
35    if len == 0 {
36        return Vec::new();
37    }
38
39    let skip = build_skip_set(bytes, len);
40    let mut diags = Vec::new();
41
42    for (func_name, message) in FUNCTIONS {
43        scan_for_function(source, bytes, len, &skip, func_name, message, rule_name, &mut diags);
44    }
45
46    diags.sort_by(|a, b| a.line.cmp(&b.line).then(a.col.cmp(&b.col)));
47    diags
48}
49
50/// Scan for `func_name(` (case-insensitive) with word boundaries.
51fn scan_for_function(
52    source: &str,
53    bytes: &[u8],
54    len: usize,
55    skip: &[bool],
56    func_name: &str,
57    message: &str,
58    rule_name: &'static str,
59    diags: &mut Vec<Diagnostic>,
60) {
61    let kw = func_name.as_bytes();
62    let kw_len = kw.len();
63    let mut i = 0;
64
65    while i + kw_len <= len {
66        if skip[i] {
67            i += 1;
68            continue;
69        }
70
71        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
72        if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(kw) {
73            let after = i + kw_len;
74            // Word boundary after: next char must not be a word char
75            let after_ok = after >= len || !is_word_char(bytes[after]);
76            if after_ok {
77                // Skip optional whitespace then check for '('
78                let mut j = after;
79                while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
80                    j += 1;
81                }
82                if j < len && bytes[j] == b'(' {
83                    let (line, col) = line_col(source, i);
84                    diags.push(Diagnostic {
85                        rule: rule_name,
86                        message: message.to_string(),
87                        line,
88                        col,
89                    });
90                    i += kw_len;
91                    continue;
92                }
93            }
94        }
95
96        i += 1;
97    }
98}
99
100#[inline]
101fn is_word_char(ch: u8) -> bool {
102    ch.is_ascii_alphanumeric() || ch == b'_'
103}
104
105fn line_col(source: &str, offset: usize) -> (usize, usize) {
106    let before = &source[..offset.min(source.len())];
107    let line = before.chars().filter(|&c| c == '\n').count() + 1;
108    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
109    (line, col)
110}
111
112/// Build a boolean skip-set: `skip[i] == true` means byte `i` is inside a
113/// single-quoted string, double-quoted identifier, block comment, or line comment.
114fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
115    let mut skip = vec![false; len];
116    let mut i = 0;
117
118    while i < len {
119        // Single-quoted string: '...' with '' escape.
120        if bytes[i] == b'\'' {
121            skip[i] = true;
122            i += 1;
123            while i < len {
124                skip[i] = true;
125                if bytes[i] == b'\'' {
126                    if i + 1 < len && bytes[i + 1] == b'\'' {
127                        i += 1;
128                        skip[i] = true;
129                        i += 1;
130                        continue;
131                    }
132                    i += 1;
133                    break;
134                }
135                i += 1;
136            }
137            continue;
138        }
139
140        // Double-quoted identifier: "..." with "" escape.
141        if bytes[i] == b'"' {
142            skip[i] = true;
143            i += 1;
144            while i < len {
145                skip[i] = true;
146                if bytes[i] == b'"' {
147                    if i + 1 < len && bytes[i + 1] == b'"' {
148                        i += 1;
149                        skip[i] = true;
150                        i += 1;
151                        continue;
152                    }
153                    i += 1;
154                    break;
155                }
156                i += 1;
157            }
158            continue;
159        }
160
161        // Block comment: /* ... */
162        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
163            skip[i] = true;
164            skip[i + 1] = true;
165            i += 2;
166            while i < len {
167                skip[i] = true;
168                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
169                    skip[i + 1] = true;
170                    i += 2;
171                    break;
172                }
173                i += 1;
174            }
175            continue;
176        }
177
178        // Line comment: -- to end of line.
179        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
180            skip[i] = true;
181            skip[i + 1] = true;
182            i += 2;
183            while i < len && bytes[i] != b'\n' {
184                skip[i] = true;
185                i += 1;
186            }
187            continue;
188        }
189
190        i += 1;
191    }
192
193    skip
194}