1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, OrderByExpr, Query, Select,
4 SelectItem, SetExpr, Statement, TableFactor, WindowFrameBound, WindowFrameUnits, WindowType,
5};
6
7pub struct WindowFrameAllRows;
8
9impl Rule for WindowFrameAllRows {
10 fn name(&self) -> &'static str {
11 "Structure/WindowFrameAllRows"
12 }
13
14 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15 if !ctx.parse_errors.is_empty() {
16 return Vec::new();
17 }
18
19 let mut diags = Vec::new();
20
21 for stmt in &ctx.statements {
22 if let Statement::Query(query) = stmt {
23 check_query(query, self.name(), ctx, &mut diags);
24 }
25 }
26
27 diags
28 }
29}
30
31fn check_query(
34 query: &Query,
35 rule: &'static str,
36 ctx: &FileContext,
37 diags: &mut Vec<Diagnostic>,
38) {
39 if let Some(with) = &query.with {
41 for cte in &with.cte_tables {
42 check_query(&cte.query, rule, ctx, diags);
43 }
44 }
45
46 check_set_expr(&query.body, rule, ctx, diags);
47
48 if let Some(order_by) = &query.order_by {
50 for ob in &order_by.exprs {
51 check_order_by_expr(ob, rule, ctx, diags);
52 }
53 }
54}
55
56fn check_set_expr(
57 expr: &SetExpr,
58 rule: &'static str,
59 ctx: &FileContext,
60 diags: &mut Vec<Diagnostic>,
61) {
62 match expr {
63 SetExpr::Select(sel) => {
64 check_select(sel, rule, ctx, diags);
65 }
66 SetExpr::Query(inner) => {
67 check_query(inner, rule, ctx, diags);
68 }
69 SetExpr::SetOperation { left, right, .. } => {
70 check_set_expr(left, rule, ctx, diags);
71 check_set_expr(right, rule, ctx, diags);
72 }
73 _ => {}
74 }
75}
76
77fn check_select(
78 sel: &Select,
79 rule: &'static str,
80 ctx: &FileContext,
81 diags: &mut Vec<Diagnostic>,
82) {
83 for item in &sel.projection {
84 if let SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } = item {
85 check_expr(e, rule, ctx, diags);
86 }
87 }
88
89 if let Some(selection) = &sel.selection {
90 check_expr(selection, rule, ctx, diags);
91 }
92
93 if let Some(having) = &sel.having {
94 check_expr(having, rule, ctx, diags);
95 }
96
97 for twj in &sel.from {
98 check_table_factor(&twj.relation, rule, ctx, diags);
99 for join in &twj.joins {
100 check_table_factor(&join.relation, rule, ctx, diags);
101 }
102 }
103}
104
105fn check_table_factor(
106 tf: &TableFactor,
107 rule: &'static str,
108 ctx: &FileContext,
109 diags: &mut Vec<Diagnostic>,
110) {
111 if let TableFactor::Derived { subquery, .. } = tf {
112 check_query(subquery, rule, ctx, diags);
113 }
114}
115
116fn check_order_by_expr(
117 ob: &OrderByExpr,
118 rule: &'static str,
119 ctx: &FileContext,
120 diags: &mut Vec<Diagnostic>,
121) {
122 check_expr(&ob.expr, rule, ctx, diags);
123}
124
125fn check_expr(
126 expr: &Expr,
127 rule: &'static str,
128 ctx: &FileContext,
129 diags: &mut Vec<Diagnostic>,
130) {
131 match expr {
132 Expr::Function(func) => {
133 check_function(func, rule, ctx, diags);
134 }
135 Expr::BinaryOp { left, right, .. } => {
136 check_expr(left, rule, ctx, diags);
137 check_expr(right, rule, ctx, diags);
138 }
139 Expr::UnaryOp { expr, .. } => {
140 check_expr(expr, rule, ctx, diags);
141 }
142 Expr::Nested(e) => {
143 check_expr(e, rule, ctx, diags);
144 }
145 Expr::IsNull(e) => {
146 check_expr(e, rule, ctx, diags);
147 }
148 Expr::IsNotNull(e) => {
149 check_expr(e, rule, ctx, diags);
150 }
151 Expr::Case {
152 operand,
153 conditions,
154 results,
155 else_result,
156 } => {
157 if let Some(op) = operand {
158 check_expr(op, rule, ctx, diags);
159 }
160 for c in conditions {
161 check_expr(c, rule, ctx, diags);
162 }
163 for r in results {
164 check_expr(r, rule, ctx, diags);
165 }
166 if let Some(el) = else_result {
167 check_expr(el, rule, ctx, diags);
168 }
169 }
170 Expr::Subquery(q) => {
171 check_query(q, rule, ctx, diags);
172 }
173 Expr::InSubquery { subquery, .. } => {
174 check_query(subquery, rule, ctx, diags);
175 }
176 Expr::Exists { subquery, .. } => {
177 check_query(subquery, rule, ctx, diags);
178 }
179 _ => {}
180 }
181}
182
183fn check_function(
186 func: &Function,
187 rule: &'static str,
188 ctx: &FileContext,
189 diags: &mut Vec<Diagnostic>,
190) {
191 if let Some(WindowType::WindowSpec(spec)) = &func.over {
192 if spec.partition_by.is_empty() {
193 if let Some(frame) = &spec.window_frame {
194 if is_rows_unbounded_all(frame) {
195 let (line, col) = find_over_pos(&ctx.source);
196 diags.push(Diagnostic {
197 rule,
198 message: "Window function with ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING and no PARTITION BY processes the entire table — verify this is intentional".to_string(),
199 line,
200 col,
201 });
202 }
203 }
204 }
205 }
206
207 if let FunctionArguments::List(list) = &func.args {
209 for arg in &list.args {
210 let fae = match arg {
211 FunctionArg::Named { arg, .. }
212 | FunctionArg::ExprNamed { arg, .. }
213 | FunctionArg::Unnamed(arg) => arg,
214 };
215 if let FunctionArgExpr::Expr(e) = fae {
216 check_expr(e, rule, ctx, diags);
217 }
218 }
219 }
220}
221
222fn is_rows_unbounded_all(frame: &sqlparser::ast::WindowFrame) -> bool {
225 if frame.units != WindowFrameUnits::Rows {
226 return false;
227 }
228 let start_ok = matches!(frame.start_bound, WindowFrameBound::Preceding(None));
229 let end_ok = matches!(
230 frame.end_bound,
231 Some(WindowFrameBound::Following(None))
232 );
233 start_ok && end_ok
234}
235
236fn find_over_pos(source: &str) -> (usize, usize) {
241 let keyword = "OVER";
242 let upper = source.to_uppercase();
243 let kw_len = keyword.len();
244 let bytes = upper.as_bytes();
245 let len = bytes.len();
246
247 let mut pos = 0;
248 while pos + kw_len <= len {
249 if let Some(rel) = upper[pos..].find(keyword) {
250 let abs = pos + rel;
251
252 let before_ok = abs == 0 || {
253 let b = bytes[abs - 1];
254 !b.is_ascii_alphanumeric() && b != b'_'
255 };
256 let after = abs + kw_len;
257 let after_ok = after >= len || {
258 let b = bytes[after];
259 !b.is_ascii_alphanumeric() && b != b'_'
260 };
261
262 if before_ok && after_ok {
263 return line_col(source, abs);
264 }
265
266 pos = abs + 1;
267 } else {
268 break;
269 }
270 }
271
272 (1, 1)
273}
274
275fn line_col(source: &str, offset: usize) -> (usize, usize) {
277 let before = &source[..offset];
278 let line = before.chars().filter(|&c| c == '\n').count() + 1;
279 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
280 (line, col)
281}