1use std::collections::{HashMap, HashSet};
15use std::path::{Path, PathBuf};
16
17use rayon::prelude::*;
18use serde::{Deserialize, Serialize};
19
20use super::bm25::{Bm25Index, Bm25Result};
21use super::text::{self, SearchMatch};
22use crate::ast::parser::parse_file;
23use crate::types::{CodeStructure, DefinitionInfo, Language};
24use crate::TldrResult;
25
26#[derive(Debug, Clone, Default)]
31pub enum SearchMode {
32 #[default]
35 Bm25,
36
37 Regex(String),
41
42 Hybrid {
48 query: String,
50 pattern: String,
52 },
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct EnrichedResult {
61 pub name: String,
63 pub kind: String,
65 pub file: PathBuf,
67 pub line_range: (u32, u32),
69 pub signature: String,
71 pub callers: Vec<String>,
73 pub callees: Vec<String>,
75 pub score: f64,
77 pub matched_terms: Vec<String>,
79 #[serde(default, skip_serializing_if = "String::is_empty")]
81 pub preview: String,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EnrichedSearchReport {
87 pub query: String,
89 pub results: Vec<EnrichedResult>,
91 pub total_files_searched: usize,
93 pub search_mode: String,
95}
96
97#[derive(Debug, Clone)]
100struct StructureEntry {
101 name: String,
102 kind: String,
103 line_start: u32,
104 line_end: u32,
105 signature: String,
106 preview: String,
108}
109
110#[derive(Debug, Clone)]
112pub struct EnrichedSearchOptions {
113 pub top_k: usize,
115 pub include_callgraph: bool,
118 pub search_mode: SearchMode,
120}
121
122impl Default for EnrichedSearchOptions {
123 fn default() -> Self {
124 Self {
125 top_k: 10,
126 include_callgraph: true,
127 search_mode: SearchMode::default(),
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct CallGraphLookup {
135 pub forward: HashMap<String, Vec<String>>,
137 pub reverse: HashMap<String, Vec<String>>,
139}
140
141#[derive(Debug, Clone, Deserialize)]
145struct WarmCallEdge {
146 #[allow(dead_code)]
147 from_file: PathBuf,
148 from_func: String,
149 #[allow(dead_code)]
150 to_file: PathBuf,
151 to_func: String,
152}
153
154#[derive(Debug, Clone, Deserialize)]
156struct WarmCallGraphCache {
157 edges: Vec<WarmCallEdge>,
158 #[allow(dead_code)]
159 languages: Vec<String>,
160 #[allow(dead_code)]
161 timestamp: i64,
162}
163
164pub fn read_callgraph_cache(cache_path: &Path) -> TldrResult<CallGraphLookup> {
169 let content = std::fs::read_to_string(cache_path).map_err(crate::TldrError::IoError)?;
170 let cache: WarmCallGraphCache = serde_json::from_str(&content).map_err(|e| {
171 crate::TldrError::SerializationError(format!("Failed to parse call graph cache: {}", e))
172 })?;
173
174 let mut forward: HashMap<String, Vec<String>> = HashMap::new();
175 let mut reverse: HashMap<String, Vec<String>> = HashMap::new();
176
177 for edge in &cache.edges {
178 forward
179 .entry(edge.from_func.clone())
180 .or_default()
181 .push(edge.to_func.clone());
182 reverse
183 .entry(edge.to_func.clone())
184 .or_default()
185 .push(edge.from_func.clone());
186 }
187
188 Ok(CallGraphLookup { forward, reverse })
189}
190
191#[derive(Debug, Clone)]
197pub struct StructureLookup {
198 pub by_file: HashMap<PathBuf, Vec<DefinitionInfo>>,
200}
201
202#[derive(Debug, Clone, Serialize, Deserialize)]
204struct StructureCacheEnvelope {
205 files: Vec<CachedFileEntry>,
206 timestamp: i64,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211struct CachedFileEntry {
212 path: PathBuf,
213 definitions: Vec<DefinitionInfo>,
214}
215
216pub fn write_structure_cache(structure: &CodeStructure, cache_path: &Path) -> TldrResult<()> {
221 let envelope = StructureCacheEnvelope {
222 files: structure
223 .files
224 .iter()
225 .map(|f| CachedFileEntry {
226 path: f.path.clone(),
227 definitions: f.definitions.clone(),
228 })
229 .collect(),
230 timestamp: std::time::SystemTime::now()
231 .duration_since(std::time::UNIX_EPOCH)
232 .unwrap_or_default()
233 .as_secs() as i64,
234 };
235 let json = serde_json::to_string_pretty(&envelope).map_err(|e| {
236 crate::TldrError::SerializationError(format!("Failed to serialize structure cache: {}", e))
237 })?;
238 if let Some(parent) = cache_path.parent() {
239 std::fs::create_dir_all(parent).map_err(crate::TldrError::IoError)?;
240 }
241 std::fs::write(cache_path, json).map_err(crate::TldrError::IoError)?;
242 Ok(())
243}
244
245pub fn read_structure_cache(cache_path: &Path) -> TldrResult<StructureLookup> {
249 let content = std::fs::read_to_string(cache_path).map_err(crate::TldrError::IoError)?;
250 let envelope: StructureCacheEnvelope = serde_json::from_str(&content).map_err(|e| {
251 crate::TldrError::SerializationError(format!("Failed to parse structure cache: {}", e))
252 })?;
253 let mut by_file = HashMap::new();
254 for entry in envelope.files {
255 by_file.insert(entry.path, entry.definitions);
256 }
257 Ok(StructureLookup { by_file })
258}
259
260fn regex_matches_to_bm25_results(matches: &[SearchMatch]) -> Vec<Bm25Result> {
267 let mut file_counts: HashMap<PathBuf, usize> = HashMap::new();
269 for m in matches {
270 *file_counts.entry(m.file.clone()).or_insert(0) += 1;
271 }
272
273 matches
274 .iter()
275 .map(|m| {
276 let file_match_count = file_counts[&m.file] as f64;
277 Bm25Result {
278 file_path: m.file.clone(),
279 score: file_match_count,
280 line_start: m.line,
281 line_end: m.line,
282 snippet: m.content.clone(),
283 matched_terms: vec![], }
285 })
286 .collect()
287}
288
289fn do_regex_search(
294 pattern: &str,
295 root: &Path,
296 language: Language,
297 top_k: usize,
298) -> crate::TldrResult<(Vec<SearchMatch>, usize)> {
299 let extensions: HashSet<String> = language
300 .extensions()
301 .iter()
302 .map(|e| e.to_string())
303 .collect();
304 let raw_limit = (top_k * 10).max(200);
305 let matches = text::search(
306 pattern,
307 root,
308 Some(&extensions),
309 0, raw_limit,
311 usize::MAX, None, )?;
314 let unique_files: HashSet<&PathBuf> = matches.iter().map(|m| &m.file).collect();
318 let total_files = walkdir::WalkDir::new(root)
320 .follow_links(false)
321 .into_iter()
322 .filter_map(|e| e.ok())
323 .filter(|e| !e.file_type().is_dir())
324 .filter(|e| {
325 e.path()
326 .extension()
327 .and_then(|ext| ext.to_str())
328 .map(|ext| {
329 let with_dot = format!(".{}", ext);
330 extensions.contains(&with_dot) || extensions.contains(ext)
331 })
332 .unwrap_or(false)
333 })
334 .count();
335 let total = total_files.max(unique_files.len());
337 Ok((matches, total))
338}
339
340pub fn enriched_search(
359 query: &str,
360 root: &Path,
361 language: Language,
362 options: EnrichedSearchOptions,
363) -> TldrResult<EnrichedSearchReport> {
364 search_with_inner(query, root, language, options, None, None, None)
365}
366
367pub fn enriched_search_with_callgraph_cache(
386 query: &str,
387 root: &Path,
388 language: Language,
389 options: EnrichedSearchOptions,
390 cache_path: &Path,
391) -> TldrResult<EnrichedSearchReport> {
392 search_with_inner(query, root, language, options, None, None, Some(cache_path))
393}
394
395pub fn enriched_search_with_index(
413 query: &str,
414 root: &Path,
415 language: Language,
416 options: EnrichedSearchOptions,
417 index: &Bm25Index,
418) -> TldrResult<EnrichedSearchReport> {
419 search_with_inner(query, root, language, options, Some(index), None, None)
420}
421
422fn process_file_results(
431 rel_path: &PathBuf,
432 results: &[&Bm25Result],
433 root: &Path,
434 language: Language,
435 cached_defs: Option<&[DefinitionInfo]>,
436) -> Vec<((PathBuf, String), EnrichedResult)> {
437 let abs_path = root.join(rel_path);
438
439 let entries = if let Some(defs) = cached_defs {
441 defs.iter()
442 .map(|d| StructureEntry {
443 name: d.name.clone(),
444 kind: d.kind.clone(),
445 line_start: d.line_start,
446 line_end: d.line_end,
447 signature: d.signature.clone(),
448 preview: String::new(), })
450 .collect()
451 } else {
452 match extract_structure_entries(&abs_path, language) {
453 Ok(entries) => entries,
454 Err(_) => {
455 let mut local_dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
458 for result in results {
459 let key = (rel_path.clone(), rel_path.display().to_string());
460 let entry = local_dedup.entry(key).or_insert_with(|| EnrichedResult {
461 name: rel_path.display().to_string(),
462 kind: "module".to_string(),
463 file: rel_path.clone(),
464 line_range: (result.line_start, result.line_end),
465 signature: result.snippet.lines().next().unwrap_or("").to_string(),
466 callers: Vec::new(),
467 callees: Vec::new(),
468 score: result.score,
469 matched_terms: result.matched_terms.clone(),
470 preview: String::new(),
471 });
472 if result.score > entry.score {
473 entry.score = result.score;
474 }
475 }
476 return local_dedup.into_iter().collect();
477 }
478 }
479 };
480
481 let mut local_dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
483
484 for result in results {
488 let enclosing = (result.line_start..=result.line_end)
492 .filter_map(|line| find_enclosing_entry(&entries, line))
493 .min_by_key(|e| e.line_end - e.line_start);
494
495 match enclosing {
496 Some(entry) => {
497 let key = (rel_path.clone(), entry.name.clone());
498 let enriched = local_dedup.entry(key).or_insert_with(|| EnrichedResult {
499 name: entry.name.clone(),
500 kind: entry.kind.clone(),
501 file: rel_path.clone(),
502 line_range: (entry.line_start, entry.line_end),
503 signature: entry.signature.clone(),
504 callers: Vec::new(),
505 callees: Vec::new(),
506 score: result.score,
507 matched_terms: result.matched_terms.clone(),
508 preview: entry.preview.clone(),
509 });
510 if result.score > enriched.score {
512 enriched.score = result.score;
513 }
514 for term in &result.matched_terms {
515 if !enriched.matched_terms.contains(term) {
516 enriched.matched_terms.push(term.clone());
517 }
518 }
519 }
520 None => {
521 let sig = result
524 .snippet
525 .lines()
526 .find(|l| {
527 let t = l.trim();
528 !t.is_empty()
529 && !t.starts_with("///")
530 && !t.starts_with("//!")
531 && !t.starts_with("//")
532 && !t.starts_with("/*")
533 && !t.starts_with("*")
534 })
535 .or_else(|| result.snippet.lines().next())
536 .unwrap_or("")
537 .trim()
538 .to_string();
539 let key = (
540 rel_path.clone(),
541 format!("{}:{}", rel_path.display(), result.line_start),
542 );
543 local_dedup.entry(key).or_insert_with(|| EnrichedResult {
544 name: rel_path
545 .file_stem()
546 .map(|s| s.to_string_lossy().to_string())
547 .unwrap_or_else(|| rel_path.display().to_string()),
548 kind: "module".to_string(),
549 file: rel_path.clone(),
550 line_range: (result.line_start, result.line_end),
551 signature: sig,
552 callers: Vec::new(),
553 callees: Vec::new(),
554 score: result.score,
555 matched_terms: result.matched_terms.clone(),
556 preview: result.snippet.clone(),
557 });
558 }
559 }
560 }
561
562 local_dedup.into_iter().collect()
563}
564
565fn enrich_and_deduplicate(
569 raw_results: &[Bm25Result],
570 root: &Path,
571 language: Language,
572) -> Vec<EnrichedResult> {
573 let mut by_file: HashMap<PathBuf, Vec<&Bm25Result>> = HashMap::new();
575 for result in raw_results {
576 by_file
577 .entry(result.file_path.clone())
578 .or_default()
579 .push(result);
580 }
581
582 let by_file_vec: Vec<(&PathBuf, &Vec<&Bm25Result>)> = by_file.iter().collect();
584
585 let file_results: Vec<Vec<((PathBuf, String), EnrichedResult)>> = if by_file_vec.len() >= 4 {
587 by_file_vec
588 .par_iter()
589 .map(|(rel_path, results)| {
590 process_file_results(rel_path, results, root, language, None)
591 })
592 .collect()
593 } else {
594 by_file_vec
595 .iter()
596 .map(|(rel_path, results)| {
597 process_file_results(rel_path, results, root, language, None)
598 })
599 .collect()
600 };
601
602 let mut dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
604 for file_entries in file_results {
605 for (key, entry) in file_entries {
606 let existing = dedup.entry(key).or_insert(entry.clone());
607 if entry.score > existing.score {
608 existing.score = entry.score;
609 }
610 for term in &entry.matched_terms {
611 if !existing.matched_terms.contains(term) {
612 existing.matched_terms.push(term.clone());
613 }
614 }
615 }
616 }
617
618 dedup.into_values().collect()
619}
620
621fn enrich_and_deduplicate_with_cache(
627 raw_results: &[Bm25Result],
628 root: &Path,
629 language: Language,
630 structure_lookup: &StructureLookup,
631) -> Vec<EnrichedResult> {
632 let mut by_file: HashMap<PathBuf, Vec<&Bm25Result>> = HashMap::new();
634 for result in raw_results {
635 by_file
636 .entry(result.file_path.clone())
637 .or_default()
638 .push(result);
639 }
640
641 let by_file_vec: Vec<(&PathBuf, &Vec<&Bm25Result>)> = by_file.iter().collect();
643
644 let file_results: Vec<Vec<((PathBuf, String), EnrichedResult)>> = if by_file_vec.len() >= 4 {
647 by_file_vec
648 .par_iter()
649 .map(|(rel_path, results)| {
650 let cached = structure_lookup
651 .by_file
652 .get(*rel_path)
653 .map(|v| v.as_slice());
654 process_file_results(rel_path, results, root, language, cached)
655 })
656 .collect()
657 } else {
658 by_file_vec
659 .iter()
660 .map(|(rel_path, results)| {
661 let cached = structure_lookup
662 .by_file
663 .get(*rel_path)
664 .map(|v| v.as_slice());
665 process_file_results(rel_path, results, root, language, cached)
666 })
667 .collect()
668 };
669
670 let mut dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
672 for file_entries in file_results {
673 for (key, entry) in file_entries {
674 let existing = dedup.entry(key).or_insert(entry.clone());
675 if entry.score > existing.score {
676 existing.score = entry.score;
677 }
678 for term in &entry.matched_terms {
679 if !existing.matched_terms.contains(term) {
680 existing.matched_terms.push(term.clone());
681 }
682 }
683 }
684 }
685
686 dedup.into_values().collect()
687}
688
689pub fn enriched_search_with_structure_cache(
706 query: &str,
707 root: &Path,
708 language: Language,
709 options: EnrichedSearchOptions,
710 structure_lookup: &StructureLookup,
711) -> TldrResult<EnrichedSearchReport> {
712 search_with_inner(
713 query,
714 root,
715 language,
716 options,
717 None,
718 Some(structure_lookup),
719 None,
720 )
721}
722
723pub fn search_with_inner(
746 query: &str,
747 root: &Path,
748 language: Language,
749 options: EnrichedSearchOptions,
750 bm25_index: Option<&Bm25Index>,
751 structure_cache: Option<&StructureLookup>,
752 callgraph_cache_path: Option<&Path>,
753) -> TldrResult<EnrichedSearchReport> {
754 let top_k = options.top_k;
755 let mode_prefix;
756
757 let (raw_results, total_files) = match &options.search_mode {
759 SearchMode::Bm25 => {
760 mode_prefix = "bm25";
761 match bm25_index {
762 Some(idx) => {
763 let total = idx.document_count();
765 if idx.is_empty() {
766 return Ok(EnrichedSearchReport {
767 query: query.to_string(),
768 results: Vec::new(),
769 total_files_searched: 0,
770 search_mode: if structure_cache.is_some() {
771 "bm25+cached-structure".to_string()
772 } else {
773 "bm25+structure".to_string()
774 },
775 });
776 }
777 let raw_limit = (top_k * 5).max(50);
778 (idx.search(query, raw_limit), total)
779 }
780 None => {
781 let index = Bm25Index::from_project(root, language)?;
783 let total = index.document_count();
784 if index.is_empty() {
785 return Ok(EnrichedSearchReport {
786 query: query.to_string(),
787 results: Vec::new(),
788 total_files_searched: 0,
789 search_mode: if structure_cache.is_some() {
790 "bm25+cached-structure".to_string()
791 } else {
792 "bm25+structure".to_string()
793 },
794 });
795 }
796 let raw_limit = (top_k * 5).max(50);
797 (index.search(query, raw_limit), total)
798 }
799 }
800 }
801 SearchMode::Regex(pattern) => {
802 mode_prefix = "regex";
803 let (matches, total) = do_regex_search(pattern, root, language, top_k)?;
804 if matches.is_empty() {
805 return Ok(EnrichedSearchReport {
806 query: pattern.clone(),
807 results: Vec::new(),
808 total_files_searched: total,
809 search_mode: if structure_cache.is_some() {
810 "regex+cached-structure".to_string()
811 } else {
812 "regex+structure".to_string()
813 },
814 });
815 }
816 (regex_matches_to_bm25_results(&matches), total)
817 }
818 SearchMode::Hybrid {
819 query: hybrid_query,
820 pattern,
821 } => {
822 mode_prefix = "hybrid(bm25+regex)";
823
824 let raw_limit = (top_k * 5).max(50);
826 let (bm25_results, total_files) = match bm25_index {
827 Some(idx) => {
828 let total = idx.document_count();
829 if idx.is_empty() {
830 return Ok(EnrichedSearchReport {
831 query: hybrid_query.clone(),
832 results: Vec::new(),
833 total_files_searched: 0,
834 search_mode: "hybrid(bm25+regex)".to_string(),
835 });
836 }
837 (idx.search(hybrid_query, raw_limit), total)
838 }
839 None => {
840 let index = Bm25Index::from_project(root, language)?;
841 let total = index.document_count();
842 if index.is_empty() {
843 return Ok(EnrichedSearchReport {
844 query: hybrid_query.clone(),
845 results: Vec::new(),
846 total_files_searched: 0,
847 search_mode: "hybrid(bm25+regex)".to_string(),
848 });
849 }
850 (index.search(hybrid_query, raw_limit), total)
851 }
852 };
853
854 let (regex_matches, _regex_total) = do_regex_search(pattern, root, language, top_k)?;
856 if regex_matches.is_empty() {
857 return Ok(EnrichedSearchReport {
858 query: hybrid_query.clone(),
859 results: Vec::new(),
860 total_files_searched: total_files,
861 search_mode: "hybrid(bm25+regex)".to_string(),
862 });
863 }
864 let regex_results = regex_matches_to_bm25_results(®ex_matches);
865
866 let bm25_ranks: HashMap<&Path, usize> = bm25_results
868 .iter()
869 .enumerate()
870 .map(|(i, r)| (r.file_path.as_path(), i + 1))
871 .collect();
872 let regex_ranks: HashMap<&Path, usize> = regex_results
873 .iter()
874 .enumerate()
875 .map(|(i, r)| (r.file_path.as_path(), i + 1))
876 .collect();
877
878 let k = 60.0_f64;
880 let mut fused: Vec<Bm25Result> = Vec::new();
881 for bm25_result in &bm25_results {
882 if let Some(®ex_rank) = regex_ranks.get(bm25_result.file_path.as_path()) {
883 let bm25_rank = bm25_ranks[bm25_result.file_path.as_path()];
884 let rrf_score = 1.0 / (k + bm25_rank as f64) + 1.0 / (k + regex_rank as f64);
885 let mut result = bm25_result.clone();
886 result.score = rrf_score;
887 fused.push(result);
888 }
889 }
890
891 fused.sort_by(|a, b| {
893 b.score
894 .partial_cmp(&a.score)
895 .unwrap_or(std::cmp::Ordering::Equal)
896 });
897
898 let mut seen_files: HashSet<PathBuf> = HashSet::new();
900 fused.retain(|r| seen_files.insert(r.file_path.clone()));
901
902 (fused, total_files)
903 }
904 };
905
906 let report_query = match &options.search_mode {
908 SearchMode::Bm25 => query.to_string(),
909 SearchMode::Regex(pattern) => pattern.clone(),
910 SearchMode::Hybrid {
911 query: hybrid_query,
912 ..
913 } => hybrid_query.clone(),
914 };
915
916 let mut enriched = match structure_cache {
918 Some(lookup) => enrich_and_deduplicate_with_cache(&raw_results, root, language, lookup),
919 None => enrich_and_deduplicate(&raw_results, root, language),
920 };
921
922 let has_function_results = enriched.iter().any(|r| r.kind != "module");
924 for result in &mut enriched {
925 if result.kind == "module" {
926 result.score *= if has_function_results { 0.2 } else { 0.5 };
927 }
928 }
929
930 let mut sorted = enriched;
932 sorted.sort_by(|a, b| {
933 b.score
934 .partial_cmp(&a.score)
935 .unwrap_or(std::cmp::Ordering::Equal)
936 .then_with(|| a.file.cmp(&b.file))
937 .then_with(|| a.name.cmp(&b.name))
938 });
939 sorted.truncate(top_k);
940
941 let structure_label = if structure_cache.is_some() {
943 "cached-structure"
944 } else {
945 "structure"
946 };
947
948 match callgraph_cache_path {
949 Some(path) => {
950 let lookup = read_callgraph_cache(path)?;
952 for result in &mut sorted {
953 if result.kind == "module" {
954 continue;
955 }
956 if let Some(callees) = lookup.forward.get(&result.name) {
957 result.callees = callees.clone();
958 result.callees.sort();
959 }
960 if let Some(callers) = lookup.reverse.get(&result.name) {
961 result.callers = callers.clone();
962 result.callers.sort();
963 }
964 }
965 Ok(EnrichedSearchReport {
966 query: report_query,
967 results: sorted,
968 total_files_searched: total_files,
969 search_mode: format!("{}+{}+callgraph", mode_prefix, structure_label),
970 })
971 }
972 None if options.include_callgraph => {
973 let sorted_enriched = try_enrich_with_callgraph(sorted, root, language);
975 Ok(EnrichedSearchReport {
976 query: report_query,
977 results: sorted_enriched,
978 total_files_searched: total_files,
979 search_mode: format!("{}+{}+callgraph", mode_prefix, structure_label),
980 })
981 }
982 None => {
983 Ok(EnrichedSearchReport {
985 query: report_query,
986 results: sorted,
987 total_files_searched: total_files,
988 search_mode: format!("{}+{}", mode_prefix, structure_label),
989 })
990 }
991 }
992}
993
994fn extract_structure_entries(path: &Path, language: Language) -> TldrResult<Vec<StructureEntry>> {
997 let (tree, source, _) = parse_file(path)?;
998 let root_node = tree.root_node();
999 let mut entries = Vec::new();
1000
1001 collect_structure_nodes(root_node, &source, language, &mut entries);
1002
1003 Ok(entries)
1004}
1005
1006fn collect_structure_nodes(
1008 node: tree_sitter::Node,
1009 source: &str,
1010 language: Language,
1011 entries: &mut Vec<StructureEntry>,
1012) {
1013 let kind = node.kind();
1014
1015 let (is_func, is_class) = classify_node(kind, language);
1016
1017 if is_func || is_class {
1018 if let Some(name) = get_definition_name(node, source, language) {
1019 let line_start = node.start_position().row as u32 + 1; let line_end = node.end_position().row as u32 + 1;
1021
1022 let signature = extract_definition_signature(node, source);
1026
1027 let entry_kind = if is_class {
1028 match kind {
1029 "struct_item" | "struct_definition" | "struct_specifier" => "struct",
1030 _ => "class",
1031 }
1032 } else {
1033 if is_inside_class_node(node) {
1035 "method"
1036 } else {
1037 "function"
1038 }
1039 };
1040
1041 let preview = extract_code_preview(node, source, &signature, 5);
1043
1044 entries.push(StructureEntry {
1045 name,
1046 kind: entry_kind.to_string(),
1047 line_start,
1048 line_end,
1049 signature,
1050 preview,
1051 });
1052 }
1053 }
1054
1055 let mut cursor = node.walk();
1057 for child in node.children(&mut cursor) {
1058 collect_structure_nodes(child, source, language, entries);
1059 }
1060}
1061
1062fn classify_node(kind: &str, _language: Language) -> (bool, bool) {
1064 let is_func = matches!(
1065 kind,
1066 "function_definition"
1067 | "function_declaration"
1068 | "function_item" | "method_definition"
1070 | "method_declaration"
1071 | "arrow_function"
1072 | "function_expression"
1073 | "function" | "func_literal" | "function_type"
1076 );
1077
1078 let is_class = matches!(
1079 kind,
1080 "class_definition"
1081 | "class_declaration"
1082 | "struct_item" | "struct_definition" | "struct_specifier" | "type_spec" | "interface_declaration"
1087 );
1088
1089 (is_func, is_class)
1090}
1091
1092fn get_definition_name(
1094 node: tree_sitter::Node,
1095 source: &str,
1096 _language: Language,
1097) -> Option<String> {
1098 if let Some(name_node) = node.child_by_field_name("name") {
1100 let text = name_node.utf8_text(source.as_bytes()).ok()?;
1101 return Some(text.to_string());
1102 }
1103
1104 if node.kind() == "arrow_function" || node.kind() == "function_expression" {
1107 if let Some(parent) = node.parent() {
1108 if parent.kind() == "variable_declarator" {
1109 if let Some(name_node) = parent.child_by_field_name("name") {
1110 let text = name_node.utf8_text(source.as_bytes()).ok()?;
1111 return Some(text.to_string());
1112 }
1113 }
1114 }
1115 }
1116
1117 None
1118}
1119
1120fn is_inside_class_node(node: tree_sitter::Node) -> bool {
1122 let mut current = node.parent();
1123 while let Some(parent) = current {
1124 let kind = parent.kind();
1125 if matches!(
1126 kind,
1127 "class_definition" | "class_declaration" | "class_body" | "impl_item" | "struct_item"
1128 ) {
1129 return true;
1130 }
1131 current = parent.parent();
1132 }
1133 false
1134}
1135
1136fn extract_definition_signature(node: tree_sitter::Node, source: &str) -> String {
1140 let mut cursor = node.walk();
1143 for child in node.children(&mut cursor) {
1144 let ckind = child.kind();
1145 if ckind == "line_comment"
1147 || ckind == "block_comment"
1148 || ckind == "comment"
1149 || ckind == "attribute_item" || ckind == "attribute" || ckind == "decorator" || ckind == "decorator_list"
1153 {
1155 continue;
1156 }
1157 let start_byte = child.start_byte();
1159 let line_from_start = &source[start_byte..];
1161 let sig = line_from_start
1162 .lines()
1163 .next()
1164 .unwrap_or("")
1165 .trim()
1166 .to_string();
1167 if !sig.is_empty() {
1168 return sig;
1169 }
1170 }
1171
1172 let node_text = &source[node.start_byte()..node.end_byte()];
1175 for line in node_text.lines() {
1176 let trimmed = line.trim();
1177 if !trimmed.is_empty()
1178 && !trimmed.starts_with("///")
1179 && !trimmed.starts_with("//!")
1180 && !trimmed.starts_with("//")
1181 && !trimmed.starts_with("/*")
1182 && !trimmed.starts_with("*")
1183 && !trimmed.starts_with("#[")
1184 && !trimmed.starts_with("@")
1185 && !trimmed.starts_with("#")
1186 {
1187 return trimmed.to_string();
1188 }
1189 }
1190
1191 source[node.start_byte()..]
1193 .lines()
1194 .next()
1195 .unwrap_or("")
1196 .trim()
1197 .to_string()
1198}
1199
1200fn extract_code_preview(
1204 node: tree_sitter::Node,
1205 source: &str,
1206 signature: &str,
1207 max_lines: usize,
1208) -> String {
1209 let node_text = &source[node.start_byte()..node.end_byte()];
1210 let mut lines: Vec<&str> = Vec::new();
1211 let mut found_sig = false;
1212
1213 for line in node_text.lines() {
1214 let trimmed = line.trim();
1215 if !found_sig {
1217 if trimmed == signature
1218 || (trimmed.starts_with(&signature[..signature.len().min(20)])
1219 && !trimmed.starts_with("///")
1220 && !trimmed.starts_with("//!"))
1221 {
1222 found_sig = true;
1223 lines.push(line);
1224 }
1225 continue;
1226 }
1227 lines.push(line);
1228 if lines.len() >= max_lines {
1229 break;
1230 }
1231 }
1232
1233 if lines.is_empty() {
1235 for line in node_text.lines() {
1236 let trimmed = line.trim();
1237 if trimmed.is_empty() || trimmed.starts_with("///") || trimmed.starts_with("//!") {
1238 continue;
1239 }
1240 lines.push(line);
1241 if lines.len() >= max_lines {
1242 break;
1243 }
1244 }
1245 }
1246
1247 lines.join("\n")
1248}
1249
1250fn find_enclosing_entry(entries: &[StructureEntry], line: u32) -> Option<&StructureEntry> {
1252 let mut best: Option<&StructureEntry> = None;
1253
1254 for entry in entries {
1255 if line >= entry.line_start && line <= entry.line_end {
1256 match best {
1257 None => best = Some(entry),
1258 Some(current_best) => {
1259 let current_range = current_best.line_end - current_best.line_start;
1261 let new_range = entry.line_end - entry.line_start;
1262 if new_range < current_range {
1263 best = Some(entry);
1264 }
1265 }
1266 }
1267 }
1268 }
1269
1270 best
1271}
1272
1273fn try_enrich_with_callgraph(
1276 mut results: Vec<EnrichedResult>,
1277 root: &Path,
1278 language: Language,
1279) -> Vec<EnrichedResult> {
1280 use crate::callgraph::{build_forward_graph, build_reverse_graph};
1281
1282 let call_graph = match crate::build_project_call_graph(root, language, None, true) {
1284 Ok(graph) => graph,
1285 Err(_) => return results, };
1287
1288 let forward = build_forward_graph(&call_graph);
1289 let reverse = build_reverse_graph(&call_graph);
1290
1291 for result in &mut results {
1294 if result.kind == "module" {
1295 continue; }
1297
1298 let result_file = result.file.to_string_lossy();
1299
1300 let mut found_callees = false;
1302 for (func_ref, callees) in &forward {
1303 let ref_file = func_ref.file.to_string_lossy();
1304 if func_ref.name == result.name
1305 && (ref_file.is_empty()
1306 || result_file.is_empty()
1307 || ref_file.ends_with(result_file.as_ref())
1308 || result_file.ends_with(ref_file.as_ref()))
1309 {
1310 result.callees = callees.iter().map(|f| f.name.clone()).collect();
1311 result.callees.sort();
1312 found_callees = true;
1313 break;
1314 }
1315 }
1316 if !found_callees {
1318 for (func_ref, callees) in &forward {
1319 if func_ref.name == result.name {
1320 result.callees = callees.iter().map(|f| f.name.clone()).collect();
1321 result.callees.sort();
1322 break;
1323 }
1324 }
1325 }
1326
1327 let mut found_callers = false;
1329 for (func_ref, callers) in &reverse {
1330 let ref_file = func_ref.file.to_string_lossy();
1331 if func_ref.name == result.name
1332 && (ref_file.is_empty()
1333 || result_file.is_empty()
1334 || ref_file.ends_with(result_file.as_ref())
1335 || result_file.ends_with(ref_file.as_ref()))
1336 {
1337 result.callers = callers.iter().map(|f| f.name.clone()).collect();
1338 result.callers.sort();
1339 found_callers = true;
1340 break;
1341 }
1342 }
1343 if !found_callers {
1344 for (func_ref, callers) in &reverse {
1345 if func_ref.name == result.name {
1346 result.callers = callers.iter().map(|f| f.name.clone()).collect();
1347 result.callers.sort();
1348 break;
1349 }
1350 }
1351 }
1352 }
1353
1354 results
1355}
1356
1357#[cfg(test)]
1362mod tests {
1363 use super::*;
1364 use std::fs;
1365 use tempfile::TempDir;
1366
1367 fn opts(top_k: usize) -> EnrichedSearchOptions {
1369 EnrichedSearchOptions {
1370 top_k,
1371 include_callgraph: false,
1372 search_mode: SearchMode::Bm25,
1373 }
1374 }
1375
1376 fn create_test_project() -> (TempDir, PathBuf) {
1381 let dir = TempDir::new().unwrap();
1382 let project = dir.path().join("project");
1383 fs::create_dir(&project).unwrap();
1384
1385 fs::write(
1387 project.join("auth.py"),
1388 r#"
1389def verify_jwt_token(request):
1390 """Verify JWT token from request headers."""
1391 token = request.headers.get("Authorization")
1392 if not token:
1393 raise AuthError("Missing token")
1394 claims = decode_token(token)
1395 check_expiry(claims)
1396 return claims
1397
1398def decode_token(token):
1399 """Decode a JWT token string."""
1400 import jwt
1401 return jwt.decode(token, key="secret")
1402
1403def check_expiry(claims):
1404 """Check if token has expired."""
1405 if claims["exp"] < time.time():
1406 raise AuthError("Token expired")
1407
1408class AuthMiddleware:
1409 """Middleware for authentication."""
1410 def __init__(self, app):
1411 self.app = app
1412
1413 def process_request(self, request):
1414 """Process incoming request for auth."""
1415 verify_jwt_token(request)
1416 return self.app(request)
1417"#,
1418 )
1419 .unwrap();
1420
1421 fs::write(
1423 project.join("routes.py"),
1424 r#"
1425def user_routes(app):
1426 """Register user routes."""
1427 @app.route("/users")
1428 def list_users():
1429 return get_all_users()
1430
1431def admin_routes(app):
1432 """Register admin routes."""
1433 @app.route("/admin")
1434 def admin_panel():
1435 return render_admin()
1436
1437def get_all_users():
1438 """Fetch all users from database."""
1439 return db.query("SELECT * FROM users")
1440
1441def render_admin():
1442 """Render admin panel."""
1443 return template.render("admin.html")
1444"#,
1445 )
1446 .unwrap();
1447
1448 fs::write(
1450 project.join("utils.py"),
1451 r#"
1452def format_date(dt):
1453 """Format a datetime object."""
1454 return dt.strftime("%Y-%m-%d")
1455
1456def parse_json(text):
1457 """Parse JSON string."""
1458 import json
1459 return json.loads(text)
1460"#,
1461 )
1462 .unwrap();
1463
1464 (dir, project)
1465 }
1466
1467 #[test]
1472 fn test_enriched_result_has_required_fields() {
1473 let result = EnrichedResult {
1474 name: "verify_jwt_token".to_string(),
1475 kind: "function".to_string(),
1476 file: PathBuf::from("auth.py"),
1477 line_range: (2, 9),
1478 signature: "def verify_jwt_token(request):".to_string(),
1479 callers: vec!["process_request".to_string()],
1480 callees: vec!["decode_token".to_string(), "check_expiry".to_string()],
1481 score: 0.94,
1482 matched_terms: vec!["verify".to_string(), "jwt".to_string(), "token".to_string()],
1483 preview: String::new(),
1484 };
1485
1486 assert_eq!(result.name, "verify_jwt_token");
1487 assert_eq!(result.kind, "function");
1488 assert_eq!(result.line_range.0, 2);
1489 assert!(result.score > 0.0);
1490 assert_eq!(result.callers.len(), 1);
1491 assert_eq!(result.callees.len(), 2);
1492 }
1493
1494 #[test]
1495 fn test_enriched_result_serializes_to_json() {
1496 let result = EnrichedResult {
1497 name: "test_func".to_string(),
1498 kind: "function".to_string(),
1499 file: PathBuf::from("test.py"),
1500 line_range: (1, 5),
1501 signature: "def test_func():".to_string(),
1502 callers: Vec::new(),
1503 callees: Vec::new(),
1504 score: 0.5,
1505 matched_terms: vec!["test".to_string()],
1506 preview: String::new(),
1507 };
1508
1509 let json = serde_json::to_string(&result).unwrap();
1510 assert!(json.contains("test_func"));
1511 assert!(json.contains("function"));
1512 }
1513
1514 #[test]
1515 fn test_enriched_search_report_has_metadata() {
1516 let report = EnrichedSearchReport {
1517 query: "authentication".to_string(),
1518 results: Vec::new(),
1519 total_files_searched: 42,
1520 search_mode: "bm25+structure".to_string(),
1521 };
1522
1523 assert_eq!(report.query, "authentication");
1524 assert_eq!(report.total_files_searched, 42);
1525 assert_eq!(report.search_mode, "bm25+structure");
1526 }
1527
1528 #[test]
1533 fn test_bm25_index_finds_test_files() {
1534 let (_dir, root) = create_test_project();
1535
1536 let index = Bm25Index::from_project(&root, Language::Python).unwrap();
1537 assert!(
1538 index.document_count() >= 3,
1539 "Should index at least 3 .py files, got {}",
1540 index.document_count()
1541 );
1542
1543 let raw = index.search("jwt token", 10);
1544 assert!(!raw.is_empty(), "BM25 should find results for 'jwt token'");
1545 }
1546
1547 #[test]
1548 fn test_enriched_search_returns_results_for_matching_query() {
1549 let (_dir, root) = create_test_project();
1550 let report =
1551 enriched_search("jwt token verify", &root, Language::Python, opts(10)).unwrap();
1552
1553 assert!(
1554 !report.results.is_empty(),
1555 "Should find results for 'jwt token verify'"
1556 );
1557 assert!(report.total_files_searched > 0);
1558 assert_eq!(report.search_mode, "bm25+structure");
1559 }
1560
1561 #[test]
1562 fn test_enriched_search_empty_query_returns_empty() {
1563 let (_dir, root) = create_test_project();
1564 let report = enriched_search("", &root, Language::Python, opts(10)).unwrap();
1565
1566 assert!(
1567 report.results.is_empty(),
1568 "Empty query should return no results"
1569 );
1570 }
1571
1572 #[test]
1573 fn test_enriched_search_no_match_returns_empty() {
1574 let (_dir, root) = create_test_project();
1575 let report =
1576 enriched_search("xyznonexistent123", &root, Language::Python, opts(10)).unwrap();
1577
1578 assert!(
1579 report.results.is_empty(),
1580 "Non-matching query should return no results"
1581 );
1582 }
1583
1584 #[test]
1585 fn test_enriched_search_results_have_function_names() {
1586 let (_dir, root) = create_test_project();
1587 let report = enriched_search("jwt token", &root, Language::Python, opts(10)).unwrap();
1588
1589 let names: Vec<&str> = report.results.iter().map(|r| r.name.as_str()).collect();
1591 let has_func = names
1593 .iter()
1594 .any(|n| *n == "verify_jwt_token" || *n == "decode_token" || *n == "check_expiry");
1595 assert!(has_func, "Should find function names, got: {:?}", names);
1596 }
1597
1598 #[test]
1599 fn test_enriched_search_results_have_signatures() {
1600 let (_dir, root) = create_test_project();
1601 let report = enriched_search("verify jwt", &root, Language::Python, opts(10)).unwrap();
1602
1603 for result in &report.results {
1604 if result.kind == "function" || result.kind == "method" {
1605 assert!(
1606 !result.signature.is_empty(),
1607 "Function '{}' should have a signature",
1608 result.name
1609 );
1610 }
1611 }
1612 }
1613
1614 #[test]
1615 fn test_enriched_search_results_have_line_ranges() {
1616 let (_dir, root) = create_test_project();
1617 let report = enriched_search("decode token", &root, Language::Python, opts(10)).unwrap();
1618
1619 for result in &report.results {
1620 assert!(
1621 result.line_range.0 > 0,
1622 "Line start should be > 0 (1-indexed)"
1623 );
1624 assert!(
1625 result.line_range.1 >= result.line_range.0,
1626 "Line end should be >= line start"
1627 );
1628 }
1629 }
1630
1631 #[test]
1632 fn test_enriched_search_deduplicates_same_function() {
1633 let (_dir, root) = create_test_project();
1634 let report = enriched_search("token", &root, Language::Python, opts(20)).unwrap();
1636
1637 let count = report
1639 .results
1640 .iter()
1641 .filter(|r| r.name == "verify_jwt_token")
1642 .count();
1643
1644 assert!(
1645 count <= 1,
1646 "verify_jwt_token should appear at most once (deduplication), found {}",
1647 count
1648 );
1649 }
1650
1651 #[test]
1652 fn test_enriched_search_respects_top_k() {
1653 let (_dir, root) = create_test_project();
1654 let report = enriched_search("def", &root, Language::Python, opts(3)).unwrap();
1655
1656 assert!(
1657 report.results.len() <= 3,
1658 "Should respect top_k=3, got {} results",
1659 report.results.len()
1660 );
1661 }
1662
1663 #[test]
1664 fn test_enriched_search_results_sorted_by_score() {
1665 let (_dir, root) = create_test_project();
1666 let report = enriched_search("token", &root, Language::Python, opts(10)).unwrap();
1667
1668 if report.results.len() > 1 {
1669 for i in 0..report.results.len() - 1 {
1670 assert!(
1671 report.results[i].score >= report.results[i + 1].score,
1672 "Results should be sorted by score descending: {} >= {}",
1673 report.results[i].score,
1674 report.results[i + 1].score
1675 );
1676 }
1677 }
1678 }
1679
1680 #[test]
1681 fn test_enriched_search_has_matched_terms() {
1682 let (_dir, root) = create_test_project();
1683 let report = enriched_search("jwt token", &root, Language::Python, opts(10)).unwrap();
1684
1685 for result in &report.results {
1686 assert!(
1687 !result.matched_terms.is_empty(),
1688 "Result '{}' should have at least one matched term",
1689 result.name
1690 );
1691 }
1692 }
1693
1694 #[test]
1695 fn test_enriched_search_finds_classes() {
1696 let (_dir, root) = create_test_project();
1697 let report = enriched_search("AuthMiddleware", &root, Language::Python, opts(10)).unwrap();
1698
1699 let has_class = report.results.iter().any(|r| r.kind == "class");
1700 assert!(
1701 has_class,
1702 "Should find class-level results for 'AuthMiddleware'"
1703 );
1704 }
1705
1706 #[test]
1707 fn test_enriched_search_finds_methods() {
1708 let (_dir, root) = create_test_project();
1709 let report = enriched_search("process_request", &root, Language::Python, opts(10)).unwrap();
1710
1711 let has_method = report.results.iter().any(|r| r.kind == "method");
1712 assert!(
1713 has_method,
1714 "Should find method-level results for 'process_request'"
1715 );
1716 }
1717
1718 #[test]
1723 fn test_extract_structure_entries_finds_functions() {
1724 let (_dir, root) = create_test_project();
1725 let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
1726
1727 let func_names: Vec<&str> = entries.iter().map(|e| e.name.as_str()).collect();
1728 assert!(
1729 func_names.contains(&"verify_jwt_token"),
1730 "Should find verify_jwt_token, got: {:?}",
1731 func_names
1732 );
1733 assert!(
1734 func_names.contains(&"decode_token"),
1735 "Should find decode_token, got: {:?}",
1736 func_names
1737 );
1738 }
1739
1740 #[test]
1741 fn test_extract_structure_entries_finds_classes() {
1742 let (_dir, root) = create_test_project();
1743 let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
1744
1745 let class_names: Vec<&str> = entries
1746 .iter()
1747 .filter(|e| e.kind == "class")
1748 .map(|e| e.name.as_str())
1749 .collect();
1750 assert!(
1751 class_names.contains(&"AuthMiddleware"),
1752 "Should find AuthMiddleware class, got: {:?}",
1753 class_names
1754 );
1755 }
1756
1757 #[test]
1758 fn test_extract_structure_entries_has_line_ranges() {
1759 let (_dir, root) = create_test_project();
1760 let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
1761
1762 for entry in &entries {
1763 assert!(entry.line_start > 0, "Line start should be 1-indexed");
1764 assert!(
1765 entry.line_end >= entry.line_start,
1766 "Line end should be >= line start for {}",
1767 entry.name
1768 );
1769 }
1770 }
1771
1772 #[test]
1773 fn test_extract_structure_entries_has_signatures() {
1774 let (_dir, root) = create_test_project();
1775 let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
1776
1777 let verify = entries
1778 .iter()
1779 .find(|e| e.name == "verify_jwt_token")
1780 .unwrap();
1781 assert!(
1782 verify.signature.contains("def verify_jwt_token"),
1783 "Signature should contain function definition, got: '{}'",
1784 verify.signature
1785 );
1786 }
1787
1788 #[test]
1793 fn test_find_enclosing_entry_returns_innermost() {
1794 let entries = vec![
1795 StructureEntry {
1796 name: "OuterClass".to_string(),
1797 kind: "class".to_string(),
1798 line_start: 1,
1799 line_end: 20,
1800 signature: "class OuterClass:".to_string(),
1801 preview: String::new(),
1802 },
1803 StructureEntry {
1804 name: "inner_method".to_string(),
1805 kind: "method".to_string(),
1806 line_start: 5,
1807 line_end: 10,
1808 signature: "def inner_method(self):".to_string(),
1809 preview: String::new(),
1810 },
1811 ];
1812
1813 let result = find_enclosing_entry(&entries, 7);
1814 assert!(result.is_some());
1815 assert_eq!(result.unwrap().name, "inner_method");
1816 }
1817
1818 #[test]
1819 fn test_find_enclosing_entry_returns_none_outside() {
1820 let entries = vec![StructureEntry {
1821 name: "some_func".to_string(),
1822 kind: "function".to_string(),
1823 line_start: 10,
1824 line_end: 20,
1825 signature: "def some_func():".to_string(),
1826 preview: String::new(),
1827 }];
1828
1829 let result = find_enclosing_entry(&entries, 5);
1830 assert!(result.is_none());
1831 }
1832
1833 #[test]
1838 fn test_enriched_search_on_empty_directory() {
1839 let dir = TempDir::new().unwrap();
1840 let empty = dir.path().join("empty_project");
1841 fs::create_dir(&empty).unwrap();
1842 let report = enriched_search("anything", &empty, Language::Python, opts(10)).unwrap();
1843
1844 assert!(report.results.is_empty());
1845 assert_eq!(report.total_files_searched, 0);
1846 }
1847
1848 #[test]
1849 fn test_enriched_search_report_query_preserved() {
1850 let (_dir, root) = create_test_project();
1851 let report = enriched_search(
1852 "authentication middleware",
1853 &root,
1854 Language::Python,
1855 opts(10),
1856 )
1857 .unwrap();
1858
1859 assert_eq!(report.query, "authentication middleware");
1860 }
1861
1862 #[test]
1883 fn test_perf_enriched_search_repeated_calls_under_200ms() {
1884 let (_dir, root) = create_test_project();
1885 let query = "jwt token verify";
1886
1887 let _ = enriched_search(query, &root, Language::Python, opts(10)).unwrap();
1889
1890 let mut durations = Vec::new();
1892 for _ in 0..2 {
1893 let start = std::time::Instant::now();
1894 let report = enriched_search(query, &root, Language::Python, opts(10)).unwrap();
1895 let elapsed = start.elapsed();
1896 durations.push(elapsed);
1897
1898 assert!(!report.results.is_empty(), "Should find results");
1900 }
1901
1902 let index = Bm25Index::from_project(&root, Language::Python).unwrap();
1917 let _cached_report =
1918 enriched_search_with_index(query, &root, Language::Python, opts(10), &index).unwrap();
1919
1920 assert!(
1921 !_cached_report.results.is_empty(),
1922 "Cached search should find results"
1923 );
1924
1925 let start = std::time::Instant::now();
1926 let _cached_report2 =
1927 enriched_search_with_index(query, &root, Language::Python, opts(10), &index).unwrap();
1928 let cached_elapsed = start.elapsed();
1929
1930 assert!(
1931 cached_elapsed.as_millis() < 200,
1932 "Cached enriched_search should complete in < 200ms, took {}ms",
1933 cached_elapsed.as_millis()
1934 );
1935
1936 for d in &durations {
1938 eprintln!(" enriched_search call took: {:?}", d);
1939 }
1940 }
1941
1942 #[test]
1947 fn test_read_callgraph_cache_builds_forward_map() {
1948 let dir = tempfile::TempDir::new().unwrap();
1949 let cache_path = dir.path().join("call_graph.json");
1950 fs::write(
1951 &cache_path,
1952 r#"{
1953 "edges": [
1954 {"from_file": "a.py", "from_func": "foo", "to_file": "a.py", "to_func": "bar"},
1955 {"from_file": "a.py", "from_func": "foo", "to_file": "b.py", "to_func": "baz"}
1956 ],
1957 "languages": ["python"],
1958 "timestamp": 1740000000
1959 }"#,
1960 )
1961 .unwrap();
1962
1963 let lookup = read_callgraph_cache(&cache_path).unwrap();
1964 let callees = lookup.forward.get("foo").unwrap();
1965 assert!(callees.contains(&"bar".to_string()));
1966 assert!(callees.contains(&"baz".to_string()));
1967 }
1968
1969 #[test]
1970 fn test_read_callgraph_cache_builds_reverse_map() {
1971 let dir = tempfile::TempDir::new().unwrap();
1972 let cache_path = dir.path().join("call_graph.json");
1973 fs::write(
1974 &cache_path,
1975 r#"{
1976 "edges": [
1977 {"from_file": "a.py", "from_func": "foo", "to_file": "a.py", "to_func": "bar"},
1978 {"from_file": "b.py", "from_func": "qux", "to_file": "a.py", "to_func": "bar"}
1979 ],
1980 "languages": ["python"],
1981 "timestamp": 1740000000
1982 }"#,
1983 )
1984 .unwrap();
1985
1986 let lookup = read_callgraph_cache(&cache_path).unwrap();
1987 let callers = lookup.reverse.get("bar").unwrap();
1988 assert!(callers.contains(&"foo".to_string()));
1989 assert!(callers.contains(&"qux".to_string()));
1990 }
1991
1992 #[test]
1993 fn test_read_callgraph_cache_empty_edges() {
1994 let dir = tempfile::TempDir::new().unwrap();
1995 let cache_path = dir.path().join("call_graph.json");
1996 fs::write(
1997 &cache_path,
1998 r#"{
1999 "edges": [],
2000 "languages": ["python"],
2001 "timestamp": 1740000000
2002 }"#,
2003 )
2004 .unwrap();
2005
2006 let lookup = read_callgraph_cache(&cache_path).unwrap();
2007 assert!(lookup.forward.is_empty());
2008 assert!(lookup.reverse.is_empty());
2009 }
2010
2011 #[test]
2012 fn test_read_callgraph_cache_invalid_json_returns_error() {
2013 let dir = tempfile::TempDir::new().unwrap();
2014 let cache_path = dir.path().join("call_graph.json");
2015 fs::write(&cache_path, "not valid json").unwrap();
2016
2017 let result = read_callgraph_cache(&cache_path);
2018 assert!(result.is_err());
2019 }
2020
2021 #[test]
2022 fn test_read_callgraph_cache_missing_file_returns_error() {
2023 let result = read_callgraph_cache(Path::new("/nonexistent/path/call_graph.json"));
2024 assert!(result.is_err());
2025 }
2026
2027 #[test]
2028 fn test_enriched_search_with_callgraph_cache_populates_callers_callees() {
2029 let (_dir, root) = create_test_project();
2030
2031 let cache_dir = root.join(".tldr").join("cache");
2033 fs::create_dir_all(&cache_dir).unwrap();
2034 let cache_path = cache_dir.join("call_graph.json");
2035 fs::write(&cache_path, r#"{
2036 "edges": [
2037 {"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "decode_token"},
2038 {"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "check_expiry"},
2039 {"from_file": "auth.py", "from_func": "process_request", "to_file": "auth.py", "to_func": "verify_jwt_token"}
2040 ],
2041 "languages": ["python"],
2042 "timestamp": 1740000000
2043 }"#).unwrap();
2044
2045 let options = EnrichedSearchOptions {
2046 top_k: 10,
2047 include_callgraph: true,
2048 search_mode: SearchMode::Bm25,
2049 };
2050 let report = enriched_search_with_callgraph_cache(
2051 "jwt token verify",
2052 &root,
2053 Language::Python,
2054 options,
2055 &cache_path,
2056 )
2057 .unwrap();
2058
2059 assert!(!report.results.is_empty());
2060 assert_eq!(report.search_mode, "bm25+structure+callgraph");
2061
2062 if let Some(verify) = report.results.iter().find(|r| r.name == "verify_jwt_token") {
2064 assert!(
2065 verify.callees.contains(&"decode_token".to_string()),
2066 "verify_jwt_token should call decode_token, got: {:?}",
2067 verify.callees
2068 );
2069 assert!(
2070 verify.callees.contains(&"check_expiry".to_string()),
2071 "verify_jwt_token should call check_expiry, got: {:?}",
2072 verify.callees
2073 );
2074 assert!(
2075 verify.callers.contains(&"process_request".to_string()),
2076 "verify_jwt_token should be called by process_request, got: {:?}",
2077 verify.callers
2078 );
2079 }
2080 }
2081
2082 #[test]
2083 fn test_enriched_search_with_callgraph_cache_sorts_callers_callees() {
2084 let (_dir, root) = create_test_project();
2085
2086 let cache_dir = root.join(".tldr").join("cache");
2087 fs::create_dir_all(&cache_dir).unwrap();
2088 let cache_path = cache_dir.join("call_graph.json");
2089 fs::write(&cache_path, r#"{
2090 "edges": [
2091 {"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "decode_token"},
2092 {"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "check_expiry"}
2093 ],
2094 "languages": ["python"],
2095 "timestamp": 1740000000
2096 }"#).unwrap();
2097
2098 let options = EnrichedSearchOptions {
2099 top_k: 10,
2100 include_callgraph: true,
2101 search_mode: SearchMode::Bm25,
2102 };
2103 let report = enriched_search_with_callgraph_cache(
2104 "verify jwt token",
2105 &root,
2106 Language::Python,
2107 options,
2108 &cache_path,
2109 )
2110 .unwrap();
2111
2112 if let Some(verify) = report.results.iter().find(|r| r.name == "verify_jwt_token") {
2113 let mut expected = verify.callees.clone();
2115 expected.sort();
2116 assert_eq!(
2117 verify.callees, expected,
2118 "Callees should be sorted alphabetically"
2119 );
2120 }
2121 }
2122}