1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 BinaryOperator, Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
4};
5
6pub struct ChainedComparisons;
7
8impl Rule for ChainedComparisons {
9 fn name(&self) -> &'static str {
10 "Ambiguous/ChainedComparisons"
11 }
12
13 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14 if !ctx.parse_errors.is_empty() {
15 return Vec::new();
16 }
17
18 let mut diags = Vec::new();
19 for stmt in &ctx.statements {
20 collect_from_statement(stmt, ctx, &mut diags);
21 }
22 diags
23 }
24}
25
26fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
27 if let Statement::Query(query) = stmt {
28 collect_from_query(query, ctx, diags);
29 }
30}
31
32fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
33 if let Some(with) = &query.with {
34 for cte in &with.cte_tables {
35 collect_from_query(&cte.query, ctx, diags);
36 }
37 }
38 if let Some(order_by) = &query.order_by {
40 for ob_expr in &order_by.exprs {
41 check_expr(&ob_expr.expr, ctx, diags);
42 }
43 }
44 collect_from_set_expr(&query.body, ctx, diags);
45}
46
47fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
48 match expr {
49 SetExpr::Select(select) => {
50 collect_from_select(select, ctx, diags);
51 }
52 SetExpr::Query(inner) => {
53 collect_from_query(inner, ctx, diags);
54 }
55 SetExpr::SetOperation { left, right, .. } => {
56 collect_from_set_expr(left, ctx, diags);
57 collect_from_set_expr(right, ctx, diags);
58 }
59 _ => {}
60 }
61}
62
63fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
64 for item in &select.projection {
66 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
67 check_expr(e, ctx, diags);
68 }
69 }
70
71 for twj in &select.from {
73 collect_from_table_factor(&twj.relation, ctx, diags);
74 for join in &twj.joins {
75 collect_from_table_factor(&join.relation, ctx, diags);
76 use sqlparser::ast::{JoinConstraint, JoinOperator};
78 let on_expr = match &join.join_operator {
79 JoinOperator::Inner(JoinConstraint::On(e))
80 | JoinOperator::LeftOuter(JoinConstraint::On(e))
81 | JoinOperator::RightOuter(JoinConstraint::On(e))
82 | JoinOperator::FullOuter(JoinConstraint::On(e)) => Some(e),
83 _ => None,
84 };
85 if let Some(e) = on_expr {
86 check_expr(e, ctx, diags);
87 }
88 }
89 }
90
91 if let Some(selection) = &select.selection {
93 check_expr(selection, ctx, diags);
94 }
95
96 if let Some(having) = &select.having {
98 check_expr(having, ctx, diags);
99 }
100
101 if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
103 for e in exprs {
104 check_expr(e, ctx, diags);
105 }
106 }
107
108}
109
110fn collect_from_table_factor(
111 factor: &TableFactor,
112 ctx: &FileContext,
113 diags: &mut Vec<Diagnostic>,
114) {
115 if let TableFactor::Derived { subquery, .. } = factor {
116 collect_from_query(subquery, ctx, diags);
117 }
118}
119
120fn is_comparison_op(op: &BinaryOperator) -> bool {
122 matches!(
123 op,
124 BinaryOperator::Lt
125 | BinaryOperator::Gt
126 | BinaryOperator::LtEq
127 | BinaryOperator::GtEq
128 | BinaryOperator::Eq
129 | BinaryOperator::NotEq
130 )
131}
132
133fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
139 match expr {
140 Expr::BinaryOp { left, op, right } => {
141 if is_comparison_op(op) {
144 if let Expr::BinaryOp { op: inner_op, .. } = left.as_ref() {
145 if is_comparison_op(inner_op) {
146 let (line, col) = find_keyword_position(&ctx.source, "where")
147 .or_else(|| find_keyword_position(&ctx.source, "select"))
148 .unwrap_or((1, 1));
149 diags.push(Diagnostic {
150 rule: "Ambiguous/ChainedComparisons",
151 message:
152 "Chained comparison 'a < b < c' is ambiguous; use 'a < b AND b < c' instead"
153 .to_string(),
154 line,
155 col,
156 });
157 }
158 }
159 }
160 check_expr(left, ctx, diags);
162 check_expr(right, ctx, diags);
163 }
164 Expr::UnaryOp { expr: inner, .. } => {
165 check_expr(inner, ctx, diags);
166 }
167 Expr::Nested(inner) => {
168 check_expr(inner, ctx, diags);
169 }
170 Expr::Case {
171 operand,
172 conditions,
173 results,
174 else_result,
175 } => {
176 if let Some(op) = operand {
177 check_expr(op, ctx, diags);
178 }
179 for cond in conditions {
180 check_expr(cond, ctx, diags);
181 }
182 for result in results {
183 check_expr(result, ctx, diags);
184 }
185 if let Some(else_e) = else_result {
186 check_expr(else_e, ctx, diags);
187 }
188 }
189 Expr::InList { expr: inner, list, .. } => {
190 check_expr(inner, ctx, diags);
191 for e in list {
192 check_expr(e, ctx, diags);
193 }
194 }
195 Expr::Between {
196 expr: inner,
197 low,
198 high,
199 ..
200 } => {
201 check_expr(inner, ctx, diags);
202 check_expr(low, ctx, diags);
203 check_expr(high, ctx, diags);
204 }
205 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
206 check_expr(inner, ctx, diags);
207 }
208 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
209 collect_from_query(q, ctx, diags);
210 }
211 _ => {}
212 }
213}
214
215fn find_keyword_position(source: &str, keyword: &str) -> Option<(usize, usize)> {
218 let upper = source.to_uppercase();
219 let kw_upper = keyword.to_uppercase();
220 let bytes = upper.as_bytes();
221 let kw_bytes = kw_upper.as_bytes();
222 let kw_len = kw_bytes.len();
223
224 let mut i = 0;
225 while i + kw_len <= bytes.len() {
226 if bytes[i..i + kw_len] == *kw_bytes {
227 let before_ok =
228 i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
229 let after = i + kw_len;
230 let after_ok = after >= bytes.len()
231 || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
232 if before_ok && after_ok {
233 return Some(offset_to_line_col(source, i));
234 }
235 }
236 i += 1;
237 }
238 None
239}
240
241fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
243 let before = &source[..offset];
244 let line = before.chars().filter(|&c| c == '\n').count() + 1;
245 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
246 (line, col)
247}