vibesql_executor/procedural/
context.rs1use std::collections::HashMap;
11use vibesql_types::SqlValue;
12
13const MAX_RECURSION_DEPTH: usize = 100;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum ControlFlow {
19 Continue,
21 Return(SqlValue),
23 Leave(String),
25 Iterate(String),
27}
28
29#[derive(Debug, Clone)]
31pub struct ExecutionContext {
32 variables: HashMap<String, SqlValue>,
34 parameters: HashMap<String, SqlValue>,
36 labels: HashMap<String, bool>,
38 recursion_depth: usize,
40 max_recursion: usize,
42 pub(crate) is_function: bool,
44 out_parameters: HashMap<String, String>,
47}
48
49impl ExecutionContext {
50 pub fn new() -> Self {
52 Self {
53 variables: HashMap::new(),
54 parameters: HashMap::new(),
55 labels: HashMap::new(),
56 recursion_depth: 0,
57 max_recursion: MAX_RECURSION_DEPTH,
58 is_function: false,
59 out_parameters: HashMap::new(),
60 }
61 }
62
63 pub fn with_recursion_depth(depth: usize) -> Self {
65 Self {
66 variables: HashMap::new(),
67 parameters: HashMap::new(),
68 labels: HashMap::new(),
69 recursion_depth: depth,
70 max_recursion: MAX_RECURSION_DEPTH,
71 is_function: false,
72 out_parameters: HashMap::new(),
73 }
74 }
75
76 pub fn set_variable(&mut self, name: &str, value: SqlValue) {
78 self.variables.insert(name.to_uppercase(), value);
79 }
80
81 pub fn get_variable(&self, name: &str) -> Option<&SqlValue> {
83 self.variables.get(&name.to_uppercase())
84 }
85
86 pub fn has_variable(&self, name: &str) -> bool {
88 self.variables.contains_key(&name.to_uppercase())
89 }
90
91 pub fn set_parameter(&mut self, name: &str, value: SqlValue) {
93 self.parameters.insert(name.to_uppercase(), value);
94 }
95
96 pub fn get_parameter(&self, name: &str) -> Option<&SqlValue> {
98 self.parameters.get(&name.to_uppercase())
99 }
100
101 pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut SqlValue> {
103 self.parameters.get_mut(&name.to_uppercase())
104 }
105
106 pub fn has_parameter(&self, name: &str) -> bool {
108 self.parameters.contains_key(&name.to_uppercase())
109 }
110
111 pub fn get_value(&self, name: &str) -> Option<&SqlValue> {
113 self.get_variable(name).or_else(|| self.get_parameter(name))
114 }
115
116 pub fn push_label(&mut self, label: &str) {
118 self.labels.insert(label.to_uppercase(), true);
119 }
120
121 pub fn pop_label(&mut self, label: &str) {
123 self.labels.remove(&label.to_uppercase());
124 }
125
126 pub fn has_label(&self, label: &str) -> bool {
128 self.labels.contains_key(&label.to_uppercase())
129 }
130
131 pub fn enter_recursion(&mut self) -> Result<(), String> {
133 self.recursion_depth += 1;
134 if self.recursion_depth > self.max_recursion {
135 Err(format!("Maximum recursion depth ({}) exceeded", self.max_recursion))
136 } else {
137 Ok(())
138 }
139 }
140
141 pub fn exit_recursion(&mut self) {
143 if self.recursion_depth > 0 {
144 self.recursion_depth -= 1;
145 }
146 }
147
148 pub fn recursion_depth(&self) -> usize {
150 self.recursion_depth
151 }
152
153 pub fn get_all_parameters(&self) -> &HashMap<String, SqlValue> {
155 &self.parameters
156 }
157
158 pub fn register_out_parameter(&mut self, param_name: &str, target_var_name: String) {
160 self.out_parameters.insert(param_name.to_uppercase(), target_var_name);
161 }
162
163 pub fn get_out_parameters(&self) -> &HashMap<String, String> {
166 &self.out_parameters
167 }
168
169 pub fn get_available_names(&self) -> Vec<String> {
171 let mut names: Vec<String> =
172 self.variables.keys().chain(self.parameters.keys()).cloned().collect();
173 names.sort();
174 names
175 }
176}
177
178impl Default for ExecutionContext {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_variable_storage() {
190 let mut ctx = ExecutionContext::new();
191
192 ctx.set_variable("x", SqlValue::Integer(42));
193 assert_eq!(ctx.get_variable("x"), Some(&SqlValue::Integer(42)));
194 assert_eq!(ctx.get_variable("X"), Some(&SqlValue::Integer(42))); assert!(ctx.has_variable("x"));
196 }
197
198 #[test]
199 fn test_parameter_storage() {
200 let mut ctx = ExecutionContext::new();
201
202 ctx.set_parameter("param1", SqlValue::Integer(100));
203 assert_eq!(ctx.get_parameter("param1"), Some(&SqlValue::Integer(100)));
204 assert!(ctx.has_parameter("PARAM1"));
205 }
206
207 #[test]
208 fn test_get_value_precedence() {
209 let mut ctx = ExecutionContext::new();
210
211 ctx.set_parameter("x", SqlValue::Integer(1));
213 ctx.set_variable("x", SqlValue::Integer(2));
214
215 assert_eq!(ctx.get_value("x"), Some(&SqlValue::Integer(2)));
216 }
217
218 #[test]
219 fn test_label_management() {
220 let mut ctx = ExecutionContext::new();
221
222 ctx.push_label("loop1");
223 assert!(ctx.has_label("loop1"));
224 assert!(ctx.has_label("LOOP1")); ctx.pop_label("loop1");
227 assert!(!ctx.has_label("loop1"));
228 }
229
230 #[test]
231 fn test_recursion_limit() {
232 let mut ctx = ExecutionContext::new();
233
234 for _ in 0..100 {
235 assert!(ctx.enter_recursion().is_ok());
236 }
237
238 assert!(ctx.enter_recursion().is_err());
240 }
241}