sqrust_rules/ambiguous/
date_trunc_function.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2
3pub struct DateTruncFunction;
4
5impl Rule for DateTruncFunction {
6 fn name(&self) -> &'static str {
7 "Ambiguous/DateTruncFunction"
8 }
9
10 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
11 find_violations(&ctx.source, self.name())
12 }
13}
14
15const FUNCTIONS: &[(&str, &str)] = &[
17 (
18 "DATE_TRUNC",
19 "DATE_TRUNC() is PostgreSQL/DuckDB-specific; syntax and supported units vary across databases",
20 ),
21 (
22 "DATE_FORMAT",
23 "DATE_FORMAT() is MySQL-specific; use TO_CHAR() (Oracle/PostgreSQL) or FORMAT() (SQL Server) for formatting dates",
24 ),
25 (
26 "TRUNC",
27 "TRUNC() for date truncation is Oracle/PostgreSQL-specific; use DATE_TRUNC() or DATETRUNC() depending on dialect",
28 ),
29];
30
31fn find_violations(source: &str, rule_name: &'static str) -> Vec<Diagnostic> {
32 let bytes = source.as_bytes();
33 let len = bytes.len();
34
35 if len == 0 {
36 return Vec::new();
37 }
38
39 let skip = build_skip_set(bytes, len);
40 let mut diags = Vec::new();
41
42 for (func_name, message) in FUNCTIONS {
43 scan_for_function(source, bytes, len, &skip, func_name, message, rule_name, &mut diags);
44 }
45
46 diags.sort_by(|a, b| a.line.cmp(&b.line).then(a.col.cmp(&b.col)));
47 diags
48}
49
50fn scan_for_function(
52 source: &str,
53 bytes: &[u8],
54 len: usize,
55 skip: &[bool],
56 func_name: &str,
57 message: &str,
58 rule_name: &'static str,
59 diags: &mut Vec<Diagnostic>,
60) {
61 let kw = func_name.as_bytes();
62 let kw_len = kw.len();
63 let mut i = 0;
64
65 while i + kw_len <= len {
66 if skip[i] {
67 i += 1;
68 continue;
69 }
70
71 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
72 if before_ok && bytes[i..i + kw_len].eq_ignore_ascii_case(kw) {
73 let after = i + kw_len;
74 let after_ok = after >= len || !is_word_char(bytes[after]);
76 if after_ok {
77 let mut j = after;
79 while j < len && (bytes[j] == b' ' || bytes[j] == b'\t') {
80 j += 1;
81 }
82 if j < len && bytes[j] == b'(' {
83 let (line, col) = line_col(source, i);
84 diags.push(Diagnostic {
85 rule: rule_name,
86 message: message.to_string(),
87 line,
88 col,
89 });
90 i += kw_len;
91 continue;
92 }
93 }
94 }
95
96 i += 1;
97 }
98}
99
100#[inline]
101fn is_word_char(ch: u8) -> bool {
102 ch.is_ascii_alphanumeric() || ch == b'_'
103}
104
105fn line_col(source: &str, offset: usize) -> (usize, usize) {
106 let before = &source[..offset.min(source.len())];
107 let line = before.chars().filter(|&c| c == '\n').count() + 1;
108 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
109 (line, col)
110}
111
112fn build_skip_set(bytes: &[u8], len: usize) -> Vec<bool> {
115 let mut skip = vec![false; len];
116 let mut i = 0;
117
118 while i < len {
119 if bytes[i] == b'\'' {
121 skip[i] = true;
122 i += 1;
123 while i < len {
124 skip[i] = true;
125 if bytes[i] == b'\'' {
126 if i + 1 < len && bytes[i + 1] == b'\'' {
127 i += 1;
128 skip[i] = true;
129 i += 1;
130 continue;
131 }
132 i += 1;
133 break;
134 }
135 i += 1;
136 }
137 continue;
138 }
139
140 if bytes[i] == b'"' {
142 skip[i] = true;
143 i += 1;
144 while i < len {
145 skip[i] = true;
146 if bytes[i] == b'"' {
147 if i + 1 < len && bytes[i + 1] == b'"' {
148 i += 1;
149 skip[i] = true;
150 i += 1;
151 continue;
152 }
153 i += 1;
154 break;
155 }
156 i += 1;
157 }
158 continue;
159 }
160
161 if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
163 skip[i] = true;
164 skip[i + 1] = true;
165 i += 2;
166 while i < len {
167 skip[i] = true;
168 if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
169 skip[i + 1] = true;
170 i += 2;
171 break;
172 }
173 i += 1;
174 }
175 continue;
176 }
177
178 if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
180 skip[i] = true;
181 skip[i + 1] = true;
182 i += 2;
183 while i < len && bytes[i] != b'\n' {
184 skip[i] = true;
185 i += 1;
186 }
187 continue;
188 }
189
190 i += 1;
191 }
192
193 skip
194}