1use crate::agent::ide::{Diagnostic, DiagnosticSeverity, DiagnosticsResponse, IdeClient};
24use rig::completion::ToolDefinition;
25use rig::tool::Tool;
26use serde::Deserialize;
27use serde_json::json;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokio::process::Command;
31use tokio::sync::Mutex;
32
33#[derive(Debug, Deserialize)]
34pub struct DiagnosticsArgs {
35 pub path: Option<String>,
37 pub include_warnings: Option<bool>,
39 pub limit: Option<usize>,
41}
42
43#[derive(Debug, thiserror::Error)]
44#[error("Diagnostics error: {0}")]
45pub struct DiagnosticsError(String);
46
47#[derive(Debug, Clone)]
48pub struct DiagnosticsTool {
49 project_path: PathBuf,
50 ide_client: Option<Arc<Mutex<IdeClient>>>,
52}
53
54impl DiagnosticsTool {
55 pub fn new(project_path: PathBuf) -> Self {
56 Self {
57 project_path,
58 ide_client: None,
59 }
60 }
61
62 pub fn with_ide_client(mut self, ide_client: Arc<Mutex<IdeClient>>) -> Self {
64 self.ide_client = Some(ide_client);
65 self
66 }
67
68 fn detect_project_type(&self) -> ProjectType {
70 let cargo_toml = self.project_path.join("Cargo.toml");
71 let package_json = self.project_path.join("package.json");
72 let go_mod = self.project_path.join("go.mod");
73 let pyproject_toml = self.project_path.join("pyproject.toml");
74 let requirements_txt = self.project_path.join("requirements.txt");
75
76 if cargo_toml.exists() {
77 ProjectType::Rust
78 } else if package_json.exists() {
79 ProjectType::JavaScript
80 } else if go_mod.exists() {
81 ProjectType::Go
82 } else if pyproject_toml.exists() || requirements_txt.exists() {
83 ProjectType::Python
84 } else {
85 ProjectType::Unknown
86 }
87 }
88
89 async fn get_ide_diagnostics(
91 &self,
92 file_path: Option<&str>,
93 ) -> Option<DiagnosticsResponse> {
94 let client = self.ide_client.as_ref()?;
95 let guard = client.lock().await;
96
97 if !guard.is_connected() {
98 return None;
99 }
100
101 guard.get_diagnostics(file_path).await.ok()
102 }
103
104 async fn get_command_diagnostics(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
106 let project_type = self.detect_project_type();
107
108 match project_type {
109 ProjectType::Rust => self.run_cargo_check().await,
110 ProjectType::JavaScript => self.run_npm_lint().await,
111 ProjectType::Go => self.run_go_build().await,
112 ProjectType::Python => self.run_python_check().await,
113 ProjectType::Unknown => Ok(DiagnosticsResponse {
114 diagnostics: Vec::new(),
115 total_errors: 0,
116 total_warnings: 0,
117 }),
118 }
119 }
120
121 async fn run_cargo_check(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
123 let output = Command::new("cargo")
124 .args(["check", "--message-format=json"])
125 .current_dir(&self.project_path)
126 .output()
127 .await
128 .map_err(|e| DiagnosticsError(format!("Failed to run cargo check: {}", e)))?;
129
130 let stdout = String::from_utf8_lossy(&output.stdout);
131 let mut diagnostics = Vec::new();
132
133 for line in stdout.lines() {
134 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(line) {
135 if msg.get("reason").and_then(|r| r.as_str()) == Some("compiler-message") {
136 if let Some(message) = msg.get("message") {
137 if let Some(diag) = self.parse_cargo_message(message) {
138 diagnostics.push(diag);
139 }
140 }
141 }
142 }
143 }
144
145 let total_errors = diagnostics
146 .iter()
147 .filter(|d| d.severity == DiagnosticSeverity::Error)
148 .count() as u32;
149 let total_warnings = diagnostics
150 .iter()
151 .filter(|d| d.severity == DiagnosticSeverity::Warning)
152 .count() as u32;
153
154 Ok(DiagnosticsResponse {
155 diagnostics,
156 total_errors,
157 total_warnings,
158 })
159 }
160
161 fn parse_cargo_message(&self, message: &serde_json::Value) -> Option<Diagnostic> {
163 let level = message.get("level")?.as_str()?;
164 let msg = message.get("message")?.as_str()?;
165
166 let severity = match level {
167 "error" => DiagnosticSeverity::Error,
168 "warning" => DiagnosticSeverity::Warning,
169 "note" | "help" => DiagnosticSeverity::Hint,
170 _ => DiagnosticSeverity::Information,
171 };
172
173 let spans = message.get("spans")?.as_array()?;
175 let span = spans.iter().find(|s| {
176 s.get("is_primary").and_then(|v| v.as_bool()).unwrap_or(false)
177 }).or_else(|| spans.first())?;
178
179 let file = span.get("file_name")?.as_str()?;
180 let line = span.get("line_start")?.as_u64()? as u32;
181 let column = span.get("column_start")?.as_u64()? as u32;
182 let end_line = span.get("line_end").and_then(|v| v.as_u64()).map(|v| v as u32);
183 let end_column = span.get("column_end").and_then(|v| v.as_u64()).map(|v| v as u32);
184
185 let code = message
186 .get("code")
187 .and_then(|c| c.get("code"))
188 .and_then(|c| c.as_str())
189 .map(|s| s.to_string());
190
191 Some(Diagnostic {
192 file: file.to_string(),
193 line,
194 column,
195 end_line,
196 end_column,
197 severity,
198 message: msg.to_string(),
199 source: Some("rustc".to_string()),
200 code,
201 })
202 }
203
204 async fn run_npm_lint(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
206 let output = Command::new("npm")
208 .args(["run", "lint", "--", "--format=json"])
209 .current_dir(&self.project_path)
210 .output()
211 .await;
212
213 if let Ok(output) = output {
214 if output.status.success() || !output.stdout.is_empty() {
215 let stdout = String::from_utf8_lossy(&output.stdout);
216 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
217 return Ok(self.parse_eslint_output(&results));
218 }
219 }
220 }
221
222 let output = Command::new("npx")
224 .args(["eslint", ".", "--format=json"])
225 .current_dir(&self.project_path)
226 .output()
227 .await
228 .map_err(|e| DiagnosticsError(format!("Failed to run eslint: {}", e)))?;
229
230 let stdout = String::from_utf8_lossy(&output.stdout);
231 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
232 return Ok(self.parse_eslint_output(&results));
233 }
234
235 Ok(DiagnosticsResponse {
237 diagnostics: Vec::new(),
238 total_errors: 0,
239 total_warnings: 0,
240 })
241 }
242
243 fn parse_eslint_output(&self, results: &[serde_json::Value]) -> DiagnosticsResponse {
245 let mut diagnostics = Vec::new();
246
247 for file_result in results {
248 let file = file_result
249 .get("filePath")
250 .and_then(|f| f.as_str())
251 .unwrap_or("");
252
253 if let Some(messages) = file_result.get("messages").and_then(|m| m.as_array()) {
254 for msg in messages {
255 let severity = match msg.get("severity").and_then(|s| s.as_u64()) {
256 Some(2) => DiagnosticSeverity::Error,
257 Some(1) => DiagnosticSeverity::Warning,
258 _ => DiagnosticSeverity::Information,
259 };
260
261 let message = msg
262 .get("message")
263 .and_then(|m| m.as_str())
264 .unwrap_or("")
265 .to_string();
266 let line = msg.get("line").and_then(|l| l.as_u64()).unwrap_or(1) as u32;
267 let column = msg.get("column").and_then(|c| c.as_u64()).unwrap_or(1) as u32;
268 let end_line = msg.get("endLine").and_then(|l| l.as_u64()).map(|v| v as u32);
269 let end_column = msg.get("endColumn").and_then(|c| c.as_u64()).map(|v| v as u32);
270 let code = msg.get("ruleId").and_then(|r| r.as_str()).map(|s| s.to_string());
271
272 diagnostics.push(Diagnostic {
273 file: file.to_string(),
274 line,
275 column,
276 end_line,
277 end_column,
278 severity,
279 message,
280 source: Some("eslint".to_string()),
281 code,
282 });
283 }
284 }
285 }
286
287 let total_errors = diagnostics
288 .iter()
289 .filter(|d| d.severity == DiagnosticSeverity::Error)
290 .count() as u32;
291 let total_warnings = diagnostics
292 .iter()
293 .filter(|d| d.severity == DiagnosticSeverity::Warning)
294 .count() as u32;
295
296 DiagnosticsResponse {
297 diagnostics,
298 total_errors,
299 total_warnings,
300 }
301 }
302
303 async fn run_go_build(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
305 let output = Command::new("go")
306 .args(["build", "-o", "/dev/null", "./..."])
307 .current_dir(&self.project_path)
308 .output()
309 .await
310 .map_err(|e| DiagnosticsError(format!("Failed to run go build: {}", e)))?;
311
312 let stderr = String::from_utf8_lossy(&output.stderr);
313 let mut diagnostics = Vec::new();
314
315 for line in stderr.lines() {
317 if let Some(diag) = self.parse_go_error(line) {
318 diagnostics.push(diag);
319 }
320 }
321
322 let total_errors = diagnostics
323 .iter()
324 .filter(|d| d.severity == DiagnosticSeverity::Error)
325 .count() as u32;
326 let total_warnings = diagnostics
327 .iter()
328 .filter(|d| d.severity == DiagnosticSeverity::Warning)
329 .count() as u32;
330
331 Ok(DiagnosticsResponse {
332 diagnostics,
333 total_errors,
334 total_warnings,
335 })
336 }
337
338 fn parse_go_error(&self, line: &str) -> Option<Diagnostic> {
340 let parts: Vec<&str> = line.splitn(4, ':').collect();
342 if parts.len() < 4 {
343 return None;
344 }
345
346 let file = parts[0].to_string();
347 let line_num = parts[1].parse::<u32>().ok()?;
348 let column = parts[2].parse::<u32>().ok()?;
349 let message = parts[3].trim().to_string();
350
351 Some(Diagnostic {
352 file,
353 line: line_num,
354 column,
355 end_line: None,
356 end_column: None,
357 severity: DiagnosticSeverity::Error,
358 message,
359 source: Some("go".to_string()),
360 code: None,
361 })
362 }
363
364 async fn run_python_check(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
366 let output = Command::new("pylint")
368 .args(["--output-format=json", "."])
369 .current_dir(&self.project_path)
370 .output()
371 .await;
372
373 if let Ok(output) = output {
374 let stdout = String::from_utf8_lossy(&output.stdout);
375 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
376 return Ok(self.parse_pylint_output(&results));
377 }
378 }
379
380 Ok(DiagnosticsResponse {
382 diagnostics: Vec::new(),
383 total_errors: 0,
384 total_warnings: 0,
385 })
386 }
387
388 fn parse_pylint_output(&self, results: &[serde_json::Value]) -> DiagnosticsResponse {
390 let mut diagnostics = Vec::new();
391
392 for msg in results {
393 let msg_type = msg.get("type").and_then(|t| t.as_str()).unwrap_or("");
394 let severity = match msg_type {
395 "error" | "fatal" => DiagnosticSeverity::Error,
396 "warning" => DiagnosticSeverity::Warning,
397 "convention" | "refactor" => DiagnosticSeverity::Hint,
398 _ => DiagnosticSeverity::Information,
399 };
400
401 let file = msg
402 .get("path")
403 .and_then(|p| p.as_str())
404 .unwrap_or("")
405 .to_string();
406 let line = msg.get("line").and_then(|l| l.as_u64()).unwrap_or(1) as u32;
407 let column = msg.get("column").and_then(|c| c.as_u64()).unwrap_or(1) as u32;
408 let message = msg
409 .get("message")
410 .and_then(|m| m.as_str())
411 .unwrap_or("")
412 .to_string();
413 let code = msg
414 .get("message-id")
415 .and_then(|m| m.as_str())
416 .map(|s| s.to_string());
417
418 diagnostics.push(Diagnostic {
419 file,
420 line,
421 column,
422 end_line: None,
423 end_column: None,
424 severity,
425 message,
426 source: Some("pylint".to_string()),
427 code,
428 });
429 }
430
431 let total_errors = diagnostics
432 .iter()
433 .filter(|d| d.severity == DiagnosticSeverity::Error)
434 .count() as u32;
435 let total_warnings = diagnostics
436 .iter()
437 .filter(|d| d.severity == DiagnosticSeverity::Warning)
438 .count() as u32;
439
440 DiagnosticsResponse {
441 diagnostics,
442 total_errors,
443 total_warnings,
444 }
445 }
446
447 fn filter_diagnostics(
449 &self,
450 mut response: DiagnosticsResponse,
451 include_warnings: bool,
452 limit: usize,
453 file_path: Option<&str>,
454 ) -> DiagnosticsResponse {
455 if let Some(path) = file_path {
457 response.diagnostics.retain(|d| d.file.contains(path));
458 }
459
460 if !include_warnings {
462 response.diagnostics.retain(|d| d.severity == DiagnosticSeverity::Error);
463 }
464
465 response.diagnostics.truncate(limit);
467
468 response.total_errors = response
470 .diagnostics
471 .iter()
472 .filter(|d| d.severity == DiagnosticSeverity::Error)
473 .count() as u32;
474 response.total_warnings = response
475 .diagnostics
476 .iter()
477 .filter(|d| d.severity == DiagnosticSeverity::Warning)
478 .count() as u32;
479
480 response
481 }
482}
483
484#[derive(Debug, Clone, Copy)]
485enum ProjectType {
486 Rust,
487 JavaScript,
488 Go,
489 Python,
490 Unknown,
491}
492
493impl Tool for DiagnosticsTool {
494 const NAME: &'static str = "diagnostics";
495
496 type Error = DiagnosticsError;
497 type Args = DiagnosticsArgs;
498 type Output = String;
499
500 async fn definition(&self, _prompt: String) -> ToolDefinition {
501 ToolDefinition {
502 name: Self::NAME.to_string(),
503 description: r#"Check for code errors, warnings, and linting issues.
504
505This tool queries language servers or runs language-specific commands to detect:
506- Compilation errors
507- Type errors
508- Syntax errors
509- Linting warnings
510- Best practice violations
511
512Use this tool after writing or modifying code to verify there are no errors.
513
514The tool automatically detects the project type and uses appropriate checking:
515- Rust: Uses rust-analyzer or cargo check
516- JavaScript/TypeScript: Uses ESLint or TypeScript compiler
517- Go: Uses gopls or go build
518- Python: Uses pylint or pyright
519
520Returns a list of diagnostics with file locations, severity, and messages."#
521 .to_string(),
522 parameters: json!({
523 "type": "object",
524 "properties": {
525 "path": {
526 "type": "string",
527 "description": "Optional file path to check. If not provided, checks all files in the project."
528 },
529 "include_warnings": {
530 "type": "boolean",
531 "description": "Whether to include warnings in addition to errors (default: true)"
532 },
533 "limit": {
534 "type": "integer",
535 "description": "Maximum number of diagnostics to return (default: 50)"
536 }
537 }
538 }),
539 }
540 }
541
542 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
543 let include_warnings = args.include_warnings.unwrap_or(true);
544 let limit = args.limit.unwrap_or(50);
545 let file_path = args.path.as_deref();
546
547 let response = if let Some(ide_response) = self.get_ide_diagnostics(file_path).await {
549 ide_response
550 } else {
551 self.get_command_diagnostics().await?
553 };
554
555 let filtered = self.filter_diagnostics(response, include_warnings, limit, file_path);
557
558 let result = if filtered.diagnostics.is_empty() {
560 json!({
561 "success": true,
562 "message": "No errors or warnings found",
563 "total_errors": 0,
564 "total_warnings": 0,
565 "diagnostics": []
566 })
567 } else {
568 let formatted_diagnostics: Vec<serde_json::Value> = filtered
569 .diagnostics
570 .iter()
571 .map(|d| {
572 json!({
573 "file": d.file,
574 "line": d.line,
575 "column": d.column,
576 "severity": d.severity.as_str(),
577 "message": d.message,
578 "source": d.source,
579 "code": d.code
580 })
581 })
582 .collect();
583
584 json!({
585 "success": filtered.total_errors == 0,
586 "total_errors": filtered.total_errors,
587 "total_warnings": filtered.total_warnings,
588 "diagnostics": formatted_diagnostics
589 })
590 };
591
592 serde_json::to_string_pretty(&result)
593 .map_err(|e| DiagnosticsError(format!("Failed to serialize: {}", e)))
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use std::env;
601
602 #[tokio::test]
603 async fn test_diagnostics_tool_creation() {
604 let tool = DiagnosticsTool::new(PathBuf::from("."));
605 assert_eq!(tool.project_path, PathBuf::from("."));
606 }
607
608 #[test]
609 fn test_project_type_detection() {
610 let tool = DiagnosticsTool::new(env::current_dir().unwrap());
612 let project_type = tool.detect_project_type();
613 assert!(matches!(project_type, ProjectType::Rust));
615 }
616
617 #[test]
618 fn test_parse_go_error() {
619 let tool = DiagnosticsTool::new(PathBuf::from("."));
620 let line = "main.go:10:5: undefined: foo";
621 let diag = tool.parse_go_error(line);
622 assert!(diag.is_some());
623 let diag = diag.unwrap();
624 assert_eq!(diag.file, "main.go");
625 assert_eq!(diag.line, 10);
626 assert_eq!(diag.column, 5);
627 assert_eq!(diag.message, "undefined: foo");
628 }
629
630 #[test]
631 fn test_filter_diagnostics() {
632 let tool = DiagnosticsTool::new(PathBuf::from("."));
633 let response = DiagnosticsResponse {
634 diagnostics: vec![
635 Diagnostic {
636 file: "src/main.rs".to_string(),
637 line: 1,
638 column: 1,
639 end_line: None,
640 end_column: None,
641 severity: DiagnosticSeverity::Error,
642 message: "error".to_string(),
643 source: None,
644 code: None,
645 },
646 Diagnostic {
647 file: "src/lib.rs".to_string(),
648 line: 1,
649 column: 1,
650 end_line: None,
651 end_column: None,
652 severity: DiagnosticSeverity::Warning,
653 message: "warning".to_string(),
654 source: None,
655 code: None,
656 },
657 ],
658 total_errors: 1,
659 total_warnings: 1,
660 };
661
662 let filtered = tool.filter_diagnostics(response.clone(), false, 50, None);
664 assert_eq!(filtered.diagnostics.len(), 1);
665 assert_eq!(filtered.total_errors, 1);
666 assert_eq!(filtered.total_warnings, 0);
667
668 let filtered = tool.filter_diagnostics(response, true, 50, Some("main.rs"));
670 assert_eq!(filtered.diagnostics.len(), 1);
671 assert_eq!(filtered.diagnostics[0].file, "src/main.rs");
672 }
673}