spec_ai_core/tools/builtin/
code_search.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::path::{Path, PathBuf};
7use toak_rs::{JsonDatabaseGenerator, JsonDatabaseOptions, SemanticSearch};
8
9const DEFAULT_TOP_N: usize = 3;
10const MAX_TOP_N: usize = 25;
11
12#[derive(Debug, Deserialize)]
13struct CodeSearchArgs {
14 query: String,
15 top_n: Option<usize>,
16 root: Option<String>,
17 refresh: Option<bool>,
18}
19
20#[derive(Debug, Serialize)]
21struct CodeSearchResult {
22 path: String,
23 similarity: f32,
24 snippet: String,
25}
26
27#[derive(Debug, Serialize)]
28struct CodeSearchResponse {
29 query: String,
30 root: String,
31 top_n: usize,
32 results: Vec<CodeSearchResult>,
33}
34
35pub struct CodeSearchTool {
37 root: PathBuf,
38}
39
40impl CodeSearchTool {
41 pub fn new() -> Self {
42 let root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
43 Self { root }
44 }
45
46 fn resolve_root(&self, override_root: &Option<String>) -> PathBuf {
47 override_root
48 .as_ref()
49 .map(PathBuf::from)
50 .unwrap_or_else(|| self.root.clone())
51 }
52
53 fn cache_path(root: &Path) -> PathBuf {
54 root.join(".spec-ai").join("code_search_embeddings.json")
55 }
56
57 async fn ensure_embeddings(&self, root: &Path, refresh: bool, top_n: usize) -> Result<PathBuf> {
58 let embeddings_path = Self::cache_path(root);
59 if embeddings_path.exists() && !refresh {
60 return Ok(embeddings_path);
61 }
62
63 if let Some(parent) = embeddings_path.parent() {
64 std::fs::create_dir_all(parent).context("creating code-search cache dir")?;
65 }
66
67 let options = JsonDatabaseOptions {
68 dir: root.to_path_buf(),
69 output_file_path: embeddings_path.clone(),
70 verbose: false,
71 chunker_config: Default::default(),
72 max_concurrent_files: 4,
73 ..Default::default()
74 };
75
76 let generator = JsonDatabaseGenerator::new(options)
77 .context("initializing toak embeddings generator")?;
78
79 generator
80 .generate_database()
81 .await
82 .with_context(|| format!("generating embeddings database (top_n={})", top_n))?;
83
84 Ok(embeddings_path)
85 }
86}
87
88impl Default for CodeSearchTool {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94#[async_trait]
95impl Tool for CodeSearchTool {
96 fn name(&self) -> &str {
97 "code_search"
98 }
99
100 fn description(&self) -> &str {
101 "Semantic code search using toak-rs embeddings"
102 }
103
104 fn parameters(&self) -> Value {
105 serde_json::json!({
106 "type": "object",
107 "properties": {
108 "query": {
109 "type": "string",
110 "description": "Search query text"
111 },
112 "top_n": {
113 "type": "integer",
114 "description": "Number of results to return (default 3, max 25)"
115 },
116 "root": {
117 "type": "string",
118 "description": "Repository root to search (defaults to current dir)"
119 },
120 "refresh": {
121 "type": "boolean",
122 "description": "Force re-generation of embeddings (default false)"
123 }
124 },
125 "required": ["query"]
126 })
127 }
128
129 async fn execute(&self, args: Value) -> Result<ToolResult> {
130 let args: CodeSearchArgs =
131 serde_json::from_value(args).context("Failed to parse code_search arguments")?;
132
133 if args.query.trim().is_empty() {
134 return Err(anyhow!("query cannot be empty"));
135 }
136
137 let top_n = args.top_n.unwrap_or(DEFAULT_TOP_N).clamp(1, MAX_TOP_N);
138
139 let root = self.resolve_root(&args.root);
140 if !root.exists() {
141 return Err(anyhow!("Search root {} does not exist", root.display()));
142 }
143
144 let refresh = args.refresh.unwrap_or(false);
145 let embeddings_path = self
146 .ensure_embeddings(&root, refresh, top_n)
147 .await
148 .context("building embeddings database")?;
149
150 let mut searcher =
151 SemanticSearch::new(&embeddings_path).context("loading embeddings database")?;
152 let hits = searcher
153 .search(&args.query, top_n)
154 .context("running semantic search")?;
155
156 let results = hits
157 .into_iter()
158 .map(|hit| {
159 let mut snippet = hit.content;
160 if snippet.len() > 480 {
161 snippet.truncate(480);
162 snippet.push_str("...[truncated]");
163 }
164 CodeSearchResult {
165 path: hit.file_path,
166 similarity: hit.similarity,
167 snippet,
168 }
169 })
170 .collect();
171
172 let response = CodeSearchResponse {
173 query: args.query,
174 root: root.display().to_string(),
175 top_n,
176 results,
177 };
178
179 Ok(ToolResult::success(
180 serde_json::to_string(&response).context("serializing search response")?,
181 ))
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use std::fs;
189 use std::process::Command;
190 use tempfile::tempdir;
191
192 #[tokio::test]
193 async fn runs_search_with_generated_embeddings() {
194 if std::env::var("RUN_TOAK_SEARCH_TEST").is_err() {
195 return;
197 }
198
199 let dir = tempdir().unwrap();
200 let root = dir.path();
201
202 fs::write(root.join("a.rs"), "fn alpha() {}\n// comment\n").unwrap();
203 fs::write(root.join("b.rs"), "fn beta_thing() { let x = 1; }\n").unwrap();
204
205 Command::new("git")
207 .arg("init")
208 .current_dir(root)
209 .output()
210 .expect("git init failed");
211 Command::new("git")
212 .args(["add", "."])
213 .current_dir(root)
214 .output()
215 .expect("git add failed");
216
217 let tool = CodeSearchTool::new();
218 let args = serde_json::json!({
219 "query": "beta",
220 "root": root.to_string_lossy(),
221 "top_n": 2,
222 "refresh": true
223 });
224
225 let result = tool.execute(args).await.unwrap();
226 assert!(result.success);
227 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
228 let hits = payload["results"].as_array().unwrap();
229 assert!(!hits.is_empty());
230 assert!(hits[0]["path"].as_str().unwrap().contains("b.rs"));
231 }
232}