1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4use crate::capitalisation::{is_word_char, SkipMap};
5
6pub struct LargeInList {
7 pub max_values: usize,
10}
11
12impl Default for LargeInList {
13 fn default() -> Self {
14 LargeInList { max_values: 10 }
15 }
16}
17
18impl Rule for LargeInList {
19 fn name(&self) -> &'static str {
20 "LargeInList"
21 }
22
23 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
24 if !ctx.parse_errors.is_empty() {
25 return Vec::new();
26 }
27
28 let mut diags = Vec::new();
29 let mut in_occurrence: usize = 0;
32
33 for stmt in &ctx.statements {
34 if let Statement::Query(query) = stmt {
35 check_query(query, self.max_values, ctx, &mut in_occurrence, &mut diags);
36 }
37 }
38
39 diags
40 }
41}
42
43fn check_query(
46 query: &Query,
47 max: usize,
48 ctx: &FileContext,
49 occurrence: &mut usize,
50 diags: &mut Vec<Diagnostic>,
51) {
52 if let Some(with) = &query.with {
53 for cte in &with.cte_tables {
54 check_query(&cte.query, max, ctx, occurrence, diags);
55 }
56 }
57
58 check_set_expr(&query.body, max, ctx, occurrence, diags);
59}
60
61fn check_set_expr(
62 expr: &SetExpr,
63 max: usize,
64 ctx: &FileContext,
65 occurrence: &mut usize,
66 diags: &mut Vec<Diagnostic>,
67) {
68 match expr {
69 SetExpr::Select(sel) => check_select(sel, max, ctx, occurrence, diags),
70 SetExpr::Query(inner) => check_query(inner, max, ctx, occurrence, diags),
71 SetExpr::SetOperation { left, right, .. } => {
72 check_set_expr(left, max, ctx, occurrence, diags);
73 check_set_expr(right, max, ctx, occurrence, diags);
74 }
75 _ => {}
76 }
77}
78
79fn check_select(
80 sel: &Select,
81 max: usize,
82 ctx: &FileContext,
83 occurrence: &mut usize,
84 diags: &mut Vec<Diagnostic>,
85) {
86 for item in &sel.projection {
88 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
89 check_expr(e, max, ctx, occurrence, diags);
90 }
91 }
92
93 if let Some(selection) = &sel.selection {
95 check_expr(selection, max, ctx, occurrence, diags);
96 }
97
98 for twj in &sel.from {
100 check_table_factor(&twj.relation, max, ctx, occurrence, diags);
101 for join in &twj.joins {
102 check_table_factor(&join.relation, max, ctx, occurrence, diags);
103 }
104 }
105}
106
107fn check_table_factor(
108 tf: &TableFactor,
109 max: usize,
110 ctx: &FileContext,
111 occurrence: &mut usize,
112 diags: &mut Vec<Diagnostic>,
113) {
114 if let TableFactor::Derived { subquery, .. } = tf {
115 check_query(subquery, max, ctx, occurrence, diags);
116 }
117}
118
119fn check_expr(
120 expr: &Expr,
121 max: usize,
122 ctx: &FileContext,
123 occurrence: &mut usize,
124 diags: &mut Vec<Diagnostic>,
125) {
126 match expr {
127 Expr::InList { expr: inner, list, .. } => {
128 let occ = *occurrence;
130 *occurrence += 1;
131
132 let n = list.len();
133 if n > max {
134 let (line, col) = find_nth_keyword_pos(&ctx.source, "IN", occ);
135 diags.push(Diagnostic {
136 rule: "LargeInList",
137 message: format!(
138 "IN list has {n} values, exceeding the maximum of {max}"
139 ),
140 line,
141 col,
142 });
143 }
144
145 check_expr(inner, max, ctx, occurrence, diags);
147 for val in list {
148 check_expr(val, max, ctx, occurrence, diags);
149 }
150 }
151
152 Expr::BinaryOp { left, right, .. } => {
153 check_expr(left, max, ctx, occurrence, diags);
154 check_expr(right, max, ctx, occurrence, diags);
155 }
156 Expr::UnaryOp { expr: inner, .. } => {
157 check_expr(inner, max, ctx, occurrence, diags);
158 }
159 Expr::Subquery(q) => check_query(q, max, ctx, occurrence, diags),
160 Expr::InSubquery { subquery, expr: e, .. } => {
161 check_expr(e, max, ctx, occurrence, diags);
162 check_query(subquery, max, ctx, occurrence, diags);
163 }
164 Expr::Exists { subquery, .. } => check_query(subquery, max, ctx, occurrence, diags),
165 Expr::Nested(inner) => check_expr(inner, max, ctx, occurrence, diags),
166 Expr::Case {
167 operand,
168 conditions,
169 results,
170 else_result,
171 } => {
172 if let Some(op) = operand {
173 check_expr(op, max, ctx, occurrence, diags);
174 }
175 for cond in conditions {
176 check_expr(cond, max, ctx, occurrence, diags);
177 }
178 for res in results {
179 check_expr(res, max, ctx, occurrence, diags);
180 }
181 if let Some(els) = else_result {
182 check_expr(els, max, ctx, occurrence, diags);
183 }
184 }
185 Expr::Function(f) => {
186 use sqlparser::ast::{FunctionArg, FunctionArgExpr, FunctionArguments};
187 if let FunctionArguments::List(list) = &f.args {
188 for arg in &list.args {
189 let arg_expr = match arg {
190 FunctionArg::Unnamed(e) => Some(e),
191 FunctionArg::Named { arg: e, .. } => Some(e),
192 FunctionArg::ExprNamed { arg: e, .. } => Some(e),
193 };
194 if let Some(FunctionArgExpr::Expr(inner)) = arg_expr {
195 check_expr(inner, max, ctx, occurrence, diags);
196 }
197 }
198 }
199 }
200 _ => {}
201 }
202}
203
204fn find_nth_keyword_pos(source: &str, keyword: &str, nth: usize) -> (usize, usize) {
210 let bytes = source.as_bytes();
211 let len = bytes.len();
212 let skip_map = SkipMap::build(source);
213 let kw_upper: Vec<u8> = keyword.bytes().map(|b| b.to_ascii_uppercase()).collect();
214 let kw_len = kw_upper.len();
215
216 let mut count = 0usize;
217 let mut i = 0;
218 while i + kw_len <= len {
219 if !skip_map.is_code(i) {
220 i += 1;
221 continue;
222 }
223
224 let before_ok = i == 0 || !is_word_char(bytes[i - 1]);
225 if !before_ok {
226 i += 1;
227 continue;
228 }
229
230 let matches = bytes[i..i + kw_len]
231 .iter()
232 .zip(kw_upper.iter())
233 .all(|(a, b)| a.eq_ignore_ascii_case(b));
234
235 if matches {
236 let after = i + kw_len;
237 let after_ok = after >= len || !is_word_char(bytes[after]);
238 let all_code = (i..i + kw_len).all(|k| skip_map.is_code(k));
239
240 if after_ok && all_code {
241 if count == nth {
242 return line_col(source, i);
243 }
244 count += 1;
245 }
246 }
247
248 i += 1;
249 }
250
251 (1, 1)
252}
253
254fn line_col(source: &str, offset: usize) -> (usize, usize) {
256 let before = &source[..offset];
257 let line = before.chars().filter(|&c| c == '\n').count() + 1;
258 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
259 (line, col)
260}