sql_cli/
non_interactive.rs

1use anyhow::{Context, Result};
2use std::fs;
3use std::io::{self, Write};
4use std::path::Path;
5use std::time::Instant;
6use tracing::{debug, info};
7
8use crate::config::config::Config;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataTable, DataValue};
11use crate::data::datatable_loaders::{load_csv_to_datatable, load_json_to_datatable};
12use crate::services::query_execution_service::QueryExecutionService;
13use crate::sql::script_parser::{ScriptParser, ScriptResult};
14
15/// Output format for query results
16#[derive(Debug, Clone)]
17pub enum OutputFormat {
18    Csv,
19    Json,
20    Table,
21    Tsv,
22}
23
24impl OutputFormat {
25    pub fn from_str(s: &str) -> Result<Self> {
26        match s.to_lowercase().as_str() {
27            "csv" => Ok(OutputFormat::Csv),
28            "json" => Ok(OutputFormat::Json),
29            "table" => Ok(OutputFormat::Table),
30            "tsv" => Ok(OutputFormat::Tsv),
31            _ => Err(anyhow::anyhow!(
32                "Invalid output format: {}. Use csv, json, table, or tsv",
33                s
34            )),
35        }
36    }
37}
38
39/// Configuration for non-interactive query execution
40pub struct NonInteractiveConfig {
41    pub data_file: String,
42    pub query: String,
43    pub output_format: OutputFormat,
44    pub output_file: Option<String>,
45    pub case_insensitive: bool,
46    pub auto_hide_empty: bool,
47    pub limit: Option<usize>,
48    pub query_plan: bool,
49}
50
51/// Execute a query in non-interactive mode
52pub fn execute_non_interactive(config: NonInteractiveConfig) -> Result<()> {
53    let start_time = Instant::now();
54
55    // Check if query uses DUAL or has no FROM clause
56    use crate::sql::recursive_parser::{Parser, SelectStatement};
57
58    fn check_statement_for_range(stmt: &SelectStatement) -> bool {
59        // Check main query
60        if stmt.from_function.is_some() {
61            return true;
62        }
63
64        // Check if it's DUAL or no FROM
65        if stmt
66            .from_table
67            .as_ref()
68            .is_some_and(|t| t.to_uppercase() == "DUAL")
69        {
70            return true;
71        }
72
73        if stmt.from_table.is_none() && stmt.from_subquery.is_none() && stmt.from_function.is_none()
74        {
75            return true;
76        }
77
78        // Recursively check CTEs
79        for cte in &stmt.ctes {
80            if check_statement_for_range(&cte.query) {
81                return true;
82            }
83        }
84
85        // Check subqueries
86        if let Some(ref subquery) = stmt.from_subquery {
87            if check_statement_for_range(subquery) {
88                return true;
89            }
90        }
91
92        false
93    }
94
95    let mut parser = Parser::new(&config.query);
96    let statement = parser
97        .parse()
98        .map_err(|e| anyhow::anyhow!("Parse error: {}", e))?;
99
100    // 1. Load the data file or create DUAL table
101    let (data_table, is_dual) =
102        if check_statement_for_range(&statement) || config.data_file.is_empty() {
103            info!("Using DUAL table for expression evaluation");
104            (crate::data::datatable::DataTable::dual(), true)
105        } else {
106            info!("Loading data from: {}", config.data_file);
107            let table = load_data_file(&config.data_file)?;
108            info!(
109                "Loaded {} rows with {} columns",
110                table.row_count(),
111                table.column_count()
112            );
113            (table, false)
114        };
115    let table_name = data_table.name.clone();
116
117    // 2. Create a DataView from the table
118    let dataview = DataView::new(std::sync::Arc::new(data_table));
119
120    // 3. Execute the query
121    info!("Executing query: {}", config.query);
122
123    // If query_plan is requested, parse and display the AST
124    if config.query_plan {
125        use crate::sql::recursive_parser::Parser;
126        let mut parser = Parser::new(&config.query);
127        match parser.parse() {
128            Ok(statement) => {
129                println!("\n=== QUERY PLAN (AST) ===");
130                println!("{statement:#?}");
131                println!("=== END QUERY PLAN ===\n");
132            }
133            Err(e) => {
134                eprintln!("Failed to parse query for plan: {e}");
135            }
136        }
137    }
138
139    let query_start = Instant::now();
140
141    // Load configuration file to get date notation and other settings
142    let app_config = Config::load().unwrap_or_else(|e| {
143        debug!("Could not load config file: {}. Using defaults.", e);
144        Config::default()
145    });
146
147    // Initialize global config for function registry
148    crate::config::global::init_config(app_config.clone());
149
150    // Use QueryExecutionService with full BehaviorConfig
151    let mut behavior_config = app_config.behavior.clone();
152    debug!(
153        "Using date notation: {}",
154        behavior_config.default_date_notation
155    );
156    // Command line args override config file settings
157    if config.case_insensitive {
158        behavior_config.case_insensitive_default = true;
159    }
160    if config.auto_hide_empty {
161        behavior_config.hide_empty_columns = true;
162    }
163
164    let query_service = QueryExecutionService::with_behavior_config(behavior_config);
165    let result = query_service.execute(&config.query, Some(&dataview), Some(dataview.source()))?;
166
167    let query_time = query_start.elapsed();
168    info!("Query executed in {:?}", query_time);
169    info!(
170        "Result: {} rows, {} columns",
171        result.dataview.row_count(),
172        result.dataview.column_count()
173    );
174
175    // 4. Apply limit if specified
176    let final_view = if let Some(limit) = config.limit {
177        let limited_table = limit_results(&result.dataview, limit)?;
178        DataView::new(std::sync::Arc::new(limited_table))
179    } else {
180        result.dataview
181    };
182
183    // 5. Output the results
184    let output_result = if let Some(ref path) = config.output_file {
185        let mut file = fs::File::create(path)
186            .with_context(|| format!("Failed to create output file: {path}"))?;
187        output_results(&final_view, config.output_format, &mut file)?;
188        info!("Results written to: {}", path);
189        Ok(())
190    } else {
191        output_results(&final_view, config.output_format, &mut io::stdout())?;
192        Ok(())
193    };
194
195    let total_time = start_time.elapsed();
196    debug!("Total execution time: {:?}", total_time);
197
198    // Print stats to stderr so they don't interfere with output
199    if config.output_file.is_none() {
200        eprintln!(
201            "\n# Query completed: {} rows in {:?}",
202            final_view.row_count(),
203            query_time
204        );
205    }
206
207    output_result
208}
209
210/// Execute a script file with multiple SQL statements separated by GO
211pub fn execute_script(config: NonInteractiveConfig) -> Result<()> {
212    let start_time = Instant::now();
213
214    // Parse the script into individual statements
215    let parser = ScriptParser::new(&config.query);
216    let statements = parser.parse_and_validate()?;
217
218    info!("Found {} statements in script", statements.len());
219
220    // Load the data file once (or use DUAL)
221    let (data_table, is_dual) = if config.data_file.is_empty() {
222        info!("Using DUAL table for script execution");
223        (DataTable::dual(), true)
224    } else {
225        info!("Loading data from: {}", config.data_file);
226        let table = load_data_file(&config.data_file)?;
227        info!(
228            "Loaded {} rows with {} columns",
229            table.row_count(),
230            table.column_count()
231        );
232        (table, false)
233    };
234
235    // Track script results
236    let mut script_result = ScriptResult::new();
237    let mut output = Vec::new();
238
239    // Execute each statement
240    for (idx, statement) in statements.iter().enumerate() {
241        let statement_num = idx + 1;
242        let stmt_start = Instant::now();
243
244        // Print separator for table format
245        if matches!(config.output_format, OutputFormat::Table) {
246            if idx > 0 {
247                output.push(String::new()); // Empty line between queries
248            }
249            output.push(format!("-- Query {} --", statement_num));
250        }
251
252        // Create a fresh DataView for each statement
253        let dataview = DataView::new(std::sync::Arc::new(data_table.clone()));
254
255        // Execute the statement
256        let service = QueryExecutionService::new(config.case_insensitive, config.auto_hide_empty);
257        match service.execute(statement, Some(&dataview), None) {
258            Ok(result) => {
259                let exec_time = stmt_start.elapsed().as_secs_f64() * 1000.0;
260                let final_view = result.dataview;
261
262                // Format the output based on the output format
263                let mut statement_output = Vec::new();
264                match config.output_format {
265                    OutputFormat::Csv => {
266                        output_csv(&final_view, &mut statement_output, ',')?;
267                    }
268                    OutputFormat::Json => {
269                        output_json(&final_view, &mut statement_output)?;
270                    }
271                    OutputFormat::Table => {
272                        output_table(&final_view, &mut statement_output)?;
273                        writeln!(
274                            &mut statement_output,
275                            "Query completed: {} rows in {:.2}ms",
276                            final_view.row_count(),
277                            exec_time
278                        )?;
279                    }
280                    OutputFormat::Tsv => {
281                        output_csv(&final_view, &mut statement_output, '\t')?;
282                    }
283                }
284
285                // Add to overall output
286                output.extend(
287                    String::from_utf8_lossy(&statement_output)
288                        .lines()
289                        .map(String::from),
290                );
291
292                script_result.add_success(
293                    statement_num,
294                    statement.clone(),
295                    final_view.row_count(),
296                    exec_time,
297                );
298            }
299            Err(e) => {
300                let exec_time = stmt_start.elapsed().as_secs_f64() * 1000.0;
301                let error_msg = format!("Query {} failed: {}", statement_num, e);
302
303                if matches!(config.output_format, OutputFormat::Table) {
304                    output.push(error_msg.clone());
305                }
306
307                script_result.add_failure(
308                    statement_num,
309                    statement.clone(),
310                    e.to_string(),
311                    exec_time,
312                );
313
314                // Continue to next statement (don't stop on error)
315            }
316        }
317    }
318
319    // Write output
320    if let Some(ref output_file) = config.output_file {
321        let mut file = fs::File::create(output_file)?;
322        for line in &output {
323            writeln!(file, "{}", line)?;
324        }
325        info!("Results written to: {}", output_file);
326    } else {
327        for line in &output {
328            println!("{}", line);
329        }
330    }
331
332    // Print summary if in table mode
333    if matches!(config.output_format, OutputFormat::Table) {
334        println!("\n=== Script Summary ===");
335        println!("Total statements: {}", script_result.total_statements);
336        println!("Successful: {}", script_result.successful_statements);
337        println!("Failed: {}", script_result.failed_statements);
338        println!(
339            "Total execution time: {:.2}ms",
340            script_result.total_execution_time_ms
341        );
342    }
343
344    if !script_result.all_successful() {
345        return Err(anyhow::anyhow!(
346            "{} of {} statements failed",
347            script_result.failed_statements,
348            script_result.total_statements
349        ));
350    }
351
352    Ok(())
353}
354
355/// Load a data file (CSV or JSON) into a `DataTable`
356fn load_data_file(path: &str) -> Result<DataTable> {
357    let path = Path::new(path);
358
359    if !path.exists() {
360        return Err(anyhow::anyhow!("File not found: {}", path.display()));
361    }
362
363    // Determine file type by extension
364    let extension = path
365        .extension()
366        .and_then(|ext| ext.to_str())
367        .map(str::to_lowercase)
368        .unwrap_or_default();
369
370    let table_name = path
371        .file_stem()
372        .and_then(|stem| stem.to_str())
373        .unwrap_or("data")
374        .to_string();
375
376    match extension.as_str() {
377        "csv" => load_csv_to_datatable(path, &table_name)
378            .with_context(|| format!("Failed to load CSV file: {}", path.display())),
379        "json" => load_json_to_datatable(path, &table_name)
380            .with_context(|| format!("Failed to load JSON file: {}", path.display())),
381        _ => Err(anyhow::anyhow!(
382            "Unsupported file type: {}. Use .csv or .json",
383            extension
384        )),
385    }
386}
387
388/// Limit the number of rows in results
389fn limit_results(dataview: &DataView, limit: usize) -> Result<DataTable> {
390    let source = dataview.source();
391    let mut limited_table = DataTable::new(&source.name);
392
393    // Copy columns
394    for col in &source.columns {
395        limited_table.add_column(col.clone());
396    }
397
398    // Copy limited rows
399    let rows_to_copy = dataview.row_count().min(limit);
400    for i in 0..rows_to_copy {
401        if let Some(row) = dataview.get_row(i) {
402            limited_table.add_row(row.clone());
403        }
404    }
405
406    Ok(limited_table)
407}
408
409/// Output query results in the specified format
410fn output_results<W: Write>(
411    dataview: &DataView,
412    format: OutputFormat,
413    writer: &mut W,
414) -> Result<()> {
415    match format {
416        OutputFormat::Csv => output_csv(dataview, writer, ','),
417        OutputFormat::Tsv => output_csv(dataview, writer, '\t'),
418        OutputFormat::Json => output_json(dataview, writer),
419        OutputFormat::Table => output_table(dataview, writer),
420    }
421}
422
423/// Output results as CSV/TSV
424fn output_csv<W: Write>(dataview: &DataView, writer: &mut W, delimiter: char) -> Result<()> {
425    // Write headers
426    let columns = dataview.column_names();
427    for (i, col) in columns.iter().enumerate() {
428        if i > 0 {
429            write!(writer, "{delimiter}")?;
430        }
431        write!(writer, "{}", escape_csv_field(col, delimiter))?;
432    }
433    writeln!(writer)?;
434
435    // Write rows
436    for row_idx in 0..dataview.row_count() {
437        if let Some(row) = dataview.get_row(row_idx) {
438            for (i, value) in row.values.iter().enumerate() {
439                if i > 0 {
440                    write!(writer, "{delimiter}")?;
441                }
442                write!(
443                    writer,
444                    "{}",
445                    escape_csv_field(&format_value(value), delimiter)
446                )?;
447            }
448            writeln!(writer)?;
449        }
450    }
451
452    Ok(())
453}
454
455/// Output results as JSON
456fn output_json<W: Write>(dataview: &DataView, writer: &mut W) -> Result<()> {
457    let columns = dataview.column_names();
458    let mut rows = Vec::new();
459
460    for row_idx in 0..dataview.row_count() {
461        if let Some(row) = dataview.get_row(row_idx) {
462            let mut json_row = serde_json::Map::new();
463            for (col_idx, value) in row.values.iter().enumerate() {
464                if col_idx < columns.len() {
465                    json_row.insert(columns[col_idx].clone(), value_to_json(value));
466                }
467            }
468            rows.push(serde_json::Value::Object(json_row));
469        }
470    }
471
472    let json = serde_json::to_string_pretty(&rows)?;
473    writeln!(writer, "{json}")?;
474
475    Ok(())
476}
477
478/// Output results as an ASCII table
479fn output_table<W: Write>(dataview: &DataView, writer: &mut W) -> Result<()> {
480    let columns = dataview.column_names();
481
482    // Calculate column widths
483    let mut widths = vec![0; columns.len()];
484    for (i, col) in columns.iter().enumerate() {
485        widths[i] = col.len();
486    }
487
488    // Check first 100 rows for width calculation
489    let sample_size = dataview.row_count().min(100);
490    for row_idx in 0..sample_size {
491        if let Some(row) = dataview.get_row(row_idx) {
492            for (i, value) in row.values.iter().enumerate() {
493                if i < widths.len() {
494                    let value_str = format_value(value);
495                    widths[i] = widths[i].max(value_str.len());
496                }
497            }
498        }
499    }
500
501    // Limit column widths to 50 characters
502    for width in &mut widths {
503        *width = (*width).min(50);
504    }
505
506    // Print header separator
507    write!(writer, "+")?;
508    for width in &widths {
509        write!(writer, "-{}-+", "-".repeat(*width))?;
510    }
511    writeln!(writer)?;
512
513    // Print headers
514    write!(writer, "|")?;
515    for (i, col) in columns.iter().enumerate() {
516        write!(writer, " {:^width$} |", col, width = widths[i])?;
517    }
518    writeln!(writer)?;
519
520    // Print header separator
521    write!(writer, "+")?;
522    for width in &widths {
523        write!(writer, "-{}-+", "-".repeat(*width))?;
524    }
525    writeln!(writer)?;
526
527    // Print rows
528    for row_idx in 0..dataview.row_count() {
529        if let Some(row) = dataview.get_row(row_idx) {
530            write!(writer, "|")?;
531            for (i, value) in row.values.iter().enumerate() {
532                if i < widths.len() {
533                    let value_str = format_value(value);
534                    let truncated = if value_str.len() > widths[i] {
535                        format!("{}...", &value_str[..widths[i] - 3])
536                    } else {
537                        value_str
538                    };
539                    write!(writer, " {:<width$} |", truncated, width = widths[i])?;
540                }
541            }
542            writeln!(writer)?;
543        }
544    }
545
546    // Print bottom separator
547    write!(writer, "+")?;
548    for width in &widths {
549        write!(writer, "-{}-+", "-".repeat(*width))?;
550    }
551    writeln!(writer)?;
552
553    Ok(())
554}
555
556/// Format a `DataValue` for display
557fn format_value(value: &DataValue) -> String {
558    match value {
559        DataValue::Null => String::new(),
560        DataValue::Integer(i) => i.to_string(),
561        DataValue::Float(f) => f.to_string(),
562        DataValue::String(s) => s.clone(),
563        DataValue::InternedString(s) => s.to_string(),
564        DataValue::Boolean(b) => b.to_string(),
565        DataValue::DateTime(dt) => dt.to_string(),
566    }
567}
568
569/// Convert `DataValue` to JSON
570fn value_to_json(value: &DataValue) -> serde_json::Value {
571    match value {
572        DataValue::Null => serde_json::Value::Null,
573        DataValue::Integer(i) => serde_json::Value::Number((*i).into()),
574        DataValue::Float(f) => {
575            if let Some(n) = serde_json::Number::from_f64(*f) {
576                serde_json::Value::Number(n)
577            } else {
578                serde_json::Value::Null
579            }
580        }
581        DataValue::String(s) => serde_json::Value::String(s.clone()),
582        DataValue::InternedString(s) => serde_json::Value::String(s.to_string()),
583        DataValue::Boolean(b) => serde_json::Value::Bool(*b),
584        DataValue::DateTime(dt) => serde_json::Value::String(dt.to_string()),
585    }
586}
587
588/// Escape a CSV field if it contains special characters
589fn escape_csv_field(field: &str, delimiter: char) -> String {
590    if field.contains(delimiter)
591        || field.contains('"')
592        || field.contains('\n')
593        || field.contains('\r')
594    {
595        format!("\"{}\"", field.replace('"', "\"\""))
596    } else {
597        field.to_string()
598    }
599}