vibesql_executor/procedural/
function.rs

1//! User-defined function execution
2//!
3//! This module handles execution of user-defined functions (UDFs) with:
4//! - Parameter binding
5//! - BEGIN...END body execution
6//! - Local variables and control flow
7//! - Recursion depth limiting
8//! - Read-only enforcement
9
10use vibesql_ast::Expression;
11use vibesql_catalog::{Function, FunctionBody};
12use vibesql_parser::Parser;
13use vibesql_storage::Database;
14use vibesql_types::SqlValue;
15
16use crate::{errors::ExecutorError, procedural::ExecutionContext};
17
18/// Execute a user-defined function
19///
20/// ## Steps:
21/// 1. Validate argument count matches parameter count
22/// 2. Create function execution context (read-only, isolated)
23/// 3. Check and increment recursion depth
24/// 4. Bind argument values to function parameters
25/// 5. Parse and execute function body
26/// 6. Handle RETURN control flow
27/// 7. Decrement recursion depth and return value
28///
29/// ## Errors:
30/// - ArgumentCountMismatch: Wrong number of arguments
31/// - RecursionLimitExceeded: Too many nested function calls
32/// - FunctionMustReturn: Function exited without RETURN statement
33/// - Other execution errors from function body
34pub fn execute_user_function(
35    func: &Function,
36    args: &[Expression],
37    db: &mut Database,
38) -> Result<SqlValue, ExecutorError> {
39    // 1. Validate argument count
40    if args.len() != func.parameters.len() {
41        return Err(ExecutorError::ArgumentCountMismatch {
42            expected: func.parameters.len(),
43            actual: args.len(),
44        });
45    }
46
47    // 2. Create function execution context (read-only)
48    let mut ctx = ExecutionContext::new_function_context();
49
50    // 3. Check recursion depth
51    ctx.enter_recursion().map_err(|e| ExecutorError::RecursionLimitExceeded {
52        message: e,
53        call_stack: vec![], // TODO: Track call stack in Phase 7
54        max_depth: 100,
55    })?;
56
57    // 4. Bind arguments to parameters
58    // Note: All function parameters are IN only (no OUT/INOUT)
59    for (param, arg_expr) in func.parameters.iter().zip(args.iter()) {
60        // Evaluate the argument expression
61        let value = super::executor::evaluate_expression(arg_expr, db, &ctx)?;
62        ctx.set_parameter(&param.name, value);
63    }
64
65    // 5. Parse and execute function body
66    let result = match &func.body {
67        FunctionBody::RawSql(sql) => {
68            // Simple RETURN expression function (e.g., "RETURN x + 10")
69            execute_simple_return(sql, &mut ctx, db)?
70        }
71        FunctionBody::BeginEnd(sql) => {
72            // Complex BEGIN...END function body
73            execute_begin_end_body(sql, &mut ctx, db, &func.name)?
74        }
75    };
76
77    // 6. Decrement recursion depth
78    ctx.exit_recursion();
79
80    Ok(result)
81}
82
83/// Execute a simple RETURN expression function
84///
85/// Handles functions like:
86/// ```sql
87/// CREATE FUNCTION add_ten(x INT) RETURNS INT
88///   RETURN x + 10;
89/// ```
90fn execute_simple_return(
91    sql: &str,
92    ctx: &mut ExecutionContext,
93    db: &mut Database,
94) -> Result<SqlValue, ExecutorError> {
95    // The sql should just be the RETURN expression: "RETURN x + 10"
96    let sql_trimmed = sql.trim();
97
98    // For simple RETURN functions, we need to parse and execute the RETURN statement
99    // Since there's no Statement::Procedural variant, we'll need to evaluate the expression
100    // directly
101
102    // Extract the expression after RETURN keyword
103    let return_upper = "RETURN";
104    if !sql_trimmed.to_uppercase().starts_with(return_upper) {
105        return Err(ExecutorError::InvalidFunctionBody(
106            "Simple function body must start with RETURN".to_string(),
107        ));
108    }
109
110    let expr_str = sql_trimmed[return_upper.len()..].trim();
111
112    // Parse just the expression
113    // We'll use Parser::parse_sql with a SELECT wrapper to get the expression parsed
114    let select_wrapper = format!("SELECT {}", expr_str);
115    let stmt = Parser::parse_sql(&select_wrapper).map_err(|e| {
116        ExecutorError::ParseError(format!("Failed to parse RETURN expression: {}", e))
117    })?;
118
119    // Extract the expression from the SELECT
120    #[allow(clippy::collapsible_match)]
121    if let vibesql_ast::Statement::Select(select_stmt) = stmt {
122        if let Some(first_item) = select_stmt.select_list.first() {
123            if let vibesql_ast::SelectItem::Expression { expr, alias: _, .. } = first_item {
124                // Evaluate the expression in the procedural context
125                let value = super::executor::evaluate_expression(expr, db, ctx)?;
126                return Ok(value);
127            }
128        }
129    }
130
131    Err(ExecutorError::InvalidFunctionBody("Could not parse RETURN expression".to_string()))
132}
133
134/// Execute a BEGIN...END function body
135///
136/// Handles functions like:
137/// ```sql
138/// CREATE FUNCTION factorial(n INT) RETURNS INT
139/// BEGIN
140///   DECLARE result INT DEFAULT 1;
141///   DECLARE i INT DEFAULT 2;
142///   WHILE i <= n DO
143///     SET result = result * i;
144///     SET i = i + 1;
145///   END WHILE;
146///   RETURN result;
147/// END;
148/// ```
149fn execute_begin_end_body(
150    _sql: &str,
151    _ctx: &mut ExecutionContext,
152    _db: &mut Database,
153    func_name: &str,
154) -> Result<SqlValue, ExecutorError> {
155    // TODO: Complex BEGIN...END functions require FunctionBody to store Vec<ProceduralStatement>
156    // like ProcedureBody does, rather than raw SQL strings.
157    // This is tracked in the issue as a known limitation.
158    //
159    // For now, we'll return an error directing users to use simple RETURN functions.
160    // The proper fix is to update FunctionBody::BeginEnd to store Vec<ProceduralStatement>
161    // and update the parser to populate it correctly.
162
163    Err(ExecutorError::UnsupportedFeature(format!(
164        "Complex BEGIN...END function bodies not yet fully supported for function '{}'. \
165             Use simple RETURN expression functions (e.g., 'RETURN x + 10') instead. \
166             Full BEGIN...END support requires updating FunctionBody to store parsed statements.",
167        func_name
168    )))
169}
170
171impl ExecutionContext {
172    /// Create a new function execution context (read-only)
173    pub fn new_function_context() -> Self {
174        let mut ctx = Self::new();
175        ctx.is_function = true;
176        ctx
177    }
178
179    /// Check if this is a function context (read-only)
180    pub fn is_function_context(&self) -> bool {
181        self.is_function
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use vibesql_catalog::FunctionParam;
188    use vibesql_types::DataType;
189
190    use super::*;
191
192    #[test]
193    fn test_simple_return_function() {
194        let mut db = Database::new();
195
196        // CREATE FUNCTION add_ten(x INT) RETURNS INT RETURN x + 10;
197        let func = Function::new(
198            "add_ten".to_string(),
199            "public".to_string(),
200            vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
201            DataType::Integer,
202            FunctionBody::RawSql("RETURN x + 10".to_string()),
203        );
204
205        // Call with argument 5
206        let args = vec![Expression::Literal(SqlValue::Integer(5))];
207        let result = execute_user_function(&func, &args, &mut db).unwrap();
208
209        assert_eq!(result, SqlValue::Integer(15));
210    }
211
212    #[test]
213    fn test_argument_count_mismatch() {
214        let mut db = Database::new();
215
216        let func = Function::new(
217            "add_ten".to_string(),
218            "public".to_string(),
219            vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
220            DataType::Integer,
221            FunctionBody::RawSql("RETURN x + 10".to_string()),
222        );
223
224        // Call with no arguments (should fail)
225        let args = vec![];
226        let result = execute_user_function(&func, &args, &mut db);
227
228        assert!(matches!(result, Err(ExecutorError::ArgumentCountMismatch { .. })));
229    }
230
231    #[test]
232    fn test_begin_end_not_supported_yet() {
233        let mut db = Database::new();
234
235        // CREATE FUNCTION factorial(n INT) RETURNS INT BEGIN...END
236        let func = Function::new(
237            "factorial".to_string(),
238            "public".to_string(),
239            vec![FunctionParam { name: "n".to_string(), data_type: DataType::Integer }],
240            DataType::Integer,
241            FunctionBody::BeginEnd("BEGIN DECLARE x INT DEFAULT 5; RETURN x; END".to_string()),
242        );
243
244        let args = vec![Expression::Literal(SqlValue::Integer(5))];
245        let result = execute_user_function(&func, &args, &mut db);
246
247        // Should return UnsupportedFeature for now
248        assert!(matches!(result, Err(ExecutorError::UnsupportedFeature(_))));
249    }
250}