sql_cli/execution/
context.rs

1//! Execution context for managing state during query execution
2//!
3//! This module provides a unified context for both script mode and single query mode,
4//! managing temp tables, variables, and source table resolution.
5
6use anyhow::Result;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::data::datatable::DataTable;
11use crate::data::temp_table_registry::TempTableRegistry;
12
13/// Context for statement execution
14///
15/// Holds all state needed to execute SQL statements including:
16/// - Base source table(s)
17/// - Temporary tables (#table syntax)
18/// - Variables (for future use)
19/// - Execution configuration
20#[derive(Clone)]
21pub struct ExecutionContext {
22    /// The primary source table (typically from CSV/JSON file)
23    pub source_table: Arc<DataTable>,
24
25    /// Registry of temporary tables created during script execution
26    pub temp_tables: TempTableRegistry,
27
28    /// Variables for template expansion and substitution
29    pub variables: HashMap<String, String>,
30}
31
32impl ExecutionContext {
33    /// Create a new execution context with a source table
34    pub fn new(source_table: Arc<DataTable>) -> Self {
35        Self {
36            source_table,
37            temp_tables: TempTableRegistry::new(),
38            variables: HashMap::new(),
39        }
40    }
41
42    /// Create a context with DUAL table (for queries without FROM clause)
43    pub fn with_dual() -> Self {
44        Self::new(Arc::new(DataTable::dual()))
45    }
46
47    /// Resolve a table name to an Arc<DataTable>
48    ///
49    /// # Arguments
50    /// * `name` - Table name, which may be:
51    ///   - A temp table (starts with '#')
52    ///   - The base table name
53    ///   - "DUAL" for the special DUAL table
54    ///
55    /// # Returns
56    /// Arc<DataTable> for the requested table, or the source table as fallback
57    pub fn resolve_table(&self, name: &str) -> Arc<DataTable> {
58        if name.starts_with('#') {
59            // Temporary table lookup
60            self.temp_tables
61                .get(name)
62                .unwrap_or_else(|| self.source_table.clone())
63        } else if name.eq_ignore_ascii_case("DUAL") {
64            // Special DUAL table
65            Arc::new(DataTable::dual())
66        } else {
67            // Base source table
68            self.source_table.clone()
69        }
70    }
71
72    /// Try to resolve a table, returning an error if temp table not found
73    pub fn resolve_table_strict(&self, name: &str) -> Result<Arc<DataTable>> {
74        if name.starts_with('#') {
75            self.temp_tables
76                .get(name)
77                .ok_or_else(|| anyhow::anyhow!("Temporary table '{}' not found", name))
78        } else if name.eq_ignore_ascii_case("DUAL") {
79            Ok(Arc::new(DataTable::dual()))
80        } else {
81            Ok(self.source_table.clone())
82        }
83    }
84
85    /// Store a result as a temporary table
86    pub fn store_temp_table(&mut self, name: String, table: Arc<DataTable>) -> Result<()> {
87        self.temp_tables.insert(name, table)
88    }
89
90    /// Check if a temporary table exists
91    pub fn has_temp_table(&self, name: &str) -> bool {
92        self.temp_tables.contains(name)
93    }
94
95    /// Get all temporary table names
96    pub fn temp_table_names(&self) -> Vec<String> {
97        self.temp_tables.list_tables()
98    }
99
100    /// Set a variable for template expansion
101    pub fn set_variable(&mut self, name: String, value: String) {
102        self.variables.insert(name, value);
103    }
104
105    /// Get a variable value
106    pub fn get_variable(&self, name: &str) -> Option<&String> {
107        self.variables.get(name)
108    }
109
110    /// Clear all temporary tables (useful between script executions)
111    pub fn clear_temp_tables(&mut self) {
112        self.temp_tables = TempTableRegistry::new();
113    }
114
115    /// Clear all variables
116    pub fn clear_variables(&mut self) {
117        self.variables.clear();
118    }
119
120    /// Get source table metadata
121    pub fn source_table_info(&self) -> (String, usize, usize) {
122        (
123            self.source_table.name.clone(),
124            self.source_table.row_count(),
125            self.source_table.column_count(),
126        )
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn create_test_table(name: &str, rows: usize) -> DataTable {
135        let mut table = DataTable::new(name);
136        table.add_column(
137            crate::data::datatable::DataColumn::new("id")
138                .with_type(crate::data::datatable::DataType::Integer),
139        );
140
141        for i in 0..rows {
142            let _ = table.add_row(crate::data::datatable::DataRow {
143                values: vec![crate::data::datatable::DataValue::Integer(i as i64)],
144            });
145        }
146
147        table
148    }
149
150    #[test]
151    fn test_new_context() {
152        let table = create_test_table("test", 10);
153        let ctx = ExecutionContext::new(Arc::new(table));
154
155        assert_eq!(ctx.source_table.name, "test");
156        assert_eq!(ctx.source_table.row_count(), 10);
157        assert_eq!(ctx.temp_tables.list_tables().len(), 0);
158    }
159
160    #[test]
161    fn test_dual_context() {
162        let ctx = ExecutionContext::with_dual();
163        assert_eq!(ctx.source_table.name, "DUAL");
164        assert_eq!(ctx.source_table.row_count(), 1);
165    }
166
167    #[test]
168    fn test_resolve_source_table() {
169        let table = create_test_table("customers", 5);
170        let ctx = ExecutionContext::new(Arc::new(table));
171
172        let resolved = ctx.resolve_table("customers");
173        assert_eq!(resolved.name, "customers");
174        assert_eq!(resolved.row_count(), 5);
175    }
176
177    #[test]
178    fn test_resolve_dual_table() {
179        let table = create_test_table("test", 10);
180        let ctx = ExecutionContext::new(Arc::new(table));
181
182        let resolved = ctx.resolve_table("DUAL");
183        assert_eq!(resolved.name, "DUAL");
184        assert_eq!(resolved.row_count(), 1);
185    }
186
187    #[test]
188    fn test_store_and_resolve_temp_table() {
189        let base_table = create_test_table("base", 10);
190        let mut ctx = ExecutionContext::new(Arc::new(base_table));
191
192        // Store a temp table
193        let temp_table = create_test_table("#temp1", 5);
194        ctx.store_temp_table("#temp1".to_string(), Arc::new(temp_table))
195            .unwrap();
196
197        // Verify it exists
198        assert!(ctx.has_temp_table("#temp1"));
199        assert_eq!(ctx.temp_table_names(), vec!["#temp1"]);
200
201        // Resolve it
202        let resolved = ctx.resolve_table("#temp1");
203        assert_eq!(resolved.name, "#temp1");
204        assert_eq!(resolved.row_count(), 5);
205    }
206
207    #[test]
208    fn test_resolve_missing_temp_table_fallback() {
209        let base_table = create_test_table("base", 10);
210        let ctx = ExecutionContext::new(Arc::new(base_table));
211
212        // Should fall back to source table
213        let resolved = ctx.resolve_table("#nonexistent");
214        assert_eq!(resolved.name, "base");
215    }
216
217    #[test]
218    fn test_resolve_missing_temp_table_strict() {
219        let base_table = create_test_table("base", 10);
220        let ctx = ExecutionContext::new(Arc::new(base_table));
221
222        // Should error
223        let result = ctx.resolve_table_strict("#nonexistent");
224        assert!(result.is_err());
225        assert!(result.unwrap_err().to_string().contains("not found"));
226    }
227
228    #[test]
229    fn test_variables() {
230        let table = create_test_table("test", 5);
231        let mut ctx = ExecutionContext::new(Arc::new(table));
232
233        // Set variables
234        ctx.set_variable("user_id".to_string(), "123".to_string());
235        ctx.set_variable("dept".to_string(), "sales".to_string());
236
237        // Get variables
238        assert_eq!(ctx.get_variable("user_id"), Some(&"123".to_string()));
239        assert_eq!(ctx.get_variable("dept"), Some(&"sales".to_string()));
240        assert_eq!(ctx.get_variable("nonexistent"), None);
241
242        // Clear variables
243        ctx.clear_variables();
244        assert_eq!(ctx.get_variable("user_id"), None);
245    }
246
247    #[test]
248    fn test_clear_temp_tables() {
249        let base_table = create_test_table("base", 10);
250        let mut ctx = ExecutionContext::new(Arc::new(base_table));
251
252        // Add temp tables
253        ctx.store_temp_table(
254            "#temp1".to_string(),
255            Arc::new(create_test_table("#temp1", 5)),
256        )
257        .unwrap();
258        ctx.store_temp_table(
259            "#temp2".to_string(),
260            Arc::new(create_test_table("#temp2", 3)),
261        )
262        .unwrap();
263
264        assert_eq!(ctx.temp_table_names().len(), 2);
265
266        // Clear
267        ctx.clear_temp_tables();
268        assert_eq!(ctx.temp_table_names().len(), 0);
269    }
270
271    #[test]
272    fn test_source_table_info() {
273        let table = create_test_table("sales", 100);
274        let ctx = ExecutionContext::new(Arc::new(table));
275
276        let (name, rows, cols) = ctx.source_table_info();
277        assert_eq!(name, "sales");
278        assert_eq!(rows, 100);
279        assert_eq!(cols, 1); // Just 'id' column
280    }
281}