1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct InSingleValue;
7
8impl Rule for InSingleValue {
9 fn name(&self) -> &'static str {
10 "InSingleValue"
11 }
12
13 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14 if !ctx.parse_errors.is_empty() {
15 return Vec::new();
16 }
17
18 let in_offsets = collect_in_offsets(&ctx.source);
21 let mut in_index: usize = 0;
22 let mut diags = Vec::new();
23
24 for stmt in &ctx.statements {
25 if let Statement::Query(query) = stmt {
26 check_query(
27 query,
28 self.name(),
29 &ctx.source,
30 &in_offsets,
31 &mut in_index,
32 &mut diags,
33 );
34 }
35 }
36
37 diags
38 }
39}
40
41fn check_query(
44 query: &Query,
45 rule: &'static str,
46 source: &str,
47 offsets: &[usize],
48 idx: &mut usize,
49 diags: &mut Vec<Diagnostic>,
50) {
51 if let Some(with) = &query.with {
52 for cte in &with.cte_tables {
53 check_query(&cte.query, rule, source, offsets, idx, diags);
54 }
55 }
56 check_set_expr(&query.body, rule, source, offsets, idx, diags);
57}
58
59fn check_set_expr(
60 body: &SetExpr,
61 rule: &'static str,
62 source: &str,
63 offsets: &[usize],
64 idx: &mut usize,
65 diags: &mut Vec<Diagnostic>,
66) {
67 match body {
68 SetExpr::Select(sel) => check_select(sel, rule, source, offsets, idx, diags),
69 SetExpr::Query(q) => check_query(q, rule, source, offsets, idx, diags),
70 SetExpr::SetOperation { left, right, .. } => {
71 check_set_expr(left, rule, source, offsets, idx, diags);
72 check_set_expr(right, rule, source, offsets, idx, diags);
73 }
74 _ => {}
75 }
76}
77
78fn check_select(
79 sel: &Select,
80 rule: &'static str,
81 source: &str,
82 offsets: &[usize],
83 idx: &mut usize,
84 diags: &mut Vec<Diagnostic>,
85) {
86 for table in &sel.from {
88 recurse_table_factor(&table.relation, rule, source, offsets, idx, diags);
89 for join in &table.joins {
90 recurse_table_factor(&join.relation, rule, source, offsets, idx, diags);
91 }
92 }
93
94 if let Some(selection) = &sel.selection {
96 check_expr(selection, rule, source, offsets, idx, diags);
97 }
98
99 for item in &sel.projection {
101 use sqlparser::ast::SelectItem;
102 match item {
103 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
104 check_expr(e, rule, source, offsets, idx, diags);
105 }
106 _ => {}
107 }
108 }
109
110 if let Some(having) = &sel.having {
112 check_expr(having, rule, source, offsets, idx, diags);
113 }
114}
115
116fn recurse_table_factor(
117 tf: &TableFactor,
118 rule: &'static str,
119 source: &str,
120 offsets: &[usize],
121 idx: &mut usize,
122 diags: &mut Vec<Diagnostic>,
123) {
124 if let TableFactor::Derived { subquery, .. } = tf {
125 check_query(subquery, rule, source, offsets, idx, diags);
126 }
127}
128
129fn check_expr(
130 expr: &Expr,
131 rule: &'static str,
132 source: &str,
133 offsets: &[usize],
134 idx: &mut usize,
135 diags: &mut Vec<Diagnostic>,
136) {
137 match expr {
138 Expr::InList {
139 list,
140 negated,
141 expr: inner,
142 } => {
143 check_expr(inner, rule, source, offsets, idx, diags);
145
146 if !negated && list.len() == 1 {
147 let offset = offsets.get(*idx).copied().unwrap_or(0);
149 let (line, col) = line_col(source, offset);
150 diags.push(Diagnostic {
151 rule,
152 message: "IN list with a single value; use = instead".to_string(),
153 line,
154 col,
155 });
156 *idx += 1;
157 } else {
158 if *idx < offsets.len() {
160 *idx += 1;
161 }
162 }
163
164 for e in list {
166 check_expr(e, rule, source, offsets, idx, diags);
167 }
168 }
169
170 Expr::BinaryOp { left, right, .. } => {
171 check_expr(left, rule, source, offsets, idx, diags);
172 check_expr(right, rule, source, offsets, idx, diags);
173 }
174
175 Expr::UnaryOp { expr: inner, .. } => {
176 check_expr(inner, rule, source, offsets, idx, diags);
177 }
178
179 Expr::Nested(inner) => {
180 check_expr(inner, rule, source, offsets, idx, diags);
181 }
182
183 Expr::Subquery(q) => {
184 check_query(q, rule, source, offsets, idx, diags);
185 }
186
187 Expr::InSubquery {
188 expr: inner,
189 subquery,
190 ..
191 } => {
192 check_expr(inner, rule, source, offsets, idx, diags);
193 check_query(subquery, rule, source, offsets, idx, diags);
194 }
195
196 Expr::Exists { subquery, .. } => {
197 check_query(subquery, rule, source, offsets, idx, diags);
198 }
199
200 Expr::Case {
201 operand,
202 conditions,
203 results,
204 else_result,
205 } => {
206 if let Some(op) = operand {
207 check_expr(op, rule, source, offsets, idx, diags);
208 }
209 for cond in conditions {
210 check_expr(cond, rule, source, offsets, idx, diags);
211 }
212 for res in results {
213 check_expr(res, rule, source, offsets, idx, diags);
214 }
215 if let Some(else_r) = else_result {
216 check_expr(else_r, rule, source, offsets, idx, diags);
217 }
218 }
219
220 _ => {}
221 }
222}
223
224fn collect_in_offsets(source: &str) -> Vec<usize> {
235 let bytes = source.as_bytes();
236 let len = bytes.len();
237 let skip_map = SkipMap::build(source);
238 let kw = b"IN";
239 let kw_len = kw.len();
240 let mut offsets = Vec::new();
241
242 let mut i = 0;
243 while i + kw_len <= len {
244 if !skip_map.is_code(i) {
245 i += 1;
246 continue;
247 }
248
249 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
251 if !before_ok {
252 i += 1;
253 continue;
254 }
255
256 let matches = bytes[i] == b'I' || bytes[i] == b'i';
258 let matches = matches && (bytes[i + 1] == b'N' || bytes[i + 1] == b'n');
259
260 if matches {
261 let after = i + kw_len;
263 let after_ok = after >= len || !is_word_char(bytes[after]);
264 let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
265
266 if after_ok && all_code {
267 offsets.push(i);
268 i += kw_len;
269 continue;
270 }
271 }
272
273 i += 1;
274 }
275
276 offsets
277}
278
279fn line_col(source: &str, offset: usize) -> (usize, usize) {
281 let before = &source[..offset];
282 let line = before.chars().filter(|&c| c == '\n').count() + 1;
283 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
284 (line, col)
285}