1use std::path::{Path, PathBuf};
5
6use schemars::JsonSchema;
7use serde::Deserialize;
8
9use zeph_common::ToolName;
10
11use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
12use crate::registry::{InvocationHint, ToolDef};
13
14#[derive(Debug, Default, Deserialize, JsonSchema, PartialEq, Eq)]
16#[serde(rename_all = "snake_case")]
17pub enum DiagnosticsLevel {
18 #[default]
20 Check,
21 Clippy,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct DiagnosticsParams {
27 path: Option<String>,
29 #[serde(default)]
31 level: DiagnosticsLevel,
32}
33
34#[derive(Debug)]
36pub struct DiagnosticsExecutor {
37 allowed_paths: Vec<PathBuf>,
38 max_diagnostics: usize,
40}
41
42impl DiagnosticsExecutor {
43 #[must_use]
44 pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
45 let paths = if allowed_paths.is_empty() {
46 vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
47 } else {
48 allowed_paths
49 };
50 Self {
51 allowed_paths: paths
52 .into_iter()
53 .map(|p| p.canonicalize().unwrap_or(p))
54 .collect(),
55 max_diagnostics: 50,
56 }
57 }
58
59 #[must_use]
60 pub fn with_max_diagnostics(mut self, max: usize) -> Self {
61 self.max_diagnostics = max;
62 self
63 }
64
65 fn validate_path(&self, path: &Path) -> Result<PathBuf, ToolError> {
66 let resolved = if path.is_absolute() {
67 path.to_path_buf()
68 } else {
69 std::env::current_dir()
70 .unwrap_or_else(|_| PathBuf::from("."))
71 .join(path)
72 };
73 let canonical = resolved.canonicalize().map_err(|e| {
74 ToolError::Execution(std::io::Error::new(
75 std::io::ErrorKind::NotFound,
76 format!("path not found: {}: {e}", resolved.display()),
77 ))
78 })?;
79 if !self.allowed_paths.iter().any(|a| canonical.starts_with(a)) {
80 return Err(ToolError::SandboxViolation {
81 path: canonical.display().to_string(),
82 });
83 }
84 Ok(canonical)
85 }
86}
87
88impl ToolExecutor for DiagnosticsExecutor {
89 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
90 Ok(None)
91 }
92
93 #[cfg_attr(
94 feature = "profiling",
95 tracing::instrument(name = "tool.diagnostics", skip_all)
96 )]
97 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
98 if call.tool_id != "diagnostics" {
99 return Ok(None);
100 }
101 let p: DiagnosticsParams = deserialize_params(&call.params)?;
102 let work_dir = if let Some(path) = &p.path {
103 self.validate_path(Path::new(path))?
104 } else {
105 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
106 self.validate_path(&cwd)?
107 };
108
109 let subcmd = match p.level {
110 DiagnosticsLevel::Check => "check",
111 DiagnosticsLevel::Clippy => "clippy",
112 };
113
114 let cargo = which_cargo()?;
115
116 let output = tokio::process::Command::new(&cargo)
117 .arg(subcmd)
118 .arg("--message-format=json")
119 .current_dir(&work_dir)
120 .output()
121 .await
122 .map_err(|e| {
123 ToolError::Execution(std::io::Error::new(
124 std::io::ErrorKind::NotFound,
125 format!("failed to run cargo: {e}"),
126 ))
127 })?;
128
129 let stdout = String::from_utf8_lossy(&output.stdout);
130 let diagnostics = parse_cargo_json(&stdout, self.max_diagnostics);
131
132 let summary = if diagnostics.is_empty() {
133 "No diagnostics".to_owned()
134 } else {
135 diagnostics.join("\n")
136 };
137
138 Ok(Some(ToolOutput {
139 tool_name: ToolName::new("diagnostics"),
140 summary,
141 blocks_executed: 1,
142 filter_stats: None,
143 diff: None,
144 streamed: false,
145 terminal_id: None,
146 locations: None,
147 raw_response: None,
148 claim_source: Some(crate::executor::ClaimSource::Diagnostics),
149 }))
150 }
151
152 fn tool_definitions(&self) -> Vec<ToolDef> {
153 vec![ToolDef {
154 id: "diagnostics".into(),
155 description: "Run cargo check or cargo clippy on a Rust workspace and return compiler diagnostics.\n\nParameters: path (string, optional) - workspace directory (default: cwd); level (string, optional) - \"check\" or \"clippy\" (default: \"check\")\nReturns: structured diagnostics with file paths, line numbers, severity, and messages; capped at 50 results\nErrors: SandboxViolation if path outside allowed dirs; Execution if cargo is not found\nExample: {\"path\": \".\", \"level\": \"clippy\"}".into(),
156 schema: schemars::schema_for!(DiagnosticsParams),
157 invocation: InvocationHint::ToolCall,
158 output_schema: None,
159 }]
160 }
161}
162
163fn which_cargo() -> Result<PathBuf, ToolError> {
170 if let Ok(cargo) = std::env::var("CARGO") {
172 let p = PathBuf::from(&cargo);
173 if p.is_file() {
174 return Ok(p.canonicalize().unwrap_or(p));
175 }
176 }
177 for dir in std::env::var("PATH").unwrap_or_default().split(':') {
179 let candidate = PathBuf::from(dir).join("cargo");
180 if candidate.is_file() {
181 return Ok(candidate.canonicalize().unwrap_or(candidate));
182 }
183 }
184 Err(ToolError::Execution(std::io::Error::new(
185 std::io::ErrorKind::NotFound,
186 "cargo not found in PATH",
187 )))
188}
189
190pub(crate) fn parse_cargo_json(output: &str, max: usize) -> Vec<String> {
195 let mut results = Vec::new();
196 for line in output.lines() {
197 if results.len() >= max {
198 break;
199 }
200 let Ok(val) = serde_json::from_str::<serde_json::Value>(line) else {
201 continue;
202 };
203 if val.get("reason").and_then(|r| r.as_str()) != Some("compiler-message") {
204 continue;
205 }
206 let Some(msg) = val.get("message") else {
207 continue;
208 };
209 let level = msg
210 .get("level")
211 .and_then(|l| l.as_str())
212 .unwrap_or("unknown");
213 let text = msg
214 .get("message")
215 .and_then(|m| m.as_str())
216 .unwrap_or("")
217 .trim();
218 if text.is_empty() {
219 continue;
220 }
221
222 let spans = msg
224 .get("spans")
225 .and_then(serde_json::Value::as_array)
226 .map_or(&[] as &[_], Vec::as_slice);
227
228 let primary = spans.iter().find(|s| {
229 s.get("is_primary")
230 .and_then(serde_json::Value::as_bool)
231 .unwrap_or(false)
232 });
233
234 if let Some(span) = primary {
235 let file = span
236 .get("file_name")
237 .and_then(|f| f.as_str())
238 .unwrap_or("?");
239 let line = span
240 .get("line_start")
241 .and_then(serde_json::Value::as_u64)
242 .unwrap_or(0);
243 let col = span
244 .get("column_start")
245 .and_then(serde_json::Value::as_u64)
246 .unwrap_or(0);
247 results.push(format!("{file}:{line}:{col}: {level}: {text}"));
248 } else {
249 results.push(format!("{level}: {text}"));
250 }
251 }
252 results
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 fn make_params(
260 pairs: &[(&str, serde_json::Value)],
261 ) -> serde_json::Map<String, serde_json::Value> {
262 pairs
263 .iter()
264 .map(|(k, v)| ((*k).to_owned(), v.clone()))
265 .collect()
266 }
267
268 #[test]
271 fn parse_cargo_json_empty_input() {
272 let result = parse_cargo_json("", 50);
273 assert!(result.is_empty());
274 }
275
276 #[test]
277 fn parse_cargo_json_non_compiler_message_ignored() {
278 let line = r#"{"reason":"build-script-executed","package_id":"foo"}"#;
279 let result = parse_cargo_json(line, 50);
280 assert!(result.is_empty());
281 }
282
283 #[test]
284 fn parse_cargo_json_compiler_message_with_span() {
285 let line = r#"{"reason":"compiler-message","message":{"level":"error","message":"cannot find value `foo` in this scope","spans":[{"file_name":"src/main.rs","line_start":10,"column_start":5,"is_primary":true}]}}"#;
286 let result = parse_cargo_json(line, 50);
287 assert_eq!(result.len(), 1);
288 assert!(result[0].contains("src/main.rs"));
289 assert!(result[0].contains("10"));
290 assert!(result[0].contains("error"));
291 assert!(result[0].contains("cannot find value"));
292 }
293
294 #[test]
295 fn parse_cargo_json_warning_with_span() {
296 let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused variable: `x`","spans":[{"file_name":"src/lib.rs","line_start":3,"column_start":9,"is_primary":true}]}}"#;
297 let result = parse_cargo_json(line, 50);
298 assert_eq!(result.len(), 1);
299 assert!(result[0].starts_with("src/lib.rs:3:9: warning:"));
300 }
301
302 #[test]
303 fn parse_cargo_json_no_primary_span_uses_message_only() {
304 let line = r#"{"reason":"compiler-message","message":{"level":"error","message":"aborting due to previous error","spans":[]}}"#;
305 let result = parse_cargo_json(line, 50);
306 assert_eq!(result.len(), 1);
307 assert_eq!(result[0], "error: aborting due to previous error");
308 }
309
310 #[test]
311 fn parse_cargo_json_max_cap_respected() {
312 let single = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused","spans":[]}}"#;
313 let input: String = (0..20).map(|_| single).collect::<Vec<_>>().join("\n");
314 let result = parse_cargo_json(&input, 5);
315 assert_eq!(result.len(), 5);
316 }
317
318 #[test]
319 fn parse_cargo_json_empty_message_skipped() {
320 let line = r#"{"reason":"compiler-message","message":{"level":"note","message":" ","spans":[]}}"#;
321 let result = parse_cargo_json(line, 50);
322 assert!(result.is_empty());
323 }
324
325 #[test]
326 fn parse_cargo_json_non_primary_span_skipped_for_location() {
327 let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"some warning","spans":[{"file_name":"src/foo.rs","line_start":1,"column_start":1,"is_primary":false}]}}"#;
328 let result = parse_cargo_json(line, 50);
330 assert_eq!(result.len(), 1);
331 assert_eq!(result[0], "warning: some warning");
332 }
333
334 #[test]
335 fn parse_cargo_json_invalid_json_line_skipped() {
336 let input = "not json\n{\"reason\":\"build-script-executed\"}";
337 let result = parse_cargo_json(input, 50);
338 assert!(result.is_empty());
339 }
340
341 #[tokio::test]
344 async fn diagnostics_sandbox_violation() {
345 let dir = tempfile::tempdir().unwrap();
346 let exec = DiagnosticsExecutor::new(vec![dir.path().to_path_buf()]);
347
348 let call = ToolCall {
349 tool_id: ToolName::new("diagnostics"),
350 params: make_params(&[("path", serde_json::json!("/etc"))]),
351 caller_id: None,
352 };
353 let result = exec.execute_tool_call(&call).await;
354 assert!(result.is_err());
355 }
356
357 #[tokio::test]
358 async fn diagnostics_unknown_tool_returns_none() {
359 let exec = DiagnosticsExecutor::new(vec![]);
360 let call = ToolCall {
361 tool_id: ToolName::new("other"),
362 params: serde_json::Map::new(),
363 caller_id: None,
364 };
365 let result = exec.execute_tool_call(&call).await.unwrap();
366 assert!(result.is_none());
367 }
368
369 #[test]
370 fn diagnostics_tool_definition() {
371 let exec = DiagnosticsExecutor::new(vec![]);
372 let defs = exec.tool_definitions();
373 assert_eq!(defs.len(), 1);
374 assert_eq!(defs[0].id, "diagnostics");
375 assert_eq!(defs[0].invocation, InvocationHint::ToolCall);
376 }
377
378 #[test]
379 fn diagnostics_level_default_is_check() {
380 assert_eq!(DiagnosticsLevel::default(), DiagnosticsLevel::Check);
381 }
382
383 #[test]
384 fn diagnostics_level_deserialize_check() {
385 let p: DiagnosticsParams = serde_json::from_str(r#"{"level":"check"}"#).unwrap();
386 assert_eq!(p.level, DiagnosticsLevel::Check);
387 }
388
389 #[test]
390 fn diagnostics_level_deserialize_clippy() {
391 let p: DiagnosticsParams = serde_json::from_str(r#"{"level":"clippy"}"#).unwrap();
392 assert_eq!(p.level, DiagnosticsLevel::Clippy);
393 }
394
395 #[test]
396 fn diagnostics_params_path_optional() {
397 let p: DiagnosticsParams = serde_json::from_str(r"{}").unwrap();
398 assert!(p.path.is_none());
399 assert_eq!(p.level, DiagnosticsLevel::Check);
400 }
401
402 #[test]
404 fn diagnostics_clippy_subcmd_string() {
405 let subcmd = match DiagnosticsLevel::Clippy {
406 DiagnosticsLevel::Check => "check",
407 DiagnosticsLevel::Clippy => "clippy",
408 };
409 assert_eq!(subcmd, "clippy");
410 }
411
412 #[test]
413 fn diagnostics_check_subcmd_string() {
414 let subcmd = match DiagnosticsLevel::Check {
415 DiagnosticsLevel::Check => "check",
416 DiagnosticsLevel::Clippy => "clippy",
417 };
418 assert_eq!(subcmd, "check");
419 }
420}