1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor};
3
4pub struct RecursiveCte;
5
6impl Rule for RecursiveCte {
7 fn name(&self) -> &'static str {
8 "Lint/RecursiveCte"
9 }
10
11 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 let source = &ctx.source;
19 let source_upper = source.to_uppercase();
20 let mut search_from = 0usize;
21
22 for stmt in &ctx.statements {
23 if let Statement::Query(q) = stmt {
24 collect_from_query(q, source, &source_upper, &mut search_from, self.name(), &mut diags);
25 }
26 }
27
28 diags
29 }
30}
31
32fn collect_from_query(
34 query: &Query,
35 source: &str,
36 source_upper: &str,
37 search_from: &mut usize,
38 rule_name: &'static str,
39 diags: &mut Vec<Diagnostic>,
40) {
41 if let Some(with) = &query.with {
42 if with.recursive {
43 let (line, col) =
44 find_keyword_position(source, source_upper, "WITH RECURSIVE", search_from);
45 diags.push(Diagnostic {
46 rule: rule_name,
47 message:
48 "WITH RECURSIVE CTE may loop indefinitely; ensure a correct termination condition"
49 .to_string(),
50 line,
51 col,
52 });
53 } else {
54 advance_past_keyword(source, source_upper, "WITH", search_from);
56 }
57
58 for cte in &with.cte_tables {
60 collect_from_query(&cte.query, source, source_upper, search_from, rule_name, diags);
61 }
62 }
63
64 collect_from_set_expr(&query.body, source, source_upper, search_from, rule_name, diags);
66}
67
68fn collect_from_set_expr(
69 expr: &SetExpr,
70 source: &str,
71 source_upper: &str,
72 search_from: &mut usize,
73 rule_name: &'static str,
74 diags: &mut Vec<Diagnostic>,
75) {
76 match expr {
77 SetExpr::Select(select) => {
78 collect_from_select(select, source, source_upper, search_from, rule_name, diags);
79 }
80 SetExpr::Query(inner) => {
81 collect_from_query(inner, source, source_upper, search_from, rule_name, diags);
82 }
83 SetExpr::SetOperation { left, right, .. } => {
84 collect_from_set_expr(left, source, source_upper, search_from, rule_name, diags);
85 collect_from_set_expr(right, source, source_upper, search_from, rule_name, diags);
86 }
87 _ => {}
88 }
89}
90
91fn collect_from_select(
92 select: &Select,
93 source: &str,
94 source_upper: &str,
95 search_from: &mut usize,
96 rule_name: &'static str,
97 diags: &mut Vec<Diagnostic>,
98) {
99 for twj in &select.from {
101 collect_from_table_factor(&twj.relation, source, source_upper, search_from, rule_name, diags);
102 for join in &twj.joins {
103 collect_from_table_factor(&join.relation, source, source_upper, search_from, rule_name, diags);
104 }
105 }
106
107 for item in &select.projection {
109 let expr = match item {
110 SelectItem::UnnamedExpr(e) => Some(e),
111 SelectItem::ExprWithAlias { expr: e, .. } => Some(e),
112 _ => None,
113 };
114 if let Some(e) = expr {
115 collect_from_expr(e, source, source_upper, search_from, rule_name, diags);
116 }
117 }
118
119 if let Some(selection) = &select.selection {
121 collect_from_expr(selection, source, source_upper, search_from, rule_name, diags);
122 }
123}
124
125fn collect_from_table_factor(
126 factor: &TableFactor,
127 source: &str,
128 source_upper: &str,
129 search_from: &mut usize,
130 rule_name: &'static str,
131 diags: &mut Vec<Diagnostic>,
132) {
133 if let TableFactor::Derived { subquery, .. } = factor {
134 collect_from_query(subquery, source, source_upper, search_from, rule_name, diags);
135 }
136}
137
138fn collect_from_expr(
139 expr: &Expr,
140 source: &str,
141 source_upper: &str,
142 search_from: &mut usize,
143 rule_name: &'static str,
144 diags: &mut Vec<Diagnostic>,
145) {
146 match expr {
147 Expr::Subquery(q) | Expr::InSubquery { subquery: q, .. } | Expr::Exists { subquery: q, .. } => {
148 collect_from_query(q, source, source_upper, search_from, rule_name, diags);
149 }
150 _ => {}
151 }
152}
153
154fn find_keyword_position(
158 source: &str,
159 source_upper: &str,
160 keyword: &str,
161 search_from: &mut usize,
162) -> (usize, usize) {
163 let (line, col, new_from) = find_keyword_inner(source, source_upper, keyword, *search_from);
164 *search_from = new_from;
165 (line, col)
166}
167
168fn advance_past_keyword(
171 source: &str,
172 source_upper: &str,
173 keyword: &str,
174 search_from: &mut usize,
175) {
176 let (_, _, new_from) = find_keyword_inner(source, source_upper, keyword, *search_from);
177 *search_from = new_from;
178}
179
180fn find_keyword_inner(
183 source: &str,
184 source_upper: &str,
185 keyword: &str,
186 start: usize,
187) -> (usize, usize, usize) {
188 let kw_len = keyword.len();
189 let bytes = source_upper.as_bytes();
190 let text_len = bytes.len();
191
192 let mut pos = start;
193 while pos < text_len {
194 let Some(rel) = source_upper[pos..].find(keyword) else {
195 break;
196 };
197 let abs = pos + rel;
198
199 let before_ok = abs == 0
201 || {
202 let b = bytes[abs - 1];
203 !b.is_ascii_alphanumeric() && b != b'_'
204 };
205 let after = abs + kw_len;
207 let after_ok = after >= text_len
208 || {
209 let b = bytes[after];
210 !b.is_ascii_alphanumeric() && b != b'_'
211 };
212
213 if before_ok && after_ok {
214 let (line, col) = offset_to_line_col(source, abs);
215 return (line, col, after);
216 }
217 pos = abs + 1;
218 }
219
220 (1, 1, start)
221}
222
223fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
225 let before = &source[..offset];
226 let line = before.chars().filter(|&c| c == '\n').count() + 1;
227 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
228 (line, col)
229}