sqrust_rules/ambiguous/
subquery_in_order_by.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, OrderBy, Query, SetExpr, Statement, TableFactor};
3
4pub struct SubqueryInOrderBy;
5
6impl Rule for SubqueryInOrderBy {
7 fn name(&self) -> &'static str {
8 "Ambiguous/SubqueryInOrderBy"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
13 return Vec::new();
14 }
15
16 let mut diags = Vec::new();
17 for stmt in &ctx.statements {
18 collect_from_statement(stmt, ctx, &mut diags);
19 }
20 diags
21 }
22}
23
24fn collect_from_statement(stmt: &Statement, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
25 if let Statement::Query(query) = stmt {
26 collect_from_query(query, ctx, diags);
27 }
28}
29
30fn collect_from_query(query: &Query, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
31 if let Some(with) = &query.with {
33 for cte in &with.cte_tables {
34 collect_from_query(&cte.query, ctx, diags);
35 }
36 }
37
38 if let Some(order_by) = &query.order_by {
40 check_order_by(order_by, ctx, diags);
41 }
42
43 collect_from_set_expr(&query.body, ctx, diags);
45}
46
47fn check_order_by(order_by: &OrderBy, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
48 for order_expr in &order_by.exprs {
49 if contains_subquery(&order_expr.expr) {
50 let (line, col) = find_order_by_position(&ctx.source).unwrap_or((1, 1));
51 diags.push(Diagnostic {
52 rule: "Ambiguous/SubqueryInOrderBy",
53 message: "Subquery in ORDER BY is ambiguous and potentially expensive".to_string(),
54 line,
55 col,
56 });
57 }
58 }
59}
60
61fn collect_from_set_expr(expr: &SetExpr, ctx: &FileContext, diags: &mut Vec<Diagnostic>) {
62 match expr {
63 SetExpr::Select(select) => {
64 for twj in &select.from {
66 collect_from_table_factor(&twj.relation, ctx, diags);
67 for join in &twj.joins {
68 collect_from_table_factor(&join.relation, ctx, diags);
69 }
70 }
71 }
72 SetExpr::Query(inner) => {
73 collect_from_query(inner, ctx, diags);
74 }
75 SetExpr::SetOperation { left, right, .. } => {
76 collect_from_set_expr(left, ctx, diags);
77 collect_from_set_expr(right, ctx, diags);
78 }
79 _ => {}
80 }
81}
82
83fn collect_from_table_factor(
84 factor: &TableFactor,
85 ctx: &FileContext,
86 diags: &mut Vec<Diagnostic>,
87) {
88 if let TableFactor::Derived { subquery, .. } = factor {
89 collect_from_query(subquery, ctx, diags);
90 }
91}
92
93fn contains_subquery(expr: &Expr) -> bool {
97 match expr {
98 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => true,
99 Expr::BinaryOp { left, right, .. } => {
100 contains_subquery(left) || contains_subquery(right)
101 }
102 Expr::UnaryOp { expr: inner, .. } => contains_subquery(inner),
103 Expr::Nested(inner) => contains_subquery(inner),
104 Expr::Case {
105 operand,
106 conditions,
107 results,
108 else_result,
109 } => {
110 operand.as_deref().is_some_and(contains_subquery)
111 || conditions.iter().any(contains_subquery)
112 || results.iter().any(contains_subquery)
113 || else_result.as_deref().is_some_and(contains_subquery)
114 }
115 _ => false,
116 }
117}
118
119fn find_order_by_position(source: &str) -> Option<(usize, usize)> {
122 let bytes = source.as_bytes();
123 let upper = source.to_ascii_uppercase();
124 let upper_bytes = upper.as_bytes();
125 let needle = b"ORDER BY";
127 let mut in_string = false;
128 let mut i = 0;
129
130 while i < bytes.len() {
131 if !in_string && bytes[i] == b'\'' {
133 in_string = true;
134 i += 1;
135 continue;
136 }
137 if in_string {
138 if bytes[i] == b'\'' {
139 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
141 i += 2;
142 continue;
143 }
144 in_string = false;
145 }
146 i += 1;
147 continue;
148 }
149
150 if i + needle.len() <= upper_bytes.len()
152 && &upper_bytes[i..i + needle.len()] == needle
153 {
154 let before_ok =
155 i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
156 let after = i + needle.len();
157 let after_ok = after >= bytes.len()
158 || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
159 if before_ok && after_ok {
160 return Some(offset_to_line_col(source, i));
161 }
162 }
163
164 i += 1;
165 }
166
167 None
168}
169
170fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
172 let before = &source[..offset];
173 let line = before.chars().filter(|&c| c == '\n').count() + 1;
174 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
175 (line, col)
176}