1use crate::assembler;
7use crate::cache::{CacheConfig, CacheKey, CachedResult, TranslationCache};
8use crate::error::{AssemblerError, NlResult};
9use crate::extractor;
10use crate::preprocess;
11use crate::types::{
12 DisambiguationOption, ExtractedEntities, Intent, TranslationResponse, ValidationStatus,
13};
14use crate::validator;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::time::Instant;
17
18const EXECUTE_THRESHOLD: f32 = 0.85;
20const CONFIRM_THRESHOLD: f32 = 0.65;
21
22const DEFAULT_CACHE_CAPACITY: usize = 128;
24
25const DEFAULT_RESULT_LIMIT: u32 = 100;
27
28#[derive(Debug, Clone)]
30pub struct TranslatorConfig {
31 pub model_dir: Option<String>,
35 pub working_directory: Option<String>,
37 pub execute_threshold: f32,
39 pub confirm_threshold: f32,
40 pub cache_config: Option<CacheConfig>,
42 pub default_limit: u32,
44 pub languages: Vec<String>,
46}
47
48impl Default for TranslatorConfig {
49 fn default() -> Self {
50 Self {
51 model_dir: None,
52 working_directory: None,
53 execute_threshold: EXECUTE_THRESHOLD,
54 confirm_threshold: CONFIRM_THRESHOLD,
55 cache_config: Some(CacheConfig {
56 capacity: DEFAULT_CACHE_CAPACITY,
57 ..Default::default()
58 }),
59 default_limit: DEFAULT_RESULT_LIMIT,
60 languages: Vec::new(),
61 }
62 }
63}
64
65pub struct Translator {
67 config: TranslatorConfig,
68 translations: AtomicU64,
70 cache: Option<TranslationCache>,
72 #[cfg(feature = "classifier")]
73 classifier: Option<crate::classifier::IntentClassifier>,
74}
75
76impl Translator {
77 pub fn new(config: TranslatorConfig) -> NlResult<Self> {
83 #[cfg(feature = "classifier")]
84 let classifier = if let Some(model_dir) = &config.model_dir {
85 use std::path::Path;
86 Some(crate::classifier::IntentClassifier::load(Path::new(
87 model_dir,
88 ))?)
89 } else {
90 None
91 };
92
93 let cache = config
95 .cache_config
96 .as_ref()
97 .map(|cfg| TranslationCache::with_config(cfg.clone()));
98
99 Ok(Self {
100 config,
101 translations: AtomicU64::new(0),
102 cache,
103 #[cfg(feature = "classifier")]
104 classifier,
105 })
106 }
107
108 pub fn load_default() -> NlResult<Self> {
114 Self::new(TranslatorConfig::default())
115 }
116
117 pub fn translate(&mut self, input: &str) -> TranslationResponse {
135 self.translations.fetch_add(1, Ordering::Relaxed);
136 self.translate_impl(input)
137 }
138
139 fn translate_impl(&mut self, input: &str) -> TranslationResponse {
141 let start_time = Instant::now();
142
143 let cache_key = CacheKey::new(
145 input,
146 &self.config.languages,
147 self.config.working_directory.clone(),
148 self.config.default_limit,
149 );
150
151 if let Some(cached_response) = self.cached_response(&cache_key, start_time) {
153 return cached_response;
154 }
155
156 let preprocessed = match preprocess::preprocess_input(input) {
158 Ok(p) => p,
159 Err(e) => {
160 return TranslationResponse::Reject {
161 reason: format!("Preprocessing failed: {e}"),
162 suggestions: vec!["Try simplifying your query".to_string()],
163 };
164 }
165 };
166
167 let entities = extractor::extract_entities(&preprocessed.text);
169
170 let (intent, confidence) = self.classify_intent(&preprocessed.text, &entities);
172
173 let command = match assembler::assemble_command(&intent, &entities) {
175 Ok(cmd) => cmd,
176 Err(e) => return Self::handle_assembly_error(e, &entities),
177 };
178
179 self.handle_validation_result(
181 command, confidence, intent, &entities, cache_key, start_time,
182 )
183 }
184
185 fn cached_response(
186 &self,
187 cache_key: &CacheKey,
188 start_time: Instant,
189 ) -> Option<TranslationResponse> {
190 let cache = self.cache.as_ref()?;
191 let cached = cache.get(cache_key)?;
192 Some(TranslationResponse::Execute {
193 command: cached.command,
194 confidence: cached.confidence,
195 intent: cached.intent,
196 cached: true,
197 latency_ms: Self::elapsed_ms(start_time),
198 })
199 }
200
201 fn handle_validation_result(
202 &self,
203 command: String,
204 confidence: f32,
205 intent: Intent,
206 entities: &ExtractedEntities,
207 cache_key: CacheKey,
208 start_time: Instant,
209 ) -> TranslationResponse {
210 match validator::validate_command(&command) {
211 ValidationStatus::Valid => {
212 let latency_ms = Self::elapsed_ms(start_time);
213
214 if confidence >= self.config.execute_threshold
215 && let Some(ref cache) = self.cache
216 {
217 cache.put(
218 cache_key,
219 CachedResult {
220 command: command.clone(),
221 intent,
222 confidence,
223 created_at: Instant::now(),
224 },
225 );
226 }
227
228 self.create_response_with_latency(command, confidence, intent, entities, latency_ms)
229 }
230 ValidationStatus::RejectedMetachar => TranslationResponse::Reject {
231 reason: "Command contains disallowed shell characters".to_string(),
232 suggestions: vec![
233 "Avoid special characters like ;, |, &, $".to_string(),
234 "Use quoted strings for literal values".to_string(),
235 ],
236 },
237 ValidationStatus::RejectedEnvVar => TranslationResponse::Reject {
238 reason: "Command contains environment variable references".to_string(),
239 suggestions: vec![
240 "Use literal paths instead of $HOME, ${VAR}".to_string(),
241 "Specify the full path explicitly".to_string(),
242 ],
243 },
244 ValidationStatus::RejectedPathTraversal => TranslationResponse::Reject {
245 reason: "Command contains path traversal patterns".to_string(),
246 suggestions: vec![
247 "Use relative paths within the project".to_string(),
248 "Avoid .. in paths".to_string(),
249 ],
250 },
251 ValidationStatus::RejectedTooLong => TranslationResponse::Reject {
252 reason: "Generated command exceeds maximum length".to_string(),
253 suggestions: vec![
254 "Try a simpler query".to_string(),
255 "Reduce the number of filters".to_string(),
256 ],
257 },
258 ValidationStatus::RejectedWriteMode => TranslationResponse::Reject {
259 reason: "Command attempts write operation".to_string(),
260 suggestions: vec![
261 "NL translation only supports read operations".to_string(),
262 "Use CLI directly for write operations".to_string(),
263 ],
264 },
265 ValidationStatus::RejectedUnknown => {
266 let template_names = assembler::templates::TEMPLATES
267 .iter()
268 .map(|(name, _)| *name)
269 .collect::<Vec<_>>()
270 .join(", ");
271 let template_examples = ["query", "search", "trace-path"]
272 .into_iter()
273 .filter_map(assembler::templates::get_template)
274 .map(str::to_string)
275 .collect::<Vec<_>>()
276 .join(" | ");
277
278 TranslationResponse::Reject {
279 reason: "Command does not match any allowed template".to_string(),
280 suggestions: vec![
281 format!("Use supported command templates: {template_names}"),
282 format!("Examples: {template_examples}"),
283 "Try rephrasing your query".to_string(),
284 ],
285 }
286 }
287 }
288 }
289
290 fn elapsed_ms(start_time: Instant) -> u64 {
291 u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX)
292 }
293
294 #[allow(clippy::unused_self)] fn classify_intent(&mut self, text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
297 #[cfg(feature = "classifier")]
298 if let Some(ref mut classifier) = self.classifier {
299 match classifier.classify(text) {
300 Ok(result) => return (result.intent, result.confidence),
301 Err(e) => {
302 eprintln!("Classifier failed, using fallback: {e}");
304 }
305 }
306 }
307
308 Self::classify_intent_rules(text, entities)
310 }
311
312 fn classify_intent_rules(text: &str, entities: &ExtractedEntities) -> (Intent, f32) {
314 let text_lower = text.to_lowercase();
315
316 if let Some(intent) = Self::classify_graph_intent(&text_lower) {
317 return intent;
318 }
319
320 if let Some(intent) = Self::classify_index_intent(&text_lower) {
321 return intent;
322 }
323
324 if let Some(intent) = Self::classify_text_search_intent(&text_lower, text) {
325 return intent;
326 }
327
328 if let Some(intent) = Self::classify_symbol_query_intent(&text_lower, entities) {
329 return intent;
330 }
331
332 if Self::is_ambiguous(&text_lower) {
333 return (Intent::Ambiguous, 0.3);
334 }
335
336 (Intent::SymbolQuery, 0.5)
337 }
338
339 fn classify_graph_intent(text_lower: &str) -> Option<(Intent, f32)> {
340 if Self::matches_callers(text_lower) {
341 return Some((Intent::FindCallers, 0.85));
342 }
343
344 if Self::matches_callees(text_lower) {
345 return Some((Intent::FindCallees, 0.85));
346 }
347
348 if Self::matches_trace_path(text_lower) {
349 return Some((Intent::TracePath, 0.8));
350 }
351
352 if Self::matches_visualize(text_lower) {
353 return Some((Intent::Visualize, 0.8));
354 }
355
356 None
357 }
358
359 fn matches_callers(text_lower: &str) -> bool {
360 text_lower.contains("callers")
361 || text_lower.contains("who calls")
362 || text_lower.contains("what calls")
363 || text_lower.contains("who uses")
364 || text_lower.contains("who depends")
365 || text_lower.contains("find usages")
366 || text_lower.contains("find all references")
367 || text_lower.contains("where is") && text_lower.contains("used")
368 }
369
370 fn matches_callees(text_lower: &str) -> bool {
371 text_lower.contains("callees")
372 || text_lower.contains("what does") && text_lower.contains("call")
373 || text_lower.contains("functions called by")
374 || text_lower.contains("methods called by")
375 || text_lower.contains("dependencies of")
376 || text_lower.contains("outgoing calls")
377 || text_lower.contains("what functions does")
378 || text_lower.contains("what methods does")
379 }
380
381 fn matches_trace_path(text_lower: &str) -> bool {
382 text_lower.contains("trace")
383 || text_lower.contains("path from")
384 || text_lower.contains("path to")
385 || text_lower.contains("path between")
386 || text_lower.contains("call chain")
387 || text_lower.contains("call sequence")
388 || (text_lower.contains("how does") && text_lower.contains("reach"))
389 || (text_lower.contains("how does") && text_lower.contains("flow"))
390 }
391
392 fn matches_visualize(text_lower: &str) -> bool {
393 text_lower.contains("visualize")
394 || text_lower.contains("diagram")
395 || text_lower.contains("draw")
396 || text_lower.contains("mermaid")
397 || text_lower.contains("dot graph")
398 || (text_lower.contains("generate") && text_lower.contains("graph"))
399 || (text_lower.contains("show") && text_lower.contains("visual"))
400 }
401
402 fn classify_index_intent(text_lower: &str) -> Option<(Intent, f32)> {
403 if (text_lower.contains("index") && text_lower.contains("status"))
404 || text_lower.starts_with("index status")
405 || text_lower.contains("is index")
406 || text_lower.contains("check index")
407 || text_lower.contains("index info")
408 || text_lower.contains("index stat")
409 || text_lower.contains("indexed")
410 || text_lower.contains("what files are indexed")
411 || text_lower.contains("how many symbols")
412 || text_lower.contains("when was index")
413 {
414 return Some((Intent::IndexStatus, 0.85));
415 }
416
417 None
418 }
419
420 fn classify_text_search_intent(text_lower: &str, text: &str) -> Option<(Intent, f32)> {
421 let is_predicate_query = Self::is_predicate_query(text_lower);
422
423 if text_lower.starts_with("grep")
424 || text_lower.starts_with("search for")
425 || text_lower.contains("grep for")
426 || text_lower.contains("grep ")
427 || text_lower.contains("look for")
428 || (text_lower.contains("search") && !text_lower.contains("code search"))
429 || text_lower.contains("todo")
430 || text_lower.contains("fixme")
431 || text_lower.contains("deprecated")
432 || text_lower.contains("copyright")
433 || text_lower.contains("hardcoded")
434 || text.contains('!')
435 || (!is_predicate_query && text_lower.contains("unsafe"))
436 || text_lower.contains(" pub ")
437 || text_lower.contains(" mut ")
438 || (!is_predicate_query && text_lower.contains("async"))
439 || text_lower.contains("unsafe blocks")
440 || text_lower.contains("impl blocks")
441 || text_lower.contains("import")
442 || text_lower.contains("use statement")
443 || text_lower.contains("require")
444 {
445 return Some((Intent::TextSearch, 0.8));
446 }
447
448 None
449 }
450
451 fn classify_symbol_query_intent(
452 text_lower: &str,
453 entities: &ExtractedEntities,
454 ) -> Option<(Intent, f32)> {
455 if text_lower.starts_with("find")
456 || text_lower.starts_with("show")
457 || text_lower.starts_with("list")
458 || text_lower.starts_with("where is")
459 || text_lower.starts_with("where are")
460 || text_lower.contains("function")
461 || text_lower.contains("method")
462 || text_lower.contains("class")
463 || text_lower.contains("struct")
464 || text_lower.contains("enum")
465 || text_lower.contains("trait")
466 || text_lower.contains("interface")
467 || text_lower.contains("module")
468 || text_lower.contains("constant")
469 || text_lower.contains("variable")
470 || text_lower.contains("public")
471 || text_lower.contains("private")
472 || text_lower.contains("defined")
473 {
474 return Some((Intent::SymbolQuery, 0.8));
475 }
476
477 if entities.kind.is_some() {
478 return Some((Intent::SymbolQuery, 0.85));
479 }
480
481 if !entities.symbols.is_empty() {
482 return Some((Intent::SymbolQuery, 0.7));
483 }
484
485 if !entities.languages.is_empty() {
486 return Some((Intent::SymbolQuery, 0.65));
487 }
488
489 None
490 }
491
492 fn is_predicate_query(text_lower: &str) -> bool {
493 text_lower.contains("functions")
494 || text_lower.contains("methods")
495 || text_lower.contains("function")
496 || text_lower.contains("method")
497 }
498
499 fn is_ambiguous(text_lower: &str) -> bool {
500 text_lower.split_whitespace().count() <= 2
501 }
502
503 fn create_response_with_latency(
505 &self,
506 command: String,
507 confidence: f32,
508 intent: Intent,
509 entities: &ExtractedEntities,
510 latency_ms: u64,
511 ) -> TranslationResponse {
512 if confidence >= self.config.execute_threshold {
513 TranslationResponse::Execute {
514 command,
515 confidence,
516 intent,
517 cached: false,
518 latency_ms,
519 }
520 } else if confidence >= self.config.confirm_threshold {
521 let prompt = format!(
522 "I'll run: {}\nConfidence: {:.0}%. Proceed? [y/N]",
523 command,
524 confidence * 100.0
525 );
526 TranslationResponse::Confirm {
527 command,
528 confidence,
529 prompt,
530 }
531 } else {
532 let options = Self::generate_disambiguation_options(entities);
534 TranslationResponse::Disambiguate {
535 options,
536 prompt: "I'm not sure what you mean. Did you want to:".to_string(),
537 }
538 }
539 }
540
541 #[allow(dead_code)]
543 fn create_response(
544 &self,
545 command: String,
546 confidence: f32,
547 intent: Intent,
548 entities: &ExtractedEntities,
549 ) -> TranslationResponse {
550 self.create_response_with_latency(command, confidence, intent, entities, 0)
551 }
552
553 fn generate_disambiguation_options(entities: &ExtractedEntities) -> Vec<DisambiguationOption> {
555 let mut options = Vec::new();
556
557 if let Some(symbol) = entities.primary_symbol() {
558 options.push(DisambiguationOption {
559 command: format!("sqry query \"{symbol}\""),
560 intent: Intent::SymbolQuery,
561 description: format!("Search for symbol \"{symbol}\""),
562 confidence: 0.5,
563 });
564 options.push(DisambiguationOption {
565 command: format!("sqry graph direct-callers \"{symbol}\""),
566 intent: Intent::FindCallers,
567 description: format!("Find callers of \"{symbol}\""),
568 confidence: 0.4,
569 });
570 } else {
571 options.push(DisambiguationOption {
572 command: "sqry query \"<symbol>\"".to_string(),
573 intent: Intent::SymbolQuery,
574 description: "Search for a specific symbol".to_string(),
575 confidence: 0.3,
576 });
577 }
578
579 options
580 }
581
582 fn handle_assembly_error(
584 error: AssemblerError,
585 entities: &ExtractedEntities,
586 ) -> TranslationResponse {
587 match error {
588 AssemblerError::MissingSymbol => {
589 let suggestions = if entities.languages.is_empty() {
590 vec![
591 "Specify what symbol or pattern you're looking for".to_string(),
592 "Example: find \"authenticate\" in rust".to_string(),
593 ]
594 } else {
595 vec![
596 format!(
597 "Try: find <symbol name> in {}",
598 entities.languages.join(", ")
599 ),
600 "Specify what you're looking for in quotes".to_string(),
601 ]
602 };
603 TranslationResponse::Reject {
604 reason: "Could not determine what to search for".to_string(),
605 suggestions,
606 }
607 }
608 AssemblerError::AmbiguousIntent => TranslationResponse::Disambiguate {
609 options: vec![
610 DisambiguationOption {
611 command: "sqry query \"<symbol>\"".to_string(),
612 intent: Intent::SymbolQuery,
613 description: "Search for symbols matching a pattern".to_string(),
614 confidence: 0.3,
615 },
616 DisambiguationOption {
617 command: "sqry graph direct-callers \"<symbol>\"".to_string(),
618 intent: Intent::FindCallers,
619 description: "Find callers of a function".to_string(),
620 confidence: 0.3,
621 },
622 ],
623 prompt: "Please clarify what you'd like to do:".to_string(),
624 },
625 AssemblerError::MissingTracePath => TranslationResponse::Reject {
626 reason: "Trace path requires both source and target symbols".to_string(),
627 suggestions: vec![
628 "Specify two symbols: trace path from X to Y".to_string(),
629 "Example: trace path from login to database".to_string(),
630 ],
631 },
632 AssemblerError::CommandTooLong { .. } => TranslationResponse::Reject {
633 reason: "Generated command is too long".to_string(),
634 suggestions: vec![
635 "Try a simpler query".to_string(),
636 "Reduce the number of filters".to_string(),
637 ],
638 },
639 AssemblerError::NoTemplate(intent_name) => TranslationResponse::Reject {
640 reason: format!("No template available for intent: {intent_name}"),
641 suggestions: vec![
642 "Try a different query type".to_string(),
643 "Supported queries: symbol search, callers, callees, trace path".to_string(),
644 ],
645 },
646 }
647 }
648
649 #[must_use]
651 pub fn translation_count(&self) -> u64 {
652 self.translations.load(Ordering::Relaxed)
653 }
654
655 #[must_use]
659 pub fn cache_stats(&self) -> Option<crate::cache::CacheStats> {
660 self.cache
661 .as_ref()
662 .map(super::cache::TranslationCache::stats)
663 }
664
665 #[must_use]
669 pub fn cache_hit_rate(&self) -> Option<f64> {
670 self.cache
671 .as_ref()
672 .map(super::cache::TranslationCache::hit_rate)
673 }
674
675 pub fn clear_cache(&self) {
679 if let Some(ref cache) = self.cache {
680 cache.clear();
681 }
682 }
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688
689 #[test]
690 fn test_translator_creation() {
691 let translator = Translator::load_default().unwrap();
692 assert_eq!(translator.translation_count(), 0);
693 }
694
695 #[test]
696 fn test_translate_simple_query() {
697 let mut translator = Translator::load_default().unwrap();
698 let response = translator.translate("find authentication functions");
699
700 if let TranslationResponse::Reject { reason, .. } = &response {
702 assert!(!reason.contains("Could not determine"));
704 }
705 assert_eq!(translator.translation_count(), 1);
706 }
707
708 #[test]
709 fn test_translate_with_language() {
710 let mut translator = Translator::load_default().unwrap();
711 let response = translator.translate("find authentication in rust");
712
713 match response {
714 TranslationResponse::Execute { command, .. }
715 | TranslationResponse::Confirm { command, .. } => {
716 assert!(command.contains("--language rust"));
717 }
718 _ => {} }
720 }
721
722 #[test]
723 fn test_translate_callers() {
724 let mut translator = Translator::load_default().unwrap();
725 let response = translator.translate("who calls authenticate");
726
727 match response {
728 TranslationResponse::Execute { intent, .. } => {
729 assert_eq!(intent, Intent::FindCallers);
730 }
731 TranslationResponse::Confirm { command, .. } => {
732 assert!(
734 command.contains("graph direct-callers") || command.contains("authenticate")
735 );
736 }
737 _ => {}
738 }
739 }
740
741 #[test]
742 fn test_custom_thresholds() {
743 let config = TranslatorConfig {
744 execute_threshold: 0.99,
745 confirm_threshold: 0.90,
746 ..Default::default()
747 };
748 let mut translator = Translator::new(config).unwrap();
749
750 let response = translator.translate("find foo");
752 assert!(!matches!(response, TranslationResponse::Execute { .. }));
753 }
754
755 #[test]
756 fn test_kind_only_query() {
757 let mut translator = Translator::load_default().unwrap();
758
759 let response = translator.translate("list all traits");
761 match response {
762 TranslationResponse::Execute { command, .. }
763 | TranslationResponse::Confirm { command, .. } => {
764 assert!(command.contains("kind:trait"));
766 }
767 _ => panic!("Expected Execute or Confirm response"),
768 }
769 }
770
771 #[test]
772 fn test_snake_case_symbol() {
773 let mut translator = Translator::load_default().unwrap();
774
775 let response = translator.translate("find user_id variable");
777 match response {
778 TranslationResponse::Execute { command, .. }
779 | TranslationResponse::Confirm { command, .. } => {
780 assert!(command.contains("user_id"));
781 }
782 _ => panic!("Expected Execute or Confirm response"),
783 }
784 }
785}
786
787#[cfg(test)]
789mod predicate_translation_tests {
790 use super::*;
791
792 #[test]
793 fn test_async_functions_translation() {
794 let config = TranslatorConfig::default();
795 let mut translator = Translator::new(config).expect("Translator init failed");
796
797 let response = translator.translate("find async functions");
798 match response {
799 TranslationResponse::Execute { command, .. }
800 | TranslationResponse::Confirm { command, .. } => {
801 assert!(command.contains("async:true"));
802 }
803 _ => panic!("should execute or confirm"),
804 }
805 }
806
807 #[test]
808 fn test_unsafe_functions_translation() {
809 let config = TranslatorConfig::default();
810 let mut translator = Translator::new(config).expect("Translator init failed");
811
812 let response = translator.translate("find unsafe functions");
813 match response {
814 TranslationResponse::Execute { command, .. }
815 | TranslationResponse::Confirm { command, .. } => {
816 assert!(command.contains("unsafe:true"));
817 }
818 _ => panic!("should execute or confirm"),
819 }
820 }
821
822 #[test]
823 fn test_public_async_functions_translation() {
824 let config = TranslatorConfig::default();
825 let mut translator = Translator::new(config).expect("Translator init failed");
826
827 let response = translator.translate("find public async functions");
828 match response {
829 TranslationResponse::Execute { command, .. }
830 | TranslationResponse::Confirm { command, .. } => {
831 assert!(command.contains("visibility:public"));
832 assert!(command.contains("async:true"));
833 }
834 _ => panic!("should execute or confirm"),
835 }
836 }
837}