1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, Value};
3
4pub struct NullInNotIn;
5
6impl Rule for NullInNotIn {
7 fn name(&self) -> &'static str {
8 "Lint/NullInNotIn"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 let mut occurrence: usize = 0;
20
21 for stmt in &ctx.statements {
22 check_statement(stmt, &mut diags, &ctx.source, &mut occurrence);
23 }
24
25 diags
26 }
27}
28
29fn check_statement(stmt: &Statement, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
32 match stmt {
33 Statement::Query(q) => check_query(q, diags, source, occ),
34 Statement::Insert(insert) => {
35 if let Some(src) = &insert.source {
36 check_query(src, diags, source, occ);
37 }
38 }
39 Statement::Update { selection, .. } => {
40 if let Some(expr) = selection {
41 check_expr(expr, diags, source, occ);
42 }
43 }
44 Statement::Delete(delete) => {
45 if let Some(expr) = &delete.selection {
46 check_expr(expr, diags, source, occ);
47 }
48 }
49 _ => {}
50 }
51}
52
53fn check_query(query: &Query, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
54 match query.body.as_ref() {
55 SetExpr::Select(select) => check_select(select, diags, source, occ),
56 SetExpr::Query(q) => check_query(q, diags, source, occ),
57 SetExpr::SetOperation { left, right, .. } => {
58 match left.as_ref() {
59 SetExpr::Select(s) => check_select(s, diags, source, occ),
60 SetExpr::Query(q) => check_query(q, diags, source, occ),
61 _ => {}
62 }
63 match right.as_ref() {
64 SetExpr::Select(s) => check_select(s, diags, source, occ),
65 SetExpr::Query(q) => check_query(q, diags, source, occ),
66 _ => {}
67 }
68 }
69 _ => {}
70 }
71}
72
73fn check_select(select: &Select, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
74 for item in &select.projection {
76 if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
77 check_expr(expr, diags, source, occ);
78 }
79 }
80
81 if let Some(expr) = &select.selection {
83 check_expr(expr, diags, source, occ);
84 }
85
86 if let Some(expr) = &select.having {
88 check_expr(expr, diags, source, occ);
89 }
90}
91
92fn check_expr(expr: &Expr, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
95 match expr {
96 Expr::InList {
98 expr: inner,
99 list,
100 negated: true,
101 } => {
102 let has_null = list.iter().any(|e| matches!(e, Expr::Value(Value::Null)));
103 if has_null {
104 let (line, col) = find_nth_phrase(source, "NOT IN", *occ);
106 *occ += 1;
107 diags.push(Diagnostic {
108 rule: "Lint/NullInNotIn",
109 message:
110 "NOT IN list contains NULL; this will always produce an empty result set"
111 .to_string(),
112 line,
113 col,
114 });
115 }
116 check_expr(inner, diags, source, occ);
118 for e in list {
119 check_expr(e, diags, source, occ);
120 }
121 }
122
123 Expr::InList {
125 expr: inner,
126 list,
127 negated: false,
128 } => {
129 check_expr(inner, diags, source, occ);
130 for e in list {
131 check_expr(e, diags, source, occ);
132 }
133 }
134
135 Expr::BinaryOp { left, right, .. } => {
137 check_expr(left, diags, source, occ);
138 check_expr(right, diags, source, occ);
139 }
140
141 Expr::Nested(inner) => check_expr(inner, diags, source, occ),
143
144 Expr::Subquery(q) => check_query(q, diags, source, occ),
146
147 Expr::InSubquery {
149 expr: inner,
150 subquery,
151 ..
152 } => {
153 check_expr(inner, diags, source, occ);
154 check_query(subquery, diags, source, occ);
155 }
156
157 Expr::Exists { subquery, .. } => check_query(subquery, diags, source, occ),
159
160 Expr::Function(f) => {
162 use sqlparser::ast::FunctionArguments;
163 if let FunctionArguments::List(arg_list) = &f.args {
164 for arg in &arg_list.args {
165 if let sqlparser::ast::FunctionArg::Unnamed(
166 sqlparser::ast::FunctionArgExpr::Expr(e),
167 ) = arg
168 {
169 check_expr(e, diags, source, occ);
170 }
171 }
172 }
173 }
174
175 Expr::Case {
177 operand,
178 conditions,
179 results,
180 else_result,
181 } => {
182 if let Some(op) = operand {
183 check_expr(op, diags, source, occ);
184 }
185 for cond in conditions {
186 check_expr(cond, diags, source, occ);
187 }
188 for res in results {
189 check_expr(res, diags, source, occ);
190 }
191 if let Some(el) = else_result {
192 check_expr(el, diags, source, occ);
193 }
194 }
195
196 Expr::UnaryOp { expr: inner, .. } => check_expr(inner, diags, source, occ),
198
199 Expr::IsNull(inner) | Expr::IsNotNull(inner) => check_expr(inner, diags, source, occ),
201
202 Expr::Between {
204 expr: inner,
205 low,
206 high,
207 ..
208 } => {
209 check_expr(inner, diags, source, occ);
210 check_expr(low, diags, source, occ);
211 check_expr(high, diags, source, occ);
212 }
213
214 Expr::Like {
216 expr: inner,
217 pattern,
218 ..
219 }
220 | Expr::ILike {
221 expr: inner,
222 pattern,
223 ..
224 } => {
225 check_expr(inner, diags, source, occ);
226 check_expr(pattern, diags, source, occ);
227 }
228
229 _ => {}
231 }
232}
233
234fn find_nth_phrase(source: &str, phrase: &str, nth: usize) -> (usize, usize) {
239 let phrase_upper = phrase.to_uppercase();
240 let source_upper = source.to_uppercase();
241 let phrase_bytes = phrase_upper.as_bytes();
242 let src_bytes = source_upper.as_bytes();
243 let phrase_len = phrase_bytes.len();
244 let src_len = src_bytes.len();
245
246 let mut count = 0usize;
247 let mut i = 0usize;
248
249 while i + phrase_len <= src_len {
250 if src_bytes[i..i + phrase_len] == *phrase_bytes {
252 let before_ok = i == 0 || {
254 let b = src_bytes[i - 1];
255 !b.is_ascii_alphanumeric() && b != b'_'
256 };
257 let after = i + phrase_len;
259 let after_ok = after >= src_len || {
260 let b = src_bytes[after];
261 !b.is_ascii_alphanumeric() && b != b'_'
262 };
263
264 if before_ok && after_ok {
265 if count == nth {
266 return offset_to_line_col(source, i);
267 }
268 count += 1;
269 }
270 }
271 i += 1;
272 }
273
274 (1, 1)
275}
276
277fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
279 let before = &source[..offset];
280 let line = before.chars().filter(|&c| c == '\n').count() + 1;
281 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
282 (line, col)
283}