Skip to main content

sqrust_rules/ambiguous/
subquery_in_order_by.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, OrderBy, Query, SetExpr, Statement, TableFactor};
3
4pub struct SubqueryInOrderBy;
5
6impl Rule for SubqueryInOrderBy {
7    fn name(&self) -> &'static str {
8        "Ambiguous/SubqueryInOrderBy"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17        for stmt in &ctx.statements {
18            collect_from_statement(stmt, ctx, &mut diags);
19        }
20        diags
21    }
22}
23
24fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
25    if let Statement::Query(query) = stmt {
26        collect_from_query(query, ctx, diags);
27    }
28}
29
30fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
31    // Recurse into CTEs first.
32    if let Some(with) = &query.with {
33        for cte in &with.cte_tables {
34            collect_from_query(&cte.query, ctx, diags);
35        }
36    }
37
38    // Check ORDER BY on this query.
39    if let Some(order_by) = &query.order_by {
40        check_order_by(order_by, ctx, diags);
41    }
42
43    // Recurse into body (subqueries in FROM, UNION arms, etc.).
44    collect_from_set_expr(&query.body, ctx, diags);
45}
46
47fn check_order_by(order_by: &OrderBy, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
48    for order_expr in &order_by.exprs {
49        if contains_subquery(&order_expr.expr) {
50            let (line, col) = find_order_by_position(&ctx.source).unwrap_or((1, 1));
51            diags.push(Diagnostic {
52                rule: "Ambiguous/SubqueryInOrderBy",
53                message: "Subquery in ORDER BY is ambiguous and potentially expensive".to_string(),
54                line,
55                col,
56            });
57        }
58    }
59}
60
61fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
62    match expr {
63        SetExpr::Select(select) => {
64            // Recurse into derived tables in FROM.
65            for twj in &select.from {
66                collect_from_table_factor(&twj.relation, ctx, diags);
67                for join in &twj.joins {
68                    collect_from_table_factor(&join.relation, ctx, diags);
69                }
70            }
71        }
72        SetExpr::Query(inner) => {
73            collect_from_query(inner, ctx, diags);
74        }
75        SetExpr::SetOperation { left, right, .. } => {
76            collect_from_set_expr(left, ctx, diags);
77            collect_from_set_expr(right, ctx, diags);
78        }
79        _ => {}
80    }
81}
82
83fn collect_from_table_factor(
84    factor: &TableFactor,
85    ctx: &FileContext,
86    diags: &mut Vec<Diagnostic>,
87) {
88    if let TableFactor::Derived { subquery, .. } = factor {
89        collect_from_query(subquery, ctx, diags);
90    }
91}
92
93/// Returns `true` if `expr` is or contains a subquery (`Subquery`, `InSubquery`,
94/// or `Exists`). Recurses into `BinaryOp`, `UnaryOp`, and `Nested` to catch
95/// subqueries embedded inside larger expressions.
96fn contains_subquery(expr: &Expr) -> bool {
97    match expr {
98        Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => true,
99        Expr::BinaryOp { left, right, .. } => {
100            contains_subquery(left) || contains_subquery(right)
101        }
102        Expr::UnaryOp { expr: inner, .. } => contains_subquery(inner),
103        Expr::Nested(inner) => contains_subquery(inner),
104        Expr::Case {
105            operand,
106            conditions,
107            results,
108            else_result,
109        } => {
110            operand.as_deref().is_some_and(contains_subquery)
111                || conditions.iter().any(contains_subquery)
112                || results.iter().any(contains_subquery)
113                || else_result.as_deref().is_some_and(contains_subquery)
114        }
115        _ => false,
116    }
117}
118
119/// Finds the first occurrence of `ORDER BY` (case-insensitive, outside string
120/// literals) and returns `Some((line, col))`. Returns `None` if not found.
121fn find_order_by_position(source: &str) -> Option<(usize, usize)> {
122    let bytes = source.as_bytes();
123    let upper = source.to_ascii_uppercase();
124    let upper_bytes = upper.as_bytes();
125    // "ORDER BY" is exactly 8 bytes.
126    let needle = b"ORDER BY";
127    let mut in_string = false;
128    let mut i = 0;
129
130    while i < bytes.len() {
131        // Track single-quoted SQL string literals.
132        if !in_string && bytes[i] == b'\'' {
133            in_string = true;
134            i += 1;
135            continue;
136        }
137        if in_string {
138            if bytes[i] == b'\'' {
139                // Escaped quote inside string: two consecutive single-quotes.
140                if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
141                    i += 2;
142                    continue;
143                }
144                in_string = false;
145            }
146            i += 1;
147            continue;
148        }
149
150        // Try to match "ORDER BY" at a word boundary.
151        if i + needle.len() <= upper_bytes.len()
152            && &upper_bytes[i..i + needle.len()] == needle
153        {
154            let before_ok =
155                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
156            let after = i + needle.len();
157            let after_ok = after >= bytes.len()
158                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
159            if before_ok && after_ok {
160                return Some(offset_to_line_col(source, i));
161            }
162        }
163
164        i += 1;
165    }
166
167    None
168}
169
170/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
171fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
172    let before = &source[..offset];
173    let line = before.chars().filter(|&c| c == '\n').count() + 1;
174    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
175    (line, col)
176}