1use anyhow::Result;
7use colored::Colorize;
8use std::io::{self, Write};
9
10use crate::args::Cli;
11use crate::output::OutputStreams;
12
13fn env_flag_truthy(name: &str) -> bool {
18 match std::env::var(name) {
19 Ok(v) => {
20 let v = v.trim();
21 v.eq_ignore_ascii_case("1")
22 || v.eq_ignore_ascii_case("true")
23 || v.eq_ignore_ascii_case("yes")
24 || v.eq_ignore_ascii_case("on")
25 }
26 Err(_) => false,
27 }
28}
29
30struct ResponseConfig<'a> {
32 cli: &'a Cli,
33 path: &'a str,
34 auto_execute: bool,
35 dry_run: bool,
36}
37
38fn write_execute_json(
40 streams: &mut OutputStreams,
41 command: &str,
42 confidence: f32,
43 intent: &str,
44 dry_run: bool,
45 auto_execute: bool,
46) -> Result<()> {
47 let output = if dry_run {
48 serde_json::json!({
49 "type": "execute",
50 "command": command,
51 "confidence": confidence,
52 "intent": intent,
53 "dry_run": true
54 })
55 } else if auto_execute {
56 serde_json::json!({
57 "type": "execute",
58 "command": command,
59 "confidence": confidence,
60 "intent": intent,
61 "auto_execute": true
62 })
63 } else {
64 serde_json::json!({
65 "type": "confirm",
66 "command": command,
67 "confidence": confidence,
68 "intent": intent
69 })
70 };
71 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
72 Ok(())
73}
74
75fn write_execute_text(
77 streams: &mut OutputStreams,
78 command: &str,
79 confidence: f32,
80 intent: &str,
81 dry_run: bool,
82 auto_execute: bool,
83) -> Result<()> {
84 if dry_run {
85 streams.write_result(&format!(
86 "{} {}\n{}: {:.0}%\n{}: {}\n",
87 "Command:".bold(),
88 command.green(),
89 "Confidence".dimmed(),
90 confidence * 100.0,
91 "Intent".dimmed(),
92 intent
93 ))?;
94 } else if auto_execute {
95 streams.write_result(&format!(
96 "{} {} ({:.0}% confidence)\n",
97 "Executing:".green().bold(),
98 command,
99 confidence * 100.0
100 ))?;
101 } else {
102 streams.write_result(&format!(
103 "{} {}\n{}: {:.0}%\n",
104 "Generated command:".bold(),
105 command.cyan(),
106 "Confidence".dimmed(),
107 confidence * 100.0
108 ))?;
109 }
110 Ok(())
111}
112
113fn handle_execute_response(
115 streams: &mut OutputStreams,
116 config: &ResponseConfig,
117 command: &str,
118 confidence: f32,
119 intent: &str,
120) -> Result<()> {
121 if config.cli.json {
122 write_execute_json(
123 streams,
124 command,
125 confidence,
126 intent,
127 config.dry_run,
128 config.auto_execute,
129 )?;
130 } else {
131 write_execute_text(
132 streams,
133 command,
134 confidence,
135 intent,
136 config.dry_run,
137 config.auto_execute,
138 )?;
139 }
140
141 if config.dry_run {
142 return Ok(());
143 }
144
145 if config.auto_execute {
146 execute_generated_command(command, config.path, config.cli)?;
147 } else if !config.cli.json {
148 if prompt_confirmation("Execute this command?")? {
150 execute_generated_command(command, config.path, config.cli)?;
151 } else {
152 streams.write_diagnostic("Cancelled.\n")?;
153 }
154 }
155
156 Ok(())
157}
158
159fn write_confirm_json(
161 streams: &mut OutputStreams,
162 command: &str,
163 confidence: f32,
164 prompt: &str,
165 dry_run: bool,
166 auto_execute: bool,
167) -> Result<()> {
168 let output = serde_json::json!({
169 "type": "confirm",
170 "command": command,
171 "confidence": confidence,
172 "prompt": prompt,
173 "dry_run": dry_run,
174 "auto_execute": auto_execute
175 });
176 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
177 Ok(())
178}
179
180fn write_confirm_text(
182 streams: &mut OutputStreams,
183 command: &str,
184 confidence: f32,
185 prompt: &str,
186 dry_run: bool,
187) -> Result<()> {
188 if dry_run {
189 streams.write_result(&format!(
190 "{} {}\n{}: {:.0}%\n{}\n",
191 "Command:".bold(),
192 command.yellow(),
193 "Confidence".dimmed(),
194 confidence * 100.0,
195 "(Medium confidence - would require confirmation)".dimmed()
196 ))?;
197 } else {
198 streams.write_result(&format!(
199 "{}\n{} {}\n",
200 prompt.yellow(),
201 "Command:".bold(),
202 command.cyan()
203 ))?;
204 }
205 Ok(())
206}
207
208fn handle_confirm_response(
210 streams: &mut OutputStreams,
211 config: &ResponseConfig,
212 command: &str,
213 confidence: f32,
214 prompt: &str,
215) -> Result<()> {
216 if config.cli.json {
217 write_confirm_json(
218 streams,
219 command,
220 confidence,
221 prompt,
222 config.dry_run,
223 config.auto_execute,
224 )?;
225 } else {
226 write_confirm_text(streams, command, confidence, prompt, config.dry_run)?;
227 }
228
229 if config.dry_run {
230 return Ok(());
231 }
232
233 let should_execute = if config.cli.json {
235 config.auto_execute
236 } else {
237 config.auto_execute || prompt_confirmation("")?
238 };
239
240 if should_execute {
241 execute_generated_command(command, config.path, config.cli)?;
242 } else if !config.cli.json {
243 streams.write_diagnostic("Cancelled.\n")?;
244 }
245
246 Ok(())
247}
248
249fn handle_disambiguate_response(
251 streams: &mut OutputStreams,
252 config: &ResponseConfig,
253 options: &[sqry_nl::DisambiguationOption],
254 prompt: &str,
255) -> Result<()> {
256 let best_option = select_best_disambiguation(options);
257
258 if config.cli.json {
259 handle_disambiguate_json(streams, config, options, prompt, best_option)?;
260 } else {
261 handle_disambiguate_text(streams, config, options, prompt, best_option)?;
262 }
263
264 Ok(())
265}
266
267fn select_best_disambiguation(
268 options: &[sqry_nl::DisambiguationOption],
269) -> Option<&sqry_nl::DisambiguationOption> {
270 options.iter().max_by(|a, b| {
271 a.confidence
272 .partial_cmp(&b.confidence)
273 .unwrap_or(std::cmp::Ordering::Equal)
274 })
275}
276
277fn handle_disambiguate_json(
278 streams: &mut OutputStreams,
279 config: &ResponseConfig,
280 options: &[sqry_nl::DisambiguationOption],
281 prompt: &str,
282 best_option: Option<&sqry_nl::DisambiguationOption>,
283) -> Result<()> {
284 let output = serde_json::json!({
285 "type": "disambiguate",
286 "prompt": prompt,
287 "options": options.iter().map(|opt| {
288 serde_json::json!({
289 "command": opt.command,
290 "intent": opt.intent.as_str(),
291 "description": opt.description,
292 "confidence": opt.confidence
293 })
294 }).collect::<Vec<_>>(),
295 "auto_execute": config.auto_execute,
296 "dry_run": config.dry_run
297 });
298 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
299
300 if let Some(selected) = best_option.filter(|_| config.auto_execute && !config.dry_run) {
301 execute_generated_command(&selected.command, config.path, config.cli)?;
302 }
303
304 Ok(())
305}
306
307fn handle_disambiguate_text(
308 streams: &mut OutputStreams,
309 config: &ResponseConfig,
310 options: &[sqry_nl::DisambiguationOption],
311 prompt: &str,
312 best_option: Option<&sqry_nl::DisambiguationOption>,
313) -> Result<()> {
314 streams.write_result(&format!("{}\n\n", prompt.yellow()))?;
315
316 for (i, opt) in options.iter().enumerate() {
317 streams.write_result(&format!(
318 " {}. {} - {}\n {}\n\n",
319 i + 1,
320 opt.description.bold(),
321 format!("{:.0}%", opt.confidence * 100.0).dimmed(),
322 opt.command.cyan()
323 ))?;
324 }
325
326 if config.dry_run || options.is_empty() {
327 return Ok(());
328 }
329
330 if config.auto_execute {
331 if let Some(selected) = best_option {
332 streams.write_result(&format!(
333 "\n{} {}\n",
334 "Auto-executing highest confidence:".green().bold(),
335 selected.command
336 ))?;
337 execute_generated_command(&selected.command, config.path, config.cli)?;
338 }
339 return Ok(());
340 }
341
342 execute_disambiguation_choice(streams, config, options)
343}
344
345fn execute_disambiguation_choice(
346 streams: &mut OutputStreams,
347 config: &ResponseConfig,
348 options: &[sqry_nl::DisambiguationOption],
349) -> Result<()> {
350 let choice = prompt_choice(options.len())?;
351 if let Some(idx) = choice {
352 let selected = &options[idx];
353 streams.write_result(&format!(
354 "\n{} {}\n",
355 "Executing:".green().bold(),
356 selected.command
357 ))?;
358 execute_generated_command(&selected.command, config.path, config.cli)?;
359 } else {
360 streams.write_diagnostic("Cancelled.\n")?;
361 }
362 Ok(())
363}
364
365fn handle_reject_response(
368 streams: &mut OutputStreams,
369 config: &ResponseConfig,
370 reason: &str,
371 suggestions: &[String],
372) -> Result<String> {
373 if config.cli.json {
374 let output = serde_json::json!({
375 "type": "reject",
376 "reason": reason,
377 "suggestions": suggestions
378 });
379 streams.write_result(&serde_json::to_string_pretty(&output)?)?;
380 } else {
381 streams.write_diagnostic(&format!(
382 "{} {}\n",
383 "Cannot translate:".red().bold(),
384 reason
385 ))?;
386
387 if !suggestions.is_empty() {
388 streams.write_diagnostic(&format!("\n{}:\n", "Suggestions".yellow()))?;
389 for suggestion in suggestions {
390 streams.write_diagnostic(&format!(" • {suggestion}\n"))?;
391 }
392 }
393 }
394 Ok(format!("Translation rejected: {reason}"))
395}
396
397#[allow(clippy::fn_params_excessive_bools, clippy::too_many_arguments)]
405pub fn run_ask(
406 cli: &Cli,
407 query: &str,
408 path: &str,
409 auto_execute: bool,
410 dry_run: bool,
411 threshold: f32,
412 model_dir_override: Option<&std::path::Path>,
413 allow_unverified_model_flag: bool,
414 allow_model_download_flag: bool,
415) -> Result<()> {
416 use sqry_nl::{TranslationResponse, Translator, TranslatorConfig};
417
418 let mut streams = OutputStreams::with_pager(cli.pager_config());
419
420 let allow_unverified_model =
423 allow_unverified_model_flag || env_flag_truthy("SQRY_NL_ALLOW_UNVERIFIED_MODEL");
424 let allow_model_download =
425 allow_model_download_flag || env_flag_truthy("SQRY_NL_ALLOW_DOWNLOAD");
426
427 let translator_config = TranslatorConfig {
429 execute_threshold: threshold,
430 confirm_threshold: threshold * 0.75, model_dir_override: model_dir_override.map(std::path::Path::to_path_buf),
432 allow_unverified_model,
433 allow_model_download,
434 ..Default::default()
435 };
436
437 let mut translator = match Translator::new(translator_config) {
441 Ok(t) => t,
442 Err(sqry_nl::NlError::OnnxRuntimeMissing { hint }) => {
443 return Err(crate::error::CliError::OnnxRuntimeMissing { hint }.into());
444 }
445 Err(e) => {
446 return Err(
447 anyhow::Error::new(e).context("Failed to initialize natural language translator")
448 );
449 }
450 };
451
452 let response = translator.translate(query);
454
455 let config = ResponseConfig {
457 cli,
458 path,
459 auto_execute,
460 dry_run,
461 };
462
463 let reject_error = match response {
465 TranslationResponse::Execute {
466 command,
467 confidence,
468 intent,
469 ..
470 } => {
471 handle_execute_response(&mut streams, &config, &command, confidence, intent.as_str())?;
472 None
473 }
474
475 TranslationResponse::Confirm {
476 command,
477 confidence,
478 prompt,
479 } => {
480 handle_confirm_response(&mut streams, &config, &command, confidence, &prompt)?;
481 None
482 }
483
484 TranslationResponse::Disambiguate { options, prompt } => {
485 handle_disambiguate_response(&mut streams, &config, &options, &prompt)?;
486 None
487 }
488
489 TranslationResponse::Reject {
490 reason,
491 suggestions,
492 } => {
493 let error_msg = handle_reject_response(&mut streams, &config, &reason, &suggestions)?;
494 Some(error_msg)
495 }
496 };
497
498 streams.finish_checked()?;
499
500 if let Some(error_msg) = reject_error {
502 anyhow::bail!("{error_msg}");
503 }
504
505 Ok(())
506}
507
508#[derive(Debug, Default)]
510struct ParsedCommandArgs {
511 primary: String,
513 language: Option<String>,
515 kind: Option<String>,
517 limit: Option<u32>,
519 path_filter: Option<String>,
521 secondary: Option<String>,
523 max_depth: Option<u32>,
525}
526
527fn extract_flag_value(command: &str, flag: &str) -> Option<String> {
532 let flag_pos = command.find(flag)?;
534 let after_flag = &command[flag_pos + flag.len()..];
535
536 let trimmed = after_flag.trim_start();
538 if trimmed.is_empty() {
539 return None;
540 }
541
542 if let Some(stripped) = trimmed.strip_prefix('"') {
544 if let Some(end) = stripped.find('"') {
546 return Some(stripped[..end].to_string());
547 }
548 return Some(stripped.to_string());
550 }
551
552 let value = trimmed.split_whitespace().next()?;
554 Some(value.to_string())
555}
556
557fn parse_generated_command(command: &str) -> Result<ParsedCommandArgs> {
559 let mut args = ParsedCommandArgs::default();
560
561 let mut quoted_strings = Vec::new();
563 let mut in_quote = false;
564 let mut current_quoted = String::new();
565
566 for c in command.chars() {
567 if c == '"' {
568 if in_quote {
569 quoted_strings.push(current_quoted.clone());
570 current_quoted.clear();
571 }
572 in_quote = !in_quote;
573 } else if in_quote {
574 current_quoted.push(c);
575 }
576 }
577
578 if let Some(primary) = quoted_strings.first() {
580 args.primary.clone_from(primary);
581 }
582
583 if let Some(secondary) = quoted_strings.get(1) {
585 args.secondary = Some(secondary.clone());
586 }
587
588 args.path_filter = extract_flag_value(command, "--path");
591
592 let parts: Vec<&str> = command.split_whitespace().collect();
594 let mut i = 0;
595 while i < parts.len() {
596 match parts[i] {
597 "--language" if i + 1 < parts.len() => {
598 args.language = Some(parts[i + 1].to_string());
599 i += 2;
600 }
601 "--kind" if i + 1 < parts.len() => {
602 args.kind = Some(parts[i + 1].to_string());
603 i += 2;
604 }
605 "--limit" if i + 1 < parts.len() => {
606 args.limit = parts[i + 1].parse().ok();
607 i += 2;
608 }
609 "--path" => {
610 i += 2;
613 }
614 "--max-depth" if i + 1 < parts.len() => {
615 args.max_depth = parts[i + 1].parse().ok();
616 i += 2;
617 }
618 _ => {
619 i += 1;
620 }
621 }
622 }
623
624 if args.primary.is_empty() {
625 anyhow::bail!("Could not extract primary argument from command: {command}");
626 }
627
628 Ok(args)
629}
630
631fn build_query_expression(args: &ParsedCommandArgs) -> String {
637 let mut expr_parts = vec![args.primary.clone()];
638
639 if let Some(lang) = &args.language
644 && !args.primary.contains("lang:")
645 && !args.primary.contains("language:")
646 {
647 expr_parts.push(format!("language:{lang}"));
648 }
649
650 if let Some(path) = &args.path_filter
652 && !args.primary.contains("path:")
653 {
654 if path.contains(' ') {
655 let escaped = path.replace('"', "\\\"");
657 expr_parts.push(format!("path:\"{escaped}\""));
658 } else {
659 expr_parts.push(format!("path:{path}"));
660 }
661 }
662
663 expr_parts.join(" ")
666}
667
668fn execute_generated_command(command: &str, path: &str, cli: &Cli) -> Result<()> {
670 let parts: Vec<&str> = command.split_whitespace().collect();
672
673 if parts.is_empty() || parts[0] != "sqry" {
674 anyhow::bail!("Invalid generated command: {command}");
675 }
676
677 if parts.len() < 2 {
678 anyhow::bail!("Generated command missing subcommand: {command}");
679 }
680
681 let subcommand = parts[1];
682
683 match subcommand {
684 "query" => {
685 let parsed = parse_generated_command(command)?;
687 let query_expr = build_query_expression(&parsed);
689 let result_limit = parsed.limit.map(|l| l as usize);
691 super::run_query(
692 cli,
693 &query_expr,
694 path,
695 false,
696 false,
697 false,
698 false,
699 None,
700 result_limit,
701 &[],
702 )?;
703 }
704 "search" => {
705 let parsed = parse_generated_command(command)?;
706 super::run_search(cli, &parsed.primary, path, None, false, false)?;
711 }
712 "graph" => {
713 if parts.len() < 3 {
715 anyhow::bail!("Graph command missing operation: {command}");
716 }
717 eprintln!(
719 "{}",
720 format!("Graph commands not yet auto-executable: {command}").yellow()
721 );
722 }
723 "index" => {
724 if command.contains("--status") {
725 super::run_index_status(cli, path, crate::args::MetricsFormat::Json)?;
726 } else {
727 eprintln!(
728 "{}",
729 format!("Index build not auto-executable: {command}").yellow()
730 );
731 }
732 }
733 _ => {
734 anyhow::bail!("Unsupported generated command: {subcommand}");
735 }
736 }
737
738 Ok(())
739}
740
741#[cfg(test)]
743fn extract_quoted_arg(command: &str, _position: usize) -> Result<String> {
744 if let Some(start) = command.find('"')
746 && let Some(end) = command[start + 1..].find('"')
747 {
748 return Ok(command[start + 1..start + 1 + end].to_string());
749 }
750 let parts: Vec<&str> = command.split_whitespace().collect();
752 if parts.len() > 2 {
753 let arg = parts[2].trim_matches('"');
755 return Ok(arg.to_string());
756 }
757 anyhow::bail!("Could not extract argument from: {command}")
758}
759
760fn prompt_confirmation(message: &str) -> Result<bool> {
762 if message.is_empty() {
763 eprint!("[y/N] ");
764 } else {
765 eprint!("{message} [y/N] ");
766 }
767 io::stderr().flush()?;
768
769 let mut input = String::new();
770 io::stdin().read_line(&mut input)?;
771
772 Ok(input.trim().eq_ignore_ascii_case("y") || input.trim().eq_ignore_ascii_case("yes"))
773}
774
775fn prompt_choice(max: usize) -> Result<Option<usize>> {
777 eprint!("Enter choice (1-{max}) or 'c' to cancel: ");
778 io::stderr().flush()?;
779
780 let mut input = String::new();
781 io::stdin().read_line(&mut input)?;
782
783 let trimmed = input.trim();
784 if trimmed.eq_ignore_ascii_case("c") || trimmed.is_empty() {
785 return Ok(None);
786 }
787
788 match trimmed.parse::<usize>() {
789 Ok(n) if n >= 1 && n <= max => Ok(Some(n - 1)),
790 _ => {
791 eprintln!("Invalid choice");
792 Ok(None)
793 }
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_extract_quoted_arg() {
803 let cmd = r#"sqry query "kind:function""#;
804 let arg = extract_quoted_arg(cmd, 2).unwrap();
805 assert_eq!(arg, "kind:function");
806 }
807
808 #[test]
809 fn test_extract_quoted_arg_with_spaces() {
810 let cmd = r#"sqry search "hello world""#;
811 let arg = extract_quoted_arg(cmd, 2).unwrap();
812 assert_eq!(arg, "hello world");
813 }
814
815 #[test]
816 fn test_parse_generated_command_basic() {
817 let cmd = r#"sqry query "authenticate" --limit 100"#;
818 let parsed = parse_generated_command(cmd).unwrap();
819 assert_eq!(parsed.primary, "authenticate");
820 assert_eq!(parsed.limit, Some(100));
821 assert!(parsed.language.is_none());
822 assert!(parsed.kind.is_none());
823 }
824
825 #[test]
826 fn test_parse_generated_command_with_all_flags() {
827 let cmd = r#"sqry query "login" --language rust --kind function --limit 50"#;
828 let parsed = parse_generated_command(cmd).unwrap();
829 assert_eq!(parsed.primary, "login");
830 assert_eq!(parsed.language.as_deref(), Some("rust"));
831 assert_eq!(parsed.kind.as_deref(), Some("function"));
832 assert_eq!(parsed.limit, Some(50));
833 }
834
835 #[test]
836 fn test_parse_generated_command_trace_path() {
837 let cmd = r#"sqry graph trace-path "source" "target" --max-depth 5"#;
838 let parsed = parse_generated_command(cmd).unwrap();
839 assert_eq!(parsed.primary, "source");
840 assert_eq!(parsed.secondary.as_deref(), Some("target"));
841 assert_eq!(parsed.max_depth, Some(5));
842 }
843
844 #[test]
845 fn test_build_query_expression_basic() {
846 let args = ParsedCommandArgs {
847 primary: "authenticate".to_string(),
848 ..Default::default()
849 };
850 let expr = build_query_expression(&args);
851 assert_eq!(expr, "authenticate");
852 }
853
854 #[test]
855 fn test_build_query_expression_with_predicates() {
856 let args = ParsedCommandArgs {
859 primary: "kind:function login".to_string(), language: Some("rust".to_string()),
861 kind: Some("function".to_string()),
862 limit: Some(50), ..Default::default()
864 };
865 let expr = build_query_expression(&args);
866 assert!(expr.contains("login"));
867 assert!(expr.contains("kind:function"));
868 assert!(expr.contains("language:rust"));
869 assert!(!expr.contains("limit:"));
871 }
872
873 #[test]
874 fn test_build_query_expression_with_path() {
875 let args = ParsedCommandArgs {
876 primary: "test".to_string(),
877 path_filter: Some("src/lib.rs".to_string()),
878 ..Default::default()
879 };
880 let expr = build_query_expression(&args);
881 assert!(expr.contains("path:src/lib.rs"));
882 }
883
884 #[test]
885 fn test_build_query_expression_with_path_spaces() {
886 let args = ParsedCommandArgs {
887 primary: "login".to_string(),
888 path_filter: Some("src/api services".to_string()),
889 language: Some("rust".to_string()),
890 ..Default::default()
891 };
892 let expr = build_query_expression(&args);
893 assert!(expr.contains(r#"path:"src/api services""#));
895 assert!(expr.contains("language:rust"));
896 }
897
898 #[test]
899 fn test_extract_flag_value_unquoted() {
900 let cmd = r#"sqry query "test" --limit 50"#;
901 assert_eq!(extract_flag_value(cmd, "--limit"), Some("50".to_string()));
902 }
903
904 #[test]
905 fn test_extract_flag_value_quoted() {
906 let cmd = r#"sqry query "test" --path "src/api services""#;
907 assert_eq!(
908 extract_flag_value(cmd, "--path"),
909 Some("src/api services".to_string())
910 );
911 }
912
913 #[test]
914 fn test_extract_flag_value_not_present() {
915 let cmd = r#"sqry query "test""#;
916 assert_eq!(extract_flag_value(cmd, "--limit"), None);
917 }
918
919 #[test]
920 fn test_parse_generated_command_with_path_spaces() {
921 let cmd = r#"sqry query "login" --path "src/api services" --language rust"#;
922 let parsed = parse_generated_command(cmd).unwrap();
923 assert_eq!(parsed.primary, "login");
924 assert_eq!(parsed.path_filter.as_deref(), Some("src/api services"));
925 assert_eq!(parsed.language.as_deref(), Some("rust"));
926 }
927}