vibesql_executor/procedural/
function.rs

1//! User-defined function execution
2//!
3//! This module handles execution of user-defined functions (UDFs) with:
4//! - Parameter binding
5//! - BEGIN...END body execution
6//! - Local variables and control flow
7//! - Recursion depth limiting
8//! - Read-only enforcement
9
10use crate::errors::ExecutorError;
11use crate::procedural::ExecutionContext;
12use vibesql_ast::Expression;
13use vibesql_catalog::{Function, FunctionBody};
14use vibesql_parser::Parser;
15use vibesql_storage::Database;
16use vibesql_types::SqlValue;
17
18/// Execute a user-defined function
19///
20/// ## Steps:
21/// 1. Validate argument count matches parameter count
22/// 2. Create function execution context (read-only, isolated)
23/// 3. Check and increment recursion depth
24/// 4. Bind argument values to function parameters
25/// 5. Parse and execute function body
26/// 6. Handle RETURN control flow
27/// 7. Decrement recursion depth and return value
28///
29/// ## Errors:
30/// - ArgumentCountMismatch: Wrong number of arguments
31/// - RecursionLimitExceeded: Too many nested function calls
32/// - FunctionMustReturn: Function exited without RETURN statement
33/// - Other execution errors from function body
34pub fn execute_user_function(
35    func: &Function,
36    args: &[Expression],
37    db: &mut Database,
38) -> Result<SqlValue, ExecutorError> {
39    // 1. Validate argument count
40    if args.len() != func.parameters.len() {
41        return Err(ExecutorError::ArgumentCountMismatch {
42            expected: func.parameters.len(),
43            actual: args.len(),
44        });
45    }
46
47    // 2. Create function execution context (read-only)
48    let mut ctx = ExecutionContext::new_function_context();
49
50    // 3. Check recursion depth
51    ctx.enter_recursion()
52        .map_err(|e| ExecutorError::RecursionLimitExceeded {
53            message: e,
54            call_stack: vec![],  // TODO: Track call stack in Phase 7
55            max_depth: 100,
56        })?;
57
58    // 4. Bind arguments to parameters
59    // Note: All function parameters are IN only (no OUT/INOUT)
60    for (param, arg_expr) in func.parameters.iter().zip(args.iter()) {
61        // Evaluate the argument expression
62        let value = super::executor::evaluate_expression(arg_expr, db, &ctx)?;
63        ctx.set_parameter(&param.name, value);
64    }
65
66    // 5. Parse and execute function body
67    let result = match &func.body {
68        FunctionBody::RawSql(sql) => {
69            // Simple RETURN expression function (e.g., "RETURN x + 10")
70            execute_simple_return(sql, &mut ctx, db)?
71        }
72        FunctionBody::BeginEnd(sql) => {
73            // Complex BEGIN...END function body
74            execute_begin_end_body(sql, &mut ctx, db, &func.name)?
75        }
76    };
77
78    // 6. Decrement recursion depth
79    ctx.exit_recursion();
80
81    Ok(result)
82}
83
84/// Execute a simple RETURN expression function
85///
86/// Handles functions like:
87/// ```sql
88/// CREATE FUNCTION add_ten(x INT) RETURNS INT
89///   RETURN x + 10;
90/// ```
91fn execute_simple_return(
92    sql: &str,
93    ctx: &mut ExecutionContext,
94    db: &mut Database,
95) -> Result<SqlValue, ExecutorError> {
96    // The sql should just be the RETURN expression: "RETURN x + 10"
97    let sql_trimmed = sql.trim();
98
99    // For simple RETURN functions, we need to parse and execute the RETURN statement
100    // Since there's no Statement::Procedural variant, we'll need to evaluate the expression directly
101
102    // Extract the expression after RETURN keyword
103    let return_upper = "RETURN";
104    if !sql_trimmed.to_uppercase().starts_with(return_upper) {
105        return Err(ExecutorError::InvalidFunctionBody(
106            "Simple function body must start with RETURN".to_string()
107        ));
108    }
109
110    let expr_str = sql_trimmed[return_upper.len()..].trim();
111
112    // Parse just the expression
113    // We'll use Parser::parse_sql with a SELECT wrapper to get the expression parsed
114    let select_wrapper = format!("SELECT {}", expr_str);
115    let stmt = Parser::parse_sql(&select_wrapper)
116        .map_err(|e| ExecutorError::ParseError(format!("Failed to parse RETURN expression: {}", e)))?;
117
118    // Extract the expression from the SELECT
119    #[allow(clippy::collapsible_match)]
120    if let vibesql_ast::Statement::Select(select_stmt) = stmt {
121        if let Some(first_item) = select_stmt.select_list.first() {
122            if let vibesql_ast::SelectItem::Expression { expr, alias: _ } = first_item {
123                // Evaluate the expression in the procedural context
124                let value = super::executor::evaluate_expression(expr, db, ctx)?;
125                return Ok(value);
126            }
127        }
128    }
129
130    Err(ExecutorError::InvalidFunctionBody(
131        "Could not parse RETURN expression".to_string()
132    ))
133}
134
135/// Execute a BEGIN...END function body
136///
137/// Handles functions like:
138/// ```sql
139/// CREATE FUNCTION factorial(n INT) RETURNS INT
140/// BEGIN
141///   DECLARE result INT DEFAULT 1;
142///   DECLARE i INT DEFAULT 2;
143///   WHILE i <= n DO
144///     SET result = result * i;
145///     SET i = i + 1;
146///   END WHILE;
147///   RETURN result;
148/// END;
149/// ```
150fn execute_begin_end_body(
151    _sql: &str,
152    _ctx: &mut ExecutionContext,
153    _db: &mut Database,
154    func_name: &str,
155) -> Result<SqlValue, ExecutorError> {
156    // TODO: Complex BEGIN...END functions require FunctionBody to store Vec<ProceduralStatement>
157    // like ProcedureBody does, rather than raw SQL strings.
158    // This is tracked in the issue as a known limitation.
159    //
160    // For now, we'll return an error directing users to use simple RETURN functions.
161    // The proper fix is to update FunctionBody::BeginEnd to store Vec<ProceduralStatement>
162    // and update the parser to populate it correctly.
163
164    Err(ExecutorError::UnsupportedFeature(
165        format!(
166            "Complex BEGIN...END function bodies not yet fully supported for function '{}'. \
167             Use simple RETURN expression functions (e.g., 'RETURN x + 10') instead. \
168             Full BEGIN...END support requires updating FunctionBody to store parsed statements.",
169            func_name
170        )
171    ))
172}
173
174impl ExecutionContext {
175    /// Create a new function execution context (read-only)
176    pub fn new_function_context() -> Self {
177        let mut ctx = Self::new();
178        ctx.is_function = true;
179        ctx
180    }
181
182    /// Check if this is a function context (read-only)
183    pub fn is_function_context(&self) -> bool {
184        self.is_function
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use vibesql_catalog::FunctionParam;
192    use vibesql_types::DataType;
193
194    #[test]
195    fn test_simple_return_function() {
196        let mut db = Database::new();
197
198        // CREATE FUNCTION add_ten(x INT) RETURNS INT RETURN x + 10;
199        let func = Function::new(
200            "add_ten".to_string(),
201            "public".to_string(),
202            vec![FunctionParam {
203                name: "x".to_string(),
204                data_type: DataType::Integer,
205            }],
206            DataType::Integer,
207            FunctionBody::RawSql("RETURN x + 10".to_string()),
208        );
209
210        // Call with argument 5
211        let args = vec![Expression::Literal(SqlValue::Integer(5))];
212        let result = execute_user_function(&func, &args, &mut db).unwrap();
213
214        assert_eq!(result, SqlValue::Integer(15));
215    }
216
217    #[test]
218    fn test_argument_count_mismatch() {
219        let mut db = Database::new();
220
221        let func = Function::new(
222            "add_ten".to_string(),
223            "public".to_string(),
224            vec![FunctionParam {
225                name: "x".to_string(),
226                data_type: DataType::Integer,
227            }],
228            DataType::Integer,
229            FunctionBody::RawSql("RETURN x + 10".to_string()),
230        );
231
232        // Call with no arguments (should fail)
233        let args = vec![];
234        let result = execute_user_function(&func, &args, &mut db);
235
236        assert!(matches!(result, Err(ExecutorError::ArgumentCountMismatch { .. })));
237    }
238
239    #[test]
240    fn test_begin_end_not_supported_yet() {
241        let mut db = Database::new();
242
243        // CREATE FUNCTION factorial(n INT) RETURNS INT BEGIN...END
244        let func = Function::new(
245            "factorial".to_string(),
246            "public".to_string(),
247            vec![FunctionParam {
248                name: "n".to_string(),
249                data_type: DataType::Integer,
250            }],
251            DataType::Integer,
252            FunctionBody::BeginEnd("BEGIN DECLARE x INT DEFAULT 5; RETURN x; END".to_string()),
253        );
254
255        let args = vec![Expression::Literal(SqlValue::Integer(5))];
256        let result = execute_user_function(&func, &args, &mut db);
257
258        // Should return UnsupportedFeature for now
259        assert!(matches!(result, Err(ExecutorError::UnsupportedFeature(_))));
260    }
261}