Skip to main content

sqry_cli/commands/
batch.rs

1//! Execute multiple semantic queries from a batch file using the unified graph.
2//!
3//! The `sqry batch` command loads queries from a file and executes them using
4//! the unified graph format (`.sqry/graph/snapshot.sqry`). Supports parallel
5//! execution and multiple output formats.
6
7use crate::args::{BatchFormat, Cli};
8use crate::commands::query::create_executor_with_plugins_for_cli;
9use crate::output::{DisplaySymbol, JsonSymbol};
10use anyhow::{Context, Result, bail};
11use rayon::prelude::*;
12use serde::Serialize;
13use sqry_core::query::results::QueryResults;
14use std::fs::File;
15use std::io::{BufRead, BufReader, BufWriter, Write};
16use std::path::{Path, PathBuf};
17use std::time::{Duration, Instant};
18
19const DEFAULT_BATCH_LIMIT: usize = 1_000;
20
21/// Run the batch command using the provided CLI configuration.
22///
23/// # Errors
24/// Returns an error if queries cannot be loaded, the index is missing,
25/// or any query execution fails.
26pub fn run_batch(
27    cli: &Cli,
28    path: &str,
29    queries_path: &Path,
30    output: BatchFormat,
31    output_file: Option<&Path>,
32    continue_on_error: bool,
33    stats: bool,
34    sequential: bool,
35) -> Result<()> {
36    let workspace = PathBuf::from(path);
37    ensure_index_exists(&workspace)?;
38
39    let queries = load_queries(queries_path)
40        .with_context(|| format!("failed to read queries from {}", queries_path.display()))?;
41    if queries.is_empty() {
42        bail!(
43            "no queries found in {} (ensure file contains non-empty, non-comment lines)",
44            queries_path.display()
45        );
46    }
47
48    let load_start = Instant::now();
49    let executor = create_executor_with_plugins_for_cli(cli, &workspace)?;
50    let preload_elapsed = load_start.elapsed();
51
52    let should_capture_results =
53        !cli.count && matches!(output, BatchFormat::Json | BatchFormat::Jsonl);
54    let limit = cli.limit.unwrap_or(DEFAULT_BATCH_LIMIT);
55
56    let outcomes = execute_queries(
57        &executor,
58        &workspace,
59        &queries,
60        limit,
61        should_capture_results,
62        continue_on_error,
63        sequential,
64    )?;
65
66    let summary = BatchSummary::from_outcomes(&outcomes);
67
68    let render_options = RenderOutputOptions {
69        output,
70        workspace: &workspace,
71        preload_elapsed,
72        summary: &summary,
73        should_capture_results,
74        limit,
75        stats,
76    };
77    let rendered = render_output(&outcomes, &render_options)?;
78
79    write_output(output_file, &rendered)?;
80
81    Ok(())
82}
83
84fn ensure_index_exists(path: &Path) -> Result<()> {
85    use sqry_core::graph::unified::persistence::GraphStorage;
86
87    let storage = GraphStorage::new(path);
88    if storage.exists() {
89        return Ok(());
90    }
91
92    bail!(
93        "no index found at {}. Run `sqry index {}` first.",
94        path.display(),
95        path.display()
96    );
97}
98
99fn load_queries(path: &Path) -> Result<Vec<String>> {
100    let file = File::open(path)
101        .with_context(|| format!("failed to open queries file {}", path.display()))?;
102    let reader = BufReader::new(file);
103
104    let mut queries = Vec::new();
105    for line in reader.lines() {
106        let line = line.context("failed to read query line")?;
107        let trimmed = line.trim();
108        if trimmed.is_empty() || trimmed.starts_with('#') {
109            continue;
110        }
111        queries.push(trimmed.to_string());
112    }
113
114    Ok(queries)
115}
116
117/// Convert `QueryResults` to `Vec<DisplaySymbol>` for display purposes.
118fn query_results_to_display_symbols(results: &QueryResults) -> Vec<DisplaySymbol> {
119    results
120        .iter()
121        .map(|m| DisplaySymbol::from_query_match(&m))
122        .collect()
123}
124
125struct QueryExecution {
126    position: usize,
127    query: String,
128    result: Result<Vec<DisplaySymbol>>,
129    elapsed: Duration,
130}
131
132fn execute_single_query(
133    executor: &sqry_core::query::QueryExecutor,
134    workspace: &Path,
135    idx: usize,
136    query: &str,
137) -> QueryExecution {
138    let position = idx + 1;
139    let start = Instant::now();
140    let query_results = executor.execute_on_graph(query, workspace);
141    let elapsed = start.elapsed();
142
143    // Convert QueryResults to Vec<DisplaySymbol> for batch processing
144    let result = query_results.map(|qr| query_results_to_display_symbols(&qr));
145
146    QueryExecution {
147        position,
148        query: query.to_string(),
149        result,
150        elapsed,
151    }
152}
153
154fn process_query_result(
155    execution: QueryExecution,
156    total_queries: usize,
157    limit: usize,
158    should_capture_results: bool,
159    continue_on_error: bool,
160) -> Result<QueryOutcome> {
161    let QueryExecution {
162        position,
163        query,
164        result,
165        elapsed,
166    } = execution;
167
168    eprintln!("[{position}/{total_queries}] {query}");
169
170    match result {
171        Ok(mut symbols) => {
172            let total_matches = symbols.len();
173
174            if limit > 0 && limit < symbols.len() {
175                symbols.truncate(limit);
176            }
177
178            eprintln!(
179                "[{position}/{total_queries}] ok: {} results in {}ms",
180                total_matches,
181                elapsed.as_millis()
182            );
183
184            let displayed_matches = symbols.len();
185            let captured_results = if should_capture_results {
186                Some(symbols)
187            } else {
188                None
189            };
190
191            Ok(QueryOutcome::Success(BatchEntry {
192                position,
193                query,
194                elapsed,
195                total_matches,
196                displayed_matches,
197                results: captured_results,
198            }))
199        }
200        Err(err) => {
201            let message = err.to_string();
202            eprintln!("[{position}/{total_queries}] error: {message}");
203
204            if continue_on_error {
205                Ok(QueryOutcome::Failure(BatchFailedEntry {
206                    position,
207                    query,
208                    error: message,
209                }))
210            } else {
211                Err(err).context(format!("failed to execute query \"{query}\""))
212            }
213        }
214    }
215}
216
217fn execute_queries(
218    executor: &sqry_core::query::QueryExecutor,
219    workspace: &Path,
220    queries: &[String],
221    limit: usize,
222    should_capture_results: bool,
223    continue_on_error: bool,
224    sequential: bool,
225) -> Result<Vec<QueryOutcome>> {
226    // Execute queries in parallel (or sequential if flag is set)
227    // P2-7 Phase 2: Parallel batch execution using Rayon
228    let intermediate_results: Vec<_> = if sequential {
229        // Sequential fallback for debugging
230        queries
231            .iter()
232            .enumerate()
233            .map(|(idx, query)| execute_single_query(executor, workspace, idx, query))
234            .collect()
235    } else {
236        // Parallel execution (default)
237        queries
238            .par_iter()
239            .enumerate()
240            .map(|(idx, query)| execute_single_query(executor, workspace, idx, query))
241            .collect()
242    };
243
244    // Process results sequentially to print progress messages in order.
245    let mut outcomes = Vec::with_capacity(queries.len());
246    for execution in intermediate_results {
247        let outcome = process_query_result(
248            execution,
249            queries.len(),
250            limit,
251            should_capture_results,
252            continue_on_error,
253        )?;
254        outcomes.push(outcome);
255    }
256
257    Ok(outcomes)
258}
259
260struct RenderOutputOptions<'a> {
261    output: BatchFormat,
262    workspace: &'a Path,
263    preload_elapsed: Duration,
264    summary: &'a BatchSummary,
265    should_capture_results: bool,
266    limit: usize,
267    stats: bool,
268}
269
270fn render_output(outcomes: &[QueryOutcome], options: &RenderOutputOptions<'_>) -> Result<String> {
271    let mut rendered = match options.output {
272        BatchFormat::Text => render_text(outcomes, options.limit),
273        BatchFormat::Csv => render_csv(outcomes),
274        BatchFormat::Json => render_json(
275            outcomes,
276            options.workspace,
277            options.preload_elapsed,
278            options.summary,
279            options.should_capture_results,
280        )?,
281        BatchFormat::Jsonl => render_jsonl(outcomes, options.should_capture_results)?,
282    };
283
284    if options.stats && matches!(options.output, BatchFormat::Text | BatchFormat::Csv) {
285        let stats_block = render_stats_block(options.preload_elapsed, options.summary);
286        if !stats_block.is_empty() {
287            if !rendered.is_empty() && !rendered.ends_with('\n') {
288                rendered.push('\n');
289            }
290            rendered.push_str(&stats_block);
291        }
292    }
293
294    Ok(rendered)
295}
296
297fn render_text(outcomes: &[QueryOutcome], limit: usize) -> String {
298    use std::fmt::Write as _;
299
300    let mut lines = Vec::new();
301    for outcome in outcomes {
302        match outcome {
303            QueryOutcome::Success(entry) => {
304                let mut line = format!(
305                    "Query {}: {} ({}ms) - {} results",
306                    entry.position,
307                    entry.query,
308                    entry.elapsed.as_millis(),
309                    entry.total_matches
310                );
311                if entry.total_matches != entry.displayed_matches && limit > 0 {
312                    let _ = write!(
313                        line,
314                        " (showing {} results; use --limit to adjust)",
315                        entry.displayed_matches
316                    );
317                }
318                lines.push(line);
319            }
320            QueryOutcome::Failure(entry) => {
321                lines.push(format!("Query {} failed: {}", entry.position, entry.error));
322            }
323        }
324    }
325
326    lines.join("\n")
327}
328
329fn render_csv(outcomes: &[QueryOutcome]) -> String {
330    let mut rows = Vec::with_capacity(outcomes.len() + 1);
331    rows.push("position,query,elapsed_ms,result_count,displayed_count,error".to_string());
332
333    for outcome in outcomes {
334        match outcome {
335            QueryOutcome::Success(entry) => {
336                rows.push(format!(
337                    "{},{},{},{},{},",
338                    entry.position,
339                    csv_escape(&entry.query),
340                    entry.elapsed.as_millis(),
341                    entry.total_matches,
342                    entry.displayed_matches
343                ));
344            }
345            QueryOutcome::Failure(entry) => {
346                rows.push(format!(
347                    "{},{},,,,{}",
348                    entry.position,
349                    csv_escape(&entry.query),
350                    csv_escape(&entry.error)
351                ));
352            }
353        }
354    }
355
356    rows.join("\n")
357}
358
359fn render_json(
360    outcomes: &[QueryOutcome],
361    workspace: &Path,
362    preload_elapsed: Duration,
363    summary: &BatchSummary,
364    capture_results: bool,
365) -> Result<String> {
366    let mut queries = Vec::new();
367    let mut errors = Vec::new();
368
369    for outcome in outcomes {
370        match outcome {
371            QueryOutcome::Success(entry) => {
372                let results = if capture_results {
373                    entry
374                        .results
375                        .as_ref()
376                        .map(|symbols| symbols.iter().map(JsonSymbol::from).collect())
377                } else {
378                    None
379                };
380
381                queries.push(BatchJsonQuery {
382                    position: entry.position,
383                    query: entry.query.clone(),
384                    elapsed_ms: entry.elapsed.as_millis(),
385                    result_count: entry.total_matches,
386                    displayed_count: entry.displayed_matches,
387                    truncated_count: if entry.total_matches == entry.displayed_matches {
388                        None
389                    } else {
390                        Some(entry.displayed_matches)
391                    },
392                    results,
393                });
394            }
395            QueryOutcome::Failure(entry) => errors.push(BatchJsonError {
396                position: entry.position,
397                query: entry.query.clone(),
398                error: entry.error.clone(),
399            }),
400        }
401    }
402
403    let payload = BatchJsonPayload {
404        session: BatchJsonSession {
405            path: workspace.display().to_string(),
406            executor_setup_ms: preload_elapsed.as_millis(),
407            total_queries: summary.total_queries,
408            success_count: summary.success_count,
409            failure_count: summary.failure_count,
410            total_query_time_ms: summary.total_query_time.as_millis(),
411            average_query_time_ms: summary.average_query_time_ms(),
412        },
413        queries,
414        errors,
415    };
416
417    Ok(serde_json::to_string_pretty(&payload)?)
418}
419
420fn render_jsonl(outcomes: &[QueryOutcome], capture_results: bool) -> Result<String> {
421    let mut lines = Vec::new();
422
423    for outcome in outcomes {
424        match outcome {
425            QueryOutcome::Success(entry) => {
426                let results = if capture_results {
427                    entry
428                        .results
429                        .as_ref()
430                        .map(|symbols| symbols.iter().map(JsonSymbol::from).collect())
431                } else {
432                    None
433                };
434
435                let row = serde_json::to_string(&BatchJsonQuery {
436                    position: entry.position,
437                    query: entry.query.clone(),
438                    elapsed_ms: entry.elapsed.as_millis(),
439                    result_count: entry.total_matches,
440                    displayed_count: entry.displayed_matches,
441                    truncated_count: if entry.total_matches == entry.displayed_matches {
442                        None
443                    } else {
444                        Some(entry.displayed_matches)
445                    },
446                    results,
447                })?;
448                lines.push(row);
449            }
450            QueryOutcome::Failure(entry) => {
451                let row = serde_json::to_string(&BatchJsonError {
452                    position: entry.position,
453                    query: entry.query.clone(),
454                    error: entry.error.clone(),
455                })?;
456                lines.push(row);
457            }
458        }
459    }
460
461    Ok(lines.join("\n"))
462}
463
464fn render_stats_block(preload_elapsed: Duration, summary: &BatchSummary) -> String {
465    let mut lines = Vec::new();
466
467    lines.push("Batch Statistics:".to_string());
468    lines.push(format!(
469        "  Executor setup time: {} ms",
470        preload_elapsed.as_millis()
471    ));
472    lines.push(format!("  Total queries: {}", summary.total_queries));
473    lines.push(format!("  Successful queries: {}", summary.success_count));
474    lines.push(format!("  Failed queries: {}", summary.failure_count));
475    lines.push(format!(
476        "  Total query time: {} ms",
477        summary.total_query_time.as_millis()
478    ));
479    if let Some(avg) = summary.average_query_time_ms() {
480        lines.push(format!("  Average query time: {avg} ms"));
481    }
482
483    lines.join("\n")
484}
485
486fn write_output(path: Option<&Path>, content: &str) -> Result<()> {
487    if content.is_empty() {
488        return Ok(());
489    }
490
491    if let Some(file_path) = path {
492        let mut writer = BufWriter::new(
493            File::create(file_path)
494                .with_context(|| format!("failed to create {}", file_path.display()))?,
495        );
496        writer.write_all(content.as_bytes())?;
497        if !content.ends_with('\n') {
498            writer.write_all(b"\n")?;
499        }
500        writer.flush()?;
501    } else {
502        print!("{content}");
503        if !content.ends_with('\n') {
504            println!();
505        }
506    }
507
508    Ok(())
509}
510
511fn csv_escape(value: &str) -> String {
512    let needs_quotes = value.contains(',') || value.contains('"') || value.contains('\n');
513    if needs_quotes {
514        let escaped = value.replace('"', "\"\"");
515        format!("\"{escaped}\"")
516    } else {
517        value.to_string()
518    }
519}
520
521/// Per-query outcome recorded during execution.
522enum QueryOutcome {
523    Success(BatchEntry),
524    Failure(BatchFailedEntry),
525}
526
527/// Successful query execution details.
528struct BatchEntry {
529    position: usize,
530    query: String,
531    elapsed: Duration,
532    total_matches: usize,
533    displayed_matches: usize,
534    results: Option<Vec<DisplaySymbol>>,
535}
536
537/// Failed query details.
538struct BatchFailedEntry {
539    position: usize,
540    query: String,
541    error: String,
542}
543
544/// Aggregate execution summary used for stats/output.
545struct BatchSummary {
546    total_queries: usize,
547    success_count: usize,
548    failure_count: usize,
549    total_query_time: Duration,
550}
551
552impl BatchSummary {
553    fn from_outcomes(outcomes: &[QueryOutcome]) -> Self {
554        let mut success_count = 0usize;
555        let mut total_time = Duration::ZERO;
556
557        for outcome in outcomes {
558            if let QueryOutcome::Success(entry) = outcome {
559                success_count += 1;
560                total_time += entry.elapsed;
561            }
562        }
563
564        let total_queries = outcomes.len();
565        let failure_count = total_queries.saturating_sub(success_count);
566        Self {
567            total_queries,
568            success_count,
569            failure_count,
570            total_query_time: total_time,
571        }
572    }
573
574    fn average_query_time_ms(&self) -> Option<u128> {
575        if self.success_count == 0 {
576            None
577        } else {
578            Some(self.total_query_time.as_millis() / self.success_count as u128)
579        }
580    }
581}
582
583#[derive(Serialize)]
584struct BatchJsonPayload {
585    session: BatchJsonSession,
586    #[serde(skip_serializing_if = "Vec::is_empty")]
587    queries: Vec<BatchJsonQuery>,
588    #[serde(skip_serializing_if = "Vec::is_empty")]
589    errors: Vec<BatchJsonError>,
590}
591
592#[derive(Serialize)]
593struct BatchJsonSession {
594    path: String,
595    executor_setup_ms: u128,
596    total_queries: usize,
597    success_count: usize,
598    failure_count: usize,
599    total_query_time_ms: u128,
600    #[serde(skip_serializing_if = "Option::is_none")]
601    average_query_time_ms: Option<u128>,
602}
603
604#[derive(Serialize)]
605struct BatchJsonQuery {
606    position: usize,
607    query: String,
608    elapsed_ms: u128,
609    result_count: usize,
610    displayed_count: usize,
611    #[serde(skip_serializing_if = "Option::is_none")]
612    #[serde(rename = "truncated")]
613    truncated_count: Option<usize>,
614    #[serde(skip_serializing_if = "Option::is_none")]
615    results: Option<Vec<JsonSymbol>>,
616}
617
618#[derive(Serialize)]
619struct BatchJsonError {
620    position: usize,
621    query: String,
622    error: String,
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use std::collections::HashMap;
629    use std::fs;
630    use std::path::{Path, PathBuf};
631    use tempfile::tempdir;
632
633    fn create_test_display_symbol(name: &str, file: &str, line: usize) -> DisplaySymbol {
634        let mut metadata = HashMap::new();
635        metadata.insert("__raw_language".to_string(), "rust".to_string());
636        metadata.insert("__raw_file_path".to_string(), file.to_string());
637        DisplaySymbol {
638            name: name.to_string(),
639            qualified_name: format!("test::{name}"),
640            kind: "function".to_string(),
641            file_path: PathBuf::from(file),
642            start_line: line,
643            start_column: 1,
644            end_line: line,
645            end_column: 10,
646            metadata,
647            caller_identity: None,
648            callee_identity: None,
649        }
650    }
651
652    #[test]
653    fn load_queries_ignores_comments_and_blank_lines() {
654        let dir = tempdir().unwrap();
655        let path = dir.path().join("queries.txt");
656        fs::write(
657            &path,
658            "# comment\nkind:function\n\n   callers(main)\n# another\n",
659        )
660        .unwrap();
661
662        let queries = load_queries(&path).unwrap();
663        assert_eq!(queries, vec!["kind:function", "callers(main)"]);
664    }
665
666    #[test]
667    fn csv_escape_quotes_fields() {
668        assert_eq!(csv_escape("kind:function"), "kind:function");
669        assert_eq!(csv_escape("name,kind"), "\"name,kind\"");
670        assert_eq!(csv_escape("say \"hi\""), "\"say \"\"hi\"\"\"");
671    }
672
673    #[test]
674    fn render_text_includes_limit_note() {
675        let outcomes = vec![QueryOutcome::Success(BatchEntry {
676            position: 1,
677            query: "kind:function".to_string(),
678            elapsed: Duration::from_millis(42),
679            total_matches: 120,
680            displayed_matches: 100,
681            results: None,
682        })];
683
684        let rendered = render_text(&outcomes, 100);
685        assert!(rendered.contains("Query 1: kind:function (42ms) - 120 results"));
686        assert!(rendered.contains("showing 100 results"));
687    }
688
689    #[test]
690    fn render_json_includes_session_metadata() {
691        let symbol = create_test_display_symbol("test_func", "src/lib.rs", 5);
692
693        let outcomes = vec![QueryOutcome::Success(BatchEntry {
694            position: 1,
695            query: "kind:function".to_string(),
696            elapsed: Duration::from_millis(40),
697            total_matches: 1,
698            displayed_matches: 1,
699            results: Some(vec![symbol]),
700        })];
701
702        let summary = BatchSummary::from_outcomes(&outcomes);
703        let rendered = render_json(
704            &outcomes,
705            Path::new("/tmp/workspace"),
706            Duration::from_millis(885),
707            &summary,
708            true,
709        )
710        .unwrap();
711
712        assert!(rendered.contains("\"path\": \"/tmp/workspace\""));
713        assert!(rendered.contains("\"executor_setup_ms\": 885"));
714        assert!(rendered.contains("\"queries\""));
715    }
716
717    #[test]
718    fn render_jsonl_outputs_lines() {
719        let outcomes = vec![
720            QueryOutcome::Success(BatchEntry {
721                position: 1,
722                query: "kind:function".to_string(),
723                elapsed: Duration::from_millis(40),
724                total_matches: 2,
725                displayed_matches: 2,
726                results: Some(vec![]),
727            }),
728            QueryOutcome::Failure(BatchFailedEntry {
729                position: 2,
730                query: "broken".to_string(),
731                error: "syntax error".to_string(),
732            }),
733        ];
734
735        let rendered = render_jsonl(&outcomes, false).unwrap();
736        let lines: Vec<_> = rendered.lines().collect();
737        assert_eq!(lines.len(), 2);
738        assert!(lines[0].contains("\"position\":1"));
739        assert!(lines[1].contains("\"error\""));
740    }
741}