spec_ai_core/tools/builtin/
code_search.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::path::{Path, PathBuf};
7use toak_rs::{JsonDatabaseGenerator, JsonDatabaseOptions, SemanticSearch};
8
9const DEFAULT_TOP_N: usize = 3;
10const MAX_TOP_N: usize = 25;
11
12#[derive(Debug, Deserialize)]
13struct CodeSearchArgs {
14 query: String,
15 top_n: Option<usize>,
16 root: Option<String>,
17 refresh: Option<bool>,
18}
19
20#[derive(Debug, Serialize)]
21struct CodeSearchResult {
22 path: String,
23 similarity: f32,
24 snippet: String,
25}
26
27#[derive(Debug, Serialize)]
28struct CodeSearchResponse {
29 query: String,
30 root: String,
31 top_n: usize,
32 results: Vec<CodeSearchResult>,
33}
34
35pub struct CodeSearchTool {
37 root: PathBuf,
38}
39
40impl CodeSearchTool {
41 pub fn new() -> Self {
42 let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
43 Self { root }
44 }
45
46 fn resolve_root(&self, override_root: &Option<String>) -> PathBuf {
47 override_root
48 .as_ref()
49 .map(PathBuf::from)
50 .unwrap_or_else(|| self.root.clone())
51 }
52
53 fn cache_path(root: &Path) -> PathBuf {
54 root.join(".spec-ai").join("code_search_embeddings.json")
55 }
56
57 async fn ensure_embeddings(&self, root: &Path, refresh: bool, top_n: usize) -> Result<PathBuf> {
58 let embeddings_path = Self::cache_path(root);
59 if embeddings_path.exists() && !refresh {
60 return Ok(embeddings_path);
61 }
62
63 if let Some(parent) = embeddings_path.parent() {
64 std::fs::create_dir_all(parent).context("creating code-search cache dir")?;
65 }
66
67 let options = JsonDatabaseOptions {
68 dir: root.to_path_buf(),
69 output_file_path: embeddings_path.clone(),
70 verbose: false,
71 chunker_config: Default::default(),
72 max_concurrent_files: 4,
73 file_type_exclusions: Default::default(),
74 file_exclusions: Default::default(),
75 };
76
77 let generator = JsonDatabaseGenerator::new(options)
78 .context("initializing toak embeddings generator")?;
79
80 generator
81 .generate_database()
82 .await
83 .with_context(|| format!("generating embeddings database (top_n={})", top_n))?;
84
85 Ok(embeddings_path)
86 }
87}
88
89impl Default for CodeSearchTool {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95#[async_trait]
96impl Tool for CodeSearchTool {
97 fn name(&self) -> &str {
98 "code_search"
99 }
100
101 fn description(&self) -> &str {
102 "Semantic code search using toak-rs embeddings"
103 }
104
105 fn parameters(&self) -> Value {
106 serde_json::json!({
107 "type": "object",
108 "properties": {
109 "query": {
110 "type": "string",
111 "description": "Search query text"
112 },
113 "top_n": {
114 "type": "integer",
115 "description": "Number of results to return (default 3, max 25)"
116 },
117 "root": {
118 "type": "string",
119 "description": "Repository root to search (defaults to current dir)"
120 },
121 "refresh": {
122 "type": "boolean",
123 "description": "Force re-generation of embeddings (default false)"
124 }
125 },
126 "required": ["query"]
127 })
128 }
129
130 async fn execute(&self, args: Value) -> Result<ToolResult> {
131 let args: CodeSearchArgs =
132 serde_json::from_value(args).context("Failed to parse code_search arguments")?;
133
134 if args.query.trim().is_empty() {
135 return Err(anyhow!("query cannot be empty"));
136 }
137
138 let top_n = args.top_n.unwrap_or(DEFAULT_TOP_N).clamp(1, MAX_TOP_N);
139
140 let root = self.resolve_root(&args.root);
141 if !root.exists() {
142 return Err(anyhow!("Search root {} does not exist", root.display()));
143 }
144
145 let refresh = args.refresh.unwrap_or(false);
146 let embeddings_path = self
147 .ensure_embeddings(&root, refresh, top_n)
148 .await
149 .context("building embeddings database")?;
150
151 let mut searcher =
152 SemanticSearch::new(&embeddings_path).context("loading embeddings database")?;
153 let hits = searcher
154 .search(&args.query, top_n)
155 .context("running semantic search")?;
156
157 let results = hits
158 .into_iter()
159 .map(|hit| {
160 let mut snippet = hit.content;
161 if snippet.len() > 480 {
162 snippet.truncate(480);
163 snippet.push_str("...[truncated]");
164 }
165 CodeSearchResult {
166 path: hit.file_path,
167 similarity: hit.similarity,
168 snippet,
169 }
170 })
171 .collect();
172
173 let response = CodeSearchResponse {
174 query: args.query,
175 root: root.display().to_string(),
176 top_n,
177 results,
178 };
179
180 Ok(ToolResult::success(
181 serde_json::to_string(&response).context("serializing search response")?,
182 ))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use std::fs;
190 use std::process::Command;
191 use tempfile::tempdir;
192
193 #[tokio::test]
194 async fn runs_search_with_generated_embeddings() {
195 if std::env::var("RUN_TOAK_SEARCH_TEST").is_err() {
196 return;
198 }
199
200 let dir = tempdir().unwrap();
201 let root = dir.path();
202
203 fs::write(root.join("a.rs"), "fn alpha() {}\n// comment\n").unwrap();
204 fs::write(root.join("b.rs"), "fn beta_thing() { let x = 1; }\n").unwrap();
205
206 Command::new("git")
208 .arg("init")
209 .current_dir(root)
210 .output()
211 .expect("git init failed");
212 Command::new("git")
213 .args(["add", "."])
214 .current_dir(root)
215 .output()
216 .expect("git add failed");
217
218 let tool = CodeSearchTool::new();
219 let args = serde_json::json!({
220 "query": "beta",
221 "root": root.to_string_lossy(),
222 "top_n": 2,
223 "refresh": true
224 });
225
226 let result = tool.execute(args).await.unwrap();
227 assert!(result.success);
228 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
229 let hits = payload["results"].as_array().unwrap();
230 assert!(!hits.is_empty());
231 assert!(hits[0]["path"].as_str().unwrap().contains("b.rs"));
232 }
233}