vibesql_executor/procedural/
context.rs1use std::collections::HashMap;
11use vibesql_types::SqlValue;
12
13const MAX_RECURSION_DEPTH: usize = 100;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum ControlFlow {
19 Continue,
21 Return(SqlValue),
23 Leave(String),
25 Iterate(String),
27}
28
29#[derive(Debug, Clone)]
31pub struct ExecutionContext {
32 variables: HashMap<String, SqlValue>,
34 parameters: HashMap<String, SqlValue>,
36 labels: HashMap<String, bool>,
38 recursion_depth: usize,
40 max_recursion: usize,
42 pub(crate) is_function: bool,
44 out_parameters: HashMap<String, String>,
47}
48
49impl ExecutionContext {
50 pub fn new() -> Self {
52 Self {
53 variables: HashMap::new(),
54 parameters: HashMap::new(),
55 labels: HashMap::new(),
56 recursion_depth: 0,
57 max_recursion: MAX_RECURSION_DEPTH,
58 is_function: false,
59 out_parameters: HashMap::new(),
60 }
61 }
62
63 pub fn with_recursion_depth(depth: usize) -> Self {
65 Self {
66 variables: HashMap::new(),
67 parameters: HashMap::new(),
68 labels: HashMap::new(),
69 recursion_depth: depth,
70 max_recursion: MAX_RECURSION_DEPTH,
71 is_function: false,
72 out_parameters: HashMap::new(),
73 }
74 }
75
76 pub fn set_variable(&mut self, name: &str, value: SqlValue) {
78 self.variables.insert(name.to_uppercase(), value);
79 }
80
81 pub fn get_variable(&self, name: &str) -> Option<&SqlValue> {
83 self.variables.get(&name.to_uppercase())
84 }
85
86 pub fn has_variable(&self, name: &str) -> bool {
88 self.variables.contains_key(&name.to_uppercase())
89 }
90
91 pub fn set_parameter(&mut self, name: &str, value: SqlValue) {
93 self.parameters.insert(name.to_uppercase(), value);
94 }
95
96 pub fn get_parameter(&self, name: &str) -> Option<&SqlValue> {
98 self.parameters.get(&name.to_uppercase())
99 }
100
101 pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut SqlValue> {
103 self.parameters.get_mut(&name.to_uppercase())
104 }
105
106 pub fn has_parameter(&self, name: &str) -> bool {
108 self.parameters.contains_key(&name.to_uppercase())
109 }
110
111 pub fn get_value(&self, name: &str) -> Option<&SqlValue> {
113 self.get_variable(name).or_else(|| self.get_parameter(name))
114 }
115
116 pub fn push_label(&mut self, label: &str) {
118 self.labels.insert(label.to_uppercase(), true);
119 }
120
121 pub fn pop_label(&mut self, label: &str) {
123 self.labels.remove(&label.to_uppercase());
124 }
125
126 pub fn has_label(&self, label: &str) -> bool {
128 self.labels.contains_key(&label.to_uppercase())
129 }
130
131 pub fn enter_recursion(&mut self) -> Result<(), String> {
133 self.recursion_depth += 1;
134 if self.recursion_depth > self.max_recursion {
135 Err(format!(
136 "Maximum recursion depth ({}) exceeded",
137 self.max_recursion
138 ))
139 } else {
140 Ok(())
141 }
142 }
143
144 pub fn exit_recursion(&mut self) {
146 if self.recursion_depth > 0 {
147 self.recursion_depth -= 1;
148 }
149 }
150
151 pub fn recursion_depth(&self) -> usize {
153 self.recursion_depth
154 }
155
156 pub fn get_all_parameters(&self) -> &HashMap<String, SqlValue> {
158 &self.parameters
159 }
160
161 pub fn register_out_parameter(&mut self, param_name: &str, target_var_name: String) {
163 self.out_parameters.insert(param_name.to_uppercase(), target_var_name);
164 }
165
166 pub fn get_out_parameters(&self) -> &HashMap<String, String> {
169 &self.out_parameters
170 }
171
172 pub fn get_available_names(&self) -> Vec<String> {
174 let mut names: Vec<String> = self.variables.keys()
175 .chain(self.parameters.keys())
176 .cloned()
177 .collect();
178 names.sort();
179 names
180 }
181}
182
183impl Default for ExecutionContext {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_variable_storage() {
195 let mut ctx = ExecutionContext::new();
196
197 ctx.set_variable("x", SqlValue::Integer(42));
198 assert_eq!(ctx.get_variable("x"), Some(&SqlValue::Integer(42)));
199 assert_eq!(ctx.get_variable("X"), Some(&SqlValue::Integer(42))); assert!(ctx.has_variable("x"));
201 }
202
203 #[test]
204 fn test_parameter_storage() {
205 let mut ctx = ExecutionContext::new();
206
207 ctx.set_parameter("param1", SqlValue::Integer(100));
208 assert_eq!(ctx.get_parameter("param1"), Some(&SqlValue::Integer(100)));
209 assert!(ctx.has_parameter("PARAM1"));
210 }
211
212 #[test]
213 fn test_get_value_precedence() {
214 let mut ctx = ExecutionContext::new();
215
216 ctx.set_parameter("x", SqlValue::Integer(1));
218 ctx.set_variable("x", SqlValue::Integer(2));
219
220 assert_eq!(ctx.get_value("x"), Some(&SqlValue::Integer(2)));
221 }
222
223 #[test]
224 fn test_label_management() {
225 let mut ctx = ExecutionContext::new();
226
227 ctx.push_label("loop1");
228 assert!(ctx.has_label("loop1"));
229 assert!(ctx.has_label("LOOP1")); ctx.pop_label("loop1");
232 assert!(!ctx.has_label("loop1"));
233 }
234
235 #[test]
236 fn test_recursion_limit() {
237 let mut ctx = ExecutionContext::new();
238
239 for _ in 0..100 {
240 assert!(ctx.enter_recursion().is_ok());
241 }
242
243 assert!(ctx.enter_recursion().is_err());
245 }
246}