1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Query,
3 Select, SelectItem, SetExpr, Statement, TableFactor};
4
5pub struct FunctionCallDepth {
6 pub max_depth: usize,
7}
8
9impl Default for FunctionCallDepth {
10 fn default() -> Self {
11 FunctionCallDepth { max_depth: 3 }
12 }
13}
14
15impl Rule for FunctionCallDepth {
16 fn name(&self) -> &'static str {
17 "FunctionCallDepth"
18 }
19
20 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
21 if !ctx.parse_errors.is_empty() {
22 return Vec::new();
23 }
24 let mut diags = Vec::new();
25 for stmt in &ctx.statements {
26 if let Statement::Query(query) = stmt {
27 check_query(query, self.max_depth, &ctx.source, &mut diags);
28 }
29 }
30 diags
31 }
32}
33
34fn check_query(query: &Query, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
35 if let Some(with) = &query.with {
36 for cte in &with.cte_tables {
37 check_query(&cte.query, max_depth, source, diags);
38 }
39 }
40 check_set_expr(&query.body, max_depth, source, diags);
41}
42
43fn check_set_expr(expr: &SetExpr, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
44 match expr {
45 SetExpr::Select(sel) => check_select(sel, max_depth, source, diags),
46 SetExpr::SetOperation { left, right, .. } => {
47 check_set_expr(left, max_depth, source, diags);
48 check_set_expr(right, max_depth, source, diags);
49 }
50 SetExpr::Query(inner) => check_query(inner, max_depth, source, diags),
51 _ => {}
52 }
53}
54
55fn check_select(sel: &Select, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
56 for item in &sel.projection {
57 match item {
58 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
59 check_top_expr(e, max_depth, source, diags);
60 }
61 _ => {}
62 }
63 }
64 if let Some(selection) = &sel.selection {
65 check_top_expr(selection, max_depth, source, diags);
66 }
67 if let Some(having) = &sel.having {
68 check_top_expr(having, max_depth, source, diags);
69 }
70 for twj in &sel.from {
71 recurse_table_factor(&twj.relation, max_depth, source, diags);
72 for join in &twj.joins {
73 recurse_table_factor(&join.relation, max_depth, source, diags);
74 }
75 }
76}
77
78fn recurse_table_factor(tf: &TableFactor, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
79 if let TableFactor::Derived { subquery, .. } = tf {
80 check_query(subquery, max_depth, source, diags);
81 }
82}
83
84fn check_top_expr(expr: &Expr, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
86 walk_expr_for_depth(expr, max_depth, source, diags);
89}
90
91fn walk_expr_for_depth(expr: &Expr, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
95 match expr {
96 Expr::Function(func) => {
97 let depth = function_depth(expr);
98 if depth > max_depth {
99 let (line, col) = find_function_position(source, func);
100 diags.push(Diagnostic {
101 rule: "FunctionCallDepth",
102 message: format!(
103 "Function call nesting depth {} exceeds maximum {}",
104 depth, max_depth
105 ),
106 line,
107 col,
108 });
109 }
110 }
111 Expr::BinaryOp { left, right, .. } => {
112 walk_expr_for_depth(left, max_depth, source, diags);
113 walk_expr_for_depth(right, max_depth, source, diags);
114 }
115 Expr::UnaryOp { expr: inner, .. } => walk_expr_for_depth(inner, max_depth, source, diags),
116 Expr::Nested(inner) => walk_expr_for_depth(inner, max_depth, source, diags),
117 Expr::Case { operand, conditions, results, else_result } => {
118 if let Some(op) = operand { walk_expr_for_depth(op, max_depth, source, diags); }
119 for c in conditions { walk_expr_for_depth(c, max_depth, source, diags); }
120 for r in results { walk_expr_for_depth(r, max_depth, source, diags); }
121 if let Some(e) = else_result { walk_expr_for_depth(e, max_depth, source, diags); }
122 }
123 _ => {}
124 }
125}
126
127fn function_depth(expr: &Expr) -> usize {
131 match expr {
132 Expr::Function(func) => {
133 let max_child = max_depth_in_args(func);
134 1 + max_child
135 }
136 Expr::Nested(inner) => function_depth(inner),
137 _ => 0,
138 }
139}
140
141fn max_depth_in_args(func: &Function) -> usize {
142 let mut max = 0usize;
143 let args = match &func.args {
144 FunctionArguments::List(list) => list.args.as_slice(),
145 _ => return 0,
146 };
147 for arg in args {
148 let d = match arg {
149 FunctionArg::Named { arg, .. }
150 | FunctionArg::Unnamed(arg)
151 | FunctionArg::ExprNamed { arg, .. } => match arg {
152 FunctionArgExpr::Expr(e) => function_depth(e),
153 _ => 0,
154 },
155 };
156 if d > max {
157 max = d;
158 }
159 }
160 max
161}
162
163fn find_function_position(source: &str, func: &Function) -> (usize, usize) {
164 let name = func.name.to_string();
166 find_keyword_position(source, &name)
167}
168
169fn find_keyword_position(source: &str, keyword: &str) -> (usize, usize) {
170 let upper = source.to_uppercase();
171 let kw_upper = keyword.to_uppercase();
172 let bytes = upper.as_bytes();
173 let kw_bytes = kw_upper.as_bytes();
174 let kw_len = kw_bytes.len();
175
176 if kw_len == 0 {
177 return (1, 1);
178 }
179
180 let mut i = 0;
181 while i + kw_len <= bytes.len() {
182 if bytes[i..i + kw_len] == *kw_bytes {
183 let before_ok =
184 i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
185 let after = i + kw_len;
186 let after_ok = after >= bytes.len()
187 || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
188 if before_ok && after_ok {
189 return offset_to_line_col(source, i);
190 }
191 }
192 i += 1;
193 }
194 (1, 1)
195}
196
197fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
198 let before = &source[..offset];
199 let line = before.chars().filter(|&c| c == '\n').count() + 1;
200 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
201 (line, col)
202}