Skip to main content

sqrust_rules/convention/
if_null_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor,
5};
6
7pub struct IfNullFunction;
8
9/// Vendor-specific null-handling function names that should use COALESCE instead.
10const FLAGGED_FUNCS: &[&str] = &["IFNULL", "NVL", "NVL2", "ISNULL"];
11
12/// Returns the uppercase last-ident of a function's name, or empty string.
13fn func_name_upper(func: &sqlparser::ast::Function) -> String {
14    func.name
15        .0
16        .last()
17        .map(|ident| ident.value.to_uppercase())
18        .unwrap_or_default()
19}
20
21/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
22fn line_col(source: &str, offset: usize) -> (usize, usize) {
23    let before = &source[..offset];
24    let line = before.chars().filter(|&c| c == '\n').count() + 1;
25    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
26    (line, col)
27}
28
29/// Find the Nth occurrence (0-indexed) of `name` as a whole word (case-insensitive)
30/// in `source`. Returns byte offset or 0 if not found.
31fn find_occurrence(source: &str, name: &str, occurrence: usize) -> usize {
32    let bytes = source.as_bytes();
33    let name_upper: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
34    let name_len = name_upper.len();
35    let len = bytes.len();
36    let mut count = 0usize;
37    let mut i = 0;
38
39    while i + name_len <= len {
40        // Word boundary before.
41        let before_ok = i == 0
42            || {
43                let b = bytes[i - 1];
44                !b.is_ascii_alphanumeric() && b != b'_'
45            };
46
47        if before_ok {
48            let matches = bytes[i..i + name_len]
49                .iter()
50                .zip(name_upper.iter())
51                .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
52
53            if matches {
54                // Word boundary after.
55                let after = i + name_len;
56                let after_ok = after >= len
57                    || {
58                        let b = bytes[after];
59                        !b.is_ascii_alphanumeric() && b != b'_'
60                    };
61
62                if after_ok {
63                    if count == occurrence {
64                        return i;
65                    }
66                    count += 1;
67                }
68            }
69        }
70
71        i += 1;
72    }
73
74    0
75}
76
77/// Walk an expression, pushing diagnostics for any flagged null-handling function.
78/// `occurrence_counters` tracks how many times each function name has been seen,
79/// so that `find_occurrence` can locate the correct source position.
80fn walk_expr(
81    expr: &Expr,
82    source: &str,
83    occurrence_counters: &mut std::collections::HashMap<String, usize>,
84    rule: &'static str,
85    diags: &mut Vec<Diagnostic>,
86) {
87    match expr {
88        Expr::Function(func) => {
89            let upper = func_name_upper(func);
90            if FLAGGED_FUNCS.contains(&upper.as_str()) {
91                let count = occurrence_counters.entry(upper.clone()).or_insert(0);
92                let occ = *count;
93                *count += 1;
94
95                let offset = find_occurrence(source, &upper, occ);
96                let (line, col) = line_col(source, offset);
97                diags.push(Diagnostic {
98                    rule,
99                    message: format!(
100                        "IFNULL/NVL is vendor-specific; use COALESCE() for portability (found {})",
101                        upper
102                    ),
103                    line,
104                    col,
105                });
106            }
107
108            // Recurse into function arguments.
109            if let FunctionArguments::List(list) = &func.args {
110                for arg in &list.args {
111                    let inner_expr = match arg {
112                        FunctionArg::Named { arg, .. }
113                        | FunctionArg::Unnamed(arg)
114                        | FunctionArg::ExprNamed { arg, .. } => match arg {
115                            FunctionArgExpr::Expr(e) => Some(e),
116                            _ => None,
117                        },
118                    };
119                    if let Some(e) = inner_expr {
120                        walk_expr(e, source, occurrence_counters, rule, diags);
121                    }
122                }
123            }
124        }
125        Expr::BinaryOp { left, right, .. } => {
126            walk_expr(left, source, occurrence_counters, rule, diags);
127            walk_expr(right, source, occurrence_counters, rule, diags);
128        }
129        Expr::UnaryOp { expr: inner, .. } => {
130            walk_expr(inner, source, occurrence_counters, rule, diags);
131        }
132        Expr::Nested(inner) => {
133            walk_expr(inner, source, occurrence_counters, rule, diags);
134        }
135        Expr::Case {
136            operand,
137            conditions,
138            results,
139            else_result,
140        } => {
141            if let Some(op) = operand {
142                walk_expr(op, source, occurrence_counters, rule, diags);
143            }
144            for c in conditions {
145                walk_expr(c, source, occurrence_counters, rule, diags);
146            }
147            for r in results {
148                walk_expr(r, source, occurrence_counters, rule, diags);
149            }
150            if let Some(e) = else_result {
151                walk_expr(e, source, occurrence_counters, rule, diags);
152            }
153        }
154        _ => {}
155    }
156}
157
158fn check_select(
159    sel: &Select,
160    source: &str,
161    occurrence_counters: &mut std::collections::HashMap<String, usize>,
162    rule: &'static str,
163    diags: &mut Vec<Diagnostic>,
164) {
165    // Projection.
166    for item in &sel.projection {
167        match item {
168            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
169                walk_expr(e, source, occurrence_counters, rule, diags);
170            }
171            _ => {}
172        }
173    }
174    // WHERE clause.
175    if let Some(selection) = &sel.selection {
176        walk_expr(selection, source, occurrence_counters, rule, diags);
177    }
178    // HAVING clause.
179    if let Some(having) = &sel.having {
180        walk_expr(having, source, occurrence_counters, rule, diags);
181    }
182    // Recurse into subqueries in FROM.
183    for twj in &sel.from {
184        recurse_table_factor(&twj.relation, source, occurrence_counters, rule, diags);
185        for join in &twj.joins {
186            recurse_table_factor(&join.relation, source, occurrence_counters, rule, diags);
187        }
188    }
189}
190
191fn recurse_table_factor(
192    tf: &TableFactor,
193    source: &str,
194    occurrence_counters: &mut std::collections::HashMap<String, usize>,
195    rule: &'static str,
196    diags: &mut Vec<Diagnostic>,
197) {
198    if let TableFactor::Derived { subquery, .. } = tf {
199        check_query(subquery, source, occurrence_counters, rule, diags);
200    }
201}
202
203fn check_set_expr(
204    expr: &SetExpr,
205    source: &str,
206    occurrence_counters: &mut std::collections::HashMap<String, usize>,
207    rule: &'static str,
208    diags: &mut Vec<Diagnostic>,
209) {
210    match expr {
211        SetExpr::Select(sel) => check_select(sel, source, occurrence_counters, rule, diags),
212        SetExpr::Query(inner) => check_query(inner, source, occurrence_counters, rule, diags),
213        SetExpr::SetOperation { left, right, .. } => {
214            check_set_expr(left, source, occurrence_counters, rule, diags);
215            check_set_expr(right, source, occurrence_counters, rule, diags);
216        }
217        _ => {}
218    }
219}
220
221fn check_query(
222    query: &Query,
223    source: &str,
224    occurrence_counters: &mut std::collections::HashMap<String, usize>,
225    rule: &'static str,
226    diags: &mut Vec<Diagnostic>,
227) {
228    if let Some(with) = &query.with {
229        for cte in &with.cte_tables {
230            check_query(&cte.query, source, occurrence_counters, rule, diags);
231        }
232    }
233    check_set_expr(&query.body, source, occurrence_counters, rule, diags);
234}
235
236impl Rule for IfNullFunction {
237    fn name(&self) -> &'static str {
238        "Convention/IfNullFunction"
239    }
240
241    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
242        // AST-based — return empty if the file did not parse.
243        if !ctx.parse_errors.is_empty() {
244            return Vec::new();
245        }
246
247        let mut diags = Vec::new();
248        let mut occurrence_counters = std::collections::HashMap::new();
249
250        for stmt in &ctx.statements {
251            if let Statement::Query(query) = stmt {
252                check_query(
253                    query,
254                    &ctx.source,
255                    &mut occurrence_counters,
256                    self.name(),
257                    &mut diags,
258                );
259            }
260        }
261
262        diags
263    }
264}