vibesql_executor/procedural/
context.rs1use std::collections::HashMap;
11
12use vibesql_types::SqlValue;
13
14const MAX_RECURSION_DEPTH: usize = 100;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum ControlFlow {
20 Continue,
22 Return(SqlValue),
24 Leave(String),
26 Iterate(String),
28}
29
30#[derive(Debug, Clone)]
32pub struct ExecutionContext {
33 variables: HashMap<String, SqlValue>,
35 parameters: HashMap<String, SqlValue>,
37 labels: HashMap<String, bool>,
39 recursion_depth: usize,
41 max_recursion: usize,
43 pub(crate) is_function: bool,
45 out_parameters: HashMap<String, String>,
48}
49
50impl ExecutionContext {
51 pub fn new() -> Self {
53 Self {
54 variables: HashMap::new(),
55 parameters: HashMap::new(),
56 labels: HashMap::new(),
57 recursion_depth: 0,
58 max_recursion: MAX_RECURSION_DEPTH,
59 is_function: false,
60 out_parameters: HashMap::new(),
61 }
62 }
63
64 pub fn with_recursion_depth(depth: usize) -> Self {
66 Self {
67 variables: HashMap::new(),
68 parameters: HashMap::new(),
69 labels: HashMap::new(),
70 recursion_depth: depth,
71 max_recursion: MAX_RECURSION_DEPTH,
72 is_function: false,
73 out_parameters: HashMap::new(),
74 }
75 }
76
77 pub fn set_variable(&mut self, name: &str, value: SqlValue) {
79 self.variables.insert(name.to_uppercase(), value);
80 }
81
82 pub fn get_variable(&self, name: &str) -> Option<&SqlValue> {
84 self.variables.get(&name.to_uppercase())
85 }
86
87 pub fn has_variable(&self, name: &str) -> bool {
89 self.variables.contains_key(&name.to_uppercase())
90 }
91
92 pub fn set_parameter(&mut self, name: &str, value: SqlValue) {
94 self.parameters.insert(name.to_uppercase(), value);
95 }
96
97 pub fn get_parameter(&self, name: &str) -> Option<&SqlValue> {
99 self.parameters.get(&name.to_uppercase())
100 }
101
102 pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut SqlValue> {
104 self.parameters.get_mut(&name.to_uppercase())
105 }
106
107 pub fn has_parameter(&self, name: &str) -> bool {
109 self.parameters.contains_key(&name.to_uppercase())
110 }
111
112 pub fn get_value(&self, name: &str) -> Option<&SqlValue> {
114 self.get_variable(name).or_else(|| self.get_parameter(name))
115 }
116
117 pub fn push_label(&mut self, label: &str) {
119 self.labels.insert(label.to_uppercase(), true);
120 }
121
122 pub fn pop_label(&mut self, label: &str) {
124 self.labels.remove(&label.to_uppercase());
125 }
126
127 pub fn has_label(&self, label: &str) -> bool {
129 self.labels.contains_key(&label.to_uppercase())
130 }
131
132 pub fn enter_recursion(&mut self) -> Result<(), String> {
134 self.recursion_depth += 1;
135 if self.recursion_depth > self.max_recursion {
136 Err(format!("Maximum recursion depth ({}) exceeded", self.max_recursion))
137 } else {
138 Ok(())
139 }
140 }
141
142 pub fn exit_recursion(&mut self) {
144 if self.recursion_depth > 0 {
145 self.recursion_depth -= 1;
146 }
147 }
148
149 pub fn recursion_depth(&self) -> usize {
151 self.recursion_depth
152 }
153
154 pub fn get_all_parameters(&self) -> &HashMap<String, SqlValue> {
156 &self.parameters
157 }
158
159 pub fn register_out_parameter(&mut self, param_name: &str, target_var_name: String) {
161 self.out_parameters.insert(param_name.to_uppercase(), target_var_name);
162 }
163
164 pub fn get_out_parameters(&self) -> &HashMap<String, String> {
167 &self.out_parameters
168 }
169
170 pub fn get_available_names(&self) -> Vec<String> {
172 let mut names: Vec<String> =
173 self.variables.keys().chain(self.parameters.keys()).cloned().collect();
174 names.sort();
175 names
176 }
177}
178
179impl Default for ExecutionContext {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_variable_storage() {
191 let mut ctx = ExecutionContext::new();
192
193 ctx.set_variable("x", SqlValue::Integer(42));
194 assert_eq!(ctx.get_variable("x"), Some(&SqlValue::Integer(42)));
195 assert_eq!(ctx.get_variable("X"), Some(&SqlValue::Integer(42))); assert!(ctx.has_variable("x"));
197 }
198
199 #[test]
200 fn test_parameter_storage() {
201 let mut ctx = ExecutionContext::new();
202
203 ctx.set_parameter("param1", SqlValue::Integer(100));
204 assert_eq!(ctx.get_parameter("param1"), Some(&SqlValue::Integer(100)));
205 assert!(ctx.has_parameter("PARAM1"));
206 }
207
208 #[test]
209 fn test_get_value_precedence() {
210 let mut ctx = ExecutionContext::new();
211
212 ctx.set_parameter("x", SqlValue::Integer(1));
214 ctx.set_variable("x", SqlValue::Integer(2));
215
216 assert_eq!(ctx.get_value("x"), Some(&SqlValue::Integer(2)));
217 }
218
219 #[test]
220 fn test_label_management() {
221 let mut ctx = ExecutionContext::new();
222
223 ctx.push_label("loop1");
224 assert!(ctx.has_label("loop1"));
225 assert!(ctx.has_label("LOOP1")); ctx.pop_label("loop1");
228 assert!(!ctx.has_label("loop1"));
229 }
230
231 #[test]
232 fn test_recursion_limit() {
233 let mut ctx = ExecutionContext::new();
234
235 for _ in 0..100 {
236 assert!(ctx.enter_recursion().is_ok());
237 }
238
239 assert!(ctx.enter_recursion().is_err());
241 }
242}