rustant_tools/
codebase_search.rs1use crate::registry::Tool;
7use async_trait::async_trait;
8use rustant_core::error::ToolError;
9use rustant_core::indexer::ProjectIndexer;
10use rustant_core::search::SearchConfig;
11use rustant_core::types::{RiskLevel, ToolOutput};
12use std::path::PathBuf;
13use std::sync::Mutex;
14
15pub struct CodebaseSearchTool {
17 indexer: Mutex<Option<ProjectIndexer>>,
18 workspace: PathBuf,
19}
20
21impl CodebaseSearchTool {
22 pub fn new(workspace: PathBuf) -> Self {
23 Self {
24 indexer: Mutex::new(None),
25 workspace,
26 }
27 }
28
29 fn ensure_indexed(&self) -> Result<(), ToolError> {
31 let mut guard = self
32 .indexer
33 .lock()
34 .map_err(|e| ToolError::ExecutionFailed {
35 name: "codebase_search".into(),
36 message: format!("Lock error: {}", e),
37 })?;
38
39 if guard.is_none() {
40 let search_config = SearchConfig {
41 index_path: self.workspace.join(".rustant/search_index"),
42 db_path: self.workspace.join(".rustant/vectors.db"),
43 ..Default::default()
44 };
45
46 let mut indexer =
47 ProjectIndexer::new(self.workspace.clone(), search_config).map_err(|e| {
48 ToolError::ExecutionFailed {
49 name: "codebase_search".into(),
50 message: format!("Failed to initialize indexer: {}", e),
51 }
52 })?;
53
54 indexer.index_workspace();
55 *guard = Some(indexer);
56 }
57
58 Ok(())
59 }
60}
61
62#[async_trait]
63impl Tool for CodebaseSearchTool {
64 fn name(&self) -> &str {
65 "codebase_search"
66 }
67
68 fn description(&self) -> &str {
69 "Search the project codebase using natural language queries. \
70 Finds relevant files, function signatures, and code content. \
71 The workspace is automatically indexed on first use."
72 }
73
74 fn parameters_schema(&self) -> serde_json::Value {
75 serde_json::json!({
76 "type": "object",
77 "properties": {
78 "query": {
79 "type": "string",
80 "description": "Natural language search query (e.g., 'authentication handler', \
81 'database connection', 'error types')"
82 },
83 "max_results": {
84 "type": "integer",
85 "description": "Maximum number of results to return (default: 10)"
86 }
87 },
88 "required": ["query"]
89 })
90 }
91
92 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
93 let query = args["query"]
94 .as_str()
95 .ok_or_else(|| ToolError::InvalidArguments {
96 name: "codebase_search".into(),
97 reason: "'query' parameter is required".into(),
98 })?;
99
100 let max_results = args["max_results"].as_u64().unwrap_or(10) as usize;
101
102 self.ensure_indexed()?;
104
105 let guard = self
106 .indexer
107 .lock()
108 .map_err(|e| ToolError::ExecutionFailed {
109 name: "codebase_search".into(),
110 message: format!("Lock error: {}", e),
111 })?;
112
113 let indexer = guard.as_ref().ok_or_else(|| ToolError::ExecutionFailed {
114 name: "codebase_search".into(),
115 message: "Indexer not initialized".into(),
116 })?;
117
118 let results = indexer
119 .search(query)
120 .map_err(|e| ToolError::ExecutionFailed {
121 name: "codebase_search".into(),
122 message: format!("Search failed: {}", e),
123 })?;
124
125 if results.is_empty() {
126 return Ok(ToolOutput::text(format!(
127 "No results found for query: '{}'",
128 query
129 )));
130 }
131
132 let mut output = format!(
133 "Found {} results for '{}':\n\n",
134 results.len().min(max_results),
135 query
136 );
137
138 for (i, result) in results.iter().take(max_results).enumerate() {
139 output.push_str(&format!(
140 "{}. [score: {:.2}] {}\n",
141 i + 1,
142 result.combined_score,
143 result.content.lines().next().unwrap_or(&result.content)
144 ));
145
146 if i < 3 {
148 let extra_lines: Vec<&str> = result.content.lines().skip(1).take(3).collect();
149 if !extra_lines.is_empty() {
150 for line in extra_lines {
151 output.push_str(&format!(" {}\n", line));
152 }
153 }
154 }
155 output.push('\n');
156 }
157
158 Ok(ToolOutput::text(output))
159 }
160
161 fn risk_level(&self) -> RiskLevel {
162 RiskLevel::ReadOnly
163 }
164
165 fn timeout(&self) -> std::time::Duration {
166 std::time::Duration::from_secs(120)
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use std::fs;
175 use tempfile::TempDir;
176
177 fn setup_workspace() -> (TempDir, PathBuf) {
178 let dir = TempDir::new().unwrap();
179 let path = dir.path().to_path_buf();
180
181 fs::create_dir_all(path.join("src")).unwrap();
182 fs::write(
183 path.join("src/main.rs"),
184 "fn main() {\n run_server();\n}\n\nfn run_server() {\n println!(\"starting\");\n}\n",
185 )
186 .unwrap();
187 fs::write(
188 path.join("src/auth.rs"),
189 "pub fn authenticate(token: &str) -> bool {\n !token.is_empty()\n}\n",
190 )
191 .unwrap();
192 fs::write(path.join("Cargo.toml"), "[package]\nname = \"test\"\n").unwrap();
193
194 (dir, path)
195 }
196
197 #[tokio::test]
198 async fn test_codebase_search_basic() {
199 let (_dir, path) = setup_workspace();
200 let tool = CodebaseSearchTool::new(path);
201
202 let args = serde_json::json!({
203 "query": "authenticate"
204 });
205
206 let result = tool.execute(args).await.unwrap();
207 assert!(
208 result.content.contains("authenticate") || result.content.contains("auth"),
209 "Should find auth-related content: {}",
210 result.content
211 );
212 }
213
214 #[tokio::test]
215 async fn test_codebase_search_no_results() {
216 let (_dir, path) = setup_workspace();
217 let tool = CodebaseSearchTool::new(path);
218
219 let args = serde_json::json!({
220 "query": "zzz_nonexistent_xyz_999"
221 });
222
223 let result = tool.execute(args).await.unwrap();
224 assert!(result.content.contains("No results") || result.content.contains("Found"));
225 }
226
227 #[tokio::test]
228 async fn test_codebase_search_missing_query() {
229 let (_dir, path) = setup_workspace();
230 let tool = CodebaseSearchTool::new(path);
231
232 let args = serde_json::json!({});
233 let result = tool.execute(args).await;
234 assert!(result.is_err());
235 }
236
237 #[test]
238 fn test_tool_properties() {
239 let dir = TempDir::new().unwrap();
240 let tool = CodebaseSearchTool::new(dir.path().to_path_buf());
241 assert_eq!(tool.name(), "codebase_search");
242 assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
243 assert!(tool.description().contains("Search"));
244 }
245}