1use std::collections::{HashMap, HashSet, VecDeque};
22use std::path::{Path, PathBuf};
23
24use serde::{Deserialize, Serialize};
25
26use crate::ast::extract::extract_file;
27use crate::callgraph::build_project_call_graph;
28use crate::cfg::get_cfg_context;
29use crate::error::TldrError;
30use crate::types::{FunctionInfo, Language, ProjectCallGraph};
31use crate::TldrResult;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct FunctionContext {
36 pub name: String,
38 pub file: PathBuf,
40 pub line: u32,
42 pub signature: String,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub docstring: Option<String>,
47 pub calls: Vec<String>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub blocks: Option<usize>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub cyclomatic: Option<u32>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct RelevantContext {
60 pub entry_point: String,
62 pub depth: usize,
64 pub functions: Vec<FunctionContext>,
66}
67
68impl RelevantContext {
69 pub fn to_llm_string(&self) -> String {
73 let mut output = String::new();
74
75 output.push_str(&format!(
76 "# Code Context: {} (depth={})\n\n",
77 self.entry_point, self.depth
78 ));
79
80 output.push_str(&format!(
81 "## Summary\n- Entry point: `{}`\n- Functions included: {}\n\n",
82 self.entry_point,
83 self.functions.len()
84 ));
85
86 output.push_str("## Functions\n\n");
87
88 for func in &self.functions {
89 output.push_str(&format!(
90 "### {} ({}:{})\n\n",
91 func.name,
92 func.file.display(),
93 func.line
94 ));
95 output.push_str(&format!("```\n{}\n```\n\n", func.signature));
96
97 if let Some(ref doc) = func.docstring {
98 output.push_str(&format!("**Docstring:** {}\n\n", doc.trim()));
99 }
100
101 if !func.calls.is_empty() {
102 output.push_str(&format!("**Calls:** {}\n\n", func.calls.join(", ")));
103 }
104
105 if let (Some(blocks), Some(cyclomatic)) = (func.blocks, func.cyclomatic) {
106 output.push_str(&format!(
107 "**Complexity:** {} blocks, cyclomatic={}\n\n",
108 blocks, cyclomatic
109 ));
110 }
111
112 output.push_str("---\n\n");
113 }
114
115 output
116 }
117}
118
119pub fn get_relevant_context(
148 project: &Path,
149 entry_point: &str,
150 depth: usize,
151 language: Language,
152 include_docstrings: bool,
153 file_filter: Option<&Path>,
154) -> TldrResult<RelevantContext> {
155 let call_graph = build_project_call_graph(project, language, None, true)?;
157
158 let entry_location = find_function_in_graph(&call_graph, entry_point, project, file_filter)?;
160
161 let function_keys = bfs_collect_functions(&call_graph, &entry_location, depth);
163
164 let mut functions = Vec::new();
166 let mut seen_files: HashMap<PathBuf, crate::types::ModuleInfo> = HashMap::new();
167
168 for (file, func_name) in function_keys {
169 let full_path = if file.is_relative() {
172 project.join(&file)
173 } else {
174 file.clone()
175 };
176
177 let module_info = if let Some(info) = seen_files.get(&file) {
179 info.clone()
180 } else {
181 let info = extract_file(&full_path, Some(project)).unwrap_or_else(|_| {
182 crate::types::ModuleInfo {
184 file_path: file.clone(),
185 language,
186 docstring: None,
187 imports: vec![],
188 functions: vec![],
189 classes: vec![],
190 constants: vec![],
191 call_graph: Default::default(),
192 }
193 });
194 seen_files.insert(file.clone(), info.clone());
195 info
196 };
197
198 if let Some(func_info) = find_function_info(&module_info, &func_name) {
200 let func_context = build_function_context(
201 &full_path,
202 &func_name,
203 func_info,
204 &module_info,
205 project,
206 language,
207 include_docstrings,
208 );
209 functions.push(func_context);
210 }
211 }
212
213 Ok(RelevantContext {
214 entry_point: entry_point.to_string(),
215 depth,
216 functions,
217 })
218}
219
220fn find_function_in_graph(
225 call_graph: &ProjectCallGraph,
226 func_name: &str,
227 project: &Path,
228 file_filter: Option<&Path>,
229) -> TldrResult<(PathBuf, String)> {
230 let file_matches = |file: &Path| -> bool {
232 match file_filter {
233 None => true,
234 Some(filter) => file.ends_with(filter),
235 }
236 };
237
238 for edge in call_graph.edges() {
240 if (edge.src_func == func_name || edge.src_func.ends_with(&format!(".{}", func_name)))
241 && file_matches(&edge.src_file)
242 {
243 return Ok((edge.src_file.clone(), edge.src_func.clone()));
244 }
245 if (edge.dst_func == func_name || edge.dst_func.ends_with(&format!(".{}", func_name)))
246 && file_matches(&edge.dst_file)
247 {
248 return Ok((edge.dst_file.clone(), edge.dst_func.clone()));
249 }
250 }
251
252 if let Some(location) = scan_project_for_function(project, func_name, file_filter)? {
255 return Ok(location);
256 }
257
258 let suggestions = collect_similar_function_names(call_graph, func_name);
260
261 Err(TldrError::FunctionNotFound {
262 name: func_name.to_string(),
263 file: None,
264 suggestions,
265 })
266}
267
268fn scan_project_for_function(
273 project: &Path,
274 func_name: &str,
275 file_filter: Option<&Path>,
276) -> TldrResult<Option<(PathBuf, String)>> {
277 use crate::fs::tree::{collect_files, get_file_tree};
278 use crate::types::IgnoreSpec;
279
280 let tree = get_file_tree(project, None, true, Some(&IgnoreSpec::default()))?;
282 let files = collect_files(&tree, project);
283
284 for file_path in files {
285 if let Some(filter) = file_filter {
287 let relative = file_path.strip_prefix(project).unwrap_or(&file_path);
289 if !relative.ends_with(filter) {
290 continue;
291 }
292 }
293
294 if let Ok(module_info) = extract_file(&file_path, Some(project)) {
295 for func in &module_info.functions {
297 if func.name == func_name {
298 return Ok(Some((file_path, func.name.clone())));
299 }
300 }
301 for class in &module_info.classes {
303 for method in &class.methods {
304 if method.name == func_name {
305 let full_name = format!("{}.{}", class.name, method.name);
306 return Ok(Some((file_path, full_name)));
307 }
308 }
309 }
310 }
311 }
312
313 Ok(None)
314}
315
316fn collect_similar_function_names(call_graph: &ProjectCallGraph, target: &str) -> Vec<String> {
318 let mut seen = HashSet::new();
319 let mut suggestions = Vec::new();
320 let target_lower = target.to_lowercase();
321
322 for edge in call_graph.edges() {
323 for func in [&edge.src_func, &edge.dst_func] {
324 if !seen.contains(func) {
325 seen.insert(func.clone());
326 let func_lower = func.to_lowercase();
327 if func_lower.contains(&target_lower) || target_lower.contains(&func_lower) {
329 suggestions.push(func.clone());
330 }
331 }
332 }
333 }
334
335 suggestions.sort();
336 suggestions.truncate(5);
337 suggestions
338}
339
340fn bfs_collect_functions(
342 call_graph: &ProjectCallGraph,
343 entry: &(PathBuf, String),
344 max_depth: usize,
345) -> Vec<(PathBuf, String)> {
346 let mut result = Vec::new();
347 let mut visited: HashSet<(PathBuf, String)> = HashSet::new();
348 let mut queue: VecDeque<((PathBuf, String), usize)> = VecDeque::new();
349
350 let forward_graph = build_forward_graph(call_graph);
352
353 queue.push_back((entry.clone(), 0));
355 visited.insert(entry.clone());
356
357 while let Some(((file, func), current_depth)) = queue.pop_front() {
358 result.push((file.clone(), func.clone()));
359
360 if current_depth >= max_depth {
362 continue;
363 }
364
365 let key = (file.clone(), func.clone());
367 if let Some(callees) = forward_graph.get(&key) {
368 for callee in callees {
369 if !visited.contains(callee) {
370 visited.insert(callee.clone());
371 queue.push_back((callee.clone(), current_depth + 1));
372 }
373 }
374 }
375 }
376
377 result
378}
379
380fn build_forward_graph(
382 call_graph: &ProjectCallGraph,
383) -> HashMap<(PathBuf, String), Vec<(PathBuf, String)>> {
384 let mut forward: HashMap<(PathBuf, String), Vec<(PathBuf, String)>> = HashMap::new();
385
386 for edge in call_graph.edges() {
387 let src_key = (edge.src_file.clone(), edge.src_func.clone());
388 let dst_key = (edge.dst_file.clone(), edge.dst_func.clone());
389
390 forward.entry(src_key).or_default().push(dst_key);
391 }
392
393 forward
394}
395
396fn find_function_info<'a>(
398 module_info: &'a crate::types::ModuleInfo,
399 func_name: &str,
400) -> Option<&'a FunctionInfo> {
401 for func in &module_info.functions {
403 if func.name == func_name {
404 return Some(func);
405 }
406 }
407
408 if let Some(dot_idx) = func_name.find('.') {
410 let class_name = &func_name[..dot_idx];
411 let method_name = &func_name[dot_idx + 1..];
412
413 for class in &module_info.classes {
414 if class.name == class_name {
415 for method in &class.methods {
416 if method.name == method_name {
417 return Some(method);
418 }
419 }
420 }
421 }
422 }
423
424 None
425}
426
427fn build_function_context(
429 file: &Path,
430 func_name: &str,
431 func_info: &FunctionInfo,
432 module_info: &crate::types::ModuleInfo,
433 project: &Path,
434 language: Language,
435 include_docstrings: bool,
436) -> FunctionContext {
437 let signature = build_signature(func_info, language);
439
440 let calls = module_info
442 .call_graph
443 .calls
444 .get(&func_info.name)
445 .cloned()
446 .unwrap_or_default();
447
448 let (blocks, cyclomatic) = get_cfg_metrics(file, func_name, language);
450
451 let relative_file = file
453 .strip_prefix(project)
454 .map(|p| p.to_path_buf())
455 .unwrap_or_else(|_| file.to_path_buf());
456
457 FunctionContext {
458 name: func_name.to_string(),
459 file: relative_file,
460 line: func_info.line_number,
461 signature,
462 docstring: if include_docstrings {
463 func_info.docstring.clone()
464 } else {
465 None
466 },
467 calls,
468 blocks,
469 cyclomatic,
470 }
471}
472
473fn build_signature(func_info: &FunctionInfo, language: Language) -> String {
475 let params = func_info.params.join(", ");
476
477 let return_type = func_info
478 .return_type
479 .as_ref()
480 .map(|t| format!(" -> {}", t))
481 .unwrap_or_default();
482
483 let async_prefix = if func_info.is_async { "async " } else { "" };
484
485 match language {
486 Language::Python => {
487 format!(
488 "{}def {}({}){}",
489 async_prefix, func_info.name, params, return_type
490 )
491 }
492 Language::TypeScript | Language::JavaScript => {
493 format!(
494 "{}function {}({}){}",
495 async_prefix, func_info.name, params, return_type
496 )
497 }
498 Language::Go => {
499 format!("func {}({}){}", func_info.name, params, return_type)
500 }
501 Language::Rust => {
502 format!(
503 "{}fn {}({}){}",
504 async_prefix, func_info.name, params, return_type
505 )
506 }
507 _ => {
508 format!("{}({}){}", func_info.name, params, return_type)
509 }
510 }
511}
512
513fn get_cfg_metrics(
515 file: &Path,
516 func_name: &str,
517 language: Language,
518) -> (Option<usize>, Option<u32>) {
519 let lookup_name = if let Some(dot_idx) = func_name.rfind('.') {
521 &func_name[dot_idx + 1..]
522 } else {
523 func_name
524 };
525
526 match get_cfg_context(file.to_str().unwrap_or(""), lookup_name, language) {
527 Ok(cfg) => (Some(cfg.blocks.len()), Some(cfg.cyclomatic_complexity)),
528 Err(_) => (None, None),
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
543 fn test_context_resolves_relative_paths_from_callgraph() {
544 use std::fs;
545 use tempfile::TempDir;
546
547 let temp_dir = TempDir::new().unwrap();
549 let project = temp_dir.path();
550
551 let main_py = r#"from helper import do_work
553
554def main():
555 """Entry point."""
556 result = do_work(42)
557 return result
558"#;
559
560 let helper_py = r#"def do_work(x):
562 """Do some work."""
563 return internal_calc(x) + 1
564
565def internal_calc(x):
566 """Internal calculation."""
567 return x * 2
568"#;
569
570 fs::write(project.join("main.py"), main_py).unwrap();
571 fs::write(project.join("helper.py"), helper_py).unwrap();
572
573 let result = get_relevant_context(project, "main", 1, Language::Python, true, None);
576
577 assert!(
578 result.is_ok(),
579 "get_relevant_context failed: {:?}",
580 result.err()
581 );
582 let ctx = result.unwrap();
583
584 assert!(
586 !ctx.functions.is_empty(),
587 "Expected non-empty functions in context, got 0. \
588 This indicates extract_file() failed to resolve relative paths from the call graph."
589 );
590
591 let func_names: Vec<&str> = ctx.functions.iter().map(|f| f.name.as_str()).collect();
593 assert!(
594 func_names.contains(&"main"),
595 "Expected 'main' in context functions, got: {:?}",
596 func_names
597 );
598
599 assert!(
601 func_names.contains(&"do_work"),
602 "Expected callee 'do_work' in context at depth=1, got: {:?}",
603 func_names
604 );
605 }
606
607 #[test]
611 fn test_context_intra_file_calls() {
612 use std::fs;
613 use tempfile::TempDir;
614
615 let temp_dir = TempDir::new().unwrap();
616 let project = temp_dir.path();
617
618 let main_py = r#"def entry():
619 """Entry function."""
620 return helper(10)
621
622def helper(n):
623 """Helper function."""
624 return n + 1
625"#;
626
627 fs::write(project.join("main.py"), main_py).unwrap();
628
629 let result = get_relevant_context(project, "entry", 1, Language::Python, true, None);
630
631 assert!(
632 result.is_ok(),
633 "get_relevant_context failed: {:?}",
634 result.err()
635 );
636 let ctx = result.unwrap();
637
638 assert!(
639 !ctx.functions.is_empty(),
640 "Expected non-empty functions in context"
641 );
642
643 let func_names: Vec<&str> = ctx.functions.iter().map(|f| f.name.as_str()).collect();
644 assert!(
645 func_names.contains(&"entry"),
646 "Expected 'entry' in context, got: {:?}",
647 func_names
648 );
649 }
650
651 #[test]
652 fn test_relevant_context_to_llm_string() {
653 let ctx = RelevantContext {
654 entry_point: "main".to_string(),
655 depth: 1,
656 functions: vec![
657 FunctionContext {
658 name: "main".to_string(),
659 file: PathBuf::from("src/main.py"),
660 line: 10,
661 signature: "def main()".to_string(),
662 docstring: Some("Entry point".to_string()),
663 calls: vec!["helper".to_string()],
664 blocks: Some(3),
665 cyclomatic: Some(2),
666 },
667 FunctionContext {
668 name: "helper".to_string(),
669 file: PathBuf::from("src/utils.py"),
670 line: 5,
671 signature: "def helper(x: int) -> str".to_string(),
672 docstring: None,
673 calls: vec![],
674 blocks: Some(1),
675 cyclomatic: Some(1),
676 },
677 ],
678 };
679
680 let output = ctx.to_llm_string();
681 assert!(output.contains("main"));
682 assert!(output.contains("helper"));
683 assert!(output.contains("Entry point"));
684 assert!(output.contains("depth=1"));
685 }
686
687 #[test]
688 fn test_build_signature_python() {
689 let func = FunctionInfo {
690 name: "process".to_string(),
691 params: vec!["x: int".to_string(), "y: str".to_string()],
692 return_type: Some("bool".to_string()),
693 docstring: None,
694 is_method: false,
695 is_async: false,
696 decorators: vec![],
697 line_number: 1,
698 };
699
700 let sig = build_signature(&func, Language::Python);
701 assert_eq!(sig, "def process(x: int, y: str) -> bool");
702 }
703
704 #[test]
705 fn test_build_signature_async() {
706 let func = FunctionInfo {
707 name: "fetch".to_string(),
708 params: vec!["url: str".to_string()],
709 return_type: Some("Response".to_string()),
710 docstring: None,
711 is_method: false,
712 is_async: true,
713 decorators: vec![],
714 line_number: 1,
715 };
716
717 let sig = build_signature(&func, Language::Python);
718 assert_eq!(sig, "async def fetch(url: str) -> Response");
719 }
720
721 #[test]
722 fn test_bfs_collect_empty_graph() {
723 let graph = ProjectCallGraph::new();
724 let entry = (PathBuf::from("main.py"), "main".to_string());
725 let result = bfs_collect_functions(&graph, &entry, 5);
726 assert_eq!(result.len(), 1);
728 assert_eq!(result[0].1, "main");
729 }
730
731 #[test]
735 fn test_file_filter_disambiguates_same_function_name() {
736 use std::fs;
737 use tempfile::TempDir;
738
739 let temp_dir = TempDir::new().unwrap();
740 let project = temp_dir.path();
741
742 let shortcuts_py = r#"def render(request, template_name):
744 """Shortcut render function."""
745 return load_template(template_name)
746
747def load_template(name):
748 """Load a template by name."""
749 return name
750"#;
751
752 let backends_py = r#"def render(template, context):
753 """Backend render function."""
754 return compile_template(template)
755
756def compile_template(template):
757 """Compile a template."""
758 return template
759"#;
760
761 fs::create_dir_all(project.join("django")).unwrap();
763 fs::write(project.join("django/shortcuts.py"), shortcuts_py).unwrap();
764 fs::create_dir_all(project.join("django/template/backends")).unwrap();
765 fs::write(
766 project.join("django/template/backends/django.py"),
767 backends_py,
768 )
769 .unwrap();
770
771 let result_any = get_relevant_context(
773 project,
774 "render",
775 1,
776 Language::Python,
777 false,
778 None, );
780 assert!(
781 result_any.is_ok(),
782 "get_relevant_context without filter failed: {:?}",
783 result_any.err()
784 );
785 let ctx_any = result_any.unwrap();
786 assert!(
787 !ctx_any.functions.is_empty(),
788 "Expected non-empty functions without filter"
789 );
790
791 let result_shortcuts = get_relevant_context(
793 project,
794 "render",
795 1,
796 Language::Python,
797 false,
798 Some(Path::new("django/shortcuts.py")),
799 );
800 assert!(
801 result_shortcuts.is_ok(),
802 "get_relevant_context with shortcuts filter failed: {:?}",
803 result_shortcuts.err()
804 );
805 let ctx_shortcuts = result_shortcuts.unwrap();
806 assert!(
807 !ctx_shortcuts.functions.is_empty(),
808 "Expected non-empty functions with shortcuts filter"
809 );
810
811 let entry_func = &ctx_shortcuts.functions[0];
813 assert_eq!(entry_func.name, "render");
814 assert!(
815 entry_func.file.ends_with("django/shortcuts.py"),
816 "Expected render from django/shortcuts.py, got: {}",
817 entry_func.file.display()
818 );
819
820 let callee_names: Vec<&str> = ctx_shortcuts
822 .functions
823 .iter()
824 .map(|f| f.name.as_str())
825 .collect();
826 assert!(
827 callee_names.contains(&"load_template"),
828 "Expected callee 'load_template' from shortcuts, got: {:?}",
829 callee_names
830 );
831 assert!(
833 !callee_names.contains(&"compile_template"),
834 "Should not contain 'compile_template' from backends when filtering to shortcuts"
835 );
836
837 let result_backends = get_relevant_context(
839 project,
840 "render",
841 1,
842 Language::Python,
843 false,
844 Some(Path::new("django/template/backends/django.py")),
845 );
846 assert!(
847 result_backends.is_ok(),
848 "get_relevant_context with backends filter failed: {:?}",
849 result_backends.err()
850 );
851 let ctx_backends = result_backends.unwrap();
852 let backend_entry = &ctx_backends.functions[0];
853 assert_eq!(backend_entry.name, "render");
854 assert!(
855 backend_entry
856 .file
857 .ends_with("django/template/backends/django.py"),
858 "Expected render from backends/django.py, got: {}",
859 backend_entry.file.display()
860 );
861
862 let backend_names: Vec<&str> = ctx_backends
864 .functions
865 .iter()
866 .map(|f| f.name.as_str())
867 .collect();
868 assert!(
869 backend_names.contains(&"compile_template"),
870 "Expected callee 'compile_template' from backends, got: {:?}",
871 backend_names
872 );
873 assert!(
874 !backend_names.contains(&"load_template"),
875 "Should not contain 'load_template' from shortcuts when filtering to backends"
876 );
877 }
878
879 #[test]
881 fn test_file_filter_nonexistent_file_returns_error() {
882 use std::fs;
883 use tempfile::TempDir;
884
885 let temp_dir = TempDir::new().unwrap();
886 let project = temp_dir.path();
887
888 let main_py = r#"def render():
889 """A render function."""
890 pass
891"#;
892 fs::write(project.join("main.py"), main_py).unwrap();
893
894 let result = get_relevant_context(
896 project,
897 "render",
898 0,
899 Language::Python,
900 false,
901 Some(Path::new("nonexistent.py")),
902 );
903
904 assert!(
905 result.is_err(),
906 "Expected FunctionNotFound error when filtering to nonexistent file"
907 );
908 }
909}