1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4 Statement, TableFactor,
5};
6
7pub struct NoIFFunction;
8
9const MESSAGE: &str =
10 "IF() is dialect-specific (MySQL/BigQuery) \
11 — use CASE WHEN ... THEN ... ELSE ... END for portable conditional logic";
12
13fn func_name_lower(func: &sqlparser::ast::Function) -> String {
15 func.name
16 .0
17 .last()
18 .map(|ident| ident.value.to_lowercase())
19 .unwrap_or_default()
20}
21
22fn line_col(source: &str, offset: usize) -> (usize, usize) {
24 let before = &source[..offset];
25 let line = before.chars().filter(|&c| c == '\n').count() + 1;
26 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
27 (line, col)
28}
29
30fn find_if_occurrence(source: &str, occurrence: usize) -> usize {
35 let bytes = source.as_bytes();
36 let len = bytes.len();
37 let mut count = 0usize;
38 let mut i = 0;
39
40 while i + 2 <= len {
41 let before_ok = i == 0
43 || {
44 let b = bytes[i - 1];
45 !b.is_ascii_alphanumeric() && b != b'_'
46 };
47
48 if before_ok {
49 let c0 = bytes[i];
51 let c1 = bytes[i + 1];
52 if (c0 == b'I' || c0 == b'i') && (c1 == b'F' || c1 == b'f') {
53 let after = i + 2;
55 if after < len && bytes[after] == b'(' {
56 if count == occurrence {
57 return i;
58 }
59 count += 1;
60 }
61 }
62 }
63
64 i += 1;
65 }
66
67 0
68}
69
70fn walk_expr(
71 expr: &Expr,
72 source: &str,
73 counter: &mut usize,
74 rule: &'static str,
75 diags: &mut Vec<Diagnostic>,
76) {
77 match expr {
78 Expr::Function(func) => {
79 let lower = func_name_lower(func);
80 if lower == "if" {
81 let occ = *counter;
82 *counter += 1;
83
84 let offset = find_if_occurrence(source, occ);
85 let (line, col) = line_col(source, offset);
86 diags.push(Diagnostic {
87 rule,
88 message: MESSAGE.to_string(),
89 line,
90 col,
91 });
92 }
93
94 if let FunctionArguments::List(list) = &func.args {
96 for arg in &list.args {
97 let inner_expr = match arg {
98 FunctionArg::Named { arg, .. }
99 | FunctionArg::Unnamed(arg)
100 | FunctionArg::ExprNamed { arg, .. } => match arg {
101 FunctionArgExpr::Expr(e) => Some(e),
102 _ => None,
103 },
104 };
105 if let Some(e) = inner_expr {
106 walk_expr(e, source, counter, rule, diags);
107 }
108 }
109 }
110 }
111 Expr::BinaryOp { left, right, .. } => {
112 walk_expr(left, source, counter, rule, diags);
113 walk_expr(right, source, counter, rule, diags);
114 }
115 Expr::UnaryOp { expr: inner, .. } => {
116 walk_expr(inner, source, counter, rule, diags);
117 }
118 Expr::Nested(inner) => {
119 walk_expr(inner, source, counter, rule, diags);
120 }
121 Expr::Case {
122 operand,
123 conditions,
124 results,
125 else_result,
126 } => {
127 if let Some(op) = operand {
128 walk_expr(op, source, counter, rule, diags);
129 }
130 for c in conditions {
131 walk_expr(c, source, counter, rule, diags);
132 }
133 for r in results {
134 walk_expr(r, source, counter, rule, diags);
135 }
136 if let Some(e) = else_result {
137 walk_expr(e, source, counter, rule, diags);
138 }
139 }
140 _ => {}
141 }
142}
143
144fn check_select(
145 sel: &Select,
146 source: &str,
147 counter: &mut usize,
148 rule: &'static str,
149 diags: &mut Vec<Diagnostic>,
150) {
151 for item in &sel.projection {
152 match item {
153 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
154 walk_expr(e, source, counter, rule, diags);
155 }
156 _ => {}
157 }
158 }
159 if let Some(selection) = &sel.selection {
160 walk_expr(selection, source, counter, rule, diags);
161 }
162 if let Some(having) = &sel.having {
163 walk_expr(having, source, counter, rule, diags);
164 }
165 for twj in &sel.from {
166 recurse_table_factor(&twj.relation, source, counter, rule, diags);
167 for join in &twj.joins {
168 recurse_table_factor(&join.relation, source, counter, rule, diags);
169 }
170 }
171}
172
173fn recurse_table_factor(
174 tf: &TableFactor,
175 source: &str,
176 counter: &mut usize,
177 rule: &'static str,
178 diags: &mut Vec<Diagnostic>,
179) {
180 if let TableFactor::Derived { subquery, .. } = tf {
181 check_query(subquery, source, counter, rule, diags);
182 }
183}
184
185fn check_set_expr(
186 expr: &SetExpr,
187 source: &str,
188 counter: &mut usize,
189 rule: &'static str,
190 diags: &mut Vec<Diagnostic>,
191) {
192 match expr {
193 SetExpr::Select(sel) => check_select(sel, source, counter, rule, diags),
194 SetExpr::Query(inner) => check_query(inner, source, counter, rule, diags),
195 SetExpr::SetOperation { left, right, .. } => {
196 check_set_expr(left, source, counter, rule, diags);
197 check_set_expr(right, source, counter, rule, diags);
198 }
199 _ => {}
200 }
201}
202
203fn check_query(
204 query: &Query,
205 source: &str,
206 counter: &mut usize,
207 rule: &'static str,
208 diags: &mut Vec<Diagnostic>,
209) {
210 if let Some(with) = &query.with {
211 for cte in &with.cte_tables {
212 check_query(&cte.query, source, counter, rule, diags);
213 }
214 }
215 check_set_expr(&query.body, source, counter, rule, diags);
216}
217
218impl Rule for NoIFFunction {
219 fn name(&self) -> &'static str {
220 "Convention/NoIFFunction"
221 }
222
223 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
224 if !ctx.parse_errors.is_empty() {
225 return Vec::new();
226 }
227
228 let mut diags = Vec::new();
229 let mut counter = 0usize;
230
231 for stmt in &ctx.statements {
232 if let Statement::Query(query) = stmt {
233 check_query(query, &ctx.source, &mut counter, self.name(), &mut diags);
234 }
235 }
236
237 diags
238 }
239}