vibesql_executor/procedural/
function.rs1use crate::errors::ExecutorError;
11use crate::procedural::ExecutionContext;
12use vibesql_ast::Expression;
13use vibesql_catalog::{Function, FunctionBody};
14use vibesql_parser::Parser;
15use vibesql_storage::Database;
16use vibesql_types::SqlValue;
17
18pub fn execute_user_function(
35 func: &Function,
36 args: &[Expression],
37 db: &mut Database,
38) -> Result<SqlValue, ExecutorError> {
39 if args.len() != func.parameters.len() {
41 return Err(ExecutorError::ArgumentCountMismatch {
42 expected: func.parameters.len(),
43 actual: args.len(),
44 });
45 }
46
47 let mut ctx = ExecutionContext::new_function_context();
49
50 ctx.enter_recursion().map_err(|e| ExecutorError::RecursionLimitExceeded {
52 message: e,
53 call_stack: vec![], max_depth: 100,
55 })?;
56
57 for (param, arg_expr) in func.parameters.iter().zip(args.iter()) {
60 let value = super::executor::evaluate_expression(arg_expr, db, &ctx)?;
62 ctx.set_parameter(¶m.name, value);
63 }
64
65 let result = match &func.body {
67 FunctionBody::RawSql(sql) => {
68 execute_simple_return(sql, &mut ctx, db)?
70 }
71 FunctionBody::BeginEnd(sql) => {
72 execute_begin_end_body(sql, &mut ctx, db, &func.name)?
74 }
75 };
76
77 ctx.exit_recursion();
79
80 Ok(result)
81}
82
83fn execute_simple_return(
91 sql: &str,
92 ctx: &mut ExecutionContext,
93 db: &mut Database,
94) -> Result<SqlValue, ExecutorError> {
95 let sql_trimmed = sql.trim();
97
98 let return_upper = "RETURN";
103 if !sql_trimmed.to_uppercase().starts_with(return_upper) {
104 return Err(ExecutorError::InvalidFunctionBody(
105 "Simple function body must start with RETURN".to_string(),
106 ));
107 }
108
109 let expr_str = sql_trimmed[return_upper.len()..].trim();
110
111 let select_wrapper = format!("SELECT {}", expr_str);
114 let stmt = Parser::parse_sql(&select_wrapper).map_err(|e| {
115 ExecutorError::ParseError(format!("Failed to parse RETURN expression: {}", e))
116 })?;
117
118 #[allow(clippy::collapsible_match)]
120 if let vibesql_ast::Statement::Select(select_stmt) = stmt {
121 if let Some(first_item) = select_stmt.select_list.first() {
122 if let vibesql_ast::SelectItem::Expression { expr, alias: _ } = first_item {
123 let value = super::executor::evaluate_expression(expr, db, ctx)?;
125 return Ok(value);
126 }
127 }
128 }
129
130 Err(ExecutorError::InvalidFunctionBody("Could not parse RETURN expression".to_string()))
131}
132
133fn execute_begin_end_body(
149 _sql: &str,
150 _ctx: &mut ExecutionContext,
151 _db: &mut Database,
152 func_name: &str,
153) -> Result<SqlValue, ExecutorError> {
154 Err(ExecutorError::UnsupportedFeature(format!(
163 "Complex BEGIN...END function bodies not yet fully supported for function '{}'. \
164 Use simple RETURN expression functions (e.g., 'RETURN x + 10') instead. \
165 Full BEGIN...END support requires updating FunctionBody to store parsed statements.",
166 func_name
167 )))
168}
169
170impl ExecutionContext {
171 pub fn new_function_context() -> Self {
173 let mut ctx = Self::new();
174 ctx.is_function = true;
175 ctx
176 }
177
178 pub fn is_function_context(&self) -> bool {
180 self.is_function
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use vibesql_catalog::FunctionParam;
188 use vibesql_types::DataType;
189
190 #[test]
191 fn test_simple_return_function() {
192 let mut db = Database::new();
193
194 let func = Function::new(
196 "add_ten".to_string(),
197 "public".to_string(),
198 vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
199 DataType::Integer,
200 FunctionBody::RawSql("RETURN x + 10".to_string()),
201 );
202
203 let args = vec![Expression::Literal(SqlValue::Integer(5))];
205 let result = execute_user_function(&func, &args, &mut db).unwrap();
206
207 assert_eq!(result, SqlValue::Integer(15));
208 }
209
210 #[test]
211 fn test_argument_count_mismatch() {
212 let mut db = Database::new();
213
214 let func = Function::new(
215 "add_ten".to_string(),
216 "public".to_string(),
217 vec![FunctionParam { name: "x".to_string(), data_type: DataType::Integer }],
218 DataType::Integer,
219 FunctionBody::RawSql("RETURN x + 10".to_string()),
220 );
221
222 let args = vec![];
224 let result = execute_user_function(&func, &args, &mut db);
225
226 assert!(matches!(result, Err(ExecutorError::ArgumentCountMismatch { .. })));
227 }
228
229 #[test]
230 fn test_begin_end_not_supported_yet() {
231 let mut db = Database::new();
232
233 let func = Function::new(
235 "factorial".to_string(),
236 "public".to_string(),
237 vec![FunctionParam { name: "n".to_string(), data_type: DataType::Integer }],
238 DataType::Integer,
239 FunctionBody::BeginEnd("BEGIN DECLARE x INT DEFAULT 5; RETURN x; END".to_string()),
240 );
241
242 let args = vec![Expression::Literal(SqlValue::Integer(5))];
243 let result = execute_user_function(&func, &args, &mut db);
244
245 assert!(matches!(result, Err(ExecutorError::UnsupportedFeature(_))));
247 }
248}