sql_cli/
non_interactive.rs

1use anyhow::{Context, Result};
2use std::fs;
3use std::io::{self, Write};
4use std::path::Path;
5use std::time::Instant;
6use tracing::{debug, info};
7
8use crate::config::config::Config;
9use crate::data::data_view::DataView;
10use crate::data::datatable::{DataTable, DataValue};
11use crate::data::datatable_loaders::{load_csv_to_datatable, load_json_to_datatable};
12use crate::services::query_execution_service::QueryExecutionService;
13use crate::sql::script_parser::{ScriptParser, ScriptResult};
14
15/// Output format for query results
16#[derive(Debug, Clone)]
17pub enum OutputFormat {
18    Csv,
19    Json,
20    Table,
21    Tsv,
22}
23
24impl OutputFormat {
25    pub fn from_str(s: &str) -> Result<Self> {
26        match s.to_lowercase().as_str() {
27            "csv" => Ok(OutputFormat::Csv),
28            "json" => Ok(OutputFormat::Json),
29            "table" => Ok(OutputFormat::Table),
30            "tsv" => Ok(OutputFormat::Tsv),
31            _ => Err(anyhow::anyhow!(
32                "Invalid output format: {}. Use csv, json, table, or tsv",
33                s
34            )),
35        }
36    }
37}
38
39/// Configuration for non-interactive query execution
40pub struct NonInteractiveConfig {
41    pub data_file: String,
42    pub query: String,
43    pub output_format: OutputFormat,
44    pub output_file: Option<String>,
45    pub case_insensitive: bool,
46    pub auto_hide_empty: bool,
47    pub limit: Option<usize>,
48    pub query_plan: bool,
49    pub script_file: Option<String>, // Path to the script file for relative path resolution
50}
51
52/// Execute a query in non-interactive mode
53pub fn execute_non_interactive(config: NonInteractiveConfig) -> Result<()> {
54    let start_time = Instant::now();
55
56    // Check if query uses DUAL or has no FROM clause
57    use crate::sql::recursive_parser::{Parser, SelectStatement};
58
59    fn check_statement_for_range(stmt: &SelectStatement) -> bool {
60        // Check main query
61        if stmt.from_function.is_some() {
62            return true;
63        }
64
65        // Check if it's DUAL or no FROM
66        if stmt
67            .from_table
68            .as_ref()
69            .is_some_and(|t| t.to_uppercase() == "DUAL")
70        {
71            return true;
72        }
73
74        if stmt.from_table.is_none() && stmt.from_subquery.is_none() && stmt.from_function.is_none()
75        {
76            return true;
77        }
78
79        // Recursively check CTEs
80        for cte in &stmt.ctes {
81            if check_statement_for_range(&cte.query) {
82                return true;
83            }
84        }
85
86        // Check subqueries
87        if let Some(ref subquery) = stmt.from_subquery {
88            if check_statement_for_range(subquery) {
89                return true;
90            }
91        }
92
93        false
94    }
95
96    let mut parser = Parser::new(&config.query);
97    let statement = parser
98        .parse()
99        .map_err(|e| anyhow::anyhow!("Parse error: {}", e))?;
100
101    // 1. Load the data file or create DUAL table
102    let (data_table, _is_dual) =
103        if check_statement_for_range(&statement) || config.data_file.is_empty() {
104            info!("Using DUAL table for expression evaluation");
105            (crate::data::datatable::DataTable::dual(), true)
106        } else {
107            info!("Loading data from: {}", config.data_file);
108            let table = load_data_file(&config.data_file)?;
109            info!(
110                "Loaded {} rows with {} columns",
111                table.row_count(),
112                table.column_count()
113            );
114            (table, false)
115        };
116    let _table_name = data_table.name.clone();
117
118    // 2. Create a DataView from the table
119    let dataview = DataView::new(std::sync::Arc::new(data_table));
120
121    // 3. Execute the query
122    info!("Executing query: {}", config.query);
123
124    // If query_plan is requested, parse and display the AST
125    if config.query_plan {
126        use crate::sql::recursive_parser::Parser;
127        let mut parser = Parser::new(&config.query);
128        match parser.parse() {
129            Ok(statement) => {
130                println!("\n=== QUERY PLAN (AST) ===");
131                println!("{statement:#?}");
132                println!("=== END QUERY PLAN ===\n");
133            }
134            Err(e) => {
135                eprintln!("Failed to parse query for plan: {e}");
136            }
137        }
138    }
139
140    let query_start = Instant::now();
141
142    // Load configuration file to get date notation and other settings
143    let app_config = Config::load().unwrap_or_else(|e| {
144        debug!("Could not load config file: {}. Using defaults.", e);
145        Config::default()
146    });
147
148    // Initialize global config for function registry
149    crate::config::global::init_config(app_config.clone());
150
151    // Use QueryExecutionService with full BehaviorConfig
152    let mut behavior_config = app_config.behavior.clone();
153    debug!(
154        "Using date notation: {}",
155        behavior_config.default_date_notation
156    );
157    // Command line args override config file settings
158    if config.case_insensitive {
159        behavior_config.case_insensitive_default = true;
160    }
161    if config.auto_hide_empty {
162        behavior_config.hide_empty_columns = true;
163    }
164
165    let query_service = QueryExecutionService::with_behavior_config(behavior_config);
166    let result = query_service.execute(&config.query, Some(&dataview), Some(dataview.source()))?;
167
168    let query_time = query_start.elapsed();
169    info!("Query executed in {:?}", query_time);
170    info!(
171        "Result: {} rows, {} columns",
172        result.dataview.row_count(),
173        result.dataview.column_count()
174    );
175
176    // 4. Apply limit if specified
177    let final_view = if let Some(limit) = config.limit {
178        let limited_table = limit_results(&result.dataview, limit)?;
179        DataView::new(std::sync::Arc::new(limited_table))
180    } else {
181        result.dataview
182    };
183
184    // 5. Output the results
185    let output_result = if let Some(ref path) = config.output_file {
186        let mut file = fs::File::create(path)
187            .with_context(|| format!("Failed to create output file: {path}"))?;
188        output_results(&final_view, config.output_format, &mut file)?;
189        info!("Results written to: {}", path);
190        Ok(())
191    } else {
192        output_results(&final_view, config.output_format, &mut io::stdout())?;
193        Ok(())
194    };
195
196    let total_time = start_time.elapsed();
197    debug!("Total execution time: {:?}", total_time);
198
199    // Print stats to stderr so they don't interfere with output
200    if config.output_file.is_none() {
201        eprintln!(
202            "\n# Query completed: {} rows in {:?}",
203            final_view.row_count(),
204            query_time
205        );
206    }
207
208    output_result
209}
210
211/// Execute a script file with multiple SQL statements separated by GO
212pub fn execute_script(config: NonInteractiveConfig) -> Result<()> {
213    let _start_time = Instant::now();
214
215    // Parse the script into individual statements
216    let parser = ScriptParser::new(&config.query);
217    let statements = parser.parse_and_validate()?;
218
219    info!("Found {} statements in script", statements.len());
220
221    // Determine data file to use (command-line overrides script hint)
222    let data_file = if !config.data_file.is_empty() {
223        // Command-line argument takes precedence
224        config.data_file.clone()
225    } else if let Some(hint) = parser.data_file_hint() {
226        // Use data file hint from script
227        info!("Using data file from script hint: {}", hint);
228
229        // Resolve relative paths relative to script file if provided
230        if let Some(script_path) = config.script_file.as_ref() {
231            let script_dir = std::path::Path::new(script_path)
232                .parent()
233                .unwrap_or(std::path::Path::new("."));
234            let hint_path = std::path::Path::new(hint);
235
236            if hint_path.is_relative() {
237                script_dir.join(hint_path).to_string_lossy().to_string()
238            } else {
239                hint.to_string()
240            }
241        } else {
242            hint.to_string()
243        }
244    } else {
245        String::new()
246    };
247
248    // Check if script needs a data file
249    let needs_data_file = if data_file.is_empty() {
250        // Parse statements to check if they use DUAL or RANGE
251        use crate::sql::recursive_parser::Parser;
252
253        let mut needs_file = false;
254        for statement_sql in &statements {
255            let mut parser = Parser::new(statement_sql);
256            match parser.parse() {
257                Ok(stmt) => {
258                    // Check if statement uses DUAL, RANGE, or has no FROM clause
259                    let uses_dual = stmt
260                        .from_table
261                        .as_ref()
262                        .map_or(false, |t| t.eq_ignore_ascii_case("dual"));
263                    let uses_range = stmt.from_function.is_some();
264                    let no_from = stmt.from_table.is_none()
265                        && stmt.from_subquery.is_none()
266                        && stmt.from_function.is_none();
267
268                    // Check CTEs for RANGE usage
269                    let cte_has_range = stmt.ctes.iter().any(|cte| {
270                        cte.query.from_function.is_some()
271                            || cte
272                                .query
273                                .from_table
274                                .as_ref()
275                                .map_or(false, |t| t.eq_ignore_ascii_case("dual"))
276                    });
277
278                    if !uses_dual && !uses_range && !no_from && !cte_has_range {
279                        needs_file = true;
280                        break;
281                    }
282                }
283                Err(_) => {
284                    // If we can't parse, assume it needs a data file
285                    needs_file = true;
286                    break;
287                }
288            }
289        }
290        needs_file
291    } else {
292        // Data file was specified, so we'll use it
293        false
294    };
295
296    // Load the data file once (or use DUAL)
297    let (data_table, _is_dual) = if data_file.is_empty() {
298        if needs_data_file {
299            anyhow::bail!(
300                "Script requires a data file. Either:\n\
301                1. Provide a data file: sql-cli data.csv -f script.sql\n\
302                2. Add a data hint to your script: -- #!data: path/to/data.csv"
303            );
304        } else {
305            // Script doesn't need a data file, use DUAL
306            info!("Using DUAL table for script execution");
307            (DataTable::dual(), true)
308        }
309    } else {
310        // Check if file exists before trying to load
311        if !std::path::Path::new(&data_file).exists() {
312            anyhow::bail!(
313                "Data file not found: {}\n\
314                Please check the path is correct",
315                data_file
316            );
317        }
318
319        info!("Loading data from: {}", data_file);
320        let table = load_data_file(&data_file)?;
321        info!(
322            "Loaded {} rows with {} columns",
323            table.row_count(),
324            table.column_count()
325        );
326        (table, false)
327    };
328
329    // Track script results
330    let mut script_result = ScriptResult::new();
331    let mut output = Vec::new();
332
333    // Create Arc<DataTable> once for all statements - avoids expensive cloning
334    let arc_data_table = std::sync::Arc::new(data_table);
335
336    // Execute each statement
337    for (idx, statement) in statements.iter().enumerate() {
338        let statement_num = idx + 1;
339        let stmt_start = Instant::now();
340
341        // Print separator for table format
342        if matches!(config.output_format, OutputFormat::Table) {
343            if idx > 0 {
344                output.push(String::new()); // Empty line between queries
345            }
346            output.push(format!("-- Query {} --", statement_num));
347        }
348
349        // Create a fresh DataView for each statement (reuses the Arc)
350        let dataview = DataView::new(arc_data_table.clone());
351
352        // Execute the statement
353        let service = QueryExecutionService::new(config.case_insensitive, config.auto_hide_empty);
354        match service.execute(statement, Some(&dataview), None) {
355            Ok(result) => {
356                let exec_time = stmt_start.elapsed().as_secs_f64() * 1000.0;
357                let final_view = result.dataview;
358
359                // Format the output based on the output format
360                let mut statement_output = Vec::new();
361                match config.output_format {
362                    OutputFormat::Csv => {
363                        output_csv(&final_view, &mut statement_output, ',')?;
364                    }
365                    OutputFormat::Json => {
366                        output_json(&final_view, &mut statement_output)?;
367                    }
368                    OutputFormat::Table => {
369                        output_table(&final_view, &mut statement_output)?;
370                        writeln!(
371                            &mut statement_output,
372                            "Query completed: {} rows in {:.2}ms",
373                            final_view.row_count(),
374                            exec_time
375                        )?;
376                    }
377                    OutputFormat::Tsv => {
378                        output_csv(&final_view, &mut statement_output, '\t')?;
379                    }
380                }
381
382                // Add to overall output
383                output.extend(
384                    String::from_utf8_lossy(&statement_output)
385                        .lines()
386                        .map(String::from),
387                );
388
389                script_result.add_success(
390                    statement_num,
391                    statement.clone(),
392                    final_view.row_count(),
393                    exec_time,
394                );
395            }
396            Err(e) => {
397                let exec_time = stmt_start.elapsed().as_secs_f64() * 1000.0;
398                let error_msg = format!("Query {} failed: {}", statement_num, e);
399
400                if matches!(config.output_format, OutputFormat::Table) {
401                    output.push(error_msg.clone());
402                }
403
404                script_result.add_failure(
405                    statement_num,
406                    statement.clone(),
407                    e.to_string(),
408                    exec_time,
409                );
410
411                // Continue to next statement (don't stop on error)
412            }
413        }
414    }
415
416    // Write output
417    if let Some(ref output_file) = config.output_file {
418        let mut file = fs::File::create(output_file)?;
419        for line in &output {
420            writeln!(file, "{}", line)?;
421        }
422        info!("Results written to: {}", output_file);
423    } else {
424        for line in &output {
425            println!("{}", line);
426        }
427    }
428
429    // Print summary if in table mode
430    if matches!(config.output_format, OutputFormat::Table) {
431        println!("\n=== Script Summary ===");
432        println!("Total statements: {}", script_result.total_statements);
433        println!("Successful: {}", script_result.successful_statements);
434        println!("Failed: {}", script_result.failed_statements);
435        println!(
436            "Total execution time: {:.2}ms",
437            script_result.total_execution_time_ms
438        );
439    }
440
441    if !script_result.all_successful() {
442        return Err(anyhow::anyhow!(
443            "{} of {} statements failed",
444            script_result.failed_statements,
445            script_result.total_statements
446        ));
447    }
448
449    Ok(())
450}
451
452/// Load a data file (CSV or JSON) into a `DataTable`
453fn load_data_file(path: &str) -> Result<DataTable> {
454    let path = Path::new(path);
455
456    if !path.exists() {
457        return Err(anyhow::anyhow!("File not found: {}", path.display()));
458    }
459
460    // Determine file type by extension
461    let extension = path
462        .extension()
463        .and_then(|ext| ext.to_str())
464        .map(str::to_lowercase)
465        .unwrap_or_default();
466
467    let table_name = path
468        .file_stem()
469        .and_then(|stem| stem.to_str())
470        .unwrap_or("data")
471        .to_string();
472
473    match extension.as_str() {
474        "csv" => load_csv_to_datatable(path, &table_name)
475            .with_context(|| format!("Failed to load CSV file: {}", path.display())),
476        "json" => load_json_to_datatable(path, &table_name)
477            .with_context(|| format!("Failed to load JSON file: {}", path.display())),
478        _ => Err(anyhow::anyhow!(
479            "Unsupported file type: {}. Use .csv or .json",
480            extension
481        )),
482    }
483}
484
485/// Limit the number of rows in results
486fn limit_results(dataview: &DataView, limit: usize) -> Result<DataTable> {
487    let source = dataview.source();
488    let mut limited_table = DataTable::new(&source.name);
489
490    // Copy columns
491    for col in &source.columns {
492        limited_table.add_column(col.clone());
493    }
494
495    // Copy limited rows
496    let rows_to_copy = dataview.row_count().min(limit);
497    for i in 0..rows_to_copy {
498        if let Some(row) = dataview.get_row(i) {
499            limited_table.add_row(row.clone());
500        }
501    }
502
503    Ok(limited_table)
504}
505
506/// Output query results in the specified format
507fn output_results<W: Write>(
508    dataview: &DataView,
509    format: OutputFormat,
510    writer: &mut W,
511) -> Result<()> {
512    match format {
513        OutputFormat::Csv => output_csv(dataview, writer, ','),
514        OutputFormat::Tsv => output_csv(dataview, writer, '\t'),
515        OutputFormat::Json => output_json(dataview, writer),
516        OutputFormat::Table => output_table(dataview, writer),
517    }
518}
519
520/// Output results as CSV/TSV
521fn output_csv<W: Write>(dataview: &DataView, writer: &mut W, delimiter: char) -> Result<()> {
522    // Write headers
523    let columns = dataview.column_names();
524    for (i, col) in columns.iter().enumerate() {
525        if i > 0 {
526            write!(writer, "{delimiter}")?;
527        }
528        write!(writer, "{}", escape_csv_field(col, delimiter))?;
529    }
530    writeln!(writer)?;
531
532    // Write rows
533    for row_idx in 0..dataview.row_count() {
534        if let Some(row) = dataview.get_row(row_idx) {
535            for (i, value) in row.values.iter().enumerate() {
536                if i > 0 {
537                    write!(writer, "{delimiter}")?;
538                }
539                write!(
540                    writer,
541                    "{}",
542                    escape_csv_field(&format_value(value), delimiter)
543                )?;
544            }
545            writeln!(writer)?;
546        }
547    }
548
549    Ok(())
550}
551
552/// Output results as JSON
553fn output_json<W: Write>(dataview: &DataView, writer: &mut W) -> Result<()> {
554    let columns = dataview.column_names();
555    let mut rows = Vec::new();
556
557    for row_idx in 0..dataview.row_count() {
558        if let Some(row) = dataview.get_row(row_idx) {
559            let mut json_row = serde_json::Map::new();
560            for (col_idx, value) in row.values.iter().enumerate() {
561                if col_idx < columns.len() {
562                    json_row.insert(columns[col_idx].clone(), value_to_json(value));
563                }
564            }
565            rows.push(serde_json::Value::Object(json_row));
566        }
567    }
568
569    let json = serde_json::to_string_pretty(&rows)?;
570    writeln!(writer, "{json}")?;
571
572    Ok(())
573}
574
575/// Output results as an ASCII table
576fn output_table<W: Write>(dataview: &DataView, writer: &mut W) -> Result<()> {
577    let columns = dataview.column_names();
578
579    // Calculate column widths
580    let mut widths = vec![0; columns.len()];
581    for (i, col) in columns.iter().enumerate() {
582        widths[i] = col.len();
583    }
584
585    // Check first 100 rows for width calculation
586    let sample_size = dataview.row_count().min(100);
587    for row_idx in 0..sample_size {
588        if let Some(row) = dataview.get_row(row_idx) {
589            for (i, value) in row.values.iter().enumerate() {
590                if i < widths.len() {
591                    let value_str = format_value(value);
592                    widths[i] = widths[i].max(value_str.len());
593                }
594            }
595        }
596    }
597
598    // Limit column widths to 50 characters
599    for width in &mut widths {
600        *width = (*width).min(50);
601    }
602
603    // Print header separator
604    write!(writer, "+")?;
605    for width in &widths {
606        write!(writer, "-{}-+", "-".repeat(*width))?;
607    }
608    writeln!(writer)?;
609
610    // Print headers
611    write!(writer, "|")?;
612    for (i, col) in columns.iter().enumerate() {
613        write!(writer, " {:^width$} |", col, width = widths[i])?;
614    }
615    writeln!(writer)?;
616
617    // Print header separator
618    write!(writer, "+")?;
619    for width in &widths {
620        write!(writer, "-{}-+", "-".repeat(*width))?;
621    }
622    writeln!(writer)?;
623
624    // Print rows
625    for row_idx in 0..dataview.row_count() {
626        if let Some(row) = dataview.get_row(row_idx) {
627            write!(writer, "|")?;
628            for (i, value) in row.values.iter().enumerate() {
629                if i < widths.len() {
630                    let value_str = format_value(value);
631                    let truncated = if value_str.len() > widths[i] {
632                        format!("{}...", &value_str[..widths[i] - 3])
633                    } else {
634                        value_str
635                    };
636                    write!(writer, " {:<width$} |", truncated, width = widths[i])?;
637                }
638            }
639            writeln!(writer)?;
640        }
641    }
642
643    // Print bottom separator
644    write!(writer, "+")?;
645    for width in &widths {
646        write!(writer, "-{}-+", "-".repeat(*width))?;
647    }
648    writeln!(writer)?;
649
650    Ok(())
651}
652
653/// Format a `DataValue` for display
654fn format_value(value: &DataValue) -> String {
655    match value {
656        DataValue::Null => String::new(),
657        DataValue::Integer(i) => i.to_string(),
658        DataValue::Float(f) => f.to_string(),
659        DataValue::String(s) => s.clone(),
660        DataValue::InternedString(s) => s.to_string(),
661        DataValue::Boolean(b) => b.to_string(),
662        DataValue::DateTime(dt) => dt.to_string(),
663    }
664}
665
666/// Convert `DataValue` to JSON
667fn value_to_json(value: &DataValue) -> serde_json::Value {
668    match value {
669        DataValue::Null => serde_json::Value::Null,
670        DataValue::Integer(i) => serde_json::Value::Number((*i).into()),
671        DataValue::Float(f) => {
672            if let Some(n) = serde_json::Number::from_f64(*f) {
673                serde_json::Value::Number(n)
674            } else {
675                serde_json::Value::Null
676            }
677        }
678        DataValue::String(s) => serde_json::Value::String(s.clone()),
679        DataValue::InternedString(s) => serde_json::Value::String(s.to_string()),
680        DataValue::Boolean(b) => serde_json::Value::Bool(*b),
681        DataValue::DateTime(dt) => serde_json::Value::String(dt.to_string()),
682    }
683}
684
685/// Escape a CSV field if it contains special characters
686fn escape_csv_field(field: &str, delimiter: char) -> String {
687    if field.contains(delimiter)
688        || field.contains('"')
689        || field.contains('\n')
690        || field.contains('\r')
691    {
692        format!("\"{}\"", field.replace('"', "\"\""))
693    } else {
694        field.to_string()
695    }
696}