vibesql_executor/procedural/
context.rs

1//! Execution context for procedural statements
2//!
3//! Manages:
4//! - Local variables (DECLARE)
5//! - Parameters (IN, OUT, INOUT)
6//! - Scope management (nested blocks)
7//! - Label tracking (for LEAVE/ITERATE)
8//! - Recursion depth limiting
9
10use std::collections::HashMap;
11use vibesql_types::SqlValue;
12
13/// Maximum recursion depth for function/procedure calls
14const MAX_RECURSION_DEPTH: usize = 100;
15
16/// Control flow state returned by procedural statement execution
17#[derive(Debug, Clone, PartialEq)]
18pub enum ControlFlow {
19    /// Continue to next statement
20    Continue,
21    /// Return from function/procedure with value
22    Return(SqlValue),
23    /// Leave a labeled block/loop
24    Leave(String),
25    /// Iterate (continue) a labeled loop
26    Iterate(String),
27}
28
29/// Execution context for procedural statements
30#[derive(Debug, Clone)]
31pub struct ExecutionContext {
32    /// Local variables (DECLARE)
33    variables: HashMap<String, SqlValue>,
34    /// Parameters (IN, OUT, INOUT)
35    parameters: HashMap<String, SqlValue>,
36    /// Active labels for LEAVE/ITERATE
37    labels: HashMap<String, bool>,
38    /// Current recursion depth
39    recursion_depth: usize,
40    /// Maximum allowed recursion depth
41    max_recursion: usize,
42    /// Whether this is a function context (read-only, cannot modify data)
43    pub(crate) is_function: bool,
44    /// Track which parameters are OUT/INOUT and their target variable names
45    /// Key: parameter name (uppercase), Value: target variable name (as specified in CALL)
46    out_parameters: HashMap<String, String>,
47}
48
49impl ExecutionContext {
50    /// Create a new execution context
51    pub fn new() -> Self {
52        Self {
53            variables: HashMap::new(),
54            parameters: HashMap::new(),
55            labels: HashMap::new(),
56            recursion_depth: 0,
57            max_recursion: MAX_RECURSION_DEPTH,
58            is_function: false,
59            out_parameters: HashMap::new(),
60        }
61    }
62
63    /// Create a new context with specified recursion depth
64    pub fn with_recursion_depth(depth: usize) -> Self {
65        Self {
66            variables: HashMap::new(),
67            parameters: HashMap::new(),
68            labels: HashMap::new(),
69            recursion_depth: depth,
70            max_recursion: MAX_RECURSION_DEPTH,
71            is_function: false,
72            out_parameters: HashMap::new(),
73        }
74    }
75
76    /// Set a local variable value
77    pub fn set_variable(&mut self, name: &str, value: SqlValue) {
78        self.variables.insert(name.to_uppercase(), value);
79    }
80
81    /// Get a local variable value
82    pub fn get_variable(&self, name: &str) -> Option<&SqlValue> {
83        self.variables.get(&name.to_uppercase())
84    }
85
86    /// Check if a variable exists
87    pub fn has_variable(&self, name: &str) -> bool {
88        self.variables.contains_key(&name.to_uppercase())
89    }
90
91    /// Set a parameter value
92    pub fn set_parameter(&mut self, name: &str, value: SqlValue) {
93        self.parameters.insert(name.to_uppercase(), value);
94    }
95
96    /// Get a parameter value
97    pub fn get_parameter(&self, name: &str) -> Option<&SqlValue> {
98        self.parameters.get(&name.to_uppercase())
99    }
100
101    /// Get a mutable reference to a parameter value
102    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut SqlValue> {
103        self.parameters.get_mut(&name.to_uppercase())
104    }
105
106    /// Check if a parameter exists
107    pub fn has_parameter(&self, name: &str) -> bool {
108        self.parameters.contains_key(&name.to_uppercase())
109    }
110
111    /// Get a value (variable or parameter)
112    pub fn get_value(&self, name: &str) -> Option<&SqlValue> {
113        self.get_variable(name).or_else(|| self.get_parameter(name))
114    }
115
116    /// Push a label onto the stack
117    pub fn push_label(&mut self, label: &str) {
118        self.labels.insert(label.to_uppercase(), true);
119    }
120
121    /// Pop a label from the stack
122    pub fn pop_label(&mut self, label: &str) {
123        self.labels.remove(&label.to_uppercase());
124    }
125
126    /// Check if a label is active
127    pub fn has_label(&self, label: &str) -> bool {
128        self.labels.contains_key(&label.to_uppercase())
129    }
130
131    /// Increment recursion depth and check limit
132    pub fn enter_recursion(&mut self) -> Result<(), String> {
133        self.recursion_depth += 1;
134        if self.recursion_depth > self.max_recursion {
135            Err(format!(
136                "Maximum recursion depth ({}) exceeded",
137                self.max_recursion
138            ))
139        } else {
140            Ok(())
141        }
142    }
143
144    /// Decrement recursion depth
145    pub fn exit_recursion(&mut self) {
146        if self.recursion_depth > 0 {
147            self.recursion_depth -= 1;
148        }
149    }
150
151    /// Get current recursion depth
152    pub fn recursion_depth(&self) -> usize {
153        self.recursion_depth
154    }
155
156    /// Get all parameters (for OUT/INOUT return)
157    pub fn get_all_parameters(&self) -> &HashMap<String, SqlValue> {
158        &self.parameters
159    }
160
161    /// Register an OUT or INOUT parameter with its target variable name
162    pub fn register_out_parameter(&mut self, param_name: &str, target_var_name: String) {
163        self.out_parameters.insert(param_name.to_uppercase(), target_var_name);
164    }
165
166    /// Get all OUT/INOUT parameters for return
167    /// Returns a HashMap of parameter name -> target variable name
168    pub fn get_out_parameters(&self) -> &HashMap<String, String> {
169        &self.out_parameters
170    }
171
172    /// Get all available variable and parameter names (for error messages)
173    pub fn get_available_names(&self) -> Vec<String> {
174        let mut names: Vec<String> = self.variables.keys()
175            .chain(self.parameters.keys())
176            .cloned()
177            .collect();
178        names.sort();
179        names
180    }
181}
182
183impl Default for ExecutionContext {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_variable_storage() {
195        let mut ctx = ExecutionContext::new();
196
197        ctx.set_variable("x", SqlValue::Integer(42));
198        assert_eq!(ctx.get_variable("x"), Some(&SqlValue::Integer(42)));
199        assert_eq!(ctx.get_variable("X"), Some(&SqlValue::Integer(42))); // Case insensitive
200        assert!(ctx.has_variable("x"));
201    }
202
203    #[test]
204    fn test_parameter_storage() {
205        let mut ctx = ExecutionContext::new();
206
207        ctx.set_parameter("param1", SqlValue::Integer(100));
208        assert_eq!(ctx.get_parameter("param1"), Some(&SqlValue::Integer(100)));
209        assert!(ctx.has_parameter("PARAM1"));
210    }
211
212    #[test]
213    fn test_get_value_precedence() {
214        let mut ctx = ExecutionContext::new();
215
216        // Variables take precedence over parameters
217        ctx.set_parameter("x", SqlValue::Integer(1));
218        ctx.set_variable("x", SqlValue::Integer(2));
219
220        assert_eq!(ctx.get_value("x"), Some(&SqlValue::Integer(2)));
221    }
222
223    #[test]
224    fn test_label_management() {
225        let mut ctx = ExecutionContext::new();
226
227        ctx.push_label("loop1");
228        assert!(ctx.has_label("loop1"));
229        assert!(ctx.has_label("LOOP1")); // Case insensitive
230
231        ctx.pop_label("loop1");
232        assert!(!ctx.has_label("loop1"));
233    }
234
235    #[test]
236    fn test_recursion_limit() {
237        let mut ctx = ExecutionContext::new();
238
239        for _ in 0..100 {
240            assert!(ctx.enter_recursion().is_ok());
241        }
242
243        // 101st call should fail
244        assert!(ctx.enter_recursion().is_err());
245    }
246}