Skip to main content

sqrust_rules/ambiguous/
substring_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor,
5};
6
7pub struct SubstringFunction;
8
9/// Returns the lowercase last-ident of a function's name, or empty string.
10fn func_name_lower(func: &sqlparser::ast::Function) -> String {
11    func.name
12        .0
13        .last()
14        .map(|ident| ident.value.to_lowercase())
15        .unwrap_or_default()
16}
17
18/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
19fn line_col(source: &str, offset: usize) -> (usize, usize) {
20    let before = &source[..offset];
21    let line = before.chars().filter(|&c| c == '\n').count() + 1;
22    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
23    (line, col)
24}
25
26/// Find the Nth occurrence (0-indexed) of `name` as a function call (case-insensitive)
27/// in `source`. Returns byte offset or 0 if not found.
28fn find_occurrence(source: &str, name: &str, occurrence: usize) -> usize {
29    let bytes = source.as_bytes();
30    let name_upper: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
31    let name_len = name_upper.len();
32    let len = bytes.len();
33    let mut count = 0usize;
34    let mut i = 0;
35
36    while i + name_len <= len {
37        let before_ok = i == 0
38            || {
39                let b = bytes[i - 1];
40                !b.is_ascii_alphanumeric() && b != b'_'
41            };
42
43        if before_ok {
44            let matches = bytes[i..i + name_len]
45                .iter()
46                .zip(name_upper.iter())
47                .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
48
49            if matches {
50                let after = i + name_len;
51                let after_ok = after < len && bytes[after] == b'(';
52
53                if after_ok {
54                    if count == occurrence {
55                        return i;
56                    }
57                    count += 1;
58                }
59            }
60        }
61
62        i += 1;
63    }
64
65    0
66}
67
68fn message_for(name: &str) -> String {
69    match name {
70        "substr" => {
71            "SUBSTR() is a non-standard alias — use SUBSTRING(str, start, length) for maximum portability"
72                .to_string()
73        }
74        "mid" => {
75            "MID() is MySQL-specific — use SUBSTRING(str, start, length) for portable substring extraction"
76                .to_string()
77        }
78        _ => unreachable!(),
79    }
80}
81
82fn walk_expr(
83    expr: &Expr,
84    source: &str,
85    counters: &mut [usize; 2],
86    rule: &'static str,
87    diags: &mut Vec<Diagnostic>,
88) {
89    match expr {
90        Expr::Function(func) => {
91            let lower = func_name_lower(func);
92            // counters[0] = substr occurrences, counters[1] = mid occurrences
93            let func_key = match lower.as_str() {
94                "substr" => Some(0usize),
95                "mid" => Some(1usize),
96                _ => None,
97            };
98            if let Some(idx) = func_key {
99                let occ = counters[idx];
100                counters[idx] += 1;
101
102                let func_name_str = if idx == 0 { "SUBSTR" } else { "MID" };
103                let offset = find_occurrence(source, func_name_str, occ);
104                let (line, col) = line_col(source, offset);
105                diags.push(Diagnostic {
106                    rule,
107                    message: message_for(lower.as_str()),
108                    line,
109                    col,
110                });
111            }
112
113            // Recurse into function arguments.
114            if let FunctionArguments::List(list) = &func.args {
115                for arg in &list.args {
116                    let inner_expr = match arg {
117                        FunctionArg::Named { arg, .. }
118                        | FunctionArg::Unnamed(arg)
119                        | FunctionArg::ExprNamed { arg, .. } => match arg {
120                            FunctionArgExpr::Expr(e) => Some(e),
121                            _ => None,
122                        },
123                    };
124                    if let Some(e) = inner_expr {
125                        walk_expr(e, source, counters, rule, diags);
126                    }
127                }
128            }
129        }
130        Expr::BinaryOp { left, right, .. } => {
131            walk_expr(left, source, counters, rule, diags);
132            walk_expr(right, source, counters, rule, diags);
133        }
134        Expr::UnaryOp { expr: inner, .. } => {
135            walk_expr(inner, source, counters, rule, diags);
136        }
137        Expr::Nested(inner) => {
138            walk_expr(inner, source, counters, rule, diags);
139        }
140        Expr::Case {
141            operand,
142            conditions,
143            results,
144            else_result,
145        } => {
146            if let Some(op) = operand {
147                walk_expr(op, source, counters, rule, diags);
148            }
149            for c in conditions {
150                walk_expr(c, source, counters, rule, diags);
151            }
152            for r in results {
153                walk_expr(r, source, counters, rule, diags);
154            }
155            if let Some(e) = else_result {
156                walk_expr(e, source, counters, rule, diags);
157            }
158        }
159        _ => {}
160    }
161}
162
163fn check_select(
164    sel: &Select,
165    source: &str,
166    counters: &mut [usize; 2],
167    rule: &'static str,
168    diags: &mut Vec<Diagnostic>,
169) {
170    for item in &sel.projection {
171        match item {
172            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
173                walk_expr(e, source, counters, rule, diags);
174            }
175            _ => {}
176        }
177    }
178    if let Some(selection) = &sel.selection {
179        walk_expr(selection, source, counters, rule, diags);
180    }
181    if let Some(having) = &sel.having {
182        walk_expr(having, source, counters, rule, diags);
183    }
184    for twj in &sel.from {
185        recurse_table_factor(&twj.relation, source, counters, rule, diags);
186        for join in &twj.joins {
187            recurse_table_factor(&join.relation, source, counters, rule, diags);
188        }
189    }
190}
191
192fn recurse_table_factor(
193    tf: &TableFactor,
194    source: &str,
195    counters: &mut [usize; 2],
196    rule: &'static str,
197    diags: &mut Vec<Diagnostic>,
198) {
199    if let TableFactor::Derived { subquery, .. } = tf {
200        check_query(subquery, source, counters, rule, diags);
201    }
202}
203
204fn check_set_expr(
205    expr: &SetExpr,
206    source: &str,
207    counters: &mut [usize; 2],
208    rule: &'static str,
209    diags: &mut Vec<Diagnostic>,
210) {
211    match expr {
212        SetExpr::Select(sel) => check_select(sel, source, counters, rule, diags),
213        SetExpr::Query(inner) => check_query(inner, source, counters, rule, diags),
214        SetExpr::SetOperation { left, right, .. } => {
215            check_set_expr(left, source, counters, rule, diags);
216            check_set_expr(right, source, counters, rule, diags);
217        }
218        _ => {}
219    }
220}
221
222fn check_query(
223    query: &Query,
224    source: &str,
225    counters: &mut [usize; 2],
226    rule: &'static str,
227    diags: &mut Vec<Diagnostic>,
228) {
229    if let Some(with) = &query.with {
230        for cte in &with.cte_tables {
231            check_query(&cte.query, source, counters, rule, diags);
232        }
233    }
234    check_set_expr(&query.body, source, counters, rule, diags);
235}
236
237impl Rule for SubstringFunction {
238    fn name(&self) -> &'static str {
239        "Ambiguous/SubstringFunction"
240    }
241
242    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
243        if !ctx.parse_errors.is_empty() {
244            return Vec::new();
245        }
246
247        let mut diags = Vec::new();
248        // counters[0] = substr occurrences, counters[1] = mid occurrences
249        let mut counters = [0usize; 2];
250
251        for stmt in &ctx.statements {
252            if let Statement::Query(query) = stmt {
253                check_query(query, &ctx.source, &mut counters, self.name(), &mut diags);
254            }
255        }
256
257        diags
258    }
259}