Skip to main content

sqrust_rules/ambiguous/
chained_comparisons.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3    BinaryOperator, Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, TableFactor,
4};
5
6pub struct ChainedComparisons;
7
8impl Rule for ChainedComparisons {
9    fn name(&self) -> &'static str {
10        "Ambiguous/ChainedComparisons"
11    }
12
13    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
14        if !ctx.parse_errors.is_empty() {
15            return Vec::new();
16        }
17
18        let mut diags = Vec::new();
19        for stmt in &ctx.statements {
20            collect_from_statement(stmt, ctx, &mut diags);
21        }
22        diags
23    }
24}
25
26fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
27    if let Statement::Query(query) = stmt {
28        collect_from_query(query, ctx, diags);
29    }
30}
31
32fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
33    if let Some(with) = &query.with {
34        for cte in &with.cte_tables {
35            collect_from_query(&cte.query, ctx, diags);
36        }
37    }
38    // Check ORDER BY expressions (they live on Query, not Select).
39    if let Some(order_by) = &query.order_by {
40        for ob_expr in &order_by.exprs {
41            check_expr(&ob_expr.expr, ctx, diags);
42        }
43    }
44    collect_from_set_expr(&query.body, ctx, diags);
45}
46
47fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
48    match expr {
49        SetExpr::Select(select) => {
50            collect_from_select(select, ctx, diags);
51        }
52        SetExpr::Query(inner) => {
53            collect_from_query(inner, ctx, diags);
54        }
55        SetExpr::SetOperation { left, right, .. } => {
56            collect_from_set_expr(left, ctx, diags);
57            collect_from_set_expr(right, ctx, diags);
58        }
59        _ => {}
60    }
61}
62
63fn collect_from_select(select: &Select, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
64    // Check SELECT projection.
65    for item in &select.projection {
66        if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
67            check_expr(e, ctx, diags);
68        }
69    }
70
71    // Check FROM subqueries.
72    for twj in &select.from {
73        collect_from_table_factor(&twj.relation, ctx, diags);
74        for join in &twj.joins {
75            collect_from_table_factor(&join.relation, ctx, diags);
76            // Check JOIN ON conditions.
77            use sqlparser::ast::{JoinConstraint, JoinOperator};
78            let on_expr = match &join.join_operator {
79                JoinOperator::Inner(JoinConstraint::On(e))
80                | JoinOperator::LeftOuter(JoinConstraint::On(e))
81                | JoinOperator::RightOuter(JoinConstraint::On(e))
82                | JoinOperator::FullOuter(JoinConstraint::On(e)) => Some(e),
83                _ => None,
84            };
85            if let Some(e) = on_expr {
86                check_expr(e, ctx, diags);
87            }
88        }
89    }
90
91    // Check WHERE.
92    if let Some(selection) = &select.selection {
93        check_expr(selection, ctx, diags);
94    }
95
96    // Check HAVING.
97    if let Some(having) = &select.having {
98        check_expr(having, ctx, diags);
99    }
100
101    // Check GROUP BY expressions.
102    if let GroupByExpr::Expressions(exprs, _) = &select.group_by {
103        for e in exprs {
104            check_expr(e, ctx, diags);
105        }
106    }
107
108}
109
110fn collect_from_table_factor(
111    factor: &TableFactor,
112    ctx: &FileContext,
113    diags: &mut Vec<Diagnostic>,
114) {
115    if let TableFactor::Derived { subquery, .. } = factor {
116        collect_from_query(subquery, ctx, diags);
117    }
118}
119
120/// Returns true if `op` is one of the six SQL comparison operators.
121fn is_comparison_op(op: &BinaryOperator) -> bool {
122    matches!(
123        op,
124        BinaryOperator::Lt
125            | BinaryOperator::Gt
126            | BinaryOperator::LtEq
127            | BinaryOperator::GtEq
128            | BinaryOperator::Eq
129            | BinaryOperator::NotEq
130    )
131}
132
133/// Walks `expr`, flagging any `BinaryOp` node that is a comparison whose LEFT
134/// child is also a comparison — the classic "chained comparison" pattern.
135///
136/// Recursion visits both children of every `BinaryOp` so that nested chains
137/// (e.g. `a < b < c < d`) are caught at every level.
138fn check_expr(expr: &Expr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
139    match expr {
140        Expr::BinaryOp { left, op, right } => {
141            // Check this node for the chained-comparison pattern BEFORE recursing,
142            // so the outermost chain is reported first.
143            if is_comparison_op(op) {
144                if let Expr::BinaryOp { op: inner_op, .. } = left.as_ref() {
145                    if is_comparison_op(inner_op) {
146                        let (line, col) = find_keyword_position(&ctx.source, "where")
147                            .or_else(|| find_keyword_position(&ctx.source, "select"))
148                            .unwrap_or((1, 1));
149                        diags.push(Diagnostic {
150                            rule: "Ambiguous/ChainedComparisons",
151                            message:
152                                "Chained comparison 'a < b < c' is ambiguous; use 'a < b AND b < c' instead"
153                                    .to_string(),
154                            line,
155                            col,
156                        });
157                    }
158                }
159            }
160            // Recurse into both children to catch nested chains.
161            check_expr(left, ctx, diags);
162            check_expr(right, ctx, diags);
163        }
164        Expr::UnaryOp { expr: inner, .. } => {
165            check_expr(inner, ctx, diags);
166        }
167        Expr::Nested(inner) => {
168            check_expr(inner, ctx, diags);
169        }
170        Expr::Case {
171            operand,
172            conditions,
173            results,
174            else_result,
175        } => {
176            if let Some(op) = operand {
177                check_expr(op, ctx, diags);
178            }
179            for cond in conditions {
180                check_expr(cond, ctx, diags);
181            }
182            for result in results {
183                check_expr(result, ctx, diags);
184            }
185            if let Some(else_e) = else_result {
186                check_expr(else_e, ctx, diags);
187            }
188        }
189        Expr::InList { expr: inner, list, .. } => {
190            check_expr(inner, ctx, diags);
191            for e in list {
192                check_expr(e, ctx, diags);
193            }
194        }
195        Expr::Between {
196            expr: inner,
197            low,
198            high,
199            ..
200        } => {
201            check_expr(inner, ctx, diags);
202            check_expr(low, ctx, diags);
203            check_expr(high, ctx, diags);
204        }
205        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
206            check_expr(inner, ctx, diags);
207        }
208        Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
209            collect_from_query(q, ctx, diags);
210        }
211        _ => {}
212    }
213}
214
215/// Finds the first word-boundary occurrence of `keyword` (case-insensitive) in
216/// `source` and returns `Some((line, col))`, or `None` if not found.
217fn find_keyword_position(source: &str, keyword: &str) -> Option<(usize, usize)> {
218    let upper = source.to_uppercase();
219    let kw_upper = keyword.to_uppercase();
220    let bytes = upper.as_bytes();
221    let kw_bytes = kw_upper.as_bytes();
222    let kw_len = kw_bytes.len();
223
224    let mut i = 0;
225    while i + kw_len <= bytes.len() {
226        if bytes[i..i + kw_len] == *kw_bytes {
227            let before_ok =
228                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
229            let after = i + kw_len;
230            let after_ok = after >= bytes.len()
231                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
232            if before_ok && after_ok {
233                return Some(offset_to_line_col(source, i));
234            }
235        }
236        i += 1;
237    }
238    None
239}
240
241/// Converts a byte offset to 1-indexed (line, col).
242fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
243    let before = &source[..offset];
244    let line = before.chars().filter(|&c| c == '\n').count() + 1;
245    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
246    (line, col)
247}