1use anyhow::Result;
8use std::sync::Arc;
9use std::time::Instant;
10
11use crate::data::data_view::DataView;
12use crate::data::datatable::DataTable;
13use crate::data::query_engine::QueryEngine;
14use crate::query_plan::{create_pipeline_with_config, IntoClauseRemover};
15use crate::sql::parser::ast::SelectStatement;
16
17use super::config::ExecutionConfig;
18use super::context::ExecutionContext;
19
20#[derive(Debug)]
22pub struct ExecutionResult {
23 pub dataview: DataView,
25
26 pub stats: ExecutionStats,
28
29 pub transformed_ast: Option<SelectStatement>,
31}
32
33#[derive(Debug, Clone)]
35pub struct ExecutionStats {
36 pub preprocessing_time_ms: f64,
38
39 pub execution_time_ms: f64,
41
42 pub total_time_ms: f64,
44
45 pub row_count: usize,
47
48 pub column_count: usize,
50
51 pub preprocessing_applied: bool,
53}
54
55impl ExecutionStats {
56 fn new() -> Self {
57 Self {
58 preprocessing_time_ms: 0.0,
59 execution_time_ms: 0.0,
60 total_time_ms: 0.0,
61 row_count: 0,
62 column_count: 0,
63 preprocessing_applied: false,
64 }
65 }
66}
67
68pub struct StatementExecutor {
77 config: ExecutionConfig,
78}
79
80impl StatementExecutor {
81 pub fn new() -> Self {
83 Self {
84 config: ExecutionConfig::default(),
85 }
86 }
87
88 pub fn with_config(config: ExecutionConfig) -> Self {
90 Self { config }
91 }
92
93 pub fn execute(
110 &self,
111 stmt: SelectStatement,
112 context: &mut ExecutionContext,
113 ) -> Result<ExecutionResult> {
114 let total_start = Instant::now();
115 let mut stats = ExecutionStats::new();
116
117 let into_table_name = stmt.into_table.as_ref().map(|it| it.name.clone());
119
120 let source_table = if let Some(ref from_source) = stmt.from_source {
126 match from_source {
127 crate::sql::parser::ast::TableSource::Table(table_name) => {
128 context.resolve_table(table_name)
130 }
131 crate::sql::parser::ast::TableSource::DerivedTable { query, .. } => {
132 Self::extract_base_table(&**query, context)
134 }
135 crate::sql::parser::ast::TableSource::Pivot { source, .. } => {
136 Self::extract_base_table_from_source(source, context, &stmt)
138 }
139 }
140 } else {
141 #[allow(deprecated)]
143 if let Some(ref from_table) = stmt.from_table {
144 context.resolve_table(from_table)
145 } else {
146 Arc::new(DataTable::dual())
148 }
149 };
150
151 let preprocess_start = Instant::now();
153 let (transformed_stmt, preprocessing_applied) = self.apply_preprocessing(stmt)?;
154 stats.preprocessing_time_ms = preprocess_start.elapsed().as_secs_f64() * 1000.0;
155 stats.preprocessing_applied = preprocessing_applied;
156
157 let final_source_table = if !transformed_stmt.ctes.is_empty() {
162 Self::extract_base_table(&transformed_stmt, context)
165 } else {
166 source_table
167 };
168
169 let exec_start = Instant::now();
171 let result_view =
172 self.execute_ast(transformed_stmt.clone(), final_source_table, context)?;
173 stats.execution_time_ms = exec_start.elapsed().as_secs_f64() * 1000.0;
174
175 if let Some(table_name) = into_table_name {
177 let engine = QueryEngine::with_case_insensitive(self.config.case_insensitive);
179 let temp_table = engine.materialize_view(result_view.clone())?;
180
181 context.store_temp_table(table_name.clone(), Arc::new(temp_table))?;
183 tracing::debug!("Stored temp table: {}", table_name);
184 }
185
186 stats.total_time_ms = total_start.elapsed().as_secs_f64() * 1000.0;
188 stats.row_count = result_view.row_count();
189 stats.column_count = result_view.column_count();
190
191 Ok(ExecutionResult {
192 dataview: result_view,
193 stats,
194 transformed_ast: Some(transformed_stmt),
195 })
196 }
197
198 fn apply_preprocessing(&self, mut stmt: SelectStatement) -> Result<(SelectStatement, bool)> {
202 let has_from_clause = if stmt.from_source.is_some() {
205 true
206 } else {
207 #[allow(deprecated)]
209 {
210 stmt.from_table.is_some()
211 || stmt.from_subquery.is_some()
212 || stmt.from_function.is_some()
213 }
214 };
215
216 if !has_from_clause {
217 return Ok((stmt, false));
219 }
220
221 let mut pipeline = create_pipeline_with_config(
223 self.config.show_preprocessing,
224 self.config.show_sql_transformations,
225 self.config.transformer_config.clone(),
226 );
227
228 match pipeline.process(stmt.clone()) {
230 Ok(transformed) => {
231 let final_stmt = if transformed.into_table.is_some() {
233 IntoClauseRemover::remove_into_clause(transformed)
234 } else {
235 transformed
236 };
237
238 Ok((final_stmt, true))
239 }
240 Err(e) => {
241 tracing::debug!("Preprocessing failed: {}, using original statement", e);
243
244 let fallback = if stmt.into_table.is_some() {
246 IntoClauseRemover::remove_into_clause(stmt)
247 } else {
248 stmt
249 };
250
251 Ok((fallback, false))
252 }
253 }
254 }
255
256 fn execute_ast(
261 &self,
262 stmt: SelectStatement,
263 source_table: Arc<DataTable>,
264 context: &ExecutionContext,
265 ) -> Result<DataView> {
266 let engine = QueryEngine::with_case_insensitive(self.config.case_insensitive);
268
269 engine.execute_statement_with_temp_tables(source_table, stmt, Some(&context.temp_tables))
272 }
273
274 pub fn config(&self) -> &ExecutionConfig {
276 &self.config
277 }
278
279 pub fn set_config(&mut self, config: ExecutionConfig) {
281 self.config = config;
282 }
283
284 fn extract_base_table(stmt: &SelectStatement, context: &ExecutionContext) -> Arc<DataTable> {
287 if !stmt.ctes.is_empty() {
289 return Self::extract_base_table_from_ctes(stmt, context);
291 }
292
293 if let Some(ref from_source) = stmt.from_source {
294 Self::extract_base_table_from_source(from_source, context, stmt)
295 } else {
296 #[allow(deprecated)]
298 if let Some(ref from_table) = stmt.from_table {
299 context.resolve_table(from_table)
300 } else {
301 Arc::new(DataTable::dual())
302 }
303 }
304 }
305
306 fn extract_base_table_from_ctes(
308 stmt: &SelectStatement,
309 context: &ExecutionContext,
310 ) -> Arc<DataTable> {
311 use crate::sql::parser::ast::CTEType;
312
313 if let Some(ref from_source) = stmt.from_source {
315 match from_source {
316 crate::sql::parser::ast::TableSource::Table(table_name) => {
317 for cte in &stmt.ctes {
319 if &cte.name == table_name {
320 if let CTEType::Standard(cte_query) = &cte.cte_type {
322 return Self::extract_base_table(cte_query, context);
323 }
324 }
325 }
326 context.resolve_table(table_name)
328 }
329 crate::sql::parser::ast::TableSource::DerivedTable { query, .. } => {
330 Self::extract_base_table(&**query, context)
331 }
332 crate::sql::parser::ast::TableSource::Pivot { source, .. } => {
333 Self::extract_base_table_from_source(&**source, context, stmt)
334 }
335 }
336 } else {
337 Arc::new(DataTable::dual())
338 }
339 }
340
341 fn extract_base_table_from_source(
343 source: &crate::sql::parser::ast::TableSource,
344 context: &ExecutionContext,
345 stmt: &SelectStatement,
346 ) -> Arc<DataTable> {
347 match source {
348 crate::sql::parser::ast::TableSource::Table(table_name) => {
349 for cte in &stmt.ctes {
351 if &cte.name == table_name {
352 use crate::sql::parser::ast::CTEType;
354 if let CTEType::Standard(cte_query) = &cte.cte_type {
355 return Self::extract_base_table(cte_query, context);
356 }
357 }
358 }
359 context.resolve_table(table_name)
361 }
362 crate::sql::parser::ast::TableSource::DerivedTable { query, .. } => {
363 Self::extract_base_table(&**query, context)
365 }
366 crate::sql::parser::ast::TableSource::Pivot { source, .. } => {
367 Self::extract_base_table_from_source(&**source, context, stmt)
369 }
370 }
371 }
372}
373
374impl Default for StatementExecutor {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use crate::data::datatable::{DataColumn, DataRow, DataType, DataValue};
384 use crate::sql::recursive_parser::Parser;
385
386 fn create_test_table(name: &str, rows: usize) -> DataTable {
387 let mut table = DataTable::new(name);
388 table.add_column(DataColumn::new("id").with_type(DataType::Integer));
389 table.add_column(DataColumn::new("name").with_type(DataType::String));
390
391 for i in 0..rows {
392 let _ = table.add_row(DataRow {
393 values: vec![
394 DataValue::Integer(i as i64),
395 DataValue::String(format!("name_{}", i)),
396 ],
397 });
398 }
399
400 table
401 }
402
403 #[test]
404 fn test_new_executor() {
405 let executor = StatementExecutor::new();
406 assert!(!executor.config().case_insensitive);
407 assert!(!executor.config().show_preprocessing);
408 }
409
410 #[test]
411 fn test_executor_with_config() {
412 let config = ExecutionConfig::new()
413 .with_case_insensitive(true)
414 .with_show_preprocessing(true);
415
416 let executor = StatementExecutor::with_config(config);
417 assert!(executor.config().case_insensitive);
418 assert!(executor.config().show_preprocessing);
419 }
420
421 #[test]
422 fn test_execute_simple_select() {
423 let table = create_test_table("test", 10);
424 let mut context = ExecutionContext::new(Arc::new(table));
425 let executor = StatementExecutor::new();
426
427 let mut parser = Parser::new("SELECT id, name FROM test WHERE id < 5");
429 let stmt = parser.parse().unwrap();
430
431 let result = executor.execute(stmt, &mut context).unwrap();
432
433 assert_eq!(result.dataview.row_count(), 5);
434 assert_eq!(result.dataview.column_count(), 2);
435 assert!(result.stats.total_time_ms >= 0.0);
436 }
437
438 #[test]
439 fn test_execute_select_star() {
440 let table = create_test_table("test", 5);
441 let mut context = ExecutionContext::new(Arc::new(table));
442 let executor = StatementExecutor::new();
443
444 let mut parser = Parser::new("SELECT * FROM test");
445 let stmt = parser.parse().unwrap();
446
447 let result = executor.execute(stmt, &mut context).unwrap();
448
449 assert_eq!(result.dataview.row_count(), 5);
450 assert_eq!(result.dataview.column_count(), 2);
451 }
452
453 #[test]
454 fn test_execute_with_dual() {
455 let table = create_test_table("test", 5);
456 let mut context = ExecutionContext::new(Arc::new(table));
457 let executor = StatementExecutor::new();
458
459 let mut parser = Parser::new("SELECT 1+1 as result");
461 let stmt = parser.parse().unwrap();
462
463 let result = executor.execute(stmt, &mut context).unwrap();
464
465 assert_eq!(result.dataview.row_count(), 1);
466 assert_eq!(result.dataview.column_count(), 1);
467 }
468
469 #[test]
470 fn test_execute_with_temp_table() {
471 let base_table = create_test_table("base", 10);
472 let mut context = ExecutionContext::new(Arc::new(base_table));
473 let executor = StatementExecutor::new();
474
475 let temp_table = create_test_table("#temp", 3);
477 context
478 .store_temp_table("#temp".to_string(), Arc::new(temp_table))
479 .unwrap();
480
481 let mut parser = Parser::new("SELECT * FROM #temp");
483 let stmt = parser.parse().unwrap();
484
485 let result = executor.execute(stmt, &mut context).unwrap();
486
487 assert_eq!(result.dataview.row_count(), 3);
488 }
489
490 #[test]
491 fn test_preprocessing_applied_with_from() {
492 let table = create_test_table("test", 10);
493 let mut context = ExecutionContext::new(Arc::new(table));
494 let executor = StatementExecutor::new();
495
496 let mut parser = Parser::new("SELECT id FROM test WHERE id > 0");
498 let stmt = parser.parse().unwrap();
499
500 let result = executor.execute(stmt, &mut context).unwrap();
501
502 assert!(result.stats.preprocessing_time_ms >= 0.0);
504 }
505
506 #[test]
507 fn test_no_preprocessing_without_from() {
508 let table = create_test_table("test", 10);
509 let mut context = ExecutionContext::new(Arc::new(table));
510 let executor = StatementExecutor::new();
511
512 let mut parser = Parser::new("SELECT 42 as answer");
514 let stmt = parser.parse().unwrap();
515
516 let result = executor.execute(stmt, &mut context).unwrap();
517
518 assert!(!result.stats.preprocessing_applied);
520 }
521
522 #[test]
523 fn test_execution_stats() {
524 let table = create_test_table("test", 100);
525 let mut context = ExecutionContext::new(Arc::new(table));
526 let executor = StatementExecutor::new();
527
528 let mut parser = Parser::new("SELECT * FROM test WHERE id < 50");
529 let stmt = parser.parse().unwrap();
530
531 let result = executor.execute(stmt, &mut context).unwrap();
532
533 let stats = result.stats;
534 assert_eq!(stats.row_count, 50);
535 assert_eq!(stats.column_count, 2);
536 assert!(stats.total_time_ms >= 0.0);
537 assert!(stats.total_time_ms >= stats.preprocessing_time_ms);
538 assert!(stats.total_time_ms >= stats.execution_time_ms);
539 }
540
541 #[test]
542 fn test_case_insensitive_execution() {
543 let table = create_test_table("test", 10);
544 let mut context = ExecutionContext::new(Arc::new(table));
545
546 let config = ExecutionConfig::new().with_case_insensitive(true);
547 let executor = StatementExecutor::with_config(config);
548
549 let mut parser = Parser::new("SELECT ID FROM test");
551 let stmt = parser.parse().unwrap();
552
553 let result = executor.execute(stmt, &mut context);
554
555 assert!(result.is_ok());
557 }
558}