vibesql_executor/procedural/
context.rs

1//! Execution context for procedural statements
2//!
3//! Manages:
4//! - Local variables (DECLARE)
5//! - Parameters (IN, OUT, INOUT)
6//! - Scope management (nested blocks)
7//! - Label tracking (for LEAVE/ITERATE)
8//! - Recursion depth limiting
9
10use std::collections::HashMap;
11
12use vibesql_types::SqlValue;
13
14/// Maximum recursion depth for function/procedure calls
15const MAX_RECURSION_DEPTH: usize = 100;
16
17/// Control flow state returned by procedural statement execution
18#[derive(Debug, Clone, PartialEq)]
19pub enum ControlFlow {
20    /// Continue to next statement
21    Continue,
22    /// Return from function/procedure with value
23    Return(SqlValue),
24    /// Leave a labeled block/loop
25    Leave(String),
26    /// Iterate (continue) a labeled loop
27    Iterate(String),
28}
29
30/// Execution context for procedural statements
31#[derive(Debug, Clone)]
32pub struct ExecutionContext {
33    /// Local variables (DECLARE)
34    variables: HashMap<String, SqlValue>,
35    /// Parameters (IN, OUT, INOUT)
36    parameters: HashMap<String, SqlValue>,
37    /// Active labels for LEAVE/ITERATE
38    labels: HashMap<String, bool>,
39    /// Current recursion depth
40    recursion_depth: usize,
41    /// Maximum allowed recursion depth
42    max_recursion: usize,
43    /// Whether this is a function context (read-only, cannot modify data)
44    pub(crate) is_function: bool,
45    /// Track which parameters are OUT/INOUT and their target variable names
46    /// Key: parameter name (uppercase), Value: target variable name (as specified in CALL)
47    out_parameters: HashMap<String, String>,
48}
49
50impl ExecutionContext {
51    /// Create a new execution context
52    pub fn new() -> Self {
53        Self {
54            variables: HashMap::new(),
55            parameters: HashMap::new(),
56            labels: HashMap::new(),
57            recursion_depth: 0,
58            max_recursion: MAX_RECURSION_DEPTH,
59            is_function: false,
60            out_parameters: HashMap::new(),
61        }
62    }
63
64    /// Create a new context with specified recursion depth
65    pub fn with_recursion_depth(depth: usize) -> Self {
66        Self {
67            variables: HashMap::new(),
68            parameters: HashMap::new(),
69            labels: HashMap::new(),
70            recursion_depth: depth,
71            max_recursion: MAX_RECURSION_DEPTH,
72            is_function: false,
73            out_parameters: HashMap::new(),
74        }
75    }
76
77    /// Set a local variable value
78    pub fn set_variable(&mut self, name: &str, value: SqlValue) {
79        self.variables.insert(name.to_uppercase(), value);
80    }
81
82    /// Get a local variable value
83    pub fn get_variable(&self, name: &str) -> Option<&SqlValue> {
84        self.variables.get(&name.to_uppercase())
85    }
86
87    /// Check if a variable exists
88    pub fn has_variable(&self, name: &str) -> bool {
89        self.variables.contains_key(&name.to_uppercase())
90    }
91
92    /// Set a parameter value
93    pub fn set_parameter(&mut self, name: &str, value: SqlValue) {
94        self.parameters.insert(name.to_uppercase(), value);
95    }
96
97    /// Get a parameter value
98    pub fn get_parameter(&self, name: &str) -> Option<&SqlValue> {
99        self.parameters.get(&name.to_uppercase())
100    }
101
102    /// Get a mutable reference to a parameter value
103    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut SqlValue> {
104        self.parameters.get_mut(&name.to_uppercase())
105    }
106
107    /// Check if a parameter exists
108    pub fn has_parameter(&self, name: &str) -> bool {
109        self.parameters.contains_key(&name.to_uppercase())
110    }
111
112    /// Get a value (variable or parameter)
113    pub fn get_value(&self, name: &str) -> Option<&SqlValue> {
114        self.get_variable(name).or_else(|| self.get_parameter(name))
115    }
116
117    /// Push a label onto the stack
118    pub fn push_label(&mut self, label: &str) {
119        self.labels.insert(label.to_uppercase(), true);
120    }
121
122    /// Pop a label from the stack
123    pub fn pop_label(&mut self, label: &str) {
124        self.labels.remove(&label.to_uppercase());
125    }
126
127    /// Check if a label is active
128    pub fn has_label(&self, label: &str) -> bool {
129        self.labels.contains_key(&label.to_uppercase())
130    }
131
132    /// Increment recursion depth and check limit
133    pub fn enter_recursion(&mut self) -> Result<(), String> {
134        self.recursion_depth += 1;
135        if self.recursion_depth > self.max_recursion {
136            Err(format!("Maximum recursion depth ({}) exceeded", self.max_recursion))
137        } else {
138            Ok(())
139        }
140    }
141
142    /// Decrement recursion depth
143    pub fn exit_recursion(&mut self) {
144        if self.recursion_depth > 0 {
145            self.recursion_depth -= 1;
146        }
147    }
148
149    /// Get current recursion depth
150    pub fn recursion_depth(&self) -> usize {
151        self.recursion_depth
152    }
153
154    /// Get all parameters (for OUT/INOUT return)
155    pub fn get_all_parameters(&self) -> &HashMap<String, SqlValue> {
156        &self.parameters
157    }
158
159    /// Register an OUT or INOUT parameter with its target variable name
160    pub fn register_out_parameter(&mut self, param_name: &str, target_var_name: String) {
161        self.out_parameters.insert(param_name.to_uppercase(), target_var_name);
162    }
163
164    /// Get all OUT/INOUT parameters for return
165    /// Returns a HashMap of parameter name -> target variable name
166    pub fn get_out_parameters(&self) -> &HashMap<String, String> {
167        &self.out_parameters
168    }
169
170    /// Get all available variable and parameter names (for error messages)
171    pub fn get_available_names(&self) -> Vec<String> {
172        let mut names: Vec<String> =
173            self.variables.keys().chain(self.parameters.keys()).cloned().collect();
174        names.sort();
175        names
176    }
177}
178
179impl Default for ExecutionContext {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_variable_storage() {
191        let mut ctx = ExecutionContext::new();
192
193        ctx.set_variable("x", SqlValue::Integer(42));
194        assert_eq!(ctx.get_variable("x"), Some(&SqlValue::Integer(42)));
195        assert_eq!(ctx.get_variable("X"), Some(&SqlValue::Integer(42))); // Case insensitive
196        assert!(ctx.has_variable("x"));
197    }
198
199    #[test]
200    fn test_parameter_storage() {
201        let mut ctx = ExecutionContext::new();
202
203        ctx.set_parameter("param1", SqlValue::Integer(100));
204        assert_eq!(ctx.get_parameter("param1"), Some(&SqlValue::Integer(100)));
205        assert!(ctx.has_parameter("PARAM1"));
206    }
207
208    #[test]
209    fn test_get_value_precedence() {
210        let mut ctx = ExecutionContext::new();
211
212        // Variables take precedence over parameters
213        ctx.set_parameter("x", SqlValue::Integer(1));
214        ctx.set_variable("x", SqlValue::Integer(2));
215
216        assert_eq!(ctx.get_value("x"), Some(&SqlValue::Integer(2)));
217    }
218
219    #[test]
220    fn test_label_management() {
221        let mut ctx = ExecutionContext::new();
222
223        ctx.push_label("loop1");
224        assert!(ctx.has_label("loop1"));
225        assert!(ctx.has_label("LOOP1")); // Case insensitive
226
227        ctx.pop_label("loop1");
228        assert!(!ctx.has_label("loop1"));
229    }
230
231    #[test]
232    fn test_recursion_limit() {
233        let mut ctx = ExecutionContext::new();
234
235        for _ in 0..100 {
236            assert!(ctx.enter_recursion().is_ok());
237        }
238
239        // 101st call should fail
240        assert!(ctx.enter_recursion().is_err());
241    }
242}