sql_cli/execution/
context.rs1use anyhow::Result;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::data::datatable::DataTable;
11use crate::data::temp_table_registry::TempTableRegistry;
12
13#[derive(Clone)]
21pub struct ExecutionContext {
22 pub source_table: Arc<DataTable>,
24
25 pub temp_tables: TempTableRegistry,
27
28 pub variables: HashMap<String, String>,
30}
31
32impl ExecutionContext {
33 pub fn new(source_table: Arc<DataTable>) -> Self {
35 Self {
36 source_table,
37 temp_tables: TempTableRegistry::new(),
38 variables: HashMap::new(),
39 }
40 }
41
42 pub fn with_dual() -> Self {
44 Self::new(Arc::new(DataTable::dual()))
45 }
46
47 pub fn resolve_table(&self, name: &str) -> Arc<DataTable> {
58 if name.starts_with('#') {
59 self.temp_tables
61 .get(name)
62 .unwrap_or_else(|| self.source_table.clone())
63 } else if name.eq_ignore_ascii_case("DUAL") {
64 Arc::new(DataTable::dual())
66 } else {
67 self.source_table.clone()
69 }
70 }
71
72 pub fn resolve_table_strict(&self, name: &str) -> Result<Arc<DataTable>> {
74 if name.starts_with('#') {
75 self.temp_tables
76 .get(name)
77 .ok_or_else(|| anyhow::anyhow!("Temporary table '{}' not found", name))
78 } else if name.eq_ignore_ascii_case("DUAL") {
79 Ok(Arc::new(DataTable::dual()))
80 } else {
81 Ok(self.source_table.clone())
82 }
83 }
84
85 pub fn store_temp_table(&mut self, name: String, table: Arc<DataTable>) -> Result<()> {
87 self.temp_tables.insert(name, table)
88 }
89
90 pub fn has_temp_table(&self, name: &str) -> bool {
92 self.temp_tables.contains(name)
93 }
94
95 pub fn temp_table_names(&self) -> Vec<String> {
97 self.temp_tables.list_tables()
98 }
99
100 pub fn set_variable(&mut self, name: String, value: String) {
102 self.variables.insert(name, value);
103 }
104
105 pub fn get_variable(&self, name: &str) -> Option<&String> {
107 self.variables.get(name)
108 }
109
110 pub fn clear_temp_tables(&mut self) {
112 self.temp_tables = TempTableRegistry::new();
113 }
114
115 pub fn clear_variables(&mut self) {
117 self.variables.clear();
118 }
119
120 pub fn source_table_info(&self) -> (String, usize, usize) {
122 (
123 self.source_table.name.clone(),
124 self.source_table.row_count(),
125 self.source_table.column_count(),
126 )
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 fn create_test_table(name: &str, rows: usize) -> DataTable {
135 let mut table = DataTable::new(name);
136 table.add_column(
137 crate::data::datatable::DataColumn::new("id")
138 .with_type(crate::data::datatable::DataType::Integer),
139 );
140
141 for i in 0..rows {
142 let _ = table.add_row(crate::data::datatable::DataRow {
143 values: vec![crate::data::datatable::DataValue::Integer(i as i64)],
144 });
145 }
146
147 table
148 }
149
150 #[test]
151 fn test_new_context() {
152 let table = create_test_table("test", 10);
153 let ctx = ExecutionContext::new(Arc::new(table));
154
155 assert_eq!(ctx.source_table.name, "test");
156 assert_eq!(ctx.source_table.row_count(), 10);
157 assert_eq!(ctx.temp_tables.list_tables().len(), 0);
158 }
159
160 #[test]
161 fn test_dual_context() {
162 let ctx = ExecutionContext::with_dual();
163 assert_eq!(ctx.source_table.name, "DUAL");
164 assert_eq!(ctx.source_table.row_count(), 1);
165 }
166
167 #[test]
168 fn test_resolve_source_table() {
169 let table = create_test_table("customers", 5);
170 let ctx = ExecutionContext::new(Arc::new(table));
171
172 let resolved = ctx.resolve_table("customers");
173 assert_eq!(resolved.name, "customers");
174 assert_eq!(resolved.row_count(), 5);
175 }
176
177 #[test]
178 fn test_resolve_dual_table() {
179 let table = create_test_table("test", 10);
180 let ctx = ExecutionContext::new(Arc::new(table));
181
182 let resolved = ctx.resolve_table("DUAL");
183 assert_eq!(resolved.name, "DUAL");
184 assert_eq!(resolved.row_count(), 1);
185 }
186
187 #[test]
188 fn test_store_and_resolve_temp_table() {
189 let base_table = create_test_table("base", 10);
190 let mut ctx = ExecutionContext::new(Arc::new(base_table));
191
192 let temp_table = create_test_table("#temp1", 5);
194 ctx.store_temp_table("#temp1".to_string(), Arc::new(temp_table))
195 .unwrap();
196
197 assert!(ctx.has_temp_table("#temp1"));
199 assert_eq!(ctx.temp_table_names(), vec!["#temp1"]);
200
201 let resolved = ctx.resolve_table("#temp1");
203 assert_eq!(resolved.name, "#temp1");
204 assert_eq!(resolved.row_count(), 5);
205 }
206
207 #[test]
208 fn test_resolve_missing_temp_table_fallback() {
209 let base_table = create_test_table("base", 10);
210 let ctx = ExecutionContext::new(Arc::new(base_table));
211
212 let resolved = ctx.resolve_table("#nonexistent");
214 assert_eq!(resolved.name, "base");
215 }
216
217 #[test]
218 fn test_resolve_missing_temp_table_strict() {
219 let base_table = create_test_table("base", 10);
220 let ctx = ExecutionContext::new(Arc::new(base_table));
221
222 let result = ctx.resolve_table_strict("#nonexistent");
224 assert!(result.is_err());
225 assert!(result.unwrap_err().to_string().contains("not found"));
226 }
227
228 #[test]
229 fn test_variables() {
230 let table = create_test_table("test", 5);
231 let mut ctx = ExecutionContext::new(Arc::new(table));
232
233 ctx.set_variable("user_id".to_string(), "123".to_string());
235 ctx.set_variable("dept".to_string(), "sales".to_string());
236
237 assert_eq!(ctx.get_variable("user_id"), Some(&"123".to_string()));
239 assert_eq!(ctx.get_variable("dept"), Some(&"sales".to_string()));
240 assert_eq!(ctx.get_variable("nonexistent"), None);
241
242 ctx.clear_variables();
244 assert_eq!(ctx.get_variable("user_id"), None);
245 }
246
247 #[test]
248 fn test_clear_temp_tables() {
249 let base_table = create_test_table("base", 10);
250 let mut ctx = ExecutionContext::new(Arc::new(base_table));
251
252 ctx.store_temp_table(
254 "#temp1".to_string(),
255 Arc::new(create_test_table("#temp1", 5)),
256 )
257 .unwrap();
258 ctx.store_temp_table(
259 "#temp2".to_string(),
260 Arc::new(create_test_table("#temp2", 3)),
261 )
262 .unwrap();
263
264 assert_eq!(ctx.temp_table_names().len(), 2);
265
266 ctx.clear_temp_tables();
268 assert_eq!(ctx.temp_table_names().len(), 0);
269 }
270
271 #[test]
272 fn test_source_table_info() {
273 let table = create_test_table("sales", 100);
274 let ctx = ExecutionContext::new(Arc::new(table));
275
276 let (name, rows, cols) = ctx.source_table_info();
277 assert_eq!(name, "sales");
278 assert_eq!(rows, 100);
279 assert_eq!(cols, 1); }
281}