vibesql_executor/procedural/
function.rs

1//! User-defined function execution
2//!
3//! This module handles execution of user-defined functions (UDFs) with:
4//! - Parameter binding
5//! - BEGIN...END body execution
6//! - Local variables and control flow
7//! - Recursion depth limiting
8//! - Read-only enforcement
9
10use crate::errors::ExecutorError;
11use crate::procedural::ExecutionContext;
12use vibesql_ast::Expression;
13use vibesql_catalog::{Function, FunctionBody};
14use vibesql_parser::Parser;
15use vibesql_storage::Database;
16use vibesql_types::SqlValue;
17
18/// Execute a user-defined function
19///
20/// ## Steps:
21/// 1. Validate argument count matches parameter count
22/// 2. Create function execution context (read-only, isolated)
23/// 3. Check and increment recursion depth
24/// 4. Bind argument values to function parameters
25/// 5. Parse and execute function body
26/// 6. Handle RETURN control flow
27/// 7. Decrement recursion depth and return value
28///
29/// ## Errors:
30/// - ArgumentCountMismatch: Wrong number of arguments
31/// - RecursionLimitExceeded: Too many nested function calls
32/// - FunctionMustReturn: Function exited without RETURN statement
33/// - Other execution errors from function body
34pub fn execute_user_function(
35    func: &Function,
36    args: &[Expression],
37    db: &mut Database,
38) -> Result<SqlValue, ExecutorError> {
39    // 1. Validate argument count
40    if args.len() != func.parameters.len() {
41        return Err(ExecutorError::ArgumentCountMismatch {
42            expected: func.parameters.len(),
43            actual: args.len(),
44        });
45    }
46
47    // 2. Create function execution context (read-only)
48    let mut ctx = ExecutionContext::new_function_context();
49
50    // 3. Check recursion depth
51    ctx.enter_recursion().map_err(|e| ExecutorError::RecursionLimitExceeded {
52        message: e,
53        call_stack: vec![], // TODO: Track call stack in Phase 7
54        max_depth: 100,
55    })?;
56
57    // 4. Bind arguments to parameters
58    // Note: All function parameters are IN only (no OUT/INOUT)
59    for (param, arg_expr) in func.parameters.iter().zip(args.iter()) {
60        // Evaluate the argument expression
61        let value = super::executor::evaluate_expression(arg_expr, db, &ctx)?;
62        ctx.set_parameter(&param.name, value);
63    }
64
65    // 5. Parse and execute function body
66    let result = match &func.body {
67        FunctionBody::RawSql(sql) => {
68            // Simple RETURN expression function (e.g., "RETURN x + 10")
69            execute_simple_return(sql, &mut ctx, db)?
70        }
71        FunctionBody::BeginEnd(sql) => {
72            // Complex BEGIN...END function body
73            execute_begin_end_body(sql, &mut ctx, db, &func.name)?
74        }
75    };
76
77    // 6. Decrement recursion depth
78    ctx.exit_recursion();
79
80    Ok(result)
81}
82
83/// Execute a simple RETURN expression function
84///
85/// Handles functions like:
86/// ```sql
87/// CREATE FUNCTION add_ten(x INT) RETURNS INT
88///   RETURN x + 10;
89/// ```
90fn execute_simple_return(
91    sql: &str,
92    ctx: &mut ExecutionContext,
93    db: &mut Database,
94) -> Result<SqlValue, ExecutorError> {
95    // The sql should just be the RETURN expression: "RETURN x + 10"
96    let sql_trimmed = sql.trim();
97
98    // For simple RETURN functions, we need to parse and execute the RETURN statement
99    // Since there's no Statement::Procedural variant, we'll need to evaluate the expression directly
100
101    // Extract the expression after RETURN keyword
102    let return_upper = "RETURN";
103    if !sql_trimmed.to_uppercase().starts_with(return_upper) {
104        return Err(ExecutorError::InvalidFunctionBody(
105            "Simple function body must start with RETURN".to_string(),
106        ));
107    }
108
109    let expr_str = sql_trimmed[return_upper.len()..].trim();
110
111    // Parse just the expression
112    // We'll use Parser::parse_sql with a SELECT wrapper to get the expression parsed
113    let select_wrapper = format!("SELECT {}", expr_str);
114    let stmt = Parser::parse_sql(&select_wrapper).map_err(|e| {
115        ExecutorError::ParseError(format!("Failed to parse RETURN expression: {}", e))
116    })?;
117
118    // Extract the expression from the SELECT
119    #[allow(clippy::collapsible_match)]
120    if let vibesql_ast::Statement::Select(select_stmt) = stmt {
121        if let Some(first_item) = select_stmt.select_list.first() {
122            if let vibesql_ast::SelectItem::Expression { expr, alias: _ } = first_item {
123                // Evaluate the expression in the procedural context
124                let value = super::executor::evaluate_expression(expr, db, ctx)?;
125                return Ok(value);
126            }
127        }
128    }
129
130    Err(ExecutorError::InvalidFunctionBody("Could not parse RETURN expression".to_string()))
131}
132
133/// Execute a BEGIN...END function body
134///
135/// Handles functions like:
136/// ```sql
137/// CREATE FUNCTION factorial(n INT) RETURNS INT
138/// BEGIN
139///   DECLARE result INT DEFAULT 1;
140///   DECLARE i INT DEFAULT 2;
141///   WHILE i <= n DO
142///     SET result = result * i;
143///     SET i = i + 1;
144///   END WHILE;
145///   RETURN result;
146/// END;
147/// ```
148fn execute_begin_end_body(
149    _sql: &str,
150    _ctx: &mut ExecutionContext,
151    _db: &mut Database,
152    func_name: &str,
153) -> Result<SqlValue, ExecutorError> {
154    // TODO: Complex BEGIN...END functions require FunctionBody to store Vec<ProceduralStatement>
155    // like ProcedureBody does, rather than raw SQL strings.
156    // This is tracked in the issue as a known limitation.
157    //
158    // For now, we'll return an error directing users to use simple RETURN functions.
159    // The proper fix is to update FunctionBody::BeginEnd to store Vec<ProceduralStatement>
160    // and update the parser to populate it correctly.
161
162    Err(ExecutorError::UnsupportedFeature(format!(
163        "Complex BEGIN...END function bodies not yet fully supported for function '{}'. \
164             Use simple RETURN expression functions (e.g., 'RETURN x + 10') instead. \
165             Full BEGIN...END support requires updating FunctionBody to store parsed statements.",
166        func_name
167    )))
168}
169
170impl ExecutionContext {
171    /// Create a new function execution context (read-only)
172    pub fn new_function_context() -> Self {
173        let mut ctx = Self::new();
174        ctx.is_function = true;
175        ctx
176    }
177
178    /// Check if this is a function context (read-only)
179    pub fn is_function_context(&self) -> bool {
180        self.is_function
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use vibesql_catalog::FunctionParam;
188    use vibesql_types::DataType;
189
190    #[test]
191    fn test_simple_return_function() {
192        let mut db = Database::new();
193
194        // CREATE FUNCTION add_ten(x INT) RETURNS INT RETURN x + 10;
195        let func = Function::new(
196            "add_ten".to_string(),
197            "public".to_string(),
198            vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
199            DataType::Integer,
200            FunctionBody::RawSql("RETURN x + 10".to_string()),
201        );
202
203        // Call with argument 5
204        let args = vec![Expression::Literal(SqlValue::Integer(5))];
205        let result = execute_user_function(&func, &args, &mut db).unwrap();
206
207        assert_eq!(result, SqlValue::Integer(15));
208    }
209
210    #[test]
211    fn test_argument_count_mismatch() {
212        let mut db = Database::new();
213
214        let func = Function::new(
215            "add_ten".to_string(),
216            "public".to_string(),
217            vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
218            DataType::Integer,
219            FunctionBody::RawSql("RETURN x + 10".to_string()),
220        );
221
222        // Call with no arguments (should fail)
223        let args = vec![];
224        let result = execute_user_function(&func, &args, &mut db);
225
226        assert!(matches!(result, Err(ExecutorError::ArgumentCountMismatch { .. })));
227    }
228
229    #[test]
230    fn test_begin_end_not_supported_yet() {
231        let mut db = Database::new();
232
233        // CREATE FUNCTION factorial(n INT) RETURNS INT BEGIN...END
234        let func = Function::new(
235            "factorial".to_string(),
236            "public".to_string(),
237            vec![FunctionParam { name: "n".to_string(), data_type: DataType::Integer }],
238            DataType::Integer,
239            FunctionBody::BeginEnd("BEGIN DECLARE x INT DEFAULT 5; RETURN x; END".to_string()),
240        );
241
242        let args = vec![Expression::Literal(SqlValue::Integer(5))];
243        let result = execute_user_function(&func, &args, &mut db);
244
245        // Should return UnsupportedFeature for now
246        assert!(matches!(result, Err(ExecutorError::UnsupportedFeature(_))));
247    }
248}