Skip to main content

sqrust_rules/ambiguous/
format_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor,
5};
6
7pub struct FormatFunction;
8
9/// Returns the lowercase last-ident of a function's name, or empty string.
10fn func_name_lower(func: &sqlparser::ast::Function) -> String {
11    func.name
12        .0
13        .last()
14        .map(|ident| ident.value.to_lowercase())
15        .unwrap_or_default()
16}
17
18/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
19fn line_col(source: &str, offset: usize) -> (usize, usize) {
20    let before = &source[..offset];
21    let line = before.chars().filter(|&c| c == '\n').count() + 1;
22    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
23    (line, col)
24}
25
26/// Find the Nth occurrence (0-indexed) of `name` as a function call (case-insensitive)
27/// in `source`. Returns byte offset or 0 if not found.
28fn find_occurrence(source: &str, name: &str, occurrence: usize) -> usize {
29    let bytes = source.as_bytes();
30    let name_upper: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
31    let name_len = name_upper.len();
32    let len = bytes.len();
33    let mut count = 0usize;
34    let mut i = 0;
35
36    while i + name_len <= len {
37        let before_ok = i == 0
38            || {
39                let b = bytes[i - 1];
40                !b.is_ascii_alphanumeric() && b != b'_'
41            };
42
43        if before_ok {
44            let matches = bytes[i..i + name_len]
45                .iter()
46                .zip(name_upper.iter())
47                .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
48
49            if matches {
50                let after = i + name_len;
51                let after_ok = after < len && bytes[after] == b'(';
52
53                if after_ok {
54                    if count == occurrence {
55                        return i;
56                    }
57                    count += 1;
58                }
59            }
60        }
61
62        i += 1;
63    }
64
65    0
66}
67
68fn message_for(name: &str) -> String {
69    match name {
70        "format" => {
71            "FORMAT() behavior differs between SQL Server and MySQL — use explicit CAST and string concatenation for portable number formatting"
72                .to_string()
73        }
74        "to_char" => {
75            "TO_CHAR() is Oracle/PostgreSQL-specific — use CAST or FORMAT functions specific to your target dialect"
76                .to_string()
77        }
78        "to_varchar" => {
79            "TO_VARCHAR() is Snowflake-specific — use CAST(value AS VARCHAR) for portable type conversion"
80                .to_string()
81        }
82        _ => unreachable!(),
83    }
84}
85
86/// Index mapping for counter array: format=0, to_char=1, to_varchar=2
87fn func_index(name: &str) -> Option<usize> {
88    match name {
89        "format" => Some(0),
90        "to_char" => Some(1),
91        "to_varchar" => Some(2),
92        _ => None,
93    }
94}
95
96const FUNC_NAMES: [&str; 3] = ["FORMAT", "TO_CHAR", "TO_VARCHAR"];
97
98fn walk_expr(
99    expr: &Expr,
100    source: &str,
101    counters: &mut [usize; 3],
102    rule: &'static str,
103    diags: &mut Vec<Diagnostic>,
104) {
105    match expr {
106        Expr::Function(func) => {
107            let lower = func_name_lower(func);
108            if let Some(idx) = func_index(lower.as_str()) {
109                let occ = counters[idx];
110                counters[idx] += 1;
111
112                let offset = find_occurrence(source, FUNC_NAMES[idx], occ);
113                let (line, col) = line_col(source, offset);
114                diags.push(Diagnostic {
115                    rule,
116                    message: message_for(lower.as_str()),
117                    line,
118                    col,
119                });
120            }
121
122            // Recurse into function arguments.
123            if let FunctionArguments::List(list) = &func.args {
124                for arg in &list.args {
125                    let inner_expr = match arg {
126                        FunctionArg::Named { arg, .. }
127                        | FunctionArg::Unnamed(arg)
128                        | FunctionArg::ExprNamed { arg, .. } => match arg {
129                            FunctionArgExpr::Expr(e) => Some(e),
130                            _ => None,
131                        },
132                    };
133                    if let Some(e) = inner_expr {
134                        walk_expr(e, source, counters, rule, diags);
135                    }
136                }
137            }
138        }
139        Expr::BinaryOp { left, right, .. } => {
140            walk_expr(left, source, counters, rule, diags);
141            walk_expr(right, source, counters, rule, diags);
142        }
143        Expr::UnaryOp { expr: inner, .. } => {
144            walk_expr(inner, source, counters, rule, diags);
145        }
146        Expr::Nested(inner) => {
147            walk_expr(inner, source, counters, rule, diags);
148        }
149        Expr::Case {
150            operand,
151            conditions,
152            results,
153            else_result,
154        } => {
155            if let Some(op) = operand {
156                walk_expr(op, source, counters, rule, diags);
157            }
158            for c in conditions {
159                walk_expr(c, source, counters, rule, diags);
160            }
161            for r in results {
162                walk_expr(r, source, counters, rule, diags);
163            }
164            if let Some(e) = else_result {
165                walk_expr(e, source, counters, rule, diags);
166            }
167        }
168        _ => {}
169    }
170}
171
172fn check_select(
173    sel: &Select,
174    source: &str,
175    counters: &mut [usize; 3],
176    rule: &'static str,
177    diags: &mut Vec<Diagnostic>,
178) {
179    for item in &sel.projection {
180        match item {
181            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
182                walk_expr(e, source, counters, rule, diags);
183            }
184            _ => {}
185        }
186    }
187    if let Some(selection) = &sel.selection {
188        walk_expr(selection, source, counters, rule, diags);
189    }
190    if let Some(having) = &sel.having {
191        walk_expr(having, source, counters, rule, diags);
192    }
193    for twj in &sel.from {
194        recurse_table_factor(&twj.relation, source, counters, rule, diags);
195        for join in &twj.joins {
196            recurse_table_factor(&join.relation, source, counters, rule, diags);
197        }
198    }
199}
200
201fn recurse_table_factor(
202    tf: &TableFactor,
203    source: &str,
204    counters: &mut [usize; 3],
205    rule: &'static str,
206    diags: &mut Vec<Diagnostic>,
207) {
208    if let TableFactor::Derived { subquery, .. } = tf {
209        check_query(subquery, source, counters, rule, diags);
210    }
211}
212
213fn check_set_expr(
214    expr: &SetExpr,
215    source: &str,
216    counters: &mut [usize; 3],
217    rule: &'static str,
218    diags: &mut Vec<Diagnostic>,
219) {
220    match expr {
221        SetExpr::Select(sel) => check_select(sel, source, counters, rule, diags),
222        SetExpr::Query(inner) => check_query(inner, source, counters, rule, diags),
223        SetExpr::SetOperation { left, right, .. } => {
224            check_set_expr(left, source, counters, rule, diags);
225            check_set_expr(right, source, counters, rule, diags);
226        }
227        _ => {}
228    }
229}
230
231fn check_query(
232    query: &Query,
233    source: &str,
234    counters: &mut [usize; 3],
235    rule: &'static str,
236    diags: &mut Vec<Diagnostic>,
237) {
238    if let Some(with) = &query.with {
239        for cte in &with.cte_tables {
240            check_query(&cte.query, source, counters, rule, diags);
241        }
242    }
243    check_set_expr(&query.body, source, counters, rule, diags);
244}
245
246impl Rule for FormatFunction {
247    fn name(&self) -> &'static str {
248        "Ambiguous/FormatFunction"
249    }
250
251    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
252        if !ctx.parse_errors.is_empty() {
253            return Vec::new();
254        }
255
256        let mut diags = Vec::new();
257        // counters[0] = format, counters[1] = to_char, counters[2] = to_varchar
258        let mut counters = [0usize; 3];
259
260        for stmt in &ctx.statements {
261            if let Statement::Query(query) = stmt {
262                check_query(query, &ctx.source, &mut counters, self.name(), &mut diags);
263            }
264        }
265
266        diags
267    }
268}