vibesql_executor/procedural/
function.rs1use vibesql_ast::Expression;
11use vibesql_catalog::{Function, FunctionBody};
12use vibesql_parser::Parser;
13use vibesql_storage::Database;
14use vibesql_types::SqlValue;
15
16use crate::{errors::ExecutorError, procedural::ExecutionContext};
17
18pub fn execute_user_function(
35 func: &Function,
36 args: &[Expression],
37 db: &mut Database,
38) -> Result<SqlValue, ExecutorError> {
39 if args.len() != func.parameters.len() {
41 return Err(ExecutorError::ArgumentCountMismatch {
42 expected: func.parameters.len(),
43 actual: args.len(),
44 });
45 }
46
47 let mut ctx = ExecutionContext::new_function_context();
49
50 ctx.enter_recursion().map_err(|e| ExecutorError::RecursionLimitExceeded {
52 message: e,
53 call_stack: vec![], max_depth: 100,
55 })?;
56
57 for (param, arg_expr) in func.parameters.iter().zip(args.iter()) {
60 let value = super::executor::evaluate_expression(arg_expr, db, &ctx)?;
62 ctx.set_parameter(¶m.name, value);
63 }
64
65 let result = match &func.body {
67 FunctionBody::RawSql(sql) => {
68 execute_simple_return(sql, &mut ctx, db)?
70 }
71 FunctionBody::BeginEnd(sql) => {
72 execute_begin_end_body(sql, &mut ctx, db, &func.name)?
74 }
75 };
76
77 ctx.exit_recursion();
79
80 Ok(result)
81}
82
83fn execute_simple_return(
91 sql: &str,
92 ctx: &mut ExecutionContext,
93 db: &mut Database,
94) -> Result<SqlValue, ExecutorError> {
95 let sql_trimmed = sql.trim();
97
98 let return_upper = "RETURN";
104 if !sql_trimmed.to_uppercase().starts_with(return_upper) {
105 return Err(ExecutorError::InvalidFunctionBody(
106 "Simple function body must start with RETURN".to_string(),
107 ));
108 }
109
110 let expr_str = sql_trimmed[return_upper.len()..].trim();
111
112 let select_wrapper = format!("SELECT {}", expr_str);
115 let stmt = Parser::parse_sql(&select_wrapper).map_err(|e| {
116 ExecutorError::ParseError(format!("Failed to parse RETURN expression: {}", e))
117 })?;
118
119 #[allow(clippy::collapsible_match)]
121 if let vibesql_ast::Statement::Select(select_stmt) = stmt {
122 if let Some(first_item) = select_stmt.select_list.first() {
123 if let vibesql_ast::SelectItem::Expression { expr, alias: _, .. } = first_item {
124 let value = super::executor::evaluate_expression(expr, db, ctx)?;
126 return Ok(value);
127 }
128 }
129 }
130
131 Err(ExecutorError::InvalidFunctionBody("Could not parse RETURN expression".to_string()))
132}
133
134fn execute_begin_end_body(
150 _sql: &str,
151 _ctx: &mut ExecutionContext,
152 _db: &mut Database,
153 func_name: &str,
154) -> Result<SqlValue, ExecutorError> {
155 Err(ExecutorError::UnsupportedFeature(format!(
164 "Complex BEGIN...END function bodies not yet fully supported for function '{}'. \
165 Use simple RETURN expression functions (e.g., 'RETURN x + 10') instead. \
166 Full BEGIN...END support requires updating FunctionBody to store parsed statements.",
167 func_name
168 )))
169}
170
171impl ExecutionContext {
172 pub fn new_function_context() -> Self {
174 let mut ctx = Self::new();
175 ctx.is_function = true;
176 ctx
177 }
178
179 pub fn is_function_context(&self) -> bool {
181 self.is_function
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use vibesql_catalog::FunctionParam;
188 use vibesql_types::DataType;
189
190 use super::*;
191
192 #[test]
193 fn test_simple_return_function() {
194 let mut db = Database::new();
195
196 let func = Function::new(
198 "add_ten".to_string(),
199 "public".to_string(),
200 vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
201 DataType::Integer,
202 FunctionBody::RawSql("RETURN x + 10".to_string()),
203 );
204
205 let args = vec![Expression::Literal(SqlValue::Integer(5))];
207 let result = execute_user_function(&func, &args, &mut db).unwrap();
208
209 assert_eq!(result, SqlValue::Integer(15));
210 }
211
212 #[test]
213 fn test_argument_count_mismatch() {
214 let mut db = Database::new();
215
216 let func = Function::new(
217 "add_ten".to_string(),
218 "public".to_string(),
219 vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
220 DataType::Integer,
221 FunctionBody::RawSql("RETURN x + 10".to_string()),
222 );
223
224 let args = vec![];
226 let result = execute_user_function(&func, &args, &mut db);
227
228 assert!(matches!(result, Err(ExecutorError::ArgumentCountMismatch { .. })));
229 }
230
231 #[test]
232 fn test_begin_end_not_supported_yet() {
233 let mut db = Database::new();
234
235 let func = Function::new(
237 "factorial".to_string(),
238 "public".to_string(),
239 vec![FunctionParam { name: "n".to_string(), data_type: DataType::Integer }],
240 DataType::Integer,
241 FunctionBody::BeginEnd("BEGIN DECLARE x INT DEFAULT 5; RETURN x; END".to_string()),
242 );
243
244 let args = vec![Expression::Literal(SqlValue::Integer(5))];
245 let result = execute_user_function(&func, &args, &mut db);
246
247 assert!(matches!(result, Err(ExecutorError::UnsupportedFeature(_))));
249 }
250}