Skip to main content

sqry_nl/
translator.rs

1//! Main Translator API for natural language to sqry command translation.
2//!
3//! This module ties together all the components of the translation pipeline:
4//! preprocess → extract → classify → assemble → validate → cache
5
6use crate::assembler;
7use crate::cache::{CacheConfig, CacheKey, CachedResult, TranslationCache};
8use crate::error::{AssemblerError, NlResult};
9use crate::extractor;
10use crate::preprocess;
11use crate::types::{
12    DisambiguationOption, ExtractedEntities, Intent, TranslationResponse, ValidationStatus,
13};
14use crate::validator;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Instant;
18
19/// Confidence thresholds for response tiers.
20const EXECUTE_THRESHOLD: f32 = 0.85;
21const CONFIRM_THRESHOLD: f32 = 0.65;
22
23/// Default cache capacity (number of entries).
24const DEFAULT_CACHE_CAPACITY: usize = 128;
25
26/// Default result limit for cache key generation.
27const DEFAULT_RESULT_LIMIT: u32 = 100;
28
29/// Configuration for the Translator.
30#[derive(Debug, Clone)]
31pub struct TranslatorConfig {
32    /// Path to model directory (for classifier feature).
33    /// Directory should contain: `intent_classifier.onnx`, `tokenizer.json`, and optionally
34    /// `calibration.json` or `temperature.json` for confidence calibration
35    pub model_dir: Option<String>,
36    /// Context: current working directory for relative paths.
37    pub working_directory: Option<String>,
38    /// Custom confidence thresholds.
39    pub execute_threshold: f32,
40    pub confirm_threshold: f32,
41    /// Cache configuration. Set to None to disable caching.
42    pub cache_config: Option<CacheConfig>,
43    /// Default result limit (affects cache key generation).
44    pub default_limit: u32,
45    /// Languages to restrict searches to (affects cache key generation).
46    pub languages: Vec<String>,
47    /// Optional override for the resolved model directory (NL02 resolver chain entry).
48    ///
49    /// When set, the resolver short-circuits the rest of the lookup chain and uses
50    /// this path directly. Distinct from the legacy [`Self::model_dir`] field, which
51    /// remains in place for backward compatibility.
52    pub model_dir_override: Option<PathBuf>,
53    /// Permit loading a model whose checksums cannot be verified.
54    ///
55    /// Defaults to `false`. NL04 enforces strict integrity by default; this flag is
56    /// the operator escape hatch for development workflows.
57    pub allow_unverified_model: bool,
58    /// Permit fetching the model from the network when not present locally.
59    ///
60    /// Defaults to `false`. NL03 wires the gated downloader behind this flag.
61    pub allow_model_download: bool,
62    /// Optional cache directory for downloaded model artifacts.
63    ///
64    /// Defaults to `None`, in which case NL03 selects a platform-appropriate cache root.
65    pub model_cache_dir: Option<PathBuf>,
66    /// Optional override for the classifier worker pool size.
67    ///
68    /// Defaults to `None`. NL07 resolves this to the `SQRY_NL_CLASSIFIER_POOL_SIZE`
69    /// environment variable, falling back to `2` workers.
70    pub classifier_pool_size: Option<usize>,
71}
72
73impl Default for TranslatorConfig {
74    fn default() -> Self {
75        Self {
76            model_dir: None,
77            working_directory: None,
78            execute_threshold: EXECUTE_THRESHOLD,
79            confirm_threshold: CONFIRM_THRESHOLD,
80            cache_config: Some(CacheConfig {
81                capacity: DEFAULT_CACHE_CAPACITY,
82                ..Default::default()
83            }),
84            default_limit: DEFAULT_RESULT_LIMIT,
85            languages: Vec::new(),
86            model_dir_override: None,
87            allow_unverified_model: false,
88            allow_model_download: false,
89            model_cache_dir: None,
90            classifier_pool_size: None,
91        }
92    }
93}
94
95/// The main Translator struct that provides the `translate()` API.
96///
97/// `Debug` is implemented manually because the inner classifier pool
98/// (`ort::Session`-bearing) does not implement `Debug`. The manual
99/// impl renders enough state for daemon `LoadedWorkspace` /
100/// `WorkspaceManager` debug output without dumping model internals.
101pub struct Translator {
102    config: TranslatorConfig,
103    /// Translation counter for stats
104    translations: AtomicU64,
105    /// Translation cache for repeated queries (Step 7)
106    cache: Option<TranslationCache>,
107    /// NL07 — bounded pool of `N` independently-loaded classifier
108    /// sessions. `None` when the resolver chain produced no model
109    /// directory (rule-based fallback only).
110    #[cfg(feature = "classifier")]
111    classifier_pool: Option<crate::classifier::ClassifierPool>,
112}
113
114impl Translator {
115    /// Create a new Translator with the given configuration.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if the classifier fails to load (when classifier feature is enabled).
120    pub fn new(config: TranslatorConfig) -> NlResult<Self> {
121        #[cfg(feature = "classifier")]
122        let classifier_pool = {
123            use crate::classifier::{
124                BAKED_MANIFEST, ClassifierPool, RealDirs, ResolverLevel, TrustMode,
125                ensure_model_in_cache, resolve_model_dir, resolve_pool_size,
126            };
127            use std::ffi::OsString;
128            use std::path::{Path, PathBuf};
129
130            // NL02 5-level resolver:
131            // 1. CLI / programmatic override
132            // 2. Legacy `TranslatorConfig::model_dir`
133            // 3. SQRY_NL_MODEL_DIR env var
134            // 4. XDG cache dir / sqry/models
135            // 5. <exe-dir>/models
136            let cli_override: Option<&Path> = config.model_dir_override.as_deref();
137            let legacy_path: Option<&Path> = config.model_dir.as_deref().map(Path::new);
138            let env_value: Option<OsString> = std::env::var_os("SQRY_NL_MODEL_DIR");
139            let env_ref = env_value.as_deref();
140            let exe = std::env::current_exe().ok();
141            let exe_ref = exe.as_deref();
142
143            let resolved =
144                resolve_model_dir(cli_override, legacy_path, env_ref, &RealDirs, exe_ref);
145
146            // NL03 gated downloader path. Only fires when:
147            //   1. The resolver missed every level (no on-disk model anywhere
148            //      sqry knows to look), AND
149            //   2. The operator opted in via `allow_model_download = true`.
150            // The download path lands the model under our managed cache
151            // directory — structurally equivalent to a Level 4 (XDG) hit,
152            // so the resulting trust mode is `Trusted`.
153            let resolved: Option<(PathBuf, ResolverLevel)> = match resolved {
154                Some(hit) => Some(hit),
155                None if config.allow_model_download => {
156                    let cache_dir: PathBuf = config
157                        .model_cache_dir
158                        .clone()
159                        .or_else(|| dirs::cache_dir().map(|p| p.join("sqry/models")))
160                        .ok_or_else(|| {
161                            crate::error::NlError::Config(
162                                "no platform cache_dir available for model download".to_string(),
163                            )
164                        })?;
165                    let dir = ensure_model_in_cache(&cache_dir, &BAKED_MANIFEST, true)?;
166                    Some((dir, ResolverLevel::XdgCache))
167                }
168                None => None,
169            };
170
171            match resolved {
172                Some((model_dir, level)) => {
173                    let trust_mode = TrustMode::from(level);
174                    // FR-14: Custom mode means integrity is rooted in
175                    // a user-supplied `manifest.json`. Emit a single
176                    // loud warn at Translator init time so the operator
177                    // is aware their model directory is the trust root.
178                    if matches!(trust_mode, TrustMode::Custom) {
179                        tracing::warn!(
180                            target: "sqry_nl::classifier",
181                            model_dir = %model_dir.display(),
182                            resolver_level = ?level,
183                            "Loading NL classifier under custom trust mode — \
184                             integrity rooted in user-supplied manifest.json. \
185                             For trusted defaults use the XDG cache or the \
186                             binary-adjacent install location."
187                        );
188                    }
189
190                    // NL07: build a pool of N independently-loaded
191                    // classifier sessions. The closure below is invoked
192                    // exactly N times by `ClassifierPool::new` — that
193                    // is the load-counter invariant the
194                    // `n_concurrent_translates_use_n_distinct_sessions`
195                    // integration test asserts.
196                    let pool_size = resolve_pool_size(config.classifier_pool_size);
197                    tracing::info!(
198                        target: "sqry_nl::classifier",
199                        model_dir = %model_dir.display(),
200                        pool_size,
201                        "Initialising NL classifier pool"
202                    );
203                    let model_dir_for_loader = model_dir.clone();
204                    let pool = ClassifierPool::new(pool_size, move || {
205                        crate::classifier::IntentClassifier::load(
206                            &model_dir_for_loader,
207                            config.allow_unverified_model,
208                            trust_mode,
209                        )
210                        .map_err(crate::error::NlError::from)
211                    })?;
212                    Some(pool)
213                }
214                None => None,
215            }
216        };
217
218        // Initialize cache if configured
219        let cache = config
220            .cache_config
221            .as_ref()
222            .map(|cfg| TranslationCache::with_config(cfg.clone()));
223
224        Ok(Self {
225            config,
226            translations: AtomicU64::new(0),
227            cache,
228            #[cfg(feature = "classifier")]
229            classifier_pool,
230        })
231    }
232
233    /// Create a Translator with default configuration.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if initialization fails.
238    pub fn load_default() -> NlResult<Self> {
239        Self::new(TranslatorConfig::default())
240    }
241
242    /// Translate a natural language query to a sqry command.
243    ///
244    /// # Arguments
245    ///
246    /// * `input` - The natural language query to translate
247    ///
248    /// # Returns
249    ///
250    /// A `TranslationResponse` indicating:
251    /// - `Execute`: High confidence, safe to run automatically
252    /// - `Confirm`: Medium confidence, ask user to confirm
253    /// - `Disambiguate`: Low confidence, need user clarification
254    /// - `Reject`: Cannot translate safely
255    ///
256    /// # Note
257    ///
258    /// NL07: this method now only needs `&self` because the classifier
259    /// is held inside a [`crate::classifier::ClassifierPool`] (each
260    /// pool slot wraps the underlying session in an
261    /// `Arc<parking_lot::Mutex<_>>`). The `&mut self` signature is
262    /// preserved for backwards compatibility with the pre-NL07 public
263    /// API. Concurrent callers can use [`Self::translate_shared`] to
264    /// drop the redundant `mut` requirement.
265    pub fn translate(&mut self, input: &str) -> TranslationResponse {
266        self.translate_shared(input)
267    }
268
269    /// Translate without requiring `&mut self`.
270    ///
271    /// Functionally identical to [`Self::translate`]; exposed because
272    /// the underlying pool architecture supports concurrent shared
273    /// access. Use this from multi-threaded callers (daemon, LSP) so
274    /// they don't need to hold a `&mut Translator` across calls.
275    pub fn translate_shared(&self, input: &str) -> TranslationResponse {
276        self.translations.fetch_add(1, Ordering::Relaxed);
277        self.translate_impl(input)
278    }
279
280    /// Internal translation implementation.
281    fn translate_impl(&self, input: &str) -> TranslationResponse {
282        let start_time = Instant::now();
283
284        // Create cache key from input and context
285        let cache_key = CacheKey::new(
286            input,
287            &self.config.languages,
288            self.config.working_directory.clone(),
289            self.config.default_limit,
290        );
291
292        // Check cache first
293        if let Some(cached_response) = self.cached_response(&cache_key, start_time) {
294            return cached_response;
295        }
296
297        // Step 1: Preprocess
298        let preprocessed = match preprocess::preprocess_input(input) {
299            Ok(p) => p,
300            Err(e) => {
301                return TranslationResponse::Reject {
302                    reason: format!("Preprocessing failed: {e}"),
303                    suggestions: vec!["Try simplifying your query".to_string()],
304                };
305            }
306        };
307
308        // Step 2: Extract entities
309        let entities = extractor::extract_entities(&preprocessed.text);
310
311        // Step 3: Classify intent
312        let (intent, confidence) = self.classify_intent(&preprocessed.text, &entities);
313
314        // Step 4: Assemble command
315        let command = match assembler::assemble_command(&intent, &entities) {
316            Ok(cmd) => cmd,
317            Err(e) => return Self::handle_assembly_error(e, &entities),
318        };
319
320        // Step 5: Validate command
321        self.handle_validation_result(
322            command, confidence, intent, &entities, cache_key, start_time,
323        )
324    }
325
326    fn cached_response(
327        &self,
328        cache_key: &CacheKey,
329        start_time: Instant,
330    ) -> Option<TranslationResponse> {
331        let cache = self.cache.as_ref()?;
332        let cached = cache.get(cache_key)?;
333        Some(TranslationResponse::Execute {
334            command: cached.command,
335            confidence: cached.confidence,
336            intent: cached.intent,
337            cached: true,
338            latency_ms: Self::elapsed_ms(start_time),
339        })
340    }
341
342    fn handle_validation_result(
343        &self,
344        command: String,
345        confidence: f32,
346        intent: Intent,
347        entities: &ExtractedEntities,
348        cache_key: CacheKey,
349        start_time: Instant,
350    ) -> TranslationResponse {
351        match validator::validate_command(&command) {
352            ValidationStatus::Valid => {
353                let latency_ms = Self::elapsed_ms(start_time);
354
355                if confidence >= self.config.execute_threshold
356                    && let Some(ref cache) = self.cache
357                {
358                    cache.put(
359                        cache_key,
360                        CachedResult {
361                            command: command.clone(),
362                            intent,
363                            confidence,
364                            created_at: Instant::now(),
365                        },
366                    );
367                }
368
369                self.create_response_with_latency(command, confidence, intent, entities, latency_ms)
370            }
371            ValidationStatus::RejectedMetachar => TranslationResponse::Reject {
372                reason: "Command contains disallowed shell characters".to_string(),
373                suggestions: vec![
374                    "Avoid special characters like ;, |, &, $".to_string(),
375                    "Use quoted strings for literal values".to_string(),
376                ],
377            },
378            ValidationStatus::RejectedEnvVar => TranslationResponse::Reject {
379                reason: "Command contains environment variable references".to_string(),
380                suggestions: vec![
381                    "Use literal paths instead of $HOME, ${VAR}".to_string(),
382                    "Specify the full path explicitly".to_string(),
383                ],
384            },
385            ValidationStatus::RejectedPathTraversal => TranslationResponse::Reject {
386                reason: "Command contains path traversal patterns".to_string(),
387                suggestions: vec![
388                    "Use relative paths within the project".to_string(),
389                    "Avoid .. in paths".to_string(),
390                ],
391            },
392            ValidationStatus::RejectedTooLong => TranslationResponse::Reject {
393                reason: "Generated command exceeds maximum length".to_string(),
394                suggestions: vec![
395                    "Try a simpler query".to_string(),
396                    "Reduce the number of filters".to_string(),
397                ],
398            },
399            ValidationStatus::RejectedWriteMode => TranslationResponse::Reject {
400                reason: "Command attempts write operation".to_string(),
401                suggestions: vec![
402                    "NL translation only supports read operations".to_string(),
403                    "Use CLI directly for write operations".to_string(),
404                ],
405            },
406            ValidationStatus::RejectedUnknown => {
407                let template_names = assembler::templates::TEMPLATES
408                    .iter()
409                    .map(|(name, _)| *name)
410                    .collect::<Vec<_>>()
411                    .join(", ");
412                let template_examples = ["query", "search", "trace-path"]
413                    .into_iter()
414                    .filter_map(assembler::templates::get_template)
415                    .map(str::to_string)
416                    .collect::<Vec<_>>()
417                    .join(" | ");
418
419                TranslationResponse::Reject {
420                    reason: "Command does not match any allowed template".to_string(),
421                    suggestions: vec![
422                        format!("Use supported command templates: {template_names}"),
423                        format!("Examples: {template_examples}"),
424                        "Try rephrasing your query".to_string(),
425                    ],
426                }
427            }
428        }
429    }
430
431    fn elapsed_ms(start_time: Instant) -> u64 {
432        u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX)
433    }
434
435    /// Classify the intent of the query.
436    ///
437    /// NL07: dispatches through the classifier pool. The pool guard
438    /// holds one of `N` loaded sessions for the duration of this call;
439    /// the slot is automatically returned on guard drop, including any
440    /// panic during `classify`.
441    #[allow(clippy::unused_self)] // Uses self when classifier feature is enabled.
442    fn classify_intent(&self, text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
443        #[cfg(feature = "classifier")]
444        if let Some(ref pool) = self.classifier_pool {
445            // Acquire blocks the current thread until a slot is free.
446            // Async callers MUST wrap this in `spawn_blocking` — the
447            // pool is sync by design (no tokio dependency in sqry-nl).
448            let guard = pool.acquire();
449            let mut classifier = guard.classifier().lock();
450            match classifier.classify(text) {
451                Ok(result) => return (result.intent, result.confidence),
452                Err(e) => {
453                    // Log and fall back to rules. The guard's `Drop`
454                    // returns the slot regardless of the classify
455                    // outcome, so this fallback is panic-safe.
456                    tracing::warn!(
457                        target: "sqry_nl::classifier",
458                        error = %e,
459                        "Classifier failed, using rule-based fallback"
460                    );
461                }
462            }
463            // Explicit drop so the pool slot is back in the channel
464            // before we run the rule-based fallback below.
465            drop(classifier);
466            drop(guard);
467        }
468
469        // Fallback: rule-based classification
470        Self::classify_intent_rules(text, entities)
471    }
472
473    /// Rule-based intent classification fallback.
474    fn classify_intent_rules(text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
475        let text_lower = text.to_lowercase();
476
477        if let Some(intent) = Self::classify_graph_intent(&text_lower) {
478            return intent;
479        }
480
481        if let Some(intent) = Self::classify_index_intent(&text_lower) {
482            return intent;
483        }
484
485        if let Some(intent) = Self::classify_text_search_intent(&text_lower, text) {
486            return intent;
487        }
488
489        if let Some(intent) = Self::classify_symbol_query_intent(&text_lower, entities) {
490            return intent;
491        }
492
493        if Self::is_ambiguous(&text_lower) {
494            return (Intent::Ambiguous, 0.3);
495        }
496
497        (Intent::SymbolQuery, 0.5)
498    }
499
500    fn classify_graph_intent(text_lower: &str) -> Option<(Intent, f32)> {
501        if Self::matches_callers(text_lower) {
502            return Some((Intent::FindCallers, 0.85));
503        }
504
505        if Self::matches_callees(text_lower) {
506            return Some((Intent::FindCallees, 0.85));
507        }
508
509        if Self::matches_trace_path(text_lower) {
510            return Some((Intent::TracePath, 0.8));
511        }
512
513        if Self::matches_visualize(text_lower) {
514            return Some((Intent::Visualize, 0.8));
515        }
516
517        None
518    }
519
520    fn matches_callers(text_lower: &str) -> bool {
521        text_lower.contains("callers")
522            || text_lower.contains("who calls")
523            || text_lower.contains("what calls")
524            || text_lower.contains("who uses")
525            || text_lower.contains("who depends")
526            || text_lower.contains("find usages")
527            || text_lower.contains("find all references")
528            || text_lower.contains("where is") && text_lower.contains("used")
529    }
530
531    fn matches_callees(text_lower: &str) -> bool {
532        text_lower.contains("callees")
533            || text_lower.contains("what does") && text_lower.contains("call")
534            || text_lower.contains("functions called by")
535            || text_lower.contains("methods called by")
536            || text_lower.contains("dependencies of")
537            || text_lower.contains("outgoing calls")
538            || text_lower.contains("what functions does")
539            || text_lower.contains("what methods does")
540    }
541
542    fn matches_trace_path(text_lower: &str) -> bool {
543        text_lower.contains("trace")
544            || text_lower.contains("path from")
545            || text_lower.contains("path to")
546            || text_lower.contains("path between")
547            || text_lower.contains("call chain")
548            || text_lower.contains("call sequence")
549            || (text_lower.contains("how does") && text_lower.contains("reach"))
550            || (text_lower.contains("how does") && text_lower.contains("flow"))
551    }
552
553    fn matches_visualize(text_lower: &str) -> bool {
554        text_lower.contains("visualize")
555            || text_lower.contains("diagram")
556            || text_lower.contains("draw")
557            || text_lower.contains("mermaid")
558            || text_lower.contains("dot graph")
559            || (text_lower.contains("generate") && text_lower.contains("graph"))
560            || (text_lower.contains("show") && text_lower.contains("visual"))
561    }
562
563    fn classify_index_intent(text_lower: &str) -> Option<(Intent, f32)> {
564        if (text_lower.contains("index") && text_lower.contains("status"))
565            || text_lower.starts_with("index status")
566            || text_lower.contains("is index")
567            || text_lower.contains("check index")
568            || text_lower.contains("index info")
569            || text_lower.contains("index stat")
570            || text_lower.contains("indexed")
571            || text_lower.contains("what files are indexed")
572            || text_lower.contains("how many symbols")
573            || text_lower.contains("when was index")
574        {
575            return Some((Intent::IndexStatus, 0.85));
576        }
577
578        None
579    }
580
581    fn classify_text_search_intent(text_lower: &str, text: &str) -> Option<(Intent, f32)> {
582        let is_predicate_query = Self::is_predicate_query(text_lower);
583
584        if text_lower.starts_with("grep")
585            || text_lower.starts_with("search for")
586            || text_lower.contains("grep for")
587            || text_lower.contains("grep ")
588            || text_lower.contains("look for")
589            || (text_lower.contains("search") && !text_lower.contains("code search"))
590            || text_lower.contains("todo")
591            || text_lower.contains("fixme")
592            || text_lower.contains("deprecated")
593            || text_lower.contains("copyright")
594            || text_lower.contains("hardcoded")
595            || text.contains('!')
596            || (!is_predicate_query && text_lower.contains("unsafe"))
597            || text_lower.contains(" pub ")
598            || text_lower.contains(" mut ")
599            || (!is_predicate_query && text_lower.contains("async"))
600            || text_lower.contains("unsafe blocks")
601            || text_lower.contains("impl blocks")
602            || text_lower.contains("import")
603            || text_lower.contains("use statement")
604            || text_lower.contains("require")
605        {
606            return Some((Intent::TextSearch, 0.8));
607        }
608
609        None
610    }
611
612    fn classify_symbol_query_intent(
613        text_lower: &str,
614        entities: &ExtractedEntities,
615    ) -> Option<(Intent, f32)> {
616        if text_lower.starts_with("find")
617            || text_lower.starts_with("show")
618            || text_lower.starts_with("list")
619            || text_lower.starts_with("where is")
620            || text_lower.starts_with("where are")
621            || text_lower.contains("function")
622            || text_lower.contains("method")
623            || text_lower.contains("class")
624            || text_lower.contains("struct")
625            || text_lower.contains("enum")
626            || text_lower.contains("trait")
627            || text_lower.contains("interface")
628            || text_lower.contains("module")
629            || text_lower.contains("constant")
630            || text_lower.contains("variable")
631            || text_lower.contains("public")
632            || text_lower.contains("private")
633            || text_lower.contains("defined")
634        {
635            return Some((Intent::SymbolQuery, 0.8));
636        }
637
638        if entities.kind.is_some() {
639            return Some((Intent::SymbolQuery, 0.85));
640        }
641
642        if !entities.symbols.is_empty() {
643            return Some((Intent::SymbolQuery, 0.7));
644        }
645
646        if !entities.languages.is_empty() {
647            return Some((Intent::SymbolQuery, 0.65));
648        }
649
650        None
651    }
652
653    fn is_predicate_query(text_lower: &str) -> bool {
654        text_lower.contains("functions")
655            || text_lower.contains("methods")
656            || text_lower.contains("function")
657            || text_lower.contains("method")
658    }
659
660    fn is_ambiguous(text_lower: &str) -> bool {
661        text_lower.split_whitespace().count() <= 2
662    }
663
664    /// Create the appropriate response based on confidence level (with latency tracking).
665    fn create_response_with_latency(
666        &self,
667        command: String,
668        confidence: f32,
669        intent: Intent,
670        entities: &ExtractedEntities,
671        latency_ms: u64,
672    ) -> TranslationResponse {
673        if confidence >= self.config.execute_threshold {
674            TranslationResponse::Execute {
675                command,
676                confidence,
677                intent,
678                cached: false,
679                latency_ms,
680            }
681        } else if confidence >= self.config.confirm_threshold {
682            let prompt = format!(
683                "I'll run: {}\nConfidence: {:.0}%. Proceed? [y/N]",
684                command,
685                confidence * 100.0
686            );
687            TranslationResponse::Confirm {
688                command,
689                confidence,
690                prompt,
691            }
692        } else {
693            // Disambiguate - present options to user
694            let options = Self::generate_disambiguation_options(entities);
695            TranslationResponse::Disambiguate {
696                options,
697                prompt: "I'm not sure what you mean. Did you want to:".to_string(),
698            }
699        }
700    }
701
702    /// Create the appropriate response based on confidence level.
703    #[allow(dead_code)]
704    fn create_response(
705        &self,
706        command: String,
707        confidence: f32,
708        intent: Intent,
709        entities: &ExtractedEntities,
710    ) -> TranslationResponse {
711        self.create_response_with_latency(command, confidence, intent, entities, 0)
712    }
713
714    /// Generate disambiguation options when confidence is low.
715    fn generate_disambiguation_options(entities: &ExtractedEntities) -> Vec<DisambiguationOption> {
716        let mut options = Vec::new();
717
718        if let Some(symbol) = entities.primary_symbol() {
719            options.push(DisambiguationOption {
720                command: format!("sqry query \"{symbol}\""),
721                intent: Intent::SymbolQuery,
722                description: format!("Search for symbol \"{symbol}\""),
723                confidence: 0.5,
724            });
725            options.push(DisambiguationOption {
726                command: format!("sqry graph direct-callers \"{symbol}\""),
727                intent: Intent::FindCallers,
728                description: format!("Find callers of \"{symbol}\""),
729                confidence: 0.4,
730            });
731        } else {
732            options.push(DisambiguationOption {
733                command: "sqry query \"<symbol>\"".to_string(),
734                intent: Intent::SymbolQuery,
735                description: "Search for a specific symbol".to_string(),
736                confidence: 0.3,
737            });
738        }
739
740        options
741    }
742
743    /// Handle assembly errors with helpful suggestions.
744    fn handle_assembly_error(
745        error: AssemblerError,
746        entities: &ExtractedEntities,
747    ) -> TranslationResponse {
748        match error {
749            AssemblerError::MissingSymbol => {
750                let suggestions = if entities.languages.is_empty() {
751                    vec![
752                        "Specify what symbol or pattern you're looking for".to_string(),
753                        "Example: find \"authenticate\" in rust".to_string(),
754                    ]
755                } else {
756                    vec![
757                        format!(
758                            "Try: find <symbol name> in {}",
759                            entities.languages.join(", ")
760                        ),
761                        "Specify what you're looking for in quotes".to_string(),
762                    ]
763                };
764                TranslationResponse::Reject {
765                    reason: "Could not determine what to search for".to_string(),
766                    suggestions,
767                }
768            }
769            AssemblerError::AmbiguousIntent => TranslationResponse::Disambiguate {
770                options: vec![
771                    DisambiguationOption {
772                        command: "sqry query \"<symbol>\"".to_string(),
773                        intent: Intent::SymbolQuery,
774                        description: "Search for symbols matching a pattern".to_string(),
775                        confidence: 0.3,
776                    },
777                    DisambiguationOption {
778                        command: "sqry graph direct-callers \"<symbol>\"".to_string(),
779                        intent: Intent::FindCallers,
780                        description: "Find callers of a function".to_string(),
781                        confidence: 0.3,
782                    },
783                ],
784                prompt: "Please clarify what you'd like to do:".to_string(),
785            },
786            AssemblerError::MissingTracePath => TranslationResponse::Reject {
787                reason: "Trace path requires both source and target symbols".to_string(),
788                suggestions: vec![
789                    "Specify two symbols: trace path from X to Y".to_string(),
790                    "Example: trace path from login to database".to_string(),
791                ],
792            },
793            AssemblerError::CommandTooLong { .. } => TranslationResponse::Reject {
794                reason: "Generated command is too long".to_string(),
795                suggestions: vec![
796                    "Try a simpler query".to_string(),
797                    "Reduce the number of filters".to_string(),
798                ],
799            },
800            AssemblerError::NoTemplate(intent_name) => TranslationResponse::Reject {
801                reason: format!("No template available for intent: {intent_name}"),
802                suggestions: vec![
803                    "Try a different query type".to_string(),
804                    "Supported queries: symbol search, callers, callees, trace path".to_string(),
805                ],
806            },
807        }
808    }
809
810    /// Get translation count.
811    #[must_use]
812    pub fn translation_count(&self) -> u64 {
813        self.translations.load(Ordering::Relaxed)
814    }
815
816    /// Get cache statistics.
817    ///
818    /// Returns `None` if caching is disabled.
819    #[must_use]
820    pub fn cache_stats(&self) -> Option<crate::cache::CacheStats> {
821        self.cache
822            .as_ref()
823            .map(super::cache::TranslationCache::stats)
824    }
825
826    /// Get cache hit rate (0.0-1.0).
827    ///
828    /// Returns `None` if caching is disabled.
829    #[must_use]
830    pub fn cache_hit_rate(&self) -> Option<f64> {
831        self.cache
832            .as_ref()
833            .map(super::cache::TranslationCache::hit_rate)
834    }
835
836    /// Clear the translation cache.
837    ///
838    /// Does nothing if caching is disabled.
839    pub fn clear_cache(&self) {
840        if let Some(ref cache) = self.cache {
841            cache.clear();
842        }
843    }
844}
845
846impl std::fmt::Debug for Translator {
847    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
848        let mut debug = f.debug_struct("Translator");
849        debug
850            .field("translations", &self.translations.load(Ordering::Relaxed))
851            .field("cache_enabled", &self.cache.is_some());
852        #[cfg(feature = "classifier")]
853        debug.field("classifier_pool", &self.classifier_pool);
854        debug.finish()
855    }
856}
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861
862    #[test]
863    fn test_translator_creation() {
864        let translator = Translator::load_default().unwrap();
865        assert_eq!(translator.translation_count(), 0);
866    }
867
868    #[test]
869    fn test_translate_simple_query() {
870        let mut translator = Translator::load_default().unwrap();
871        let response = translator.translate("find authentication functions");
872
873        // Should not be a Reject response for missing symbol
874        if let TranslationResponse::Reject { reason, .. } = &response {
875            // May reject due to validation, but not missing symbol
876            assert!(!reason.contains("Could not determine"));
877        }
878        assert_eq!(translator.translation_count(), 1);
879    }
880
881    #[test]
882    fn test_translate_with_language() {
883        let mut translator = Translator::load_default().unwrap();
884        let response = translator.translate("find authentication in rust");
885
886        match response {
887            TranslationResponse::Execute { command, .. }
888            | TranslationResponse::Confirm { command, .. } => {
889                assert!(command.contains("--language rust"));
890            }
891            _ => {} // Disambiguate or Reject is ok for rule-based fallback
892        }
893    }
894
895    #[test]
896    fn test_translate_callers() {
897        let mut translator = Translator::load_default().unwrap();
898        let response = translator.translate("who calls authenticate");
899
900        match response {
901            TranslationResponse::Execute { intent, .. } => {
902                assert_eq!(intent, Intent::FindCallers);
903            }
904            TranslationResponse::Confirm { command, .. } => {
905                // Confirm doesn't carry intent, but should have graph direct-callers command
906                assert!(
907                    command.contains("graph direct-callers") || command.contains("authenticate")
908                );
909            }
910            _ => {}
911        }
912    }
913
914    #[test]
915    fn test_custom_thresholds() {
916        let config = TranslatorConfig {
917            execute_threshold: 0.99,
918            confirm_threshold: 0.90,
919            ..Default::default()
920        };
921        let mut translator = Translator::new(config).unwrap();
922
923        // With high thresholds, most queries should need confirmation or disambiguation
924        let response = translator.translate("find foo");
925        assert!(!matches!(response, TranslationResponse::Execute { .. }));
926    }
927
928    #[test]
929    fn test_kind_only_query() {
930        let mut translator = Translator::load_default().unwrap();
931
932        // Kind-only queries should work (e.g., "list all traits")
933        let response = translator.translate("list all traits");
934        match response {
935            TranslationResponse::Execute { command, .. }
936            | TranslationResponse::Confirm { command, .. } => {
937                // kind is now part of the query expression as a predicate
938                assert!(command.contains("kind:trait"));
939            }
940            _ => panic!("Expected Execute or Confirm response"),
941        }
942    }
943
944    #[test]
945    fn test_snake_case_symbol() {
946        let mut translator = Translator::load_default().unwrap();
947
948        // Snake_case symbols should be extracted correctly
949        let response = translator.translate("find user_id variable");
950        match response {
951            TranslationResponse::Execute { command, .. }
952            | TranslationResponse::Confirm { command, .. } => {
953                assert!(command.contains("user_id"));
954            }
955            _ => panic!("Expected Execute or Confirm response"),
956        }
957    }
958}
959
960// Predicate translation regression tests
961#[cfg(test)]
962mod predicate_translation_tests {
963    use super::*;
964
965    #[test]
966    fn test_async_functions_translation() {
967        let config = TranslatorConfig::default();
968        let mut translator = Translator::new(config).expect("Translator init failed");
969
970        let response = translator.translate("find async functions");
971        match response {
972            TranslationResponse::Execute { command, .. }
973            | TranslationResponse::Confirm { command, .. } => {
974                assert!(command.contains("async:true"));
975            }
976            _ => panic!("should execute or confirm"),
977        }
978    }
979
980    #[test]
981    fn test_unsafe_functions_translation() {
982        let config = TranslatorConfig::default();
983        let mut translator = Translator::new(config).expect("Translator init failed");
984
985        let response = translator.translate("find unsafe functions");
986        match response {
987            TranslationResponse::Execute { command, .. }
988            | TranslationResponse::Confirm { command, .. } => {
989                assert!(command.contains("unsafe:true"));
990            }
991            _ => panic!("should execute or confirm"),
992        }
993    }
994
995    #[test]
996    fn test_public_async_functions_translation() {
997        let config = TranslatorConfig::default();
998        let mut translator = Translator::new(config).expect("Translator init failed");
999
1000        let response = translator.translate("find public async functions");
1001        match response {
1002            TranslationResponse::Execute { command, .. }
1003            | TranslationResponse::Confirm { command, .. } => {
1004                assert!(command.contains("visibility:public"));
1005                assert!(command.contains("async:true"));
1006            }
1007            _ => panic!("should execute or confirm"),
1008        }
1009    }
1010}