Skip to main content

sqrust_rules/convention/
in_single_value.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct InSingleValue;
7
8impl Rule for InSingleValue {
9    fn name(&self) -> &'static str {
10        "InSingleValue"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        if !ctx.parse_errors.is_empty() {
15            return Vec::new();
16        }
17
18        // Collect all byte offsets of `IN` keywords (word-boundary, outside
19        // strings/comments, case-insensitive) in source order.
20        let in_offsets = collect_in_offsets(&ctx.source);
21        let mut in_index: usize = 0;
22        let mut diags = Vec::new();
23
24        for stmt in &ctx.statements {
25            if let Statement::Query(query) = stmt {
26                check_query(
27                    query,
28                    self.name(),
29                    &ctx.source,
30                    &in_offsets,
31                    &mut in_index,
32                    &mut diags,
33                );
34            }
35        }
36
37        diags
38    }
39}
40
41// ── AST walking ───────────────────────────────────────────────────────────────
42
43fn check_query(
44    query: &Query,
45    rule: &'static str,
46    source: &str,
47    offsets: &[usize],
48    idx: &mut usize,
49    diags: &mut Vec<Diagnostic>,
50) {
51    if let Some(with) = &query.with {
52        for cte in &with.cte_tables {
53            check_query(&cte.query, rule, source, offsets, idx, diags);
54        }
55    }
56    check_set_expr(&query.body, rule, source, offsets, idx, diags);
57}
58
59fn check_set_expr(
60    body: &SetExpr,
61    rule: &'static str,
62    source: &str,
63    offsets: &[usize],
64    idx: &mut usize,
65    diags: &mut Vec<Diagnostic>,
66) {
67    match body {
68        SetExpr::Select(sel) => check_select(sel, rule, source, offsets, idx, diags),
69        SetExpr::Query(q) => check_query(q, rule, source, offsets, idx, diags),
70        SetExpr::SetOperation { left, right, .. } => {
71            check_set_expr(left, rule, source, offsets, idx, diags);
72            check_set_expr(right, rule, source, offsets, idx, diags);
73        }
74        _ => {}
75    }
76}
77
78fn check_select(
79    sel: &Select,
80    rule: &'static str,
81    source: &str,
82    offsets: &[usize],
83    idx: &mut usize,
84    diags: &mut Vec<Diagnostic>,
85) {
86    // Recurse into subqueries in the FROM clause.
87    for table in &sel.from {
88        recurse_table_factor(&table.relation, rule, source, offsets, idx, diags);
89        for join in &table.joins {
90            recurse_table_factor(&join.relation, rule, source, offsets, idx, diags);
91        }
92    }
93
94    // Check the WHERE clause.
95    if let Some(selection) = &sel.selection {
96        check_expr(selection, rule, source, offsets, idx, diags);
97    }
98
99    // Check expressions in the projection (HAVING-style subqueries, etc.).
100    for item in &sel.projection {
101        use sqlparser::ast::SelectItem;
102        match item {
103            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
104                check_expr(e, rule, source, offsets, idx, diags);
105            }
106            _ => {}
107        }
108    }
109
110    // Check HAVING clause.
111    if let Some(having) = &sel.having {
112        check_expr(having, rule, source, offsets, idx, diags);
113    }
114}
115
116fn recurse_table_factor(
117    tf: &TableFactor,
118    rule: &'static str,
119    source: &str,
120    offsets: &[usize],
121    idx: &mut usize,
122    diags: &mut Vec<Diagnostic>,
123) {
124    if let TableFactor::Derived { subquery, .. } = tf {
125        check_query(subquery, rule, source, offsets, idx, diags);
126    }
127}
128
129fn check_expr(
130    expr: &Expr,
131    rule: &'static str,
132    source: &str,
133    offsets: &[usize],
134    idx: &mut usize,
135    diags: &mut Vec<Diagnostic>,
136) {
137    match expr {
138        Expr::InList {
139            list,
140            negated,
141            expr: inner,
142        } => {
143            // Check the inner expression first (it may contain nested IN).
144            check_expr(inner, rule, source, offsets, idx, diags);
145
146            if !negated && list.len() == 1 {
147                // Consume the next IN offset as a violation.
148                let offset = offsets.get(*idx).copied().unwrap_or(0);
149                let (line, col) = line_col(source, offset);
150                diags.push(Diagnostic {
151                    rule,
152                    message: "IN list with a single value; use = instead".to_string(),
153                    line,
154                    col,
155                });
156                *idx += 1;
157            } else {
158                // Consume the IN offset without flagging.
159                if *idx < offsets.len() {
160                    *idx += 1;
161                }
162            }
163
164            // Recurse into the list elements.
165            for e in list {
166                check_expr(e, rule, source, offsets, idx, diags);
167            }
168        }
169
170        Expr::BinaryOp { left, right, .. } => {
171            check_expr(left, rule, source, offsets, idx, diags);
172            check_expr(right, rule, source, offsets, idx, diags);
173        }
174
175        Expr::UnaryOp { expr: inner, .. } => {
176            check_expr(inner, rule, source, offsets, idx, diags);
177        }
178
179        Expr::Nested(inner) => {
180            check_expr(inner, rule, source, offsets, idx, diags);
181        }
182
183        Expr::Subquery(q) => {
184            check_query(q, rule, source, offsets, idx, diags);
185        }
186
187        Expr::InSubquery {
188            expr: inner,
189            subquery,
190            ..
191        } => {
192            check_expr(inner, rule, source, offsets, idx, diags);
193            check_query(subquery, rule, source, offsets, idx, diags);
194        }
195
196        Expr::Exists { subquery, .. } => {
197            check_query(subquery, rule, source, offsets, idx, diags);
198        }
199
200        Expr::Case {
201            operand,
202            conditions,
203            results,
204            else_result,
205        } => {
206            if let Some(op) = operand {
207                check_expr(op, rule, source, offsets, idx, diags);
208            }
209            for cond in conditions {
210                check_expr(cond, rule, source, offsets, idx, diags);
211            }
212            for res in results {
213                check_expr(res, rule, source, offsets, idx, diags);
214            }
215            if let Some(else_r) = else_result {
216                check_expr(else_r, rule, source, offsets, idx, diags);
217            }
218        }
219
220        _ => {}
221    }
222}
223
224// ── helpers ───────────────────────────────────────────────────────────────────
225
226/// Collect byte offsets of every `IN` keyword (case-insensitive, word-boundary,
227/// outside strings/comments) in source order.
228///
229/// We must exclude `IN` that is part of `NOT IN` — we skip those because the
230/// AST's `negated` flag handles them, and we don't want to consume an offset
231/// slot that the AST will never fire on.  We therefore collect ALL `IN`
232/// occurrences (including those following `NOT`) and let the AST traversal
233/// consume or skip them in lock-step.
234fn collect_in_offsets(source: &str) -> Vec<usize> {
235    let bytes = source.as_bytes();
236    let len = bytes.len();
237    let skip_map = SkipMap::build(source);
238    let kw = b"IN";
239    let kw_len = kw.len();
240    let mut offsets = Vec::new();
241
242    let mut i = 0;
243    while i + kw_len <= len {
244        if !skip_map.is_code(i) {
245            i += 1;
246            continue;
247        }
248
249        // Word boundary before.
250        let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
251        if !before_ok {
252            i += 1;
253            continue;
254        }
255
256        // Case-insensitive match for "IN".
257        let matches = bytes[i] == b'I' || bytes[i] == b'i';
258        let matches = matches && (bytes[i + 1] == b'N' || bytes[i + 1] == b'n');
259
260        if matches {
261            // Word boundary after.
262            let after = i + kw_len;
263            let after_ok = after >= len || !is_word_char(bytes[after]);
264            let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
265
266            if after_ok && all_code {
267                offsets.push(i);
268                i += kw_len;
269                continue;
270            }
271        }
272
273        i += 1;
274    }
275
276    offsets
277}
278
279/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
280fn line_col(source: &str, offset: usize) -> (usize, usize) {
281    let before = &source[..offset];
282    let line = before.chars().filter(|&c| c == '\n').count() + 1;
283    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
284    (line, col)
285}