Skip to main content

sqrust_rules/structure/
function_call_depth.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, Query,
3    Select, SelectItem, SetExpr, Statement, TableFactor};
4
5pub struct FunctionCallDepth {
6    pub max_depth: usize,
7}
8
9impl Default for FunctionCallDepth {
10    fn default() -> Self {
11        FunctionCallDepth { max_depth: 3 }
12    }
13}
14
15impl Rule for FunctionCallDepth {
16    fn name(&self) -> &'static str {
17        "FunctionCallDepth"
18    }
19
20    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
21        if !ctx.parse_errors.is_empty() {
22            return Vec::new();
23        }
24        let mut diags = Vec::new();
25        for stmt in &ctx.statements {
26            if let Statement::Query(query) = stmt {
27                check_query(query, self.max_depth, &ctx.source, &mut diags);
28            }
29        }
30        diags
31    }
32}
33
34fn check_query(query: &Query, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
35    if let Some(with) = &query.with {
36        for cte in &with.cte_tables {
37            check_query(&cte.query, max_depth, source, diags);
38        }
39    }
40    check_set_expr(&query.body, max_depth, source, diags);
41}
42
43fn check_set_expr(expr: &SetExpr, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
44    match expr {
45        SetExpr::Select(sel) => check_select(sel, max_depth, source, diags),
46        SetExpr::SetOperation { left, right, .. } => {
47            check_set_expr(left, max_depth, source, diags);
48            check_set_expr(right, max_depth, source, diags);
49        }
50        SetExpr::Query(inner) => check_query(inner, max_depth, source, diags),
51        _ => {}
52    }
53}
54
55fn check_select(sel: &Select, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
56    for item in &sel.projection {
57        match item {
58            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
59                check_top_expr(e, max_depth, source, diags);
60            }
61            _ => {}
62        }
63    }
64    if let Some(selection) = &sel.selection {
65        check_top_expr(selection, max_depth, source, diags);
66    }
67    if let Some(having) = &sel.having {
68        check_top_expr(having, max_depth, source, diags);
69    }
70    for twj in &sel.from {
71        recurse_table_factor(&twj.relation, max_depth, source, diags);
72        for join in &twj.joins {
73            recurse_table_factor(&join.relation, max_depth, source, diags);
74        }
75    }
76}
77
78fn recurse_table_factor(tf: &TableFactor, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
79    if let TableFactor::Derived { subquery, .. } = tf {
80        check_query(subquery, max_depth, source, diags);
81    }
82}
83
84/// Entry point for an expression; checks if any function call chain exceeds max_depth.
85fn check_top_expr(expr: &Expr, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
86    // Walk the expression tree. For each function call node, compute its depth
87    // and report if over max_depth.
88    walk_expr_for_depth(expr, max_depth, source, diags);
89}
90
91/// Returns the nesting depth of function calls starting at `expr` as the top-level root.
92/// A plain function call (no nested functions) has depth 1.
93/// Reports violations for any root function call that exceeds max_depth.
94fn walk_expr_for_depth(expr: &Expr, max_depth: usize, source: &str, diags: &mut Vec<Diagnostic>) {
95    match expr {
96        Expr::Function(func) => {
97            let depth = function_depth(expr);
98            if depth > max_depth {
99                let (line, col) = find_function_position(source, func);
100                diags.push(Diagnostic {
101                    rule: "FunctionCallDepth",
102                    message: format!(
103                        "Function call nesting depth {} exceeds maximum {}",
104                        depth, max_depth
105                    ),
106                    line,
107                    col,
108                });
109            }
110        }
111        Expr::BinaryOp { left, right, .. } => {
112            walk_expr_for_depth(left, max_depth, source, diags);
113            walk_expr_for_depth(right, max_depth, source, diags);
114        }
115        Expr::UnaryOp { expr: inner, .. } => walk_expr_for_depth(inner, max_depth, source, diags),
116        Expr::Nested(inner) => walk_expr_for_depth(inner, max_depth, source, diags),
117        Expr::Case { operand, conditions, results, else_result } => {
118            if let Some(op) = operand { walk_expr_for_depth(op, max_depth, source, diags); }
119            for c in conditions { walk_expr_for_depth(c, max_depth, source, diags); }
120            for r in results { walk_expr_for_depth(r, max_depth, source, diags); }
121            if let Some(e) = else_result { walk_expr_for_depth(e, max_depth, source, diags); }
122        }
123        _ => {}
124    }
125}
126
127/// Computes the maximum function call nesting depth of a subtree.
128/// `Expr::Function` at a leaf → depth 1.
129/// `f(g(x))` → depth 2.
130fn function_depth(expr: &Expr) -> usize {
131    match expr {
132        Expr::Function(func) => {
133            let max_child = max_depth_in_args(func);
134            1 + max_child
135        }
136        Expr::Nested(inner) => function_depth(inner),
137        _ => 0,
138    }
139}
140
141fn max_depth_in_args(func: &Function) -> usize {
142    let mut max = 0usize;
143    let args = match &func.args {
144        FunctionArguments::List(list) => list.args.as_slice(),
145        _ => return 0,
146    };
147    for arg in args {
148        let d = match arg {
149            FunctionArg::Named { arg, .. }
150            | FunctionArg::Unnamed(arg)
151            | FunctionArg::ExprNamed { arg, .. } => match arg {
152                FunctionArgExpr::Expr(e) => function_depth(e),
153                _ => 0,
154            },
155        };
156        if d > max {
157            max = d;
158        }
159    }
160    max
161}
162
163fn find_function_position(source: &str, func: &Function) -> (usize, usize) {
164    // Use the function name to find the first occurrence in source
165    let name = func.name.to_string();
166    find_keyword_position(source, &name)
167}
168
169fn find_keyword_position(source: &str, keyword: &str) -> (usize, usize) {
170    let upper = source.to_uppercase();
171    let kw_upper = keyword.to_uppercase();
172    let bytes = upper.as_bytes();
173    let kw_bytes = kw_upper.as_bytes();
174    let kw_len = kw_bytes.len();
175
176    if kw_len == 0 {
177        return (1, 1);
178    }
179
180    let mut i = 0;
181    while i + kw_len <= bytes.len() {
182        if bytes[i..i + kw_len] == *kw_bytes {
183            let before_ok =
184                i == 0 || (!bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_');
185            let after = i + kw_len;
186            let after_ok = after >= bytes.len()
187                || (!bytes[after].is_ascii_alphanumeric() && bytes[after] != b'_');
188            if before_ok && after_ok {
189                return offset_to_line_col(source, i);
190            }
191        }
192        i += 1;
193    }
194    (1, 1)
195}
196
197fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
198    let before = &source[..offset];
199    let line = before.chars().filter(|&c| c == '\n').count() + 1;
200    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
201    (line, col)
202}