1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4 Statement, TableFactor,
5};
6
7pub struct FormatFunction;
8
9fn func_name_lower(func: &sqlparser::ast::Function) -> String {
11 func.name
12 .0
13 .last()
14 .map(|ident| ident.value.to_lowercase())
15 .unwrap_or_default()
16}
17
18fn line_col(source: &str, offset: usize) -> (usize, usize) {
20 let before = &source[..offset];
21 let line = before.chars().filter(|&c| c == '\n').count() + 1;
22 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
23 (line, col)
24}
25
26fn find_occurrence(source: &str, name: &str, occurrence: usize) -> usize {
29 let bytes = source.as_bytes();
30 let name_upper: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
31 let name_len = name_upper.len();
32 let len = bytes.len();
33 let mut count = 0usize;
34 let mut i = 0;
35
36 while i + name_len <= len {
37 let before_ok = i == 0
38 || {
39 let b = bytes[i - 1];
40 !b.is_ascii_alphanumeric() && b != b'_'
41 };
42
43 if before_ok {
44 let matches = bytes[i..i + name_len]
45 .iter()
46 .zip(name_upper.iter())
47 .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
48
49 if matches {
50 let after = i + name_len;
51 let after_ok = after < len && bytes[after] == b'(';
52
53 if after_ok {
54 if count == occurrence {
55 return i;
56 }
57 count += 1;
58 }
59 }
60 }
61
62 i += 1;
63 }
64
65 0
66}
67
68fn message_for(name: &str) -> String {
69 match name {
70 "format" => {
71 "FORMAT() behavior differs between SQL Server and MySQL — use explicit CAST and string concatenation for portable number formatting"
72 .to_string()
73 }
74 "to_char" => {
75 "TO_CHAR() is Oracle/PostgreSQL-specific — use CAST or FORMAT functions specific to your target dialect"
76 .to_string()
77 }
78 "to_varchar" => {
79 "TO_VARCHAR() is Snowflake-specific — use CAST(value AS VARCHAR) for portable type conversion"
80 .to_string()
81 }
82 _ => unreachable!(),
83 }
84}
85
86fn func_index(name: &str) -> Option<usize> {
88 match name {
89 "format" => Some(0),
90 "to_char" => Some(1),
91 "to_varchar" => Some(2),
92 _ => None,
93 }
94}
95
96const FUNC_NAMES: [&str; 3] = ["FORMAT", "TO_CHAR", "TO_VARCHAR"];
97
98fn walk_expr(
99 expr: &Expr,
100 source: &str,
101 counters: &mut [usize; 3],
102 rule: &'static str,
103 diags: &mut Vec<Diagnostic>,
104) {
105 match expr {
106 Expr::Function(func) => {
107 let lower = func_name_lower(func);
108 if let Some(idx) = func_index(lower.as_str()) {
109 let occ = counters[idx];
110 counters[idx] += 1;
111
112 let offset = find_occurrence(source, FUNC_NAMES[idx], occ);
113 let (line, col) = line_col(source, offset);
114 diags.push(Diagnostic {
115 rule,
116 message: message_for(lower.as_str()),
117 line,
118 col,
119 });
120 }
121
122 if let FunctionArguments::List(list) = &func.args {
124 for arg in &list.args {
125 let inner_expr = match arg {
126 FunctionArg::Named { arg, .. }
127 | FunctionArg::Unnamed(arg)
128 | FunctionArg::ExprNamed { arg, .. } => match arg {
129 FunctionArgExpr::Expr(e) => Some(e),
130 _ => None,
131 },
132 };
133 if let Some(e) = inner_expr {
134 walk_expr(e, source, counters, rule, diags);
135 }
136 }
137 }
138 }
139 Expr::BinaryOp { left, right, .. } => {
140 walk_expr(left, source, counters, rule, diags);
141 walk_expr(right, source, counters, rule, diags);
142 }
143 Expr::UnaryOp { expr: inner, .. } => {
144 walk_expr(inner, source, counters, rule, diags);
145 }
146 Expr::Nested(inner) => {
147 walk_expr(inner, source, counters, rule, diags);
148 }
149 Expr::Case {
150 operand,
151 conditions,
152 results,
153 else_result,
154 } => {
155 if let Some(op) = operand {
156 walk_expr(op, source, counters, rule, diags);
157 }
158 for c in conditions {
159 walk_expr(c, source, counters, rule, diags);
160 }
161 for r in results {
162 walk_expr(r, source, counters, rule, diags);
163 }
164 if let Some(e) = else_result {
165 walk_expr(e, source, counters, rule, diags);
166 }
167 }
168 _ => {}
169 }
170}
171
172fn check_select(
173 sel: &Select,
174 source: &str,
175 counters: &mut [usize; 3],
176 rule: &'static str,
177 diags: &mut Vec<Diagnostic>,
178) {
179 for item in &sel.projection {
180 match item {
181 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
182 walk_expr(e, source, counters, rule, diags);
183 }
184 _ => {}
185 }
186 }
187 if let Some(selection) = &sel.selection {
188 walk_expr(selection, source, counters, rule, diags);
189 }
190 if let Some(having) = &sel.having {
191 walk_expr(having, source, counters, rule, diags);
192 }
193 for twj in &sel.from {
194 recurse_table_factor(&twj.relation, source, counters, rule, diags);
195 for join in &twj.joins {
196 recurse_table_factor(&join.relation, source, counters, rule, diags);
197 }
198 }
199}
200
201fn recurse_table_factor(
202 tf: &TableFactor,
203 source: &str,
204 counters: &mut [usize; 3],
205 rule: &'static str,
206 diags: &mut Vec<Diagnostic>,
207) {
208 if let TableFactor::Derived { subquery, .. } = tf {
209 check_query(subquery, source, counters, rule, diags);
210 }
211}
212
213fn check_set_expr(
214 expr: &SetExpr,
215 source: &str,
216 counters: &mut [usize; 3],
217 rule: &'static str,
218 diags: &mut Vec<Diagnostic>,
219) {
220 match expr {
221 SetExpr::Select(sel) => check_select(sel, source, counters, rule, diags),
222 SetExpr::Query(inner) => check_query(inner, source, counters, rule, diags),
223 SetExpr::SetOperation { left, right, .. } => {
224 check_set_expr(left, source, counters, rule, diags);
225 check_set_expr(right, source, counters, rule, diags);
226 }
227 _ => {}
228 }
229}
230
231fn check_query(
232 query: &Query,
233 source: &str,
234 counters: &mut [usize; 3],
235 rule: &'static str,
236 diags: &mut Vec<Diagnostic>,
237) {
238 if let Some(with) = &query.with {
239 for cte in &with.cte_tables {
240 check_query(&cte.query, source, counters, rule, diags);
241 }
242 }
243 check_set_expr(&query.body, source, counters, rule, diags);
244}
245
246impl Rule for FormatFunction {
247 fn name(&self) -> &'static str {
248 "Ambiguous/FormatFunction"
249 }
250
251 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
252 if !ctx.parse_errors.is_empty() {
253 return Vec::new();
254 }
255
256 let mut diags = Vec::new();
257 let mut counters = [0usize; 3];
259
260 for stmt in &ctx.statements {
261 if let Statement::Query(query) = stmt {
262 check_query(query, &ctx.source, &mut counters, self.name(), &mut diags);
263 }
264 }
265
266 diags
267 }
268}