Skip to main content

sqrust_rules/lint/
null_in_not_in.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, Value};
3
4pub struct NullInNotIn;
5
6impl Rule for NullInNotIn {
7    fn name(&self) -> &'static str {
8        "Lint/NullInNotIn"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        // Skip files that failed to parse — AST may be incomplete.
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16
17        let mut diags = Vec::new();
18        // Track occurrence count so we can locate the Nth `NOT IN` phrase in source.
19        let mut occurrence: usize = 0;
20
21        for stmt in &ctx.statements {
22            check_statement(stmt, &mut diags, &ctx.source, &mut occurrence);
23        }
24
25        diags
26    }
27}
28
29// ── Statement-level walker ────────────────────────────────────────────────────
30
31fn check_statement(stmt: &Statement, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
32    match stmt {
33        Statement::Query(q) => check_query(q, diags, source, occ),
34        Statement::Insert(insert) => {
35            if let Some(src) = &insert.source {
36                check_query(src, diags, source, occ);
37            }
38        }
39        Statement::Update { selection, .. } => {
40            if let Some(expr) = selection {
41                check_expr(expr, diags, source, occ);
42            }
43        }
44        Statement::Delete(delete) => {
45            if let Some(expr) = &delete.selection {
46                check_expr(expr, diags, source, occ);
47            }
48        }
49        _ => {}
50    }
51}
52
53fn check_query(query: &Query, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
54    match query.body.as_ref() {
55        SetExpr::Select(select) => check_select(select, diags, source, occ),
56        SetExpr::Query(q) => check_query(q, diags, source, occ),
57        SetExpr::SetOperation { left, right, .. } => {
58            match left.as_ref() {
59                SetExpr::Select(s) => check_select(s, diags, source, occ),
60                SetExpr::Query(q) => check_query(q, diags, source, occ),
61                _ => {}
62            }
63            match right.as_ref() {
64                SetExpr::Select(s) => check_select(s, diags, source, occ),
65                SetExpr::Query(q) => check_query(q, diags, source, occ),
66                _ => {}
67            }
68        }
69        _ => {}
70    }
71}
72
73fn check_select(select: &Select, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
74    // Projection expressions
75    for item in &select.projection {
76        if let SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } = item {
77            check_expr(expr, diags, source, occ);
78        }
79    }
80
81    // WHERE clause
82    if let Some(expr) = &select.selection {
83        check_expr(expr, diags, source, occ);
84    }
85
86    // HAVING clause
87    if let Some(expr) = &select.having {
88        check_expr(expr, diags, source, occ);
89    }
90}
91
92// ── Expression walker ─────────────────────────────────────────────────────────
93
94fn check_expr(expr: &Expr, diags: &mut Vec<Diagnostic>, source: &str, occ: &mut usize) {
95    match expr {
96        // NOT IN (...) — the case we want to flag when list contains NULL
97        Expr::InList {
98            expr: inner,
99            list,
100            negated: true,
101        } => {
102            let has_null = list.iter().any(|e| matches!(e, Expr::Value(Value::Null)));
103            if has_null {
104                // Find the Nth occurrence of "NOT IN" in source (case-insensitive).
105                let (line, col) = find_nth_phrase(source, "NOT IN", *occ);
106                *occ += 1;
107                diags.push(Diagnostic {
108                    rule: "Lint/NullInNotIn",
109                    message:
110                        "NOT IN list contains NULL; this will always produce an empty result set"
111                            .to_string(),
112                    line,
113                    col,
114                });
115            }
116            // Recurse into inner and list elements.
117            check_expr(inner, diags, source, occ);
118            for e in list {
119                check_expr(e, diags, source, occ);
120            }
121        }
122
123        // Positive IN (...) — not flagged, but recurse in case list has subqueries
124        Expr::InList {
125            expr: inner,
126            list,
127            negated: false,
128        } => {
129            check_expr(inner, diags, source, occ);
130            for e in list {
131                check_expr(e, diags, source, occ);
132            }
133        }
134
135        // Binary operators — recurse both sides
136        Expr::BinaryOp { left, right, .. } => {
137            check_expr(left, diags, source, occ);
138            check_expr(right, diags, source, occ);
139        }
140
141        // Parenthesised expression
142        Expr::Nested(inner) => check_expr(inner, diags, source, occ),
143
144        // Subquery (scalar)
145        Expr::Subquery(q) => check_query(q, diags, source, occ),
146
147        // [NOT] IN (SELECT ...) — recurse into the inner expression and the subquery
148        Expr::InSubquery {
149            expr: inner,
150            subquery,
151            ..
152        } => {
153            check_expr(inner, diags, source, occ);
154            check_query(subquery, diags, source, occ);
155        }
156
157        // EXISTS (SELECT ...) — recurse into the subquery
158        Expr::Exists { subquery, .. } => check_query(subquery, diags, source, occ),
159
160        // Function calls — check arguments
161        Expr::Function(f) => {
162            use sqlparser::ast::FunctionArguments;
163            if let FunctionArguments::List(arg_list) = &f.args {
164                for arg in &arg_list.args {
165                    if let sqlparser::ast::FunctionArg::Unnamed(
166                        sqlparser::ast::FunctionArgExpr::Expr(e),
167                    ) = arg
168                    {
169                        check_expr(e, diags, source, occ);
170                    }
171                }
172            }
173        }
174
175        // CASE WHEN expressions
176        Expr::Case {
177            operand,
178            conditions,
179            results,
180            else_result,
181        } => {
182            if let Some(op) = operand {
183                check_expr(op, diags, source, occ);
184            }
185            for cond in conditions {
186                check_expr(cond, diags, source, occ);
187            }
188            for res in results {
189                check_expr(res, diags, source, occ);
190            }
191            if let Some(el) = else_result {
192                check_expr(el, diags, source, occ);
193            }
194        }
195
196        // Unary operators
197        Expr::UnaryOp { expr: inner, .. } => check_expr(inner, diags, source, occ),
198
199        // IS NULL / IS NOT NULL
200        Expr::IsNull(inner) | Expr::IsNotNull(inner) => check_expr(inner, diags, source, occ),
201
202        // BETWEEN
203        Expr::Between {
204            expr: inner,
205            low,
206            high,
207            ..
208        } => {
209            check_expr(inner, diags, source, occ);
210            check_expr(low, diags, source, occ);
211            check_expr(high, diags, source, occ);
212        }
213
214        // LIKE, ILIKE
215        Expr::Like {
216            expr: inner,
217            pattern,
218            ..
219        }
220        | Expr::ILike {
221            expr: inner,
222            pattern,
223            ..
224        } => {
225            check_expr(inner, diags, source, occ);
226            check_expr(pattern, diags, source, occ);
227        }
228
229        // Everything else (literals, identifiers, etc.) — nothing to recurse into
230        _ => {}
231    }
232}
233
234// ── Source-text helpers ───────────────────────────────────────────────────────
235
236/// Finds the (line, col) of the `nth` (0-indexed) whole-word, case-insensitive
237/// occurrence of `phrase` in `source`. Returns (1, 1) if not found.
238fn find_nth_phrase(source: &str, phrase: &str, nth: usize) -> (usize, usize) {
239    let phrase_upper = phrase.to_uppercase();
240    let source_upper = source.to_uppercase();
241    let phrase_bytes = phrase_upper.as_bytes();
242    let src_bytes = source_upper.as_bytes();
243    let phrase_len = phrase_bytes.len();
244    let src_len = src_bytes.len();
245
246    let mut count = 0usize;
247    let mut i = 0usize;
248
249    while i + phrase_len <= src_len {
250        // Check if source_upper[i..i+phrase_len] == phrase_upper
251        if src_bytes[i..i + phrase_len] == *phrase_bytes {
252            // Word boundary before
253            let before_ok = i == 0 || {
254                let b = src_bytes[i - 1];
255                !b.is_ascii_alphanumeric() && b != b'_'
256            };
257            // Word boundary after
258            let after = i + phrase_len;
259            let after_ok = after >= src_len || {
260                let b = src_bytes[after];
261                !b.is_ascii_alphanumeric() && b != b'_'
262            };
263
264            if before_ok && after_ok {
265                if count == nth {
266                    return offset_to_line_col(source, i);
267                }
268                count += 1;
269            }
270        }
271        i += 1;
272    }
273
274    (1, 1)
275}
276
277/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
278fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
279    let before = &source[..offset];
280    let line = before.chars().filter(|&c| c == '\n').count() + 1;
281    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
282    (line, col)
283}