Skip to main content

sqrust_rules/ambiguous/
concat_function_null_arg.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4    Statement, TableFactor, Value,
5};
6
7pub struct ConcatFunctionNullArg;
8
9impl Rule for ConcatFunctionNullArg {
10    fn name(&self) -> &'static str {
11        "Ambiguous/ConcatFunctionNullArg"
12    }
13
14    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15        if !ctx.parse_errors.is_empty() {
16            return Vec::new();
17        }
18
19        let mut diags = Vec::new();
20        for stmt in &ctx.statements {
21            collect_from_statement(stmt, ctx, &mut diags);
22        }
23        diags
24    }
25}
26
27fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
28    if let Statement::Query(query) = stmt {
29        collect_from_query(query, ctx, diags);
30    }
31}
32
33fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
34    if let Some(with) = &query.with {
35        for cte in &with.cte_tables {
36            collect_from_query(&cte.query, ctx, diags);
37        }
38    }
39    collect_from_set_expr(&query.body, ctx, diags);
40}
41
42fn collect_from_set_expr(set_expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
43    match set_expr {
44        SetExpr::Select(select) => {
45            collect_from_select(select, ctx, diags);
46        }
47        SetExpr::Query(inner) => {
48            collect_from_query(inner, ctx, diags);
49        }
50        SetExpr::SetOperation { left, right, .. } => {
51            collect_from_set_expr(left, ctx, diags);
52            collect_from_set_expr(right, ctx, diags);
53        }
54        _ => {}
55    }
56}
57
58fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
59    for item in &select.projection {
60        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
61            check_expr(e, ctx, diags);
62        }
63    }
64
65    for twj in &select.from {
66        collect_from_table_factor(&twj.relation, ctx, diags);
67        for join in &twj.joins {
68            collect_from_table_factor(&join.relation, ctx, diags);
69        }
70    }
71
72    if let Some(selection) = &select.selection {
73        check_expr(selection, ctx, diags);
74    }
75
76    if let Some(having) = &select.having {
77        check_expr(having, ctx, diags);
78    }
79}
80
81fn collect_from_table_factor(
82    factor: &TableFactor,
83    ctx: &FileContext,
84    diags: &mut Vec<Diagnostic>,
85) {
86    if let TableFactor::Derived { subquery, .. } = factor {
87        collect_from_query(subquery, ctx, diags);
88    }
89}
90
91/// Returns `true` if the expression is a literal NULL value.
92fn is_null_literal(expr: &Expr) -> bool {
93    matches!(expr, Expr::Value(Value::Null))
94}
95
96/// Returns `true` if `expr` is a CONCAT(…) call (not CONCAT_WS) that
97/// contains at least one NULL literal argument.
98fn is_concat_with_null(expr: &Expr) -> bool {
99    let Expr::Function(func) = expr else {
100        return false;
101    };
102
103    let func_name = func
104        .name
105        .0
106        .last()
107        .map(|ident| ident.value.to_uppercase())
108        .unwrap_or_default();
109
110    // Only flag CONCAT, not CONCAT_WS or any other function.
111    if func_name != "CONCAT" {
112        return false;
113    }
114
115    let FunctionArguments::List(arg_list) = &func.args else {
116        return false;
117    };
118
119    arg_list.args.iter().any(|arg| {
120        let expr_arg = match arg {
121            FunctionArg::Named { arg, .. }
122            | FunctionArg::ExprNamed { arg, .. }
123            | FunctionArg::Unnamed(arg) => arg,
124        };
125        if let FunctionArgExpr::Expr(e) = expr_arg {
126            is_null_literal(e)
127        } else {
128            false
129        }
130    })
131}
132
133/// Recursively walks an expression, flagging every CONCAT(…) call that has a
134/// NULL literal argument and recursing into nested expressions.
135fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
136    if is_concat_with_null(expr) {
137        let (line, col) = find_keyword_pos(&ctx.source, "CONCAT");
138        diags.push(Diagnostic {
139            rule: "Ambiguous/ConcatFunctionNullArg",
140            message: "CONCAT() with a NULL argument always returns NULL — use COALESCE to provide a fallback value".to_string(),
141            line,
142            col,
143        });
144        // Still recurse into arguments in case there are nested CONCAT calls.
145        if let Expr::Function(func) = expr {
146            if let FunctionArguments::List(arg_list) = &func.args {
147                for arg in &arg_list.args {
148                    let expr_arg = match arg {
149                        FunctionArg::Named { arg, .. }
150                        | FunctionArg::ExprNamed { arg, .. }
151                        | FunctionArg::Unnamed(arg) => arg,
152                    };
153                    if let FunctionArgExpr::Expr(e) = expr_arg {
154                        check_expr(e, ctx, diags);
155                    }
156                }
157            }
158        }
159        return;
160    }
161
162    match expr {
163        Expr::Function(func) => {
164            if let FunctionArguments::List(arg_list) = &func.args {
165                for arg in &arg_list.args {
166                    let expr_arg = match arg {
167                        FunctionArg::Named { arg, .. }
168                        | FunctionArg::ExprNamed { arg, .. }
169                        | FunctionArg::Unnamed(arg) => arg,
170                    };
171                    if let FunctionArgExpr::Expr(e) = expr_arg {
172                        check_expr(e, ctx, diags);
173                    }
174                }
175            }
176        }
177        Expr::BinaryOp { left, right, .. } => {
178            check_expr(left, ctx, diags);
179            check_expr(right, ctx, diags);
180        }
181        Expr::UnaryOp { expr: inner, .. } => {
182            check_expr(inner, ctx, diags);
183        }
184        Expr::Nested(inner) => {
185            check_expr(inner, ctx, diags);
186        }
187        Expr::Case {
188            operand,
189            conditions,
190            results,
191            else_result,
192        } => {
193            if let Some(op) = operand {
194                check_expr(op, ctx, diags);
195            }
196            for cond in conditions {
197                check_expr(cond, ctx, diags);
198            }
199            for result in results {
200                check_expr(result, ctx, diags);
201            }
202            if let Some(else_e) = else_result {
203                check_expr(else_e, ctx, diags);
204            }
205        }
206        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
207            check_expr(inner, ctx, diags);
208        }
209        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
210            collect_from_query(q, ctx, diags);
211        }
212        _ => {}
213    }
214}
215
216/// Finds the first occurrence of `keyword` (case-insensitive, word-boundary)
217/// in `source` and returns a 1-indexed (line, col). Falls back to (1, 1).
218fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
219    let upper = source.to_uppercase();
220    let kw_upper = keyword.to_uppercase();
221    let bytes = upper.as_bytes();
222    let kw_bytes = kw_upper.as_bytes();
223    let kw_len = kw_bytes.len();
224
225    let mut i = 0;
226    while i + kw_len <= bytes.len() {
227        if bytes[i..i + kw_len] == *kw_bytes {
228            let before_ok =
229                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
230            let after = i + kw_len;
231            let after_ok = after >= bytes.len()
232                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
233            if before_ok && after_ok {
234                return offset_to_line_col(source, i);
235            }
236        }
237        i += 1;
238    }
239    (1, 1)
240}
241
242/// Converts a byte offset to 1-indexed (line, col).
243fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
244    let before = &source[..offset];
245    let line = before.chars().filter(|&c| c == '\n').count() + 1;
246    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
247    (line, col)
248}