spec_ai_core/tools/builtin/
code_search.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::path::{Path, PathBuf};
7use toak_rs::{JsonDatabaseGenerator, JsonDatabaseOptions, SemanticSearch};
8
9const DEFAULT_TOP_N: usize = 3;
10const MAX_TOP_N: usize = 25;
11
12#[derive(Debug, Deserialize)]
13struct CodeSearchArgs {
14    query: String,
15    top_n: Option<usize>,
16    root: Option<String>,
17    refresh: Option<bool>,
18}
19
20#[derive(Debug, Serialize)]
21struct CodeSearchResult {
22    path: String,
23    similarity: f32,
24    snippet: String,
25}
26
27#[derive(Debug, Serialize)]
28struct CodeSearchResponse {
29    query: String,
30    root: String,
31    top_n: usize,
32    results: Vec<CodeSearchResult>,
33}
34
35/// Simple semantic code search powered by toak-rs embeddings.
36pub struct CodeSearchTool {
37    root: PathBuf,
38}
39
40impl CodeSearchTool {
41    pub fn new() -> Self {
42        let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
43        Self { root }
44    }
45
46    fn resolve_root(&self, override_root: &Option<String>) -> PathBuf {
47        override_root
48            .as_ref()
49            .map(PathBuf::from)
50            .unwrap_or_else(|| self.root.clone())
51    }
52
53    fn cache_path(root: &Path) -> PathBuf {
54        root.join(".spec-ai").join("code_search_embeddings.json")
55    }
56
57    async fn ensure_embeddings(&self, root: &Path, refresh: bool, top_n: usize) -> Result<PathBuf> {
58        let embeddings_path = Self::cache_path(root);
59        if embeddings_path.exists() && !refresh {
60            return Ok(embeddings_path);
61        }
62
63        if let Some(parent) = embeddings_path.parent() {
64            std::fs::create_dir_all(parent).context("creating code-search cache dir")?;
65        }
66
67        let options = JsonDatabaseOptions {
68            dir: root.to_path_buf(),
69            output_file_path: embeddings_path.clone(),
70            verbose: false,
71            chunker_config: Default::default(),
72            max_concurrent_files: 4,
73            ..Default::default()
74        };
75
76        let generator = JsonDatabaseGenerator::new(options)
77            .context("initializing toak embeddings generator")?;
78
79        generator
80            .generate_database()
81            .await
82            .with_context(|| format!("generating embeddings database (top_n={})", top_n))?;
83
84        Ok(embeddings_path)
85    }
86}
87
88impl Default for CodeSearchTool {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94#[async_trait]
95impl Tool for CodeSearchTool {
96    fn name(&self) -> &str {
97        "code_search"
98    }
99
100    fn description(&self) -> &str {
101        "Semantic code search using toak-rs embeddings"
102    }
103
104    fn parameters(&self) -> Value {
105        serde_json::json!({
106            "type": "object",
107            "properties": {
108                "query": {
109                    "type": "string",
110                    "description": "Search query text"
111                },
112                "top_n": {
113                    "type": "integer",
114                    "description": "Number of results to return (default 3, max 25)"
115                },
116                "root": {
117                    "type": "string",
118                    "description": "Repository root to search (defaults to current dir)"
119                },
120                "refresh": {
121                    "type": "boolean",
122                    "description": "Force re-generation of embeddings (default false)"
123                }
124            },
125            "required": ["query"]
126        })
127    }
128
129    async fn execute(&self, args: Value) -> Result<ToolResult> {
130        let args: CodeSearchArgs =
131            serde_json::from_value(args).context("Failed to parse code_search arguments")?;
132
133        if args.query.trim().is_empty() {
134            return Err(anyhow!("query cannot be empty"));
135        }
136
137        let top_n = args.top_n.unwrap_or(DEFAULT_TOP_N).clamp(1, MAX_TOP_N);
138
139        let root = self.resolve_root(&args.root);
140        if !root.exists() {
141            return Err(anyhow!("Search root {} does not exist", root.display()));
142        }
143
144        let refresh = args.refresh.unwrap_or(false);
145        let embeddings_path = self
146            .ensure_embeddings(&root, refresh, top_n)
147            .await
148            .context("building embeddings database")?;
149
150        let mut searcher =
151            SemanticSearch::new(&embeddings_path).context("loading embeddings database")?;
152        let hits = searcher
153            .search(&args.query, top_n)
154            .context("running semantic search")?;
155
156        let results = hits
157            .into_iter()
158            .map(|hit| {
159                let mut snippet = hit.content;
160                if snippet.len() > 480 {
161                    snippet.truncate(480);
162                    snippet.push_str("...[truncated]");
163                }
164                CodeSearchResult {
165                    path: hit.file_path,
166                    similarity: hit.similarity,
167                    snippet,
168                }
169            })
170            .collect();
171
172        let response = CodeSearchResponse {
173            query: args.query,
174            root: root.display().to_string(),
175            top_n,
176            results,
177        };
178
179        Ok(ToolResult::success(
180            serde_json::to_string(&response).context("serializing search response")?,
181        ))
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use std::fs;
189    use std::process::Command;
190    use tempfile::tempdir;
191
192    #[tokio::test]
193    async fn runs_search_with_generated_embeddings() {
194        if std::env::var("RUN_TOAK_SEARCH_TEST").is_err() {
195            // Skip in environments without fastembed model access
196            return;
197        }
198
199        let dir = tempdir().unwrap();
200        let root = dir.path();
201
202        fs::write(root.join("a.rs"), "fn alpha() {}\n// comment\n").unwrap();
203        fs::write(root.join("b.rs"), "fn beta_thing() { let x = 1; }\n").unwrap();
204
205        // Initialize git so toak can discover files
206        Command::new("git")
207            .arg("init")
208            .current_dir(root)
209            .output()
210            .expect("git init failed");
211        Command::new("git")
212            .args(["add", "."])
213            .current_dir(root)
214            .output()
215            .expect("git add failed");
216
217        let tool = CodeSearchTool::new();
218        let args = serde_json::json!({
219            "query": "beta",
220            "root": root.to_string_lossy(),
221            "top_n": 2,
222            "refresh": true
223        });
224
225        let result = tool.execute(args).await.unwrap();
226        assert!(result.success);
227        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
228        let hits = payload["results"].as_array().unwrap();
229        assert!(!hits.is_empty());
230        assert!(hits[0]["path"].as_str().unwrap().contains("b.rs"));
231    }
232}