Skip to main content

rustant_tools/
codebase_search.rs

1//! Codebase search tool powered by the Project Context Auto-Indexer.
2//!
3//! Provides semantic search over the indexed project files, function signatures,
4//! and content summaries. Requires the workspace to have been indexed first.
5
6use crate::registry::Tool;
7use async_trait::async_trait;
8use rustant_core::error::ToolError;
9use rustant_core::indexer::ProjectIndexer;
10use rustant_core::search::SearchConfig;
11use rustant_core::types::{RiskLevel, ToolOutput};
12use std::path::PathBuf;
13use std::sync::Mutex;
14
15/// Tool for searching the project codebase using hybrid search.
16pub struct CodebaseSearchTool {
17    indexer: Mutex<Option<ProjectIndexer>>,
18    workspace: PathBuf,
19}
20
21impl CodebaseSearchTool {
22    pub fn new(workspace: PathBuf) -> Self {
23        Self {
24            indexer: Mutex::new(None),
25            workspace,
26        }
27    }
28
29    /// Ensure the indexer is initialized and workspace is indexed.
30    fn ensure_indexed(&self) -> Result<(), ToolError> {
31        let mut guard = self
32            .indexer
33            .lock()
34            .map_err(|e| ToolError::ExecutionFailed {
35                name: "codebase_search".into(),
36                message: format!("Lock error: {}", e),
37            })?;
38
39        if guard.is_none() {
40            let search_config = SearchConfig {
41                index_path: self.workspace.join(".rustant/search_index"),
42                db_path: self.workspace.join(".rustant/vectors.db"),
43                ..Default::default()
44            };
45
46            let mut indexer =
47                ProjectIndexer::new(self.workspace.clone(), search_config).map_err(|e| {
48                    ToolError::ExecutionFailed {
49                        name: "codebase_search".into(),
50                        message: format!("Failed to initialize indexer: {}", e),
51                    }
52                })?;
53
54            indexer.index_workspace();
55            *guard = Some(indexer);
56        }
57
58        Ok(())
59    }
60}
61
62#[async_trait]
63impl Tool for CodebaseSearchTool {
64    fn name(&self) -> &str {
65        "codebase_search"
66    }
67
68    fn description(&self) -> &str {
69        "Search the project codebase using natural language queries. \
70         Finds relevant files, function signatures, and code content. \
71         The workspace is automatically indexed on first use."
72    }
73
74    fn parameters_schema(&self) -> serde_json::Value {
75        serde_json::json!({
76            "type": "object",
77            "properties": {
78                "query": {
79                    "type": "string",
80                    "description": "Natural language search query (e.g., 'authentication handler', \
81                        'database connection', 'error types')"
82                },
83                "max_results": {
84                    "type": "integer",
85                    "description": "Maximum number of results to return (default: 10)"
86                }
87            },
88            "required": ["query"]
89        })
90    }
91
92    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
93        let query = args["query"]
94            .as_str()
95            .ok_or_else(|| ToolError::InvalidArguments {
96                name: "codebase_search".into(),
97                reason: "'query' parameter is required".into(),
98            })?;
99
100        let max_results = args["max_results"].as_u64().unwrap_or(10) as usize;
101
102        // Ensure workspace is indexed (lazy initialization)
103        self.ensure_indexed()?;
104
105        let guard = self
106            .indexer
107            .lock()
108            .map_err(|e| ToolError::ExecutionFailed {
109                name: "codebase_search".into(),
110                message: format!("Lock error: {}", e),
111            })?;
112
113        let indexer = guard.as_ref().ok_or_else(|| ToolError::ExecutionFailed {
114            name: "codebase_search".into(),
115            message: "Indexer not initialized".into(),
116        })?;
117
118        let results = indexer
119            .search(query)
120            .map_err(|e| ToolError::ExecutionFailed {
121                name: "codebase_search".into(),
122                message: format!("Search failed: {}", e),
123            })?;
124
125        if results.is_empty() {
126            return Ok(ToolOutput::text(format!(
127                "No results found for query: '{}'",
128                query
129            )));
130        }
131
132        let mut output = format!(
133            "Found {} results for '{}':\n\n",
134            results.len().min(max_results),
135            query
136        );
137
138        for (i, result) in results.iter().take(max_results).enumerate() {
139            output.push_str(&format!(
140                "{}. [score: {:.2}] {}\n",
141                i + 1,
142                result.combined_score,
143                result.content.lines().next().unwrap_or(&result.content)
144            ));
145
146            // Show a bit more context for top results
147            if i < 3 {
148                let extra_lines: Vec<&str> = result.content.lines().skip(1).take(3).collect();
149                if !extra_lines.is_empty() {
150                    for line in extra_lines {
151                        output.push_str(&format!("   {}\n", line));
152                    }
153                }
154            }
155            output.push('\n');
156        }
157
158        Ok(ToolOutput::text(output))
159    }
160
161    fn risk_level(&self) -> RiskLevel {
162        RiskLevel::ReadOnly
163    }
164
165    fn timeout(&self) -> std::time::Duration {
166        // Indexing can take a while on first run
167        std::time::Duration::from_secs(120)
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use std::fs;
175    use tempfile::TempDir;
176
177    fn setup_workspace() -> (TempDir, PathBuf) {
178        let dir = TempDir::new().unwrap();
179        let path = dir.path().to_path_buf();
180
181        fs::create_dir_all(path.join("src")).unwrap();
182        fs::write(
183            path.join("src/main.rs"),
184            "fn main() {\n    run_server();\n}\n\nfn run_server() {\n    println!(\"starting\");\n}\n",
185        )
186        .unwrap();
187        fs::write(
188            path.join("src/auth.rs"),
189            "pub fn authenticate(token: &str) -> bool {\n    !token.is_empty()\n}\n",
190        )
191        .unwrap();
192        fs::write(path.join("Cargo.toml"), "[package]\nname = \"test\"\n").unwrap();
193
194        (dir, path)
195    }
196
197    #[tokio::test]
198    async fn test_codebase_search_basic() {
199        let (_dir, path) = setup_workspace();
200        let tool = CodebaseSearchTool::new(path);
201
202        let args = serde_json::json!({
203            "query": "authenticate"
204        });
205
206        let result = tool.execute(args).await.unwrap();
207        assert!(
208            result.content.contains("authenticate") || result.content.contains("auth"),
209            "Should find auth-related content: {}",
210            result.content
211        );
212    }
213
214    #[tokio::test]
215    async fn test_codebase_search_no_results() {
216        let (_dir, path) = setup_workspace();
217        let tool = CodebaseSearchTool::new(path);
218
219        let args = serde_json::json!({
220            "query": "zzz_nonexistent_xyz_999"
221        });
222
223        let result = tool.execute(args).await.unwrap();
224        assert!(result.content.contains("No results") || result.content.contains("Found"));
225    }
226
227    #[tokio::test]
228    async fn test_codebase_search_missing_query() {
229        let (_dir, path) = setup_workspace();
230        let tool = CodebaseSearchTool::new(path);
231
232        let args = serde_json::json!({});
233        let result = tool.execute(args).await;
234        assert!(result.is_err());
235    }
236
237    #[test]
238    fn test_tool_properties() {
239        let dir = TempDir::new().unwrap();
240        let tool = CodebaseSearchTool::new(dir.path().to_path_buf());
241        assert_eq!(tool.name(), "codebase_search");
242        assert_eq!(tool.risk_level(), RiskLevel::ReadOnly);
243        assert!(tool.description().contains("Search"));
244    }
245}