1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
4 Statement, TableFactor, Value,
5};
6
7pub struct ConcatFunctionNullArg;
8
9impl Rule for ConcatFunctionNullArg {
10 fn name(&self) -> &'static str {
11 "Ambiguous/ConcatFunctionNullArg"
12 }
13
14 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15 if !ctx.parse_errors.is_empty() {
16 return Vec::new();
17 }
18
19 let mut diags = Vec::new();
20 for stmt in &ctx.statements {
21 collect_from_statement(stmt, ctx, &mut diags);
22 }
23 diags
24 }
25}
26
27fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
28 if let Statement::Query(query) = stmt {
29 collect_from_query(query, ctx, diags);
30 }
31}
32
33fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
34 if let Some(with) = &query.with {
35 for cte in &with.cte_tables {
36 collect_from_query(&cte.query, ctx, diags);
37 }
38 }
39 collect_from_set_expr(&query.body, ctx, diags);
40}
41
42fn collect_from_set_expr(set_expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
43 match set_expr {
44 SetExpr::Select(select) => {
45 collect_from_select(select, ctx, diags);
46 }
47 SetExpr::Query(inner) => {
48 collect_from_query(inner, ctx, diags);
49 }
50 SetExpr::SetOperation { left, right, .. } => {
51 collect_from_set_expr(left, ctx, diags);
52 collect_from_set_expr(right, ctx, diags);
53 }
54 _ => {}
55 }
56}
57
58fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
59 for item in &select.projection {
60 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
61 check_expr(e, ctx, diags);
62 }
63 }
64
65 for twj in &select.from {
66 collect_from_table_factor(&twj.relation, ctx, diags);
67 for join in &twj.joins {
68 collect_from_table_factor(&join.relation, ctx, diags);
69 }
70 }
71
72 if let Some(selection) = &select.selection {
73 check_expr(selection, ctx, diags);
74 }
75
76 if let Some(having) = &select.having {
77 check_expr(having, ctx, diags);
78 }
79}
80
81fn collect_from_table_factor(
82 factor: &TableFactor,
83 ctx: &FileContext,
84 diags: &mut Vec<Diagnostic>,
85) {
86 if let TableFactor::Derived { subquery, .. } = factor {
87 collect_from_query(subquery, ctx, diags);
88 }
89}
90
91fn is_null_literal(expr: &Expr) -> bool {
93 matches!(expr, Expr::Value(Value::Null))
94}
95
96fn is_concat_with_null(expr: &Expr) -> bool {
99 let Expr::Function(func) = expr else {
100 return false;
101 };
102
103 let func_name = func
104 .name
105 .0
106 .last()
107 .map(|ident| ident.value.to_uppercase())
108 .unwrap_or_default();
109
110 if func_name != "CONCAT" {
112 return false;
113 }
114
115 let FunctionArguments::List(arg_list) = &func.args else {
116 return false;
117 };
118
119 arg_list.args.iter().any(|arg| {
120 let expr_arg = match arg {
121 FunctionArg::Named { arg, .. }
122 | FunctionArg::ExprNamed { arg, .. }
123 | FunctionArg::Unnamed(arg) => arg,
124 };
125 if let FunctionArgExpr::Expr(e) = expr_arg {
126 is_null_literal(e)
127 } else {
128 false
129 }
130 })
131}
132
133fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
136 if is_concat_with_null(expr) {
137 let (line, col) = find_keyword_pos(&ctx.source, "CONCAT");
138 diags.push(Diagnostic {
139 rule: "Ambiguous/ConcatFunctionNullArg",
140 message: "CONCAT() with a NULL argument always returns NULL — use COALESCE to provide a fallback value".to_string(),
141 line,
142 col,
143 });
144 if let Expr::Function(func) = expr {
146 if let FunctionArguments::List(arg_list) = &func.args {
147 for arg in &arg_list.args {
148 let expr_arg = match arg {
149 FunctionArg::Named { arg, .. }
150 | FunctionArg::ExprNamed { arg, .. }
151 | FunctionArg::Unnamed(arg) => arg,
152 };
153 if let FunctionArgExpr::Expr(e) = expr_arg {
154 check_expr(e, ctx, diags);
155 }
156 }
157 }
158 }
159 return;
160 }
161
162 match expr {
163 Expr::Function(func) => {
164 if let FunctionArguments::List(arg_list) = &func.args {
165 for arg in &arg_list.args {
166 let expr_arg = match arg {
167 FunctionArg::Named { arg, .. }
168 | FunctionArg::ExprNamed { arg, .. }
169 | FunctionArg::Unnamed(arg) => arg,
170 };
171 if let FunctionArgExpr::Expr(e) = expr_arg {
172 check_expr(e, ctx, diags);
173 }
174 }
175 }
176 }
177 Expr::BinaryOp { left, right, .. } => {
178 check_expr(left, ctx, diags);
179 check_expr(right, ctx, diags);
180 }
181 Expr::UnaryOp { expr: inner, .. } => {
182 check_expr(inner, ctx, diags);
183 }
184 Expr::Nested(inner) => {
185 check_expr(inner, ctx, diags);
186 }
187 Expr::Case {
188 operand,
189 conditions,
190 results,
191 else_result,
192 } => {
193 if let Some(op) = operand {
194 check_expr(op, ctx, diags);
195 }
196 for cond in conditions {
197 check_expr(cond, ctx, diags);
198 }
199 for result in results {
200 check_expr(result, ctx, diags);
201 }
202 if let Some(else_e) = else_result {
203 check_expr(else_e, ctx, diags);
204 }
205 }
206 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
207 check_expr(inner, ctx, diags);
208 }
209 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
210 collect_from_query(q, ctx, diags);
211 }
212 _ => {}
213 }
214}
215
216fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
219 let upper = source.to_uppercase();
220 let kw_upper = keyword.to_uppercase();
221 let bytes = upper.as_bytes();
222 let kw_bytes = kw_upper.as_bytes();
223 let kw_len = kw_bytes.len();
224
225 let mut i = 0;
226 while i + kw_len <= bytes.len() {
227 if bytes[i..i + kw_len] == *kw_bytes {
228 let before_ok =
229 i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
230 let after = i + kw_len;
231 let after_ok = after >= bytes.len()
232 || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
233 if before_ok && after_ok {
234 return offset_to_line_col(source, i);
235 }
236 }
237 i += 1;
238 }
239 (1, 1)
240}
241
242fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
244 let before = &source[..offset];
245 let line = before.chars().filter(|&c| c == '\n').count() + 1;
246 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
247 (line, col)
248}