1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4 Statement, TableFactor,
5};
6
7pub struct SubstringFunction;
8
9fn func_name_lower(func: &sqlparser::ast::Function) -> String {
11 func.name
12 .0
13 .last()
14 .map(|ident| ident.value.to_lowercase())
15 .unwrap_or_default()
16}
17
18fn line_col(source: &str, offset: usize) -> (usize, usize) {
20 let before = &source[..offset];
21 let line = before.chars().filter(|&c| c == '\n').count() + 1;
22 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
23 (line, col)
24}
25
26fn find_occurrence(source: &str, name: &str, occurrence: usize) -> usize {
29 let bytes = source.as_bytes();
30 let name_upper: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
31 let name_len = name_upper.len();
32 let len = bytes.len();
33 let mut count = 0usize;
34 let mut i = 0;
35
36 while i + name_len <= len {
37 let before_ok = i == 0
38 || {
39 let b = bytes[i - 1];
40 !b.is_ascii_alphanumeric() && b != b'_'
41 };
42
43 if before_ok {
44 let matches = bytes[i..i + name_len]
45 .iter()
46 .zip(name_upper.iter())
47 .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
48
49 if matches {
50 let after = i + name_len;
51 let after_ok = after < len && bytes[after] == b'(';
52
53 if after_ok {
54 if count == occurrence {
55 return i;
56 }
57 count += 1;
58 }
59 }
60 }
61
62 i += 1;
63 }
64
65 0
66}
67
68fn message_for(name: &str) -> String {
69 match name {
70 "substr" => {
71 "SUBSTR() is a non-standard alias — use SUBSTRING(str, start, length) for maximum portability"
72 .to_string()
73 }
74 "mid" => {
75 "MID() is MySQL-specific — use SUBSTRING(str, start, length) for portable substring extraction"
76 .to_string()
77 }
78 _ => unreachable!(),
79 }
80}
81
82fn walk_expr(
83 expr: &Expr,
84 source: &str,
85 counters: &mut [usize; 2],
86 rule: &'static str,
87 diags: &mut Vec<Diagnostic>,
88) {
89 match expr {
90 Expr::Function(func) => {
91 let lower = func_name_lower(func);
92 let func_key = match lower.as_str() {
94 "substr" => Some(0usize),
95 "mid" => Some(1usize),
96 _ => None,
97 };
98 if let Some(idx) = func_key {
99 let occ = counters[idx];
100 counters[idx] += 1;
101
102 let func_name_str = if idx == 0 { "SUBSTR" } else { "MID" };
103 let offset = find_occurrence(source, func_name_str, occ);
104 let (line, col) = line_col(source, offset);
105 diags.push(Diagnostic {
106 rule,
107 message: message_for(lower.as_str()),
108 line,
109 col,
110 });
111 }
112
113 if let FunctionArguments::List(list) = &func.args {
115 for arg in &list.args {
116 let inner_expr = match arg {
117 FunctionArg::Named { arg, .. }
118 | FunctionArg::Unnamed(arg)
119 | FunctionArg::ExprNamed { arg, .. } => match arg {
120 FunctionArgExpr::Expr(e) => Some(e),
121 _ => None,
122 },
123 };
124 if let Some(e) = inner_expr {
125 walk_expr(e, source, counters, rule, diags);
126 }
127 }
128 }
129 }
130 Expr::BinaryOp { left, right, .. } => {
131 walk_expr(left, source, counters, rule, diags);
132 walk_expr(right, source, counters, rule, diags);
133 }
134 Expr::UnaryOp { expr: inner, .. } => {
135 walk_expr(inner, source, counters, rule, diags);
136 }
137 Expr::Nested(inner) => {
138 walk_expr(inner, source, counters, rule, diags);
139 }
140 Expr::Case {
141 operand,
142 conditions,
143 results,
144 else_result,
145 } => {
146 if let Some(op) = operand {
147 walk_expr(op, source, counters, rule, diags);
148 }
149 for c in conditions {
150 walk_expr(c, source, counters, rule, diags);
151 }
152 for r in results {
153 walk_expr(r, source, counters, rule, diags);
154 }
155 if let Some(e) = else_result {
156 walk_expr(e, source, counters, rule, diags);
157 }
158 }
159 _ => {}
160 }
161}
162
163fn check_select(
164 sel: &Select,
165 source: &str,
166 counters: &mut [usize; 2],
167 rule: &'static str,
168 diags: &mut Vec<Diagnostic>,
169) {
170 for item in &sel.projection {
171 match item {
172 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
173 walk_expr(e, source, counters, rule, diags);
174 }
175 _ => {}
176 }
177 }
178 if let Some(selection) = &sel.selection {
179 walk_expr(selection, source, counters, rule, diags);
180 }
181 if let Some(having) = &sel.having {
182 walk_expr(having, source, counters, rule, diags);
183 }
184 for twj in &sel.from {
185 recurse_table_factor(&twj.relation, source, counters, rule, diags);
186 for join in &twj.joins {
187 recurse_table_factor(&join.relation, source, counters, rule, diags);
188 }
189 }
190}
191
192fn recurse_table_factor(
193 tf: &TableFactor,
194 source: &str,
195 counters: &mut [usize; 2],
196 rule: &'static str,
197 diags: &mut Vec<Diagnostic>,
198) {
199 if let TableFactor::Derived { subquery, .. } = tf {
200 check_query(subquery, source, counters, rule, diags);
201 }
202}
203
204fn check_set_expr(
205 expr: &SetExpr,
206 source: &str,
207 counters: &mut [usize; 2],
208 rule: &'static str,
209 diags: &mut Vec<Diagnostic>,
210) {
211 match expr {
212 SetExpr::Select(sel) => check_select(sel, source, counters, rule, diags),
213 SetExpr::Query(inner) => check_query(inner, source, counters, rule, diags),
214 SetExpr::SetOperation { left, right, .. } => {
215 check_set_expr(left, source, counters, rule, diags);
216 check_set_expr(right, source, counters, rule, diags);
217 }
218 _ => {}
219 }
220}
221
222fn check_query(
223 query: &Query,
224 source: &str,
225 counters: &mut [usize; 2],
226 rule: &'static str,
227 diags: &mut Vec<Diagnostic>,
228) {
229 if let Some(with) = &query.with {
230 for cte in &with.cte_tables {
231 check_query(&cte.query, source, counters, rule, diags);
232 }
233 }
234 check_set_expr(&query.body, source, counters, rule, diags);
235}
236
237impl Rule for SubstringFunction {
238 fn name(&self) -> &'static str {
239 "Ambiguous/SubstringFunction"
240 }
241
242 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
243 if !ctx.parse_errors.is_empty() {
244 return Vec::new();
245 }
246
247 let mut diags = Vec::new();
248 let mut counters = [0usize; 2];
250
251 for stmt in &ctx.statements {
252 if let Statement::Query(query) = stmt {
253 check_query(query, &ctx.source, &mut counters, self.name(), &mut diags);
254 }
255 }
256
257 diags
258 }
259}