spec_ai_core/tools/builtin/
code_search.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::path::{Path, PathBuf};
7use toak_rs::{JsonDatabaseGenerator, JsonDatabaseOptions, SemanticSearch};
8
9const DEFAULT_TOP_N: usize = 3;
10const MAX_TOP_N: usize = 25;
11
12#[derive(Debug, Deserialize)]
13struct CodeSearchArgs {
14    query: String,
15    top_n: Option<usize>,
16    root: Option<String>,
17    refresh: Option<bool>,
18}
19
20#[derive(Debug, Serialize)]
21struct CodeSearchResult {
22    path: String,
23    similarity: f32,
24    snippet: String,
25}
26
27#[derive(Debug, Serialize)]
28struct CodeSearchResponse {
29    query: String,
30    root: String,
31    top_n: usize,
32    results: Vec<CodeSearchResult>,
33}
34
35/// Simple semantic code search powered by toak-rs embeddings.
36pub struct CodeSearchTool {
37    root: PathBuf,
38}
39
40impl CodeSearchTool {
41    pub fn new() -> Self {
42        let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
43        Self { root }
44    }
45
46    fn resolve_root(&self, override_root: &Option<String>) -> PathBuf {
47        override_root
48            .as_ref()
49            .map(PathBuf::from)
50            .unwrap_or_else(|| self.root.clone())
51    }
52
53    fn cache_path(root: &Path) -> PathBuf {
54        root.join(".spec-ai").join("code_search_embeddings.json")
55    }
56
57    async fn ensure_embeddings(&self, root: &Path, refresh: bool, top_n: usize) -> Result<PathBuf> {
58        let embeddings_path = Self::cache_path(root);
59        if embeddings_path.exists() && !refresh {
60            return Ok(embeddings_path);
61        }
62
63        if let Some(parent) = embeddings_path.parent() {
64            std::fs::create_dir_all(parent).context("creating code-search cache dir")?;
65        }
66
67        let options = JsonDatabaseOptions {
68            dir: root.to_path_buf(),
69            output_file_path: embeddings_path.clone(),
70            verbose: false,
71            chunker_config: Default::default(),
72            max_concurrent_files: 4,
73            file_type_exclusions: Default::default(),
74            file_exclusions: Default::default(),
75        };
76
77        let generator = JsonDatabaseGenerator::new(options)
78            .context("initializing toak embeddings generator")?;
79
80        generator
81            .generate_database()
82            .await
83            .with_context(|| format!("generating embeddings database (top_n={})", top_n))?;
84
85        Ok(embeddings_path)
86    }
87}
88
89impl Default for CodeSearchTool {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[async_trait]
96impl Tool for CodeSearchTool {
97    fn name(&self) -> &str {
98        "code_search"
99    }
100
101    fn description(&self) -> &str {
102        "Semantic code search using toak-rs embeddings"
103    }
104
105    fn parameters(&self) -> Value {
106        serde_json::json!({
107            "type": "object",
108            "properties": {
109                "query": {
110                    "type": "string",
111                    "description": "Search query text"
112                },
113                "top_n": {
114                    "type": "integer",
115                    "description": "Number of results to return (default 3, max 25)"
116                },
117                "root": {
118                    "type": "string",
119                    "description": "Repository root to search (defaults to current dir)"
120                },
121                "refresh": {
122                    "type": "boolean",
123                    "description": "Force re-generation of embeddings (default false)"
124                }
125            },
126            "required": ["query"]
127        })
128    }
129
130    async fn execute(&self, args: Value) -> Result<ToolResult> {
131        let args: CodeSearchArgs =
132            serde_json::from_value(args).context("Failed to parse code_search arguments")?;
133
134        if args.query.trim().is_empty() {
135            return Err(anyhow!("query cannot be empty"));
136        }
137
138        let top_n = args.top_n.unwrap_or(DEFAULT_TOP_N).clamp(1, MAX_TOP_N);
139
140        let root = self.resolve_root(&args.root);
141        if !root.exists() {
142            return Err(anyhow!("Search root {} does not exist", root.display()));
143        }
144
145        let refresh = args.refresh.unwrap_or(false);
146        let embeddings_path = self
147            .ensure_embeddings(&root, refresh, top_n)
148            .await
149            .context("building embeddings database")?;
150
151        let mut searcher =
152            SemanticSearch::new(&embeddings_path).context("loading embeddings database")?;
153        let hits = searcher
154            .search(&args.query, top_n)
155            .context("running semantic search")?;
156
157        let results = hits
158            .into_iter()
159            .map(|hit| {
160                let mut snippet = hit.content;
161                if snippet.len() > 480 {
162                    snippet.truncate(480);
163                    snippet.push_str("...[truncated]");
164                }
165                CodeSearchResult {
166                    path: hit.file_path,
167                    similarity: hit.similarity,
168                    snippet,
169                }
170            })
171            .collect();
172
173        let response = CodeSearchResponse {
174            query: args.query,
175            root: root.display().to_string(),
176            top_n,
177            results,
178        };
179
180        Ok(ToolResult::success(
181            serde_json::to_string(&response).context("serializing search response")?,
182        ))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::fs;
190    use std::process::Command;
191    use tempfile::tempdir;
192
193    #[tokio::test]
194    async fn runs_search_with_generated_embeddings() {
195        if std::env::var("RUN_TOAK_SEARCH_TEST").is_err() {
196            // Skip in environments without fastembed model access
197            return;
198        }
199
200        let dir = tempdir().unwrap();
201        let root = dir.path();
202
203        fs::write(root.join("a.rs"), "fn alpha() {}\n// comment\n").unwrap();
204        fs::write(root.join("b.rs"), "fn beta_thing() { let x = 1; }\n").unwrap();
205
206        // Initialize git so toak can discover files
207        Command::new("git")
208            .arg("init")
209            .current_dir(root)
210            .output()
211            .expect("git init failed");
212        Command::new("git")
213            .args(["add", "."])
214            .current_dir(root)
215            .output()
216            .expect("git add failed");
217
218        let tool = CodeSearchTool::new();
219        let args = serde_json::json!({
220            "query": "beta",
221            "root": root.to_string_lossy(),
222            "top_n": 2,
223            "refresh": true
224        });
225
226        let result = tool.execute(args).await.unwrap();
227        assert!(result.success);
228        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
229        let hits = payload["results"].as_array().unwrap();
230        assert!(!hits.is_empty());
231        assert!(hits[0]["path"].as_str().unwrap().contains("b.rs"));
232    }
233}