1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4 Statement, TableFactor,
5};
6
7pub struct IfNullFunction;
8
9const FLAGGED_FUNCS: &[&str] = &["IFNULL", "NVL", "NVL2", "ISNULL"];
11
12fn func_name_upper(func: &sqlparser::ast::Function) -> String {
14 func.name
15 .0
16 .last()
17 .map(|ident| ident.value.to_uppercase())
18 .unwrap_or_default()
19}
20
21fn line_col(source: &str, offset: usize) -> (usize, usize) {
23 let before = &source[..offset];
24 let line = before.chars().filter(|&c| c == '\n').count() + 1;
25 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
26 (line, col)
27}
28
29fn find_occurrence(source: &str, name: &str, occurrence: usize) -> usize {
32 let bytes = source.as_bytes();
33 let name_upper: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
34 let name_len = name_upper.len();
35 let len = bytes.len();
36 let mut count = 0usize;
37 let mut i = 0;
38
39 while i + name_len <= len {
40 let before_ok = i == 0
42 || {
43 let b = bytes[i - 1];
44 !b.is_ascii_alphanumeric() && b != b'_'
45 };
46
47 if before_ok {
48 let matches = bytes[i..i + name_len]
49 .iter()
50 .zip(name_upper.iter())
51 .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
52
53 if matches {
54 let after = i + name_len;
56 let after_ok = after >= len
57 || {
58 let b = bytes[after];
59 !b.is_ascii_alphanumeric() && b != b'_'
60 };
61
62 if after_ok {
63 if count == occurrence {
64 return i;
65 }
66 count += 1;
67 }
68 }
69 }
70
71 i += 1;
72 }
73
74 0
75}
76
77fn walk_expr(
81 expr: &Expr,
82 source: &str,
83 occurrence_counters: &mut std::collections::HashMap<String, usize>,
84 rule: &'static str,
85 diags: &mut Vec<Diagnostic>,
86) {
87 match expr {
88 Expr::Function(func) => {
89 let upper = func_name_upper(func);
90 if FLAGGED_FUNCS.contains(&upper.as_str()) {
91 let count = occurrence_counters.entry(upper.clone()).or_insert(0);
92 let occ = *count;
93 *count += 1;
94
95 let offset = find_occurrence(source, &upper, occ);
96 let (line, col) = line_col(source, offset);
97 diags.push(Diagnostic {
98 rule,
99 message: format!(
100 "IFNULL/NVL is vendor-specific; use COALESCE() for portability (found {})",
101 upper
102 ),
103 line,
104 col,
105 });
106 }
107
108 if let FunctionArguments::List(list) = &func.args {
110 for arg in &list.args {
111 let inner_expr = match arg {
112 FunctionArg::Named { arg, .. }
113 | FunctionArg::Unnamed(arg)
114 | FunctionArg::ExprNamed { arg, .. } => match arg {
115 FunctionArgExpr::Expr(e) => Some(e),
116 _ => None,
117 },
118 };
119 if let Some(e) = inner_expr {
120 walk_expr(e, source, occurrence_counters, rule, diags);
121 }
122 }
123 }
124 }
125 Expr::BinaryOp { left, right, .. } => {
126 walk_expr(left, source, occurrence_counters, rule, diags);
127 walk_expr(right, source, occurrence_counters, rule, diags);
128 }
129 Expr::UnaryOp { expr: inner, .. } => {
130 walk_expr(inner, source, occurrence_counters, rule, diags);
131 }
132 Expr::Nested(inner) => {
133 walk_expr(inner, source, occurrence_counters, rule, diags);
134 }
135 Expr::Case {
136 operand,
137 conditions,
138 results,
139 else_result,
140 } => {
141 if let Some(op) = operand {
142 walk_expr(op, source, occurrence_counters, rule, diags);
143 }
144 for c in conditions {
145 walk_expr(c, source, occurrence_counters, rule, diags);
146 }
147 for r in results {
148 walk_expr(r, source, occurrence_counters, rule, diags);
149 }
150 if let Some(e) = else_result {
151 walk_expr(e, source, occurrence_counters, rule, diags);
152 }
153 }
154 _ => {}
155 }
156}
157
158fn check_select(
159 sel: &Select,
160 source: &str,
161 occurrence_counters: &mut std::collections::HashMap<String, usize>,
162 rule: &'static str,
163 diags: &mut Vec<Diagnostic>,
164) {
165 for item in &sel.projection {
167 match item {
168 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
169 walk_expr(e, source, occurrence_counters, rule, diags);
170 }
171 _ => {}
172 }
173 }
174 if let Some(selection) = &sel.selection {
176 walk_expr(selection, source, occurrence_counters, rule, diags);
177 }
178 if let Some(having) = &sel.having {
180 walk_expr(having, source, occurrence_counters, rule, diags);
181 }
182 for twj in &sel.from {
184 recurse_table_factor(&twj.relation, source, occurrence_counters, rule, diags);
185 for join in &twj.joins {
186 recurse_table_factor(&join.relation, source, occurrence_counters, rule, diags);
187 }
188 }
189}
190
191fn recurse_table_factor(
192 tf: &TableFactor,
193 source: &str,
194 occurrence_counters: &mut std::collections::HashMap<String, usize>,
195 rule: &'static str,
196 diags: &mut Vec<Diagnostic>,
197) {
198 if let TableFactor::Derived { subquery, .. } = tf {
199 check_query(subquery, source, occurrence_counters, rule, diags);
200 }
201}
202
203fn check_set_expr(
204 expr: &SetExpr,
205 source: &str,
206 occurrence_counters: &mut std::collections::HashMap<String, usize>,
207 rule: &'static str,
208 diags: &mut Vec<Diagnostic>,
209) {
210 match expr {
211 SetExpr::Select(sel) => check_select(sel, source, occurrence_counters, rule, diags),
212 SetExpr::Query(inner) => check_query(inner, source, occurrence_counters, rule, diags),
213 SetExpr::SetOperation { left, right, .. } => {
214 check_set_expr(left, source, occurrence_counters, rule, diags);
215 check_set_expr(right, source, occurrence_counters, rule, diags);
216 }
217 _ => {}
218 }
219}
220
221fn check_query(
222 query: &Query,
223 source: &str,
224 occurrence_counters: &mut std::collections::HashMap<String, usize>,
225 rule: &'static str,
226 diags: &mut Vec<Diagnostic>,
227) {
228 if let Some(with) = &query.with {
229 for cte in &with.cte_tables {
230 check_query(&cte.query, source, occurrence_counters, rule, diags);
231 }
232 }
233 check_set_expr(&query.body, source, occurrence_counters, rule, diags);
234}
235
236impl Rule for IfNullFunction {
237 fn name(&self) -> &'static str {
238 "Convention/IfNullFunction"
239 }
240
241 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
242 if !ctx.parse_errors.is_empty() {
244 return Vec::new();
245 }
246
247 let mut diags = Vec::new();
248 let mut occurrence_counters = std::collections::HashMap::new();
249
250 for stmt in &ctx.statements {
251 if let Statement::Query(query) = stmt {
252 check_query(
253 query,
254 &ctx.source,
255 &mut occurrence_counters,
256 self.name(),
257 &mut diags,
258 );
259 }
260 }
261
262 diags
263 }
264}