1use std::path::{Path, PathBuf};
5
6use schemars::JsonSchema;
7use serde::Deserialize;
8
9use zeph_common::ToolName;
10
11use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
12use crate::registry::{InvocationHint, ToolDef};
13
14#[derive(Debug, Default, Deserialize, JsonSchema, PartialEq, Eq)]
16#[serde(rename_all = "snake_case")]
17pub enum DiagnosticsLevel {
18 #[default]
20 Check,
21 Clippy,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct DiagnosticsParams {
27 path: Option<String>,
29 #[serde(default)]
31 level: DiagnosticsLevel,
32}
33
34#[derive(Debug)]
36pub struct DiagnosticsExecutor {
37 allowed_paths: Vec<PathBuf>,
38 max_diagnostics: usize,
40}
41
42impl DiagnosticsExecutor {
43 #[must_use]
44 pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
45 let paths = if allowed_paths.is_empty() {
46 vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
47 } else {
48 allowed_paths
49 };
50 Self {
51 allowed_paths: paths
52 .into_iter()
53 .map(|p| p.canonicalize().unwrap_or(p))
54 .collect(),
55 max_diagnostics: 50,
56 }
57 }
58
59 #[must_use]
60 pub fn with_max_diagnostics(mut self, max: usize) -> Self {
61 self.max_diagnostics = max;
62 self
63 }
64
65 fn validate_path(&self, path: &Path) -> Result<PathBuf, ToolError> {
66 let resolved = if path.is_absolute() {
67 path.to_path_buf()
68 } else {
69 std::env::current_dir()
70 .unwrap_or_else(|_| PathBuf::from("."))
71 .join(path)
72 };
73 let canonical = resolved.canonicalize().map_err(|e| {
74 ToolError::Execution(std::io::Error::new(
75 std::io::ErrorKind::NotFound,
76 format!("path not found: {}: {e}", resolved.display()),
77 ))
78 })?;
79 if !self.allowed_paths.iter().any(|a| canonical.starts_with(a)) {
80 return Err(ToolError::SandboxViolation {
81 path: canonical.display().to_string(),
82 });
83 }
84 Ok(canonical)
85 }
86}
87
88impl ToolExecutor for DiagnosticsExecutor {
89 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
90 Ok(None)
91 }
92
93 #[cfg_attr(
94 feature = "profiling",
95 tracing::instrument(name = "tool.diagnostics", skip_all)
96 )]
97 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
98 if call.tool_id != "diagnostics" {
99 return Ok(None);
100 }
101 let p: DiagnosticsParams = deserialize_params(&call.params)?;
102 let work_dir = if let Some(path) = &p.path {
103 self.validate_path(Path::new(path))?
104 } else {
105 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
106 self.validate_path(&cwd)?
107 };
108
109 let subcmd = match p.level {
110 DiagnosticsLevel::Check => "check",
111 DiagnosticsLevel::Clippy => "clippy",
112 };
113
114 let cargo = which_cargo()?;
115
116 let output = tokio::process::Command::new(&cargo)
117 .arg(subcmd)
118 .arg("--message-format=json")
119 .current_dir(&work_dir)
120 .output()
121 .await
122 .map_err(|e| {
123 ToolError::Execution(std::io::Error::new(
124 std::io::ErrorKind::NotFound,
125 format!("failed to run cargo: {e}"),
126 ))
127 })?;
128
129 let stdout = String::from_utf8_lossy(&output.stdout);
130 let diagnostics = parse_cargo_json(&stdout, self.max_diagnostics);
131
132 let summary = if diagnostics.is_empty() {
133 "No diagnostics".to_owned()
134 } else {
135 diagnostics.join("\n")
136 };
137
138 Ok(Some(ToolOutput {
139 tool_name: ToolName::new("diagnostics"),
140 summary,
141 blocks_executed: 1,
142 filter_stats: None,
143 diff: None,
144 streamed: false,
145 terminal_id: None,
146 locations: None,
147 raw_response: None,
148 claim_source: Some(crate::executor::ClaimSource::Diagnostics),
149 }))
150 }
151
152 fn tool_definitions(&self) -> Vec<ToolDef> {
153 vec![ToolDef {
154 id: "diagnostics".into(),
155 description: "Run cargo check or cargo clippy on a Rust workspace and return compiler diagnostics.\n\nParameters: path (string, optional) - workspace directory (default: cwd); level (string, optional) - \"check\" or \"clippy\" (default: \"check\")\nReturns: structured diagnostics with file paths, line numbers, severity, and messages; capped at 50 results\nErrors: SandboxViolation if path outside allowed dirs; Execution if cargo is not found\nExample: {\"path\": \".\", \"level\": \"clippy\"}".into(),
156 schema: schemars::schema_for!(DiagnosticsParams),
157 invocation: InvocationHint::ToolCall,
158 }]
159 }
160}
161
162fn which_cargo() -> Result<PathBuf, ToolError> {
169 if let Ok(cargo) = std::env::var("CARGO") {
171 let p = PathBuf::from(&cargo);
172 if p.is_file() {
173 return Ok(p.canonicalize().unwrap_or(p));
174 }
175 }
176 for dir in std::env::var("PATH").unwrap_or_default().split(':') {
178 let candidate = PathBuf::from(dir).join("cargo");
179 if candidate.is_file() {
180 return Ok(candidate.canonicalize().unwrap_or(candidate));
181 }
182 }
183 Err(ToolError::Execution(std::io::Error::new(
184 std::io::ErrorKind::NotFound,
185 "cargo not found in PATH",
186 )))
187}
188
189pub(crate) fn parse_cargo_json(output: &str, max: usize) -> Vec<String> {
194 let mut results = Vec::new();
195 for line in output.lines() {
196 if results.len() >= max {
197 break;
198 }
199 let Ok(val) = serde_json::from_str::<serde_json::Value>(line) else {
200 continue;
201 };
202 if val.get("reason").and_then(|r| r.as_str()) != Some("compiler-message") {
203 continue;
204 }
205 let Some(msg) = val.get("message") else {
206 continue;
207 };
208 let level = msg
209 .get("level")
210 .and_then(|l| l.as_str())
211 .unwrap_or("unknown");
212 let text = msg
213 .get("message")
214 .and_then(|m| m.as_str())
215 .unwrap_or("")
216 .trim();
217 if text.is_empty() {
218 continue;
219 }
220
221 let spans = msg
223 .get("spans")
224 .and_then(serde_json::Value::as_array)
225 .map_or(&[] as &[_], Vec::as_slice);
226
227 let primary = spans.iter().find(|s| {
228 s.get("is_primary")
229 .and_then(serde_json::Value::as_bool)
230 .unwrap_or(false)
231 });
232
233 if let Some(span) = primary {
234 let file = span
235 .get("file_name")
236 .and_then(|f| f.as_str())
237 .unwrap_or("?");
238 let line = span
239 .get("line_start")
240 .and_then(serde_json::Value::as_u64)
241 .unwrap_or(0);
242 let col = span
243 .get("column_start")
244 .and_then(serde_json::Value::as_u64)
245 .unwrap_or(0);
246 results.push(format!("{file}:{line}:{col}: {level}: {text}"));
247 } else {
248 results.push(format!("{level}: {text}"));
249 }
250 }
251 results
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 fn make_params(
259 pairs: &[(&str, serde_json::Value)],
260 ) -> serde_json::Map<String, serde_json::Value> {
261 pairs
262 .iter()
263 .map(|(k, v)| ((*k).to_owned(), v.clone()))
264 .collect()
265 }
266
267 #[test]
270 fn parse_cargo_json_empty_input() {
271 let result = parse_cargo_json("", 50);
272 assert!(result.is_empty());
273 }
274
275 #[test]
276 fn parse_cargo_json_non_compiler_message_ignored() {
277 let line = r#"{"reason":"build-script-executed","package_id":"foo"}"#;
278 let result = parse_cargo_json(line, 50);
279 assert!(result.is_empty());
280 }
281
282 #[test]
283 fn parse_cargo_json_compiler_message_with_span() {
284 let line = r#"{"reason":"compiler-message","message":{"level":"error","message":"cannot find value `foo` in this scope","spans":[{"file_name":"src/main.rs","line_start":10,"column_start":5,"is_primary":true}]}}"#;
285 let result = parse_cargo_json(line, 50);
286 assert_eq!(result.len(), 1);
287 assert!(result[0].contains("src/main.rs"));
288 assert!(result[0].contains("10"));
289 assert!(result[0].contains("error"));
290 assert!(result[0].contains("cannot find value"));
291 }
292
293 #[test]
294 fn parse_cargo_json_warning_with_span() {
295 let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused variable: `x`","spans":[{"file_name":"src/lib.rs","line_start":3,"column_start":9,"is_primary":true}]}}"#;
296 let result = parse_cargo_json(line, 50);
297 assert_eq!(result.len(), 1);
298 assert!(result[0].starts_with("src/lib.rs:3:9: warning:"));
299 }
300
301 #[test]
302 fn parse_cargo_json_no_primary_span_uses_message_only() {
303 let line = r#"{"reason":"compiler-message","message":{"level":"error","message":"aborting due to previous error","spans":[]}}"#;
304 let result = parse_cargo_json(line, 50);
305 assert_eq!(result.len(), 1);
306 assert_eq!(result[0], "error: aborting due to previous error");
307 }
308
309 #[test]
310 fn parse_cargo_json_max_cap_respected() {
311 let single = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused","spans":[]}}"#;
312 let input: String = (0..20).map(|_| single).collect::<Vec<_>>().join("\n");
313 let result = parse_cargo_json(&input, 5);
314 assert_eq!(result.len(), 5);
315 }
316
317 #[test]
318 fn parse_cargo_json_empty_message_skipped() {
319 let line = r#"{"reason":"compiler-message","message":{"level":"note","message":" ","spans":[]}}"#;
320 let result = parse_cargo_json(line, 50);
321 assert!(result.is_empty());
322 }
323
324 #[test]
325 fn parse_cargo_json_non_primary_span_skipped_for_location() {
326 let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"some warning","spans":[{"file_name":"src/foo.rs","line_start":1,"column_start":1,"is_primary":false}]}}"#;
327 let result = parse_cargo_json(line, 50);
329 assert_eq!(result.len(), 1);
330 assert_eq!(result[0], "warning: some warning");
331 }
332
333 #[test]
334 fn parse_cargo_json_invalid_json_line_skipped() {
335 let input = "not json\n{\"reason\":\"build-script-executed\"}";
336 let result = parse_cargo_json(input, 50);
337 assert!(result.is_empty());
338 }
339
340 #[tokio::test]
343 async fn diagnostics_sandbox_violation() {
344 let dir = tempfile::tempdir().unwrap();
345 let exec = DiagnosticsExecutor::new(vec![dir.path().to_path_buf()]);
346
347 let call = ToolCall {
348 tool_id: ToolName::new("diagnostics"),
349 params: make_params(&[("path", serde_json::json!("/etc"))]),
350 caller_id: None,
351 };
352 let result = exec.execute_tool_call(&call).await;
353 assert!(result.is_err());
354 }
355
356 #[tokio::test]
357 async fn diagnostics_unknown_tool_returns_none() {
358 let exec = DiagnosticsExecutor::new(vec![]);
359 let call = ToolCall {
360 tool_id: ToolName::new("other"),
361 params: serde_json::Map::new(),
362 caller_id: None,
363 };
364 let result = exec.execute_tool_call(&call).await.unwrap();
365 assert!(result.is_none());
366 }
367
368 #[test]
369 fn diagnostics_tool_definition() {
370 let exec = DiagnosticsExecutor::new(vec![]);
371 let defs = exec.tool_definitions();
372 assert_eq!(defs.len(), 1);
373 assert_eq!(defs[0].id, "diagnostics");
374 assert_eq!(defs[0].invocation, InvocationHint::ToolCall);
375 }
376
377 #[test]
378 fn diagnostics_level_default_is_check() {
379 assert_eq!(DiagnosticsLevel::default(), DiagnosticsLevel::Check);
380 }
381
382 #[test]
383 fn diagnostics_level_deserialize_check() {
384 let p: DiagnosticsParams = serde_json::from_str(r#"{"level":"check"}"#).unwrap();
385 assert_eq!(p.level, DiagnosticsLevel::Check);
386 }
387
388 #[test]
389 fn diagnostics_level_deserialize_clippy() {
390 let p: DiagnosticsParams = serde_json::from_str(r#"{"level":"clippy"}"#).unwrap();
391 assert_eq!(p.level, DiagnosticsLevel::Clippy);
392 }
393
394 #[test]
395 fn diagnostics_params_path_optional() {
396 let p: DiagnosticsParams = serde_json::from_str(r"{}").unwrap();
397 assert!(p.path.is_none());
398 assert_eq!(p.level, DiagnosticsLevel::Check);
399 }
400
401 #[test]
403 fn diagnostics_clippy_subcmd_string() {
404 let subcmd = match DiagnosticsLevel::Clippy {
405 DiagnosticsLevel::Check => "check",
406 DiagnosticsLevel::Clippy => "clippy",
407 };
408 assert_eq!(subcmd, "clippy");
409 }
410
411 #[test]
412 fn diagnostics_check_subcmd_string() {
413 let subcmd = match DiagnosticsLevel::Check {
414 DiagnosticsLevel::Check => "check",
415 DiagnosticsLevel::Clippy => "clippy",
416 };
417 assert_eq!(subcmd, "check");
418 }
419}