vtcode_core/tools/
ast_grep_tool.rs1use super::ast_grep::AstGrepEngine;
7use super::traits::Tool;
8use crate::config::constants::tools;
9use anyhow::{Context, Result};
10use async_trait::async_trait;
11use serde_json::Value;
12use std::path::PathBuf;
13use std::sync::Arc;
14
15pub struct AstGrepTool {
17 engine: Arc<AstGrepEngine>,
19 workspace_root: PathBuf,
21}
22
23impl AstGrepTool {
24 pub fn new(workspace_root: PathBuf) -> Result<Self> {
26 let engine =
27 Arc::new(AstGrepEngine::new().context("Failed to initialize AST-grep engine")?);
28
29 Ok(Self {
30 engine,
31 workspace_root,
32 })
33 }
34
35 pub fn workspace_root(&self) -> &PathBuf {
37 &self.workspace_root
38 }
39
40 fn normalize_path(&self, path: &str) -> Result<String> {
42 let path_buf = PathBuf::from(path);
43
44 if path_buf.is_absolute() {
46 if !path_buf.starts_with(&self.workspace_root) {
47 return Err(anyhow::anyhow!(
48 "Path {} is outside workspace root {}",
49 path,
50 self.workspace_root.display()
51 ));
52 }
53 Ok(path.to_string())
54 } else {
55 let resolved = self.workspace_root.join(path);
57 Ok(resolved.to_string_lossy().to_string())
58 }
59 }
60}
61
62#[async_trait]
63impl Tool for AstGrepTool {
64 async fn execute(&self, args: Value) -> Result<Value> {
65 let operation = args
66 .get("operation")
67 .and_then(|v| v.as_str())
68 .unwrap_or("search");
69
70 match operation {
71 "search" => self.search(args).await,
72 "transform" => self.transform(args).await,
73 "lint" => self.lint(args).await,
74 "refactor" => self.refactor(args).await,
75 "custom" => self.custom(args).await,
76 _ => Err(anyhow::anyhow!("Unknown AST-grep operation: {}", operation)),
77 }
78 }
79
80 fn name(&self) -> &'static str {
81 tools::AST_GREP_SEARCH
82 }
83
84 fn description(&self) -> &'static str {
85 "Advanced syntax-aware code search, transformation, and analysis using AST-grep patterns"
86 }
87
88 fn validate_args(&self, args: &Value) -> Result<()> {
89 if let Some(operation) = args.get("operation").and_then(|v| v.as_str()) {
90 match operation {
91 "search" => {
92 if args.get("pattern").is_none() {
93 return Err(anyhow::anyhow!(
94 "'pattern' is required for search operation"
95 ));
96 }
97 if args.get("path").is_none() {
98 return Err(anyhow::anyhow!("'path' is required for search operation"));
99 }
100 }
101 "transform" => {
102 if args.get("pattern").is_none() {
103 return Err(anyhow::anyhow!(
104 "'pattern' is required for transform operation"
105 ));
106 }
107 if args.get("replacement").is_none() {
108 return Err(anyhow::anyhow!(
109 "'replacement' is required for transform operation"
110 ));
111 }
112 if args.get("path").is_none() {
113 return Err(anyhow::anyhow!(
114 "'path' is required for transform operation"
115 ));
116 }
117 }
118 "refactor" => {
119 if args.get("path").is_none() {
120 return Err(anyhow::anyhow!("'path' is required for refactor operation"));
121 }
122 if args.get("refactor_type").is_none() {
123 return Err(anyhow::anyhow!(
124 "'refactor_type' is required for refactor operation"
125 ));
126 }
127 }
128 _ => {} }
130 }
131
132 Ok(())
133 }
134}
135
136impl AstGrepTool {
137 async fn search(&self, args: Value) -> Result<Value> {
139 let pattern = args
140 .get("pattern")
141 .and_then(|v| v.as_str())
142 .context("'pattern' is required")?;
143
144 let path = args
145 .get("path")
146 .and_then(|v| v.as_str())
147 .context("'path' is required")?;
148
149 let path = self.normalize_path(path)?;
150
151 let language = args.get("language").and_then(|v| v.as_str());
152 let context_lines = args
153 .get("context_lines")
154 .and_then(|v| v.as_u64())
155 .map(|v| v as usize);
156 let max_results = args
157 .get("max_results")
158 .and_then(|v| v.as_u64())
159 .map(|v| v as usize);
160
161 self.engine
162 .search(pattern, &path, language, context_lines, max_results)
163 .await
164 }
165
166 async fn transform(&self, args: Value) -> Result<Value> {
168 let pattern = args
169 .get("pattern")
170 .and_then(|v| v.as_str())
171 .context("'pattern' is required")?;
172
173 let replacement = args
174 .get("replacement")
175 .and_then(|v| v.as_str())
176 .context("'replacement' is required")?;
177
178 let path = args
179 .get("path")
180 .and_then(|v| v.as_str())
181 .context("'path' is required")?;
182
183 let path = self.normalize_path(path)?;
184
185 let language = args.get("language").and_then(|v| v.as_str());
186 let preview_only = args
187 .get("preview_only")
188 .and_then(|v| v.as_bool())
189 .unwrap_or(true);
190 let update_all = args
191 .get("update_all")
192 .and_then(|v| v.as_bool())
193 .unwrap_or(false);
194
195 self.engine
196 .transform(
197 pattern,
198 replacement,
199 &path,
200 language,
201 preview_only,
202 update_all,
203 )
204 .await
205 }
206
207 async fn lint(&self, args: Value) -> Result<Value> {
209 let path = args
210 .get("path")
211 .and_then(|v| v.as_str())
212 .context("'path' is required")?;
213
214 let path = self.normalize_path(path)?;
215
216 let language = args.get("language").and_then(|v| v.as_str());
217 let severity_filter = args.get("severity_filter").and_then(|v| v.as_str());
218
219 self.engine
220 .lint(&path, language, severity_filter, None)
221 .await
222 }
223
224 async fn refactor(&self, args: Value) -> Result<Value> {
226 let path = args
227 .get("path")
228 .and_then(|v| v.as_str())
229 .context("'path' is required")?;
230
231 let path = self.normalize_path(path)?;
232
233 let language = args.get("language").and_then(|v| v.as_str());
234 let refactor_type = args
235 .get("refactor_type")
236 .and_then(|v| v.as_str())
237 .context("'refactor_type' is required")?;
238
239 self.engine.refactor(&path, language, refactor_type).await
240 }
241
242 async fn custom(&self, args: Value) -> Result<Value> {
244 let pattern = args
245 .get("pattern")
246 .and_then(|v| v.as_str())
247 .context("'pattern' is required")?;
248
249 let path = args
250 .get("path")
251 .and_then(|v| v.as_str())
252 .context("'path' is required")?;
253
254 let path = self.normalize_path(path)?;
255
256 let language = args.get("language").and_then(|v| v.as_str());
257 let rewrite = args.get("rewrite").and_then(|v| v.as_str());
258 let context_lines = args
259 .get("context_lines")
260 .and_then(|v| v.as_u64())
261 .map(|v| v as usize);
262 let max_results = args
263 .get("max_results")
264 .and_then(|v| v.as_u64())
265 .map(|v| v as usize);
266 let interactive = args
267 .get("interactive")
268 .and_then(|v| v.as_bool())
269 .unwrap_or(false);
270 let update_all = args
271 .get("update_all")
272 .and_then(|v| v.as_bool())
273 .unwrap_or(false);
274
275 self.engine
276 .run_custom(
277 pattern,
278 &path,
279 language,
280 rewrite,
281 context_lines,
282 max_results,
283 interactive,
284 update_all,
285 )
286 .await
287 }
288}