Skip to main content

sqrust_rules/convention/
no_if_function.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor,
5};
6
7pub struct NoIFFunction;
8
9const MESSAGE: &str =
10    "IF() is dialect-specific (MySQL/BigQuery) \
11     — use CASE WHEN ... THEN ... ELSE ... END for portable conditional logic";
12
13/// Returns the lowercase last-ident of a function's name, or empty string.
14fn func_name_lower(func: &sqlparser::ast::Function) -> String {
15    func.name
16        .0
17        .last()
18        .map(|ident| ident.value.to_lowercase())
19        .unwrap_or_default()
20}
21
22/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
23fn line_col(source: &str, offset: usize) -> (usize, usize) {
24    let before = &source[..offset];
25    let line = before.chars().filter(|&c| c == '\n').count() + 1;
26    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
27    (line, col)
28}
29
30/// Find the Nth occurrence (0-indexed) of `IF` as a function call (case-insensitive)
31/// in `source`. Returns byte offset or 0 if not found.
32///
33/// Matches `IF(` with a word boundary before so that `IFNULL` is not matched.
34fn find_if_occurrence(source: &str, occurrence: usize) -> usize {
35    let bytes = source.as_bytes();
36    let len = bytes.len();
37    let mut count = 0usize;
38    let mut i = 0;
39
40    while i + 2 <= len {
41        // Word boundary before.
42        let before_ok = i == 0
43            || {
44                let b = bytes[i - 1];
45                !b.is_ascii_alphanumeric() && b != b'_'
46            };
47
48        if before_ok {
49            // Match 'I' or 'i' followed by 'F' or 'f'.
50            let c0 = bytes[i];
51            let c1 = bytes[i + 1];
52            if (c0 == b'I' || c0 == b'i') && (c1 == b'F' || c1 == b'f') {
53                // Must be followed by '(' — so the word ends here and it's a function call.
54                let after = i + 2;
55                if after < len && bytes[after] == b'(' {
56                    if count == occurrence {
57                        return i;
58                    }
59                    count += 1;
60                }
61            }
62        }
63
64        i += 1;
65    }
66
67    0
68}
69
70fn walk_expr(
71    expr: &Expr,
72    source: &str,
73    counter: &mut usize,
74    rule: &'static str,
75    diags: &mut Vec<Diagnostic>,
76) {
77    match expr {
78        Expr::Function(func) => {
79            let lower = func_name_lower(func);
80            if lower == "if" {
81                let occ = *counter;
82                *counter += 1;
83
84                let offset = find_if_occurrence(source, occ);
85                let (line, col) = line_col(source, offset);
86                diags.push(Diagnostic {
87                    rule,
88                    message: MESSAGE.to_string(),
89                    line,
90                    col,
91                });
92            }
93
94            // Recurse into function arguments.
95            if let FunctionArguments::List(list) = &func.args {
96                for arg in &list.args {
97                    let inner_expr = match arg {
98                        FunctionArg::Named { arg, .. }
99                        | FunctionArg::Unnamed(arg)
100                        | FunctionArg::ExprNamed { arg, .. } => match arg {
101                            FunctionArgExpr::Expr(e) => Some(e),
102                            _ => None,
103                        },
104                    };
105                    if let Some(e) = inner_expr {
106                        walk_expr(e, source, counter, rule, diags);
107                    }
108                }
109            }
110        }
111        Expr::BinaryOp { left, right, .. } => {
112            walk_expr(left, source, counter, rule, diags);
113            walk_expr(right, source, counter, rule, diags);
114        }
115        Expr::UnaryOp { expr: inner, .. } => {
116            walk_expr(inner, source, counter, rule, diags);
117        }
118        Expr::Nested(inner) => {
119            walk_expr(inner, source, counter, rule, diags);
120        }
121        Expr::Case {
122            operand,
123            conditions,
124            results,
125            else_result,
126        } => {
127            if let Some(op) = operand {
128                walk_expr(op, source, counter, rule, diags);
129            }
130            for c in conditions {
131                walk_expr(c, source, counter, rule, diags);
132            }
133            for r in results {
134                walk_expr(r, source, counter, rule, diags);
135            }
136            if let Some(e) = else_result {
137                walk_expr(e, source, counter, rule, diags);
138            }
139        }
140        _ => {}
141    }
142}
143
144fn check_select(
145    sel: &Select,
146    source: &str,
147    counter: &mut usize,
148    rule: &'static str,
149    diags: &mut Vec<Diagnostic>,
150) {
151    for item in &sel.projection {
152        match item {
153            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
154                walk_expr(e, source, counter, rule, diags);
155            }
156            _ => {}
157        }
158    }
159    if let Some(selection) = &sel.selection {
160        walk_expr(selection, source, counter, rule, diags);
161    }
162    if let Some(having) = &sel.having {
163        walk_expr(having, source, counter, rule, diags);
164    }
165    for twj in &sel.from {
166        recurse_table_factor(&twj.relation, source, counter, rule, diags);
167        for join in &twj.joins {
168            recurse_table_factor(&join.relation, source, counter, rule, diags);
169        }
170    }
171}
172
173fn recurse_table_factor(
174    tf: &TableFactor,
175    source: &str,
176    counter: &mut usize,
177    rule: &'static str,
178    diags: &mut Vec<Diagnostic>,
179) {
180    if let TableFactor::Derived { subquery, .. } = tf {
181        check_query(subquery, source, counter, rule, diags);
182    }
183}
184
185fn check_set_expr(
186    expr: &SetExpr,
187    source: &str,
188    counter: &mut usize,
189    rule: &'static str,
190    diags: &mut Vec<Diagnostic>,
191) {
192    match expr {
193        SetExpr::Select(sel) => check_select(sel, source, counter, rule, diags),
194        SetExpr::Query(inner) => check_query(inner, source, counter, rule, diags),
195        SetExpr::SetOperation { left, right, .. } => {
196            check_set_expr(left, source, counter, rule, diags);
197            check_set_expr(right, source, counter, rule, diags);
198        }
199        _ => {}
200    }
201}
202
203fn check_query(
204    query: &Query,
205    source: &str,
206    counter: &mut usize,
207    rule: &'static str,
208    diags: &mut Vec<Diagnostic>,
209) {
210    if let Some(with) = &query.with {
211        for cte in &with.cte_tables {
212            check_query(&cte.query, source, counter, rule, diags);
213        }
214    }
215    check_set_expr(&query.body, source, counter, rule, diags);
216}
217
218impl Rule for NoIFFunction {
219    fn name(&self) -> &'static str {
220        "Convention/NoIFFunction"
221    }
222
223    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
224        if !ctx.parse_errors.is_empty() {
225            return Vec::new();
226        }
227
228        let mut diags = Vec::new();
229        let mut counter = 0usize;
230
231        for stmt in &ctx.statements {
232            if let Statement::Query(query) = stmt {
233                check_query(query, &ctx.source, &mut counter, self.name(), &mut diags);
234            }
235        }
236
237        diags
238    }
239}