vibesql_executor/procedural/
context.rs

1//! Execution context for procedural statements
2//!
3//! Manages:
4//! - Local variables (DECLARE)
5//! - Parameters (IN, OUT, INOUT)
6//! - Scope management (nested blocks)
7//! - Label tracking (for LEAVE/ITERATE)
8//! - Recursion depth limiting
9
10use std::collections::HashMap;
11use vibesql_types::SqlValue;
12
13/// Maximum recursion depth for function/procedure calls
14const MAX_RECURSION_DEPTH: usize = 100;
15
16/// Control flow state returned by procedural statement execution
17#[derive(Debug, Clone, PartialEq)]
18pub enum ControlFlow {
19    /// Continue to next statement
20    Continue,
21    /// Return from function/procedure with value
22    Return(SqlValue),
23    /// Leave a labeled block/loop
24    Leave(String),
25    /// Iterate (continue) a labeled loop
26    Iterate(String),
27}
28
29/// Execution context for procedural statements
30#[derive(Debug, Clone)]
31pub struct ExecutionContext {
32    /// Local variables (DECLARE)
33    variables: HashMap<String, SqlValue>,
34    /// Parameters (IN, OUT, INOUT)
35    parameters: HashMap<String, SqlValue>,
36    /// Active labels for LEAVE/ITERATE
37    labels: HashMap<String, bool>,
38    /// Current recursion depth
39    recursion_depth: usize,
40    /// Maximum allowed recursion depth
41    max_recursion: usize,
42    /// Whether this is a function context (read-only, cannot modify data)
43    pub(crate) is_function: bool,
44    /// Track which parameters are OUT/INOUT and their target variable names
45    /// Key: parameter name (uppercase), Value: target variable name (as specified in CALL)
46    out_parameters: HashMap<String, String>,
47}
48
49impl ExecutionContext {
50    /// Create a new execution context
51    pub fn new() -> Self {
52        Self {
53            variables: HashMap::new(),
54            parameters: HashMap::new(),
55            labels: HashMap::new(),
56            recursion_depth: 0,
57            max_recursion: MAX_RECURSION_DEPTH,
58            is_function: false,
59            out_parameters: HashMap::new(),
60        }
61    }
62
63    /// Create a new context with specified recursion depth
64    pub fn with_recursion_depth(depth: usize) -> Self {
65        Self {
66            variables: HashMap::new(),
67            parameters: HashMap::new(),
68            labels: HashMap::new(),
69            recursion_depth: depth,
70            max_recursion: MAX_RECURSION_DEPTH,
71            is_function: false,
72            out_parameters: HashMap::new(),
73        }
74    }
75
76    /// Set a local variable value
77    pub fn set_variable(&mut self, name: &str, value: SqlValue) {
78        self.variables.insert(name.to_uppercase(), value);
79    }
80
81    /// Get a local variable value
82    pub fn get_variable(&self, name: &str) -> Option<&SqlValue> {
83        self.variables.get(&name.to_uppercase())
84    }
85
86    /// Check if a variable exists
87    pub fn has_variable(&self, name: &str) -> bool {
88        self.variables.contains_key(&name.to_uppercase())
89    }
90
91    /// Set a parameter value
92    pub fn set_parameter(&mut self, name: &str, value: SqlValue) {
93        self.parameters.insert(name.to_uppercase(), value);
94    }
95
96    /// Get a parameter value
97    pub fn get_parameter(&self, name: &str) -> Option<&SqlValue> {
98        self.parameters.get(&name.to_uppercase())
99    }
100
101    /// Get a mutable reference to a parameter value
102    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut SqlValue> {
103        self.parameters.get_mut(&name.to_uppercase())
104    }
105
106    /// Check if a parameter exists
107    pub fn has_parameter(&self, name: &str) -> bool {
108        self.parameters.contains_key(&name.to_uppercase())
109    }
110
111    /// Get a value (variable or parameter)
112    pub fn get_value(&self, name: &str) -> Option<&SqlValue> {
113        self.get_variable(name).or_else(|| self.get_parameter(name))
114    }
115
116    /// Push a label onto the stack
117    pub fn push_label(&mut self, label: &str) {
118        self.labels.insert(label.to_uppercase(), true);
119    }
120
121    /// Pop a label from the stack
122    pub fn pop_label(&mut self, label: &str) {
123        self.labels.remove(&label.to_uppercase());
124    }
125
126    /// Check if a label is active
127    pub fn has_label(&self, label: &str) -> bool {
128        self.labels.contains_key(&label.to_uppercase())
129    }
130
131    /// Increment recursion depth and check limit
132    pub fn enter_recursion(&mut self) -> Result<(), String> {
133        self.recursion_depth += 1;
134        if self.recursion_depth > self.max_recursion {
135            Err(format!("Maximum recursion depth ({}) exceeded", self.max_recursion))
136        } else {
137            Ok(())
138        }
139    }
140
141    /// Decrement recursion depth
142    pub fn exit_recursion(&mut self) {
143        if self.recursion_depth > 0 {
144            self.recursion_depth -= 1;
145        }
146    }
147
148    /// Get current recursion depth
149    pub fn recursion_depth(&self) -> usize {
150        self.recursion_depth
151    }
152
153    /// Get all parameters (for OUT/INOUT return)
154    pub fn get_all_parameters(&self) -> &HashMap<String, SqlValue> {
155        &self.parameters
156    }
157
158    /// Register an OUT or INOUT parameter with its target variable name
159    pub fn register_out_parameter(&mut self, param_name: &str, target_var_name: String) {
160        self.out_parameters.insert(param_name.to_uppercase(), target_var_name);
161    }
162
163    /// Get all OUT/INOUT parameters for return
164    /// Returns a HashMap of parameter name -> target variable name
165    pub fn get_out_parameters(&self) -> &HashMap<String, String> {
166        &self.out_parameters
167    }
168
169    /// Get all available variable and parameter names (for error messages)
170    pub fn get_available_names(&self) -> Vec<String> {
171        let mut names: Vec<String> =
172            self.variables.keys().chain(self.parameters.keys()).cloned().collect();
173        names.sort();
174        names
175    }
176}
177
178impl Default for ExecutionContext {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_variable_storage() {
190        let mut ctx = ExecutionContext::new();
191
192        ctx.set_variable("x", SqlValue::Integer(42));
193        assert_eq!(ctx.get_variable("x"), Some(&SqlValue::Integer(42)));
194        assert_eq!(ctx.get_variable("X"), Some(&SqlValue::Integer(42))); // Case insensitive
195        assert!(ctx.has_variable("x"));
196    }
197
198    #[test]
199    fn test_parameter_storage() {
200        let mut ctx = ExecutionContext::new();
201
202        ctx.set_parameter("param1", SqlValue::Integer(100));
203        assert_eq!(ctx.get_parameter("param1"), Some(&SqlValue::Integer(100)));
204        assert!(ctx.has_parameter("PARAM1"));
205    }
206
207    #[test]
208    fn test_get_value_precedence() {
209        let mut ctx = ExecutionContext::new();
210
211        // Variables take precedence over parameters
212        ctx.set_parameter("x", SqlValue::Integer(1));
213        ctx.set_variable("x", SqlValue::Integer(2));
214
215        assert_eq!(ctx.get_value("x"), Some(&SqlValue::Integer(2)));
216    }
217
218    #[test]
219    fn test_label_management() {
220        let mut ctx = ExecutionContext::new();
221
222        ctx.push_label("loop1");
223        assert!(ctx.has_label("loop1"));
224        assert!(ctx.has_label("LOOP1")); // Case insensitive
225
226        ctx.pop_label("loop1");
227        assert!(!ctx.has_label("loop1"));
228    }
229
230    #[test]
231    fn test_recursion_limit() {
232        let mut ctx = ExecutionContext::new();
233
234        for _ in 0..100 {
235            assert!(ctx.enter_recursion().is_ok());
236        }
237
238        // 101st call should fail
239        assert!(ctx.enter_recursion().is_err());
240    }
241}