vibesql_executor/procedural/
function.rs1use crate::errors::ExecutorError;
11use crate::procedural::ExecutionContext;
12use vibesql_ast::Expression;
13use vibesql_catalog::{Function, FunctionBody};
14use vibesql_parser::Parser;
15use vibesql_storage::Database;
16use vibesql_types::SqlValue;
17
18pub fn execute_user_function(
35 func: &Function,
36 args: &[Expression],
37 db: &mut Database,
38) -> Result<SqlValue, ExecutorError> {
39 if args.len() != func.parameters.len() {
41 return Err(ExecutorError::ArgumentCountMismatch {
42 expected: func.parameters.len(),
43 actual: args.len(),
44 });
45 }
46
47 let mut ctx = ExecutionContext::new_function_context();
49
50 ctx.enter_recursion()
52 .map_err(|e| ExecutorError::RecursionLimitExceeded {
53 message: e,
54 call_stack: vec![], max_depth: 100,
56 })?;
57
58 for (param, arg_expr) in func.parameters.iter().zip(args.iter()) {
61 let value = super::executor::evaluate_expression(arg_expr, db, &ctx)?;
63 ctx.set_parameter(¶m.name, value);
64 }
65
66 let result = match &func.body {
68 FunctionBody::RawSql(sql) => {
69 execute_simple_return(sql, &mut ctx, db)?
71 }
72 FunctionBody::BeginEnd(sql) => {
73 execute_begin_end_body(sql, &mut ctx, db, &func.name)?
75 }
76 };
77
78 ctx.exit_recursion();
80
81 Ok(result)
82}
83
84fn execute_simple_return(
92 sql: &str,
93 ctx: &mut ExecutionContext,
94 db: &mut Database,
95) -> Result<SqlValue, ExecutorError> {
96 let sql_trimmed = sql.trim();
98
99 let return_upper = "RETURN";
104 if !sql_trimmed.to_uppercase().starts_with(return_upper) {
105 return Err(ExecutorError::InvalidFunctionBody(
106 "Simple function body must start with RETURN".to_string()
107 ));
108 }
109
110 let expr_str = sql_trimmed[return_upper.len()..].trim();
111
112 let select_wrapper = format!("SELECT {}", expr_str);
115 let stmt = Parser::parse_sql(&select_wrapper)
116 .map_err(|e| ExecutorError::ParseError(format!("Failed to parse RETURN expression: {}", e)))?;
117
118 #[allow(clippy::collapsible_match)]
120 if let vibesql_ast::Statement::Select(select_stmt) = stmt {
121 if let Some(first_item) = select_stmt.select_list.first() {
122 if let vibesql_ast::SelectItem::Expression { expr, alias: _ } = first_item {
123 let value = super::executor::evaluate_expression(expr, db, ctx)?;
125 return Ok(value);
126 }
127 }
128 }
129
130 Err(ExecutorError::InvalidFunctionBody(
131 "Could not parse RETURN expression".to_string()
132 ))
133}
134
135fn execute_begin_end_body(
151 _sql: &str,
152 _ctx: &mut ExecutionContext,
153 _db: &mut Database,
154 func_name: &str,
155) -> Result<SqlValue, ExecutorError> {
156 Err(ExecutorError::UnsupportedFeature(
165 format!(
166 "Complex BEGIN...END function bodies not yet fully supported for function '{}'. \
167 Use simple RETURN expression functions (e.g., 'RETURN x + 10') instead. \
168 Full BEGIN...END support requires updating FunctionBody to store parsed statements.",
169 func_name
170 )
171 ))
172}
173
174impl ExecutionContext {
175 pub fn new_function_context() -> Self {
177 let mut ctx = Self::new();
178 ctx.is_function = true;
179 ctx
180 }
181
182 pub fn is_function_context(&self) -> bool {
184 self.is_function
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use vibesql_catalog::FunctionParam;
192 use vibesql_types::DataType;
193
194 #[test]
195 fn test_simple_return_function() {
196 let mut db = Database::new();
197
198 let func = Function::new(
200 "add_ten".to_string(),
201 "public".to_string(),
202 vec![FunctionParam {
203 name: "x".to_string(),
204 data_type: DataType::Integer,
205 }],
206 DataType::Integer,
207 FunctionBody::RawSql("RETURN x + 10".to_string()),
208 );
209
210 let args = vec![Expression::Literal(SqlValue::Integer(5))];
212 let result = execute_user_function(&func, &args, &mut db).unwrap();
213
214 assert_eq!(result, SqlValue::Integer(15));
215 }
216
217 #[test]
218 fn test_argument_count_mismatch() {
219 let mut db = Database::new();
220
221 let func = Function::new(
222 "add_ten".to_string(),
223 "public".to_string(),
224 vec![FunctionParam {
225 name: "x".to_string(),
226 data_type: DataType::Integer,
227 }],
228 DataType::Integer,
229 FunctionBody::RawSql("RETURN x + 10".to_string()),
230 );
231
232 let args = vec![];
234 let result = execute_user_function(&func, &args, &mut db);
235
236 assert!(matches!(result, Err(ExecutorError::ArgumentCountMismatch { .. })));
237 }
238
239 #[test]
240 fn test_begin_end_not_supported_yet() {
241 let mut db = Database::new();
242
243 let func = Function::new(
245 "factorial".to_string(),
246 "public".to_string(),
247 vec![FunctionParam {
248 name: "n".to_string(),
249 data_type: DataType::Integer,
250 }],
251 DataType::Integer,
252 FunctionBody::BeginEnd("BEGIN DECLARE x INT DEFAULT 5; RETURN x; END".to_string()),
253 );
254
255 let args = vec![Expression::Literal(SqlValue::Integer(5))];
256 let result = execute_user_function(&func, &args, &mut db);
257
258 assert!(matches!(result, Err(ExecutorError::UnsupportedFeature(_))));
260 }
261}