1use crate::agent::ide::{Diagnostic, DiagnosticSeverity, DiagnosticsResponse, IdeClient};
24use rig::completion::ToolDefinition;
25use rig::tool::Tool;
26use serde::Deserialize;
27use serde_json::json;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokio::process::Command;
31use tokio::sync::Mutex;
32
33#[derive(Debug, Deserialize)]
34pub struct DiagnosticsArgs {
35 pub path: Option<String>,
37 pub include_warnings: Option<bool>,
39 pub limit: Option<usize>,
41}
42
43#[derive(Debug, thiserror::Error)]
44#[error("Diagnostics error: {0}")]
45pub struct DiagnosticsError(String);
46
47#[derive(Debug, Clone)]
48pub struct DiagnosticsTool {
49 project_path: PathBuf,
50 ide_client: Option<Arc<Mutex<IdeClient>>>,
52}
53
54impl DiagnosticsTool {
55 pub fn new(project_path: PathBuf) -> Self {
56 Self {
57 project_path,
58 ide_client: None,
59 }
60 }
61
62 pub fn with_ide_client(mut self, ide_client: Arc<Mutex<IdeClient>>) -> Self {
64 self.ide_client = Some(ide_client);
65 self
66 }
67
68 fn detect_project_type(&self) -> ProjectType {
70 let cargo_toml = self.project_path.join("Cargo.toml");
71 let package_json = self.project_path.join("package.json");
72 let go_mod = self.project_path.join("go.mod");
73 let pyproject_toml = self.project_path.join("pyproject.toml");
74 let requirements_txt = self.project_path.join("requirements.txt");
75
76 if cargo_toml.exists() {
77 ProjectType::Rust
78 } else if package_json.exists() {
79 ProjectType::JavaScript
80 } else if go_mod.exists() {
81 ProjectType::Go
82 } else if pyproject_toml.exists() || requirements_txt.exists() {
83 ProjectType::Python
84 } else {
85 ProjectType::Unknown
86 }
87 }
88
89 async fn get_ide_diagnostics(&self, file_path: Option<&str>) -> Option<DiagnosticsResponse> {
91 let client = self.ide_client.as_ref()?;
92 let guard = client.lock().await;
93
94 if !guard.is_connected() {
95 return None;
96 }
97
98 guard.get_diagnostics(file_path).await.ok()
99 }
100
101 async fn get_command_diagnostics(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
103 let project_type = self.detect_project_type();
104
105 match project_type {
106 ProjectType::Rust => self.run_cargo_check().await,
107 ProjectType::JavaScript => self.run_npm_lint().await,
108 ProjectType::Go => self.run_go_build().await,
109 ProjectType::Python => self.run_python_check().await,
110 ProjectType::Unknown => Ok(DiagnosticsResponse {
111 diagnostics: Vec::new(),
112 total_errors: 0,
113 total_warnings: 0,
114 }),
115 }
116 }
117
118 async fn run_cargo_check(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
120 let output = Command::new("cargo")
121 .args(["check", "--message-format=json"])
122 .current_dir(&self.project_path)
123 .output()
124 .await
125 .map_err(|e| DiagnosticsError(format!("Failed to run cargo check: {}", e)))?;
126
127 let stdout = String::from_utf8_lossy(&output.stdout);
128 let mut diagnostics = Vec::new();
129
130 for line in stdout.lines() {
131 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(line) {
132 if msg.get("reason").and_then(|r| r.as_str()) == Some("compiler-message") {
133 if let Some(message) = msg.get("message") {
134 if let Some(diag) = self.parse_cargo_message(message) {
135 diagnostics.push(diag);
136 }
137 }
138 }
139 }
140 }
141
142 let total_errors = diagnostics
143 .iter()
144 .filter(|d| d.severity == DiagnosticSeverity::Error)
145 .count() as u32;
146 let total_warnings = diagnostics
147 .iter()
148 .filter(|d| d.severity == DiagnosticSeverity::Warning)
149 .count() as u32;
150
151 Ok(DiagnosticsResponse {
152 diagnostics,
153 total_errors,
154 total_warnings,
155 })
156 }
157
158 fn parse_cargo_message(&self, message: &serde_json::Value) -> Option<Diagnostic> {
160 let level = message.get("level")?.as_str()?;
161 let msg = message.get("message")?.as_str()?;
162
163 let severity = match level {
164 "error" => DiagnosticSeverity::Error,
165 "warning" => DiagnosticSeverity::Warning,
166 "note" | "help" => DiagnosticSeverity::Hint,
167 _ => DiagnosticSeverity::Information,
168 };
169
170 let spans = message.get("spans")?.as_array()?;
172 let span = spans
173 .iter()
174 .find(|s| {
175 s.get("is_primary")
176 .and_then(|v| v.as_bool())
177 .unwrap_or(false)
178 })
179 .or_else(|| spans.first())?;
180
181 let file = span.get("file_name")?.as_str()?;
182 let line = span.get("line_start")?.as_u64()? as u32;
183 let column = span.get("column_start")?.as_u64()? as u32;
184 let end_line = span
185 .get("line_end")
186 .and_then(|v| v.as_u64())
187 .map(|v| v as u32);
188 let end_column = span
189 .get("column_end")
190 .and_then(|v| v.as_u64())
191 .map(|v| v as u32);
192
193 let code = message
194 .get("code")
195 .and_then(|c| c.get("code"))
196 .and_then(|c| c.as_str())
197 .map(|s| s.to_string());
198
199 Some(Diagnostic {
200 file: file.to_string(),
201 line,
202 column,
203 end_line,
204 end_column,
205 severity,
206 message: msg.to_string(),
207 source: Some("rustc".to_string()),
208 code,
209 })
210 }
211
212 async fn run_npm_lint(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
214 let output = Command::new("npm")
216 .args(["run", "lint", "--", "--format=json"])
217 .current_dir(&self.project_path)
218 .output()
219 .await;
220
221 if let Ok(output) = output {
222 if output.status.success() || !output.stdout.is_empty() {
223 let stdout = String::from_utf8_lossy(&output.stdout);
224 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
225 return Ok(self.parse_eslint_output(&results));
226 }
227 }
228 }
229
230 let output = Command::new("npx")
232 .args(["eslint", ".", "--format=json"])
233 .current_dir(&self.project_path)
234 .output()
235 .await
236 .map_err(|e| DiagnosticsError(format!("Failed to run eslint: {}", e)))?;
237
238 let stdout = String::from_utf8_lossy(&output.stdout);
239 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
240 return Ok(self.parse_eslint_output(&results));
241 }
242
243 Ok(DiagnosticsResponse {
245 diagnostics: Vec::new(),
246 total_errors: 0,
247 total_warnings: 0,
248 })
249 }
250
251 fn parse_eslint_output(&self, results: &[serde_json::Value]) -> DiagnosticsResponse {
253 let mut diagnostics = Vec::new();
254
255 for file_result in results {
256 let file = file_result
257 .get("filePath")
258 .and_then(|f| f.as_str())
259 .unwrap_or("");
260
261 if let Some(messages) = file_result.get("messages").and_then(|m| m.as_array()) {
262 for msg in messages {
263 let severity = match msg.get("severity").and_then(|s| s.as_u64()) {
264 Some(2) => DiagnosticSeverity::Error,
265 Some(1) => DiagnosticSeverity::Warning,
266 _ => DiagnosticSeverity::Information,
267 };
268
269 let message = msg
270 .get("message")
271 .and_then(|m| m.as_str())
272 .unwrap_or("")
273 .to_string();
274 let line = msg.get("line").and_then(|l| l.as_u64()).unwrap_or(1) as u32;
275 let column = msg.get("column").and_then(|c| c.as_u64()).unwrap_or(1) as u32;
276 let end_line = msg
277 .get("endLine")
278 .and_then(|l| l.as_u64())
279 .map(|v| v as u32);
280 let end_column = msg
281 .get("endColumn")
282 .and_then(|c| c.as_u64())
283 .map(|v| v as u32);
284 let code = msg
285 .get("ruleId")
286 .and_then(|r| r.as_str())
287 .map(|s| s.to_string());
288
289 diagnostics.push(Diagnostic {
290 file: file.to_string(),
291 line,
292 column,
293 end_line,
294 end_column,
295 severity,
296 message,
297 source: Some("eslint".to_string()),
298 code,
299 });
300 }
301 }
302 }
303
304 let total_errors = diagnostics
305 .iter()
306 .filter(|d| d.severity == DiagnosticSeverity::Error)
307 .count() as u32;
308 let total_warnings = diagnostics
309 .iter()
310 .filter(|d| d.severity == DiagnosticSeverity::Warning)
311 .count() as u32;
312
313 DiagnosticsResponse {
314 diagnostics,
315 total_errors,
316 total_warnings,
317 }
318 }
319
320 async fn run_go_build(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
322 let output = Command::new("go")
323 .args(["build", "-o", "/dev/null", "./..."])
324 .current_dir(&self.project_path)
325 .output()
326 .await
327 .map_err(|e| DiagnosticsError(format!("Failed to run go build: {}", e)))?;
328
329 let stderr = String::from_utf8_lossy(&output.stderr);
330 let mut diagnostics = Vec::new();
331
332 for line in stderr.lines() {
334 if let Some(diag) = self.parse_go_error(line) {
335 diagnostics.push(diag);
336 }
337 }
338
339 let total_errors = diagnostics
340 .iter()
341 .filter(|d| d.severity == DiagnosticSeverity::Error)
342 .count() as u32;
343 let total_warnings = diagnostics
344 .iter()
345 .filter(|d| d.severity == DiagnosticSeverity::Warning)
346 .count() as u32;
347
348 Ok(DiagnosticsResponse {
349 diagnostics,
350 total_errors,
351 total_warnings,
352 })
353 }
354
355 fn parse_go_error(&self, line: &str) -> Option<Diagnostic> {
357 let parts: Vec<&str> = line.splitn(4, ':').collect();
359 if parts.len() < 4 {
360 return None;
361 }
362
363 let file = parts[0].to_string();
364 let line_num = parts[1].parse::<u32>().ok()?;
365 let column = parts[2].parse::<u32>().ok()?;
366 let message = parts[3].trim().to_string();
367
368 Some(Diagnostic {
369 file,
370 line: line_num,
371 column,
372 end_line: None,
373 end_column: None,
374 severity: DiagnosticSeverity::Error,
375 message,
376 source: Some("go".to_string()),
377 code: None,
378 })
379 }
380
381 async fn run_python_check(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
383 let output = Command::new("pylint")
385 .args(["--output-format=json", "."])
386 .current_dir(&self.project_path)
387 .output()
388 .await;
389
390 if let Ok(output) = output {
391 let stdout = String::from_utf8_lossy(&output.stdout);
392 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
393 return Ok(self.parse_pylint_output(&results));
394 }
395 }
396
397 Ok(DiagnosticsResponse {
399 diagnostics: Vec::new(),
400 total_errors: 0,
401 total_warnings: 0,
402 })
403 }
404
405 fn parse_pylint_output(&self, results: &[serde_json::Value]) -> DiagnosticsResponse {
407 let mut diagnostics = Vec::new();
408
409 for msg in results {
410 let msg_type = msg.get("type").and_then(|t| t.as_str()).unwrap_or("");
411 let severity = match msg_type {
412 "error" | "fatal" => DiagnosticSeverity::Error,
413 "warning" => DiagnosticSeverity::Warning,
414 "convention" | "refactor" => DiagnosticSeverity::Hint,
415 _ => DiagnosticSeverity::Information,
416 };
417
418 let file = msg
419 .get("path")
420 .and_then(|p| p.as_str())
421 .unwrap_or("")
422 .to_string();
423 let line = msg.get("line").and_then(|l| l.as_u64()).unwrap_or(1) as u32;
424 let column = msg.get("column").and_then(|c| c.as_u64()).unwrap_or(1) as u32;
425 let message = msg
426 .get("message")
427 .and_then(|m| m.as_str())
428 .unwrap_or("")
429 .to_string();
430 let code = msg
431 .get("message-id")
432 .and_then(|m| m.as_str())
433 .map(|s| s.to_string());
434
435 diagnostics.push(Diagnostic {
436 file,
437 line,
438 column,
439 end_line: None,
440 end_column: None,
441 severity,
442 message,
443 source: Some("pylint".to_string()),
444 code,
445 });
446 }
447
448 let total_errors = diagnostics
449 .iter()
450 .filter(|d| d.severity == DiagnosticSeverity::Error)
451 .count() as u32;
452 let total_warnings = diagnostics
453 .iter()
454 .filter(|d| d.severity == DiagnosticSeverity::Warning)
455 .count() as u32;
456
457 DiagnosticsResponse {
458 diagnostics,
459 total_errors,
460 total_warnings,
461 }
462 }
463
464 fn filter_diagnostics(
466 &self,
467 mut response: DiagnosticsResponse,
468 include_warnings: bool,
469 limit: usize,
470 file_path: Option<&str>,
471 ) -> DiagnosticsResponse {
472 if let Some(path) = file_path {
474 response.diagnostics.retain(|d| d.file.contains(path));
475 }
476
477 if !include_warnings {
479 response
480 .diagnostics
481 .retain(|d| d.severity == DiagnosticSeverity::Error);
482 }
483
484 response.diagnostics.truncate(limit);
486
487 response.total_errors = response
489 .diagnostics
490 .iter()
491 .filter(|d| d.severity == DiagnosticSeverity::Error)
492 .count() as u32;
493 response.total_warnings = response
494 .diagnostics
495 .iter()
496 .filter(|d| d.severity == DiagnosticSeverity::Warning)
497 .count() as u32;
498
499 response
500 }
501}
502
503#[derive(Debug, Clone, Copy)]
504enum ProjectType {
505 Rust,
506 JavaScript,
507 Go,
508 Python,
509 Unknown,
510}
511
512impl Tool for DiagnosticsTool {
513 const NAME: &'static str = "diagnostics";
514
515 type Error = DiagnosticsError;
516 type Args = DiagnosticsArgs;
517 type Output = String;
518
519 async fn definition(&self, _prompt: String) -> ToolDefinition {
520 ToolDefinition {
521 name: Self::NAME.to_string(),
522 description: r#"Check for code errors, warnings, and linting issues.
523
524This tool queries language servers or runs language-specific commands to detect:
525- Compilation errors
526- Type errors
527- Syntax errors
528- Linting warnings
529- Best practice violations
530
531Use this tool after writing or modifying code to verify there are no errors.
532
533The tool automatically detects the project type and uses appropriate checking:
534- Rust: Uses rust-analyzer or cargo check
535- JavaScript/TypeScript: Uses ESLint or TypeScript compiler
536- Go: Uses gopls or go build
537- Python: Uses pylint or pyright
538
539Returns a list of diagnostics with file locations, severity, and messages."#
540 .to_string(),
541 parameters: json!({
542 "type": "object",
543 "properties": {
544 "path": {
545 "type": "string",
546 "description": "Optional file path to check. If not provided, checks all files in the project."
547 },
548 "include_warnings": {
549 "type": "boolean",
550 "description": "Whether to include warnings in addition to errors (default: true)"
551 },
552 "limit": {
553 "type": "integer",
554 "description": "Maximum number of diagnostics to return (default: 50)"
555 }
556 }
557 }),
558 }
559 }
560
561 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
562 let include_warnings = args.include_warnings.unwrap_or(true);
563 let limit = args.limit.unwrap_or(50);
564 let file_path = args.path.as_deref();
565
566 let response = if let Some(ide_response) = self.get_ide_diagnostics(file_path).await {
568 ide_response
569 } else {
570 self.get_command_diagnostics().await?
572 };
573
574 let filtered = self.filter_diagnostics(response, include_warnings, limit, file_path);
576
577 let result = if filtered.diagnostics.is_empty() {
579 json!({
580 "success": true,
581 "message": "No errors or warnings found",
582 "total_errors": 0,
583 "total_warnings": 0,
584 "diagnostics": []
585 })
586 } else {
587 let formatted_diagnostics: Vec<serde_json::Value> = filtered
588 .diagnostics
589 .iter()
590 .map(|d| {
591 json!({
592 "file": d.file,
593 "line": d.line,
594 "column": d.column,
595 "severity": d.severity.as_str(),
596 "message": d.message,
597 "source": d.source,
598 "code": d.code
599 })
600 })
601 .collect();
602
603 json!({
604 "success": filtered.total_errors == 0,
605 "total_errors": filtered.total_errors,
606 "total_warnings": filtered.total_warnings,
607 "diagnostics": formatted_diagnostics
608 })
609 };
610
611 serde_json::to_string_pretty(&result)
612 .map_err(|e| DiagnosticsError(format!("Failed to serialize: {}", e)))
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619 use std::env;
620
621 #[tokio::test]
622 async fn test_diagnostics_tool_creation() {
623 let tool = DiagnosticsTool::new(PathBuf::from("."));
624 assert_eq!(tool.project_path, PathBuf::from("."));
625 }
626
627 #[test]
628 fn test_project_type_detection() {
629 let tool = DiagnosticsTool::new(env::current_dir().unwrap());
631 let project_type = tool.detect_project_type();
632 assert!(matches!(project_type, ProjectType::Rust));
634 }
635
636 #[test]
637 fn test_parse_go_error() {
638 let tool = DiagnosticsTool::new(PathBuf::from("."));
639 let line = "main.go:10:5: undefined: foo";
640 let diag = tool.parse_go_error(line);
641 assert!(diag.is_some());
642 let diag = diag.unwrap();
643 assert_eq!(diag.file, "main.go");
644 assert_eq!(diag.line, 10);
645 assert_eq!(diag.column, 5);
646 assert_eq!(diag.message, "undefined: foo");
647 }
648
649 #[test]
650 fn test_filter_diagnostics() {
651 let tool = DiagnosticsTool::new(PathBuf::from("."));
652 let response = DiagnosticsResponse {
653 diagnostics: vec![
654 Diagnostic {
655 file: "src/main.rs".to_string(),
656 line: 1,
657 column: 1,
658 end_line: None,
659 end_column: None,
660 severity: DiagnosticSeverity::Error,
661 message: "error".to_string(),
662 source: None,
663 code: None,
664 },
665 Diagnostic {
666 file: "src/lib.rs".to_string(),
667 line: 1,
668 column: 1,
669 end_line: None,
670 end_column: None,
671 severity: DiagnosticSeverity::Warning,
672 message: "warning".to_string(),
673 source: None,
674 code: None,
675 },
676 ],
677 total_errors: 1,
678 total_warnings: 1,
679 };
680
681 let filtered = tool.filter_diagnostics(response.clone(), false, 50, None);
683 assert_eq!(filtered.diagnostics.len(), 1);
684 assert_eq!(filtered.total_errors, 1);
685 assert_eq!(filtered.total_warnings, 0);
686
687 let filtered = tool.filter_diagnostics(response, true, 50, Some("main.rs"));
689 assert_eq!(filtered.diagnostics.len(), 1);
690 assert_eq!(filtered.diagnostics[0].file, "src/main.rs");
691 }
692}