1use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6
7use async_trait::async_trait;
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tree_sitter::{Parser, QueryCursor, StreamingIterator};
11use zeph_index::languages::detect_language;
12
13use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
14use crate::registry::{InvocationHint, ToolDef};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SearchCodeSource {
18 Semantic,
19 Structural,
20 LspSymbol,
21 LspReferences,
22 GrepFallback,
23}
24
25impl SearchCodeSource {
26 fn label(self) -> &'static str {
27 match self {
28 Self::Semantic => "vector search",
29 Self::Structural => "tree-sitter",
30 Self::LspSymbol => "LSP symbol search",
31 Self::LspReferences => "LSP references",
32 Self::GrepFallback => "grep fallback",
33 }
34 }
35
36 #[must_use]
37 pub fn default_score(self) -> f32 {
38 match self {
39 Self::Structural => 0.98,
40 Self::LspSymbol => 0.95,
41 Self::LspReferences => 0.90,
42 Self::Semantic => 0.75,
43 Self::GrepFallback => 0.45,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
49pub struct SearchCodeHit {
50 pub file_path: String,
51 pub line_start: usize,
52 pub line_end: usize,
53 pub snippet: String,
54 pub source: SearchCodeSource,
55 pub score: f32,
56 pub symbol_name: Option<String>,
57}
58
59#[async_trait]
60pub trait SemanticSearchBackend: Send + Sync {
61 async fn search(
62 &self,
63 query: &str,
64 file_pattern: Option<&str>,
65 max_results: usize,
66 ) -> Result<Vec<SearchCodeHit>, ToolError>;
67}
68
69#[async_trait]
70pub trait LspSearchBackend: Send + Sync {
71 async fn workspace_symbol(
72 &self,
73 symbol: &str,
74 file_pattern: Option<&str>,
75 max_results: usize,
76 ) -> Result<Vec<SearchCodeHit>, ToolError>;
77
78 async fn references(
79 &self,
80 symbol: &str,
81 file_pattern: Option<&str>,
82 max_results: usize,
83 ) -> Result<Vec<SearchCodeHit>, ToolError>;
84}
85
86#[derive(Deserialize, JsonSchema)]
87struct SearchCodeParams {
88 #[serde(default)]
90 query: Option<String>,
91 #[serde(default)]
93 symbol: Option<String>,
94 #[serde(default)]
96 file_pattern: Option<String>,
97 #[serde(default)]
99 include_references: bool,
100 #[serde(default = "default_max_results")]
102 max_results: usize,
103}
104
105const fn default_max_results() -> usize {
106 10
107}
108
109pub struct SearchCodeExecutor {
110 allowed_paths: Vec<PathBuf>,
111 semantic_backend: Option<std::sync::Arc<dyn SemanticSearchBackend>>,
112 lsp_backend: Option<std::sync::Arc<dyn LspSearchBackend>>,
113}
114
115impl std::fmt::Debug for SearchCodeExecutor {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.debug_struct("SearchCodeExecutor")
118 .field("allowed_paths", &self.allowed_paths)
119 .field("has_semantic_backend", &self.semantic_backend.is_some())
120 .field("has_lsp_backend", &self.lsp_backend.is_some())
121 .finish()
122 }
123}
124
125impl SearchCodeExecutor {
126 #[must_use]
127 pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
128 let paths = if allowed_paths.is_empty() {
129 vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
130 } else {
131 allowed_paths
132 };
133 Self {
134 allowed_paths: paths
135 .into_iter()
136 .map(|p| p.canonicalize().unwrap_or(p))
137 .collect(),
138 semantic_backend: None,
139 lsp_backend: None,
140 }
141 }
142
143 #[must_use]
144 pub fn with_semantic_backend(
145 mut self,
146 backend: std::sync::Arc<dyn SemanticSearchBackend>,
147 ) -> Self {
148 self.semantic_backend = Some(backend);
149 self
150 }
151
152 #[must_use]
153 pub fn with_lsp_backend(mut self, backend: std::sync::Arc<dyn LspSearchBackend>) -> Self {
154 self.lsp_backend = Some(backend);
155 self
156 }
157
158 async fn handle_search_code(
159 &self,
160 params: &SearchCodeParams,
161 ) -> Result<Option<ToolOutput>, ToolError> {
162 let query = params
163 .query
164 .as_deref()
165 .map(str::trim)
166 .filter(|s| !s.is_empty());
167 let symbol = params
168 .symbol
169 .as_deref()
170 .map(str::trim)
171 .filter(|s| !s.is_empty());
172
173 if query.is_none() && symbol.is_none() {
174 return Err(ToolError::InvalidParams {
175 message: "at least one of `query` or `symbol` must be provided".into(),
176 });
177 }
178
179 let max_results = params.max_results.clamp(1, 50);
180 let mut hits = Vec::new();
181
182 if let Some(query) = query
183 && let Some(backend) = &self.semantic_backend
184 {
185 hits.extend(
186 backend
187 .search(query, params.file_pattern.as_deref(), max_results)
188 .await?,
189 );
190 }
191
192 if let Some(symbol) = symbol {
193 hits.extend(self.structural_search(
194 symbol,
195 params.file_pattern.as_deref(),
196 max_results,
197 )?);
198
199 if let Some(backend) = &self.lsp_backend {
200 if let Ok(lsp_hits) = backend
201 .workspace_symbol(symbol, params.file_pattern.as_deref(), max_results)
202 .await
203 {
204 hits.extend(lsp_hits);
205 }
206 if params.include_references
207 && let Ok(lsp_refs) = backend
208 .references(symbol, params.file_pattern.as_deref(), max_results)
209 .await
210 {
211 hits.extend(lsp_refs);
212 }
213 }
214 }
215
216 if hits.is_empty() {
217 let fallback_term = symbol.or(query).unwrap_or_default();
218 hits.extend(self.grep_fallback(
219 fallback_term,
220 params.file_pattern.as_deref(),
221 max_results,
222 )?);
223 }
224
225 let merged = dedupe_hits(hits, max_results);
226 let root = self
227 .allowed_paths
228 .first()
229 .map_or(Path::new("."), PathBuf::as_path);
230 let summary = format_hits(&merged, root);
231 let locations = merged
232 .iter()
233 .map(|hit| hit.file_path.clone())
234 .collect::<Vec<_>>();
235 let raw_response = serde_json::json!({
236 "results": merged.iter().map(|hit| {
237 serde_json::json!({
238 "file_path": hit.file_path,
239 "line_start": hit.line_start,
240 "line_end": hit.line_end,
241 "snippet": hit.snippet,
242 "source": hit.source.label(),
243 "score": hit.score,
244 "symbol_name": hit.symbol_name,
245 })
246 }).collect::<Vec<_>>()
247 });
248
249 Ok(Some(ToolOutput {
250 tool_name: "search_code".to_owned(),
251 summary,
252 blocks_executed: 1,
253 filter_stats: None,
254 diff: None,
255 streamed: false,
256 terminal_id: None,
257 locations: Some(locations),
258 raw_response: Some(raw_response),
259 }))
260 }
261
262 fn structural_search(
263 &self,
264 symbol: &str,
265 file_pattern: Option<&str>,
266 max_results: usize,
267 ) -> Result<Vec<SearchCodeHit>, ToolError> {
268 let matcher = file_pattern
269 .map(glob::Pattern::new)
270 .transpose()
271 .map_err(|e| ToolError::InvalidParams {
272 message: format!("invalid file_pattern: {e}"),
273 })?;
274 let mut hits = Vec::new();
275 let symbol_lower = symbol.to_lowercase();
276
277 for root in &self.allowed_paths {
278 collect_structural_hits(root, root, matcher.as_ref(), &symbol_lower, &mut hits)?;
279 if hits.len() >= max_results {
280 break;
281 }
282 }
283
284 Ok(hits)
285 }
286
287 fn grep_fallback(
288 &self,
289 pattern: &str,
290 file_pattern: Option<&str>,
291 max_results: usize,
292 ) -> Result<Vec<SearchCodeHit>, ToolError> {
293 let matcher = file_pattern
294 .map(glob::Pattern::new)
295 .transpose()
296 .map_err(|e| ToolError::InvalidParams {
297 message: format!("invalid file_pattern: {e}"),
298 })?;
299 let escaped = regex::escape(pattern);
300 let regex = regex::RegexBuilder::new(&escaped)
301 .case_insensitive(true)
302 .build()
303 .map_err(|e| ToolError::InvalidParams {
304 message: e.to_string(),
305 })?;
306 let mut hits = Vec::new();
307 for root in &self.allowed_paths {
308 collect_grep_hits(root, root, matcher.as_ref(), ®ex, &mut hits, max_results)?;
309 if hits.len() >= max_results {
310 break;
311 }
312 }
313 Ok(hits)
314 }
315}
316
317impl ToolExecutor for SearchCodeExecutor {
318 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
319 Ok(None)
320 }
321
322 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
323 if call.tool_id != "search_code" {
324 return Ok(None);
325 }
326 let params: SearchCodeParams = deserialize_params(&call.params)?;
327 self.handle_search_code(¶ms).await
328 }
329
330 fn tool_definitions(&self) -> Vec<ToolDef> {
331 vec![ToolDef {
332 id: "search_code".into(),
333 description: "Search the codebase using semantic, structural, and LSP sources.\n\nParameters: query (string, optional) - natural language description to find semantically similar code; symbol (string, optional) - exact or partial symbol name for definition search; file_pattern (string, optional) - glob restricting files; include_references (boolean, optional) - also return symbol references when LSP is available; max_results (integer, optional) - cap results 1-50, default 10\nReturns: ranked code locations with file path, line range, snippet, source label, and score\nErrors: InvalidParams when both query and symbol are empty\nExample: {\"query\": \"where is retry backoff calculated\", \"symbol\": \"retry_backoff_ms\", \"include_references\": true}".into(),
334 schema: schemars::schema_for!(SearchCodeParams),
335 invocation: InvocationHint::ToolCall,
336 }]
337 }
338}
339
340fn dedupe_hits(mut hits: Vec<SearchCodeHit>, max_results: usize) -> Vec<SearchCodeHit> {
341 let mut merged: HashMap<(String, usize, usize), SearchCodeHit> = HashMap::new();
342 for hit in hits.drain(..) {
343 let key = (hit.file_path.clone(), hit.line_start, hit.line_end);
344 merged
345 .entry(key)
346 .and_modify(|existing| {
347 if hit.score > existing.score {
348 existing.score = hit.score;
349 existing.snippet.clone_from(&hit.snippet);
350 existing.symbol_name = hit.symbol_name.clone().or(existing.symbol_name.clone());
351 }
352 if existing.source != hit.source {
353 existing.source = if existing.score >= hit.score {
354 existing.source
355 } else {
356 hit.source
357 };
358 }
359 })
360 .or_insert(hit);
361 }
362
363 let mut merged = merged.into_values().collect::<Vec<_>>();
364 merged.sort_by(|a, b| {
365 b.score
366 .partial_cmp(&a.score)
367 .unwrap_or(std::cmp::Ordering::Equal)
368 .then_with(|| a.file_path.cmp(&b.file_path))
369 .then_with(|| a.line_start.cmp(&b.line_start))
370 });
371 merged.truncate(max_results);
372 merged
373}
374
375fn format_hits(hits: &[SearchCodeHit], root: &Path) -> String {
376 if hits.is_empty() {
377 return "No code matches found.".into();
378 }
379
380 hits.iter()
381 .enumerate()
382 .map(|(idx, hit)| {
383 let display_path = Path::new(&hit.file_path)
384 .strip_prefix(root)
385 .map_or_else(|_| hit.file_path.clone(), |p| p.display().to_string());
386 format!(
387 "[{}] {}:{}-{}\n {}\n source: {}\n score: {:.2}",
388 idx + 1,
389 display_path,
390 hit.line_start,
391 hit.line_end,
392 hit.snippet.replace('\n', " "),
393 hit.source.label(),
394 hit.score,
395 )
396 })
397 .collect::<Vec<_>>()
398 .join("\n\n")
399}
400
401fn collect_structural_hits(
402 root: &Path,
403 current: &Path,
404 matcher: Option<&glob::Pattern>,
405 symbol_lower: &str,
406 hits: &mut Vec<SearchCodeHit>,
407) -> Result<(), ToolError> {
408 if should_skip_path(current) {
409 return Ok(());
410 }
411
412 let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
413 for entry in entries {
414 let entry = entry.map_err(ToolError::Execution)?;
415 let path = entry.path();
416 if path.is_dir() {
417 collect_structural_hits(root, &path, matcher, symbol_lower, hits)?;
418 continue;
419 }
420 if !matches_pattern(root, &path, matcher) {
421 continue;
422 }
423 let Some(lang) = detect_language(&path) else {
424 continue;
425 };
426 let Some(grammar) = lang.grammar() else {
427 continue;
428 };
429 let Some(query) = lang.symbol_query() else {
430 continue;
431 };
432 let Ok(source) = std::fs::read_to_string(&path) else {
433 continue;
434 };
435 let mut parser = Parser::new();
436 if parser.set_language(&grammar).is_err() {
437 continue;
438 }
439 let Some(tree) = parser.parse(&source, None) else {
440 continue;
441 };
442 let mut cursor = QueryCursor::new();
443 let capture_names = query.capture_names();
444 let def_idx = capture_names.iter().position(|name| *name == "def");
445 let name_idx = capture_names.iter().position(|name| *name == "name");
446 let (Some(def_idx), Some(name_idx)) = (def_idx, name_idx) else {
447 continue;
448 };
449
450 let mut query_matches = cursor.matches(query, tree.root_node(), source.as_bytes());
451 while let Some(match_) = query_matches.next() {
452 let mut def_node = None;
453 let mut name = None;
454 for capture in match_.captures {
455 if capture.index as usize == def_idx {
456 def_node = Some(capture.node);
457 }
458 if capture.index as usize == name_idx {
459 name = Some(source[capture.node.byte_range()].to_string());
460 }
461 }
462 let Some(name) = name else {
463 continue;
464 };
465 if !name.to_lowercase().contains(symbol_lower) {
466 continue;
467 }
468 let Some(def_node) = def_node else {
469 continue;
470 };
471 hits.push(SearchCodeHit {
472 file_path: canonical_string(&path),
473 line_start: def_node.start_position().row + 1,
474 line_end: def_node.end_position().row + 1,
475 snippet: extract_snippet(&source, def_node.start_position().row + 1),
476 source: SearchCodeSource::Structural,
477 score: SearchCodeSource::Structural.default_score(),
478 symbol_name: Some(name),
479 });
480 }
481 }
482 Ok(())
483}
484
485fn collect_grep_hits(
486 root: &Path,
487 current: &Path,
488 matcher: Option<&glob::Pattern>,
489 regex: ®ex::Regex,
490 hits: &mut Vec<SearchCodeHit>,
491 max_results: usize,
492) -> Result<(), ToolError> {
493 if hits.len() >= max_results || should_skip_path(current) {
494 return Ok(());
495 }
496
497 let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
498 for entry in entries {
499 let entry = entry.map_err(ToolError::Execution)?;
500 let path = entry.path();
501 if path.is_dir() {
502 collect_grep_hits(root, &path, matcher, regex, hits, max_results)?;
503 continue;
504 }
505 if !matches_pattern(root, &path, matcher) {
506 continue;
507 }
508 let Ok(source) = std::fs::read_to_string(&path) else {
509 continue;
510 };
511 for (idx, line) in source.lines().enumerate() {
512 if regex.is_match(line) {
513 hits.push(SearchCodeHit {
514 file_path: canonical_string(&path),
515 line_start: idx + 1,
516 line_end: idx + 1,
517 snippet: line.trim().to_string(),
518 source: SearchCodeSource::GrepFallback,
519 score: SearchCodeSource::GrepFallback.default_score(),
520 symbol_name: None,
521 });
522 if hits.len() >= max_results {
523 return Ok(());
524 }
525 }
526 }
527 }
528 Ok(())
529}
530
531fn matches_pattern(root: &Path, path: &Path, matcher: Option<&glob::Pattern>) -> bool {
532 let Some(matcher) = matcher else {
533 return true;
534 };
535 let relative = path.strip_prefix(root).unwrap_or(path);
536 matcher.matches_path(relative)
537}
538
539fn should_skip_path(path: &Path) -> bool {
540 path.file_name()
541 .and_then(|name| name.to_str())
542 .is_some_and(|name| matches!(name, ".git" | "target" | "node_modules" | ".zeph"))
543}
544
545fn canonical_string(path: &Path) -> String {
546 path.canonicalize()
547 .unwrap_or_else(|_| path.to_path_buf())
548 .display()
549 .to_string()
550}
551
552fn extract_snippet(source: &str, line_number: usize) -> String {
553 source
554 .lines()
555 .nth(line_number.saturating_sub(1))
556 .map(str::trim)
557 .unwrap_or_default()
558 .to_string()
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564
565 struct EmptySemantic;
566
567 #[async_trait]
568 impl SemanticSearchBackend for EmptySemantic {
569 async fn search(
570 &self,
571 _query: &str,
572 _file_pattern: Option<&str>,
573 _max_results: usize,
574 ) -> Result<Vec<SearchCodeHit>, ToolError> {
575 Ok(vec![])
576 }
577 }
578
579 #[tokio::test]
580 async fn search_code_requires_query_or_symbol() {
581 let dir = tempfile::tempdir().unwrap();
582 let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
583 let call = ToolCall {
584 tool_id: "search_code".into(),
585 params: serde_json::Map::new(),
586 };
587 let err = exec.execute_tool_call(&call).await.unwrap_err();
588 assert!(matches!(err, ToolError::InvalidParams { .. }));
589 }
590
591 #[tokio::test]
592 async fn search_code_finds_structural_symbol() {
593 let dir = tempfile::tempdir().unwrap();
594 let file = dir.path().join("lib.rs");
595 std::fs::write(&file, "pub fn retry_backoff_ms() -> u64 { 0 }\n").unwrap();
596 let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
597 let call = ToolCall {
598 tool_id: "search_code".into(),
599 params: serde_json::json!({ "symbol": "retry_backoff_ms" })
600 .as_object()
601 .unwrap()
602 .clone(),
603 };
604 let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
605 assert!(out.summary.contains("retry_backoff_ms"));
606 assert!(out.summary.contains("tree-sitter"));
607 assert_eq!(out.tool_name, "search_code");
608 }
609
610 #[tokio::test]
611 async fn search_code_uses_grep_fallback() {
612 let dir = tempfile::tempdir().unwrap();
613 let file = dir.path().join("mod.rs");
614 std::fs::write(&file, "let retry_backoff_ms = 5;\n").unwrap();
615 let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
616 let call = ToolCall {
617 tool_id: "search_code".into(),
618 params: serde_json::json!({ "query": "retry_backoff_ms" })
619 .as_object()
620 .unwrap()
621 .clone(),
622 };
623 let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
624 assert!(out.summary.contains("grep fallback"));
625 }
626
627 #[test]
628 fn tool_definitions_include_search_code() {
629 let exec = SearchCodeExecutor::new(vec![])
630 .with_semantic_backend(std::sync::Arc::new(EmptySemantic));
631 let defs = exec.tool_definitions();
632 assert_eq!(defs.len(), 1);
633 assert_eq!(defs[0].id.as_ref(), "search_code");
634 }
635
636 #[test]
637 fn format_hits_strips_root_prefix() {
638 let root = Path::new("/tmp/myproject");
639 let hits = vec![SearchCodeHit {
640 file_path: "/tmp/myproject/crates/foo/src/lib.rs".to_owned(),
641 line_start: 10,
642 line_end: 15,
643 snippet: "pub fn example() {}".to_owned(),
644 source: SearchCodeSource::GrepFallback,
645 score: 0.45,
646 symbol_name: None,
647 }];
648 let output = format_hits(&hits, root);
649 assert!(
650 output.contains("crates/foo/src/lib.rs"),
651 "expected relative path in output, got: {output}"
652 );
653 assert!(
654 !output.contains("/tmp/myproject"),
655 "absolute path must not appear in output, got: {output}"
656 );
657 }
658}