1use crate::agent::ide::{Diagnostic, DiagnosticSeverity, DiagnosticsResponse, IdeClient};
24use rig::completion::ToolDefinition;
25use rig::tool::Tool;
26use serde::Deserialize;
27use serde_json::json;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokio::process::Command;
31use tokio::sync::Mutex;
32
33#[derive(Debug, Deserialize)]
34pub struct DiagnosticsArgs {
35 pub path: Option<String>,
37 pub include_warnings: Option<bool>,
39 pub limit: Option<usize>,
41}
42
43#[derive(Debug, thiserror::Error)]
44#[error("Diagnostics error: {0}")]
45pub struct DiagnosticsError(String);
46
47#[derive(Debug, Clone)]
48pub struct DiagnosticsTool {
49 project_path: PathBuf,
50 ide_client: Option<Arc<Mutex<IdeClient>>>,
52}
53
54impl DiagnosticsTool {
55 pub fn new(project_path: PathBuf) -> Self {
56 Self {
57 project_path,
58 ide_client: None,
59 }
60 }
61
62 pub fn with_ide_client(mut self, ide_client: Arc<Mutex<IdeClient>>) -> Self {
64 self.ide_client = Some(ide_client);
65 self
66 }
67
68 fn detect_project_type(&self) -> ProjectType {
70 let cargo_toml = self.project_path.join("Cargo.toml");
71 let package_json = self.project_path.join("package.json");
72 let go_mod = self.project_path.join("go.mod");
73 let pyproject_toml = self.project_path.join("pyproject.toml");
74 let requirements_txt = self.project_path.join("requirements.txt");
75
76 if cargo_toml.exists() {
77 ProjectType::Rust
78 } else if package_json.exists() {
79 ProjectType::JavaScript
80 } else if go_mod.exists() {
81 ProjectType::Go
82 } else if pyproject_toml.exists() || requirements_txt.exists() {
83 ProjectType::Python
84 } else {
85 ProjectType::Unknown
86 }
87 }
88
89 async fn get_ide_diagnostics(&self, file_path: Option<&str>) -> Option<DiagnosticsResponse> {
91 let client = self.ide_client.as_ref()?;
92 let guard = client.lock().await;
93
94 if !guard.is_connected() {
95 return None;
96 }
97
98 guard.get_diagnostics(file_path).await.ok()
99 }
100
101 async fn get_command_diagnostics(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
103 let project_type = self.detect_project_type();
104
105 match project_type {
106 ProjectType::Rust => self.run_cargo_check().await,
107 ProjectType::JavaScript => self.run_npm_lint().await,
108 ProjectType::Go => self.run_go_build().await,
109 ProjectType::Python => self.run_python_check().await,
110 ProjectType::Unknown => Ok(DiagnosticsResponse {
111 diagnostics: Vec::new(),
112 total_errors: 0,
113 total_warnings: 0,
114 }),
115 }
116 }
117
118 async fn run_cargo_check(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
120 let output = Command::new("cargo")
121 .args(["check", "--message-format=json"])
122 .current_dir(&self.project_path)
123 .output()
124 .await
125 .map_err(|e| DiagnosticsError(format!("Failed to run cargo check: {}", e)))?;
126
127 let stdout = String::from_utf8_lossy(&output.stdout);
128 let mut diagnostics = Vec::new();
129
130 for line in stdout.lines() {
131 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(line)
132 && msg.get("reason").and_then(|r| r.as_str()) == Some("compiler-message")
133 && let Some(message) = msg.get("message")
134 && let Some(diag) = self.parse_cargo_message(message)
135 {
136 diagnostics.push(diag);
137 }
138 }
139
140 let total_errors = diagnostics
141 .iter()
142 .filter(|d| d.severity == DiagnosticSeverity::Error)
143 .count() as u32;
144 let total_warnings = diagnostics
145 .iter()
146 .filter(|d| d.severity == DiagnosticSeverity::Warning)
147 .count() as u32;
148
149 Ok(DiagnosticsResponse {
150 diagnostics,
151 total_errors,
152 total_warnings,
153 })
154 }
155
156 fn parse_cargo_message(&self, message: &serde_json::Value) -> Option<Diagnostic> {
158 let level = message.get("level")?.as_str()?;
159 let msg = message.get("message")?.as_str()?;
160
161 let severity = match level {
162 "error" => DiagnosticSeverity::Error,
163 "warning" => DiagnosticSeverity::Warning,
164 "note" | "help" => DiagnosticSeverity::Hint,
165 _ => DiagnosticSeverity::Information,
166 };
167
168 let spans = message.get("spans")?.as_array()?;
170 let span = spans
171 .iter()
172 .find(|s| {
173 s.get("is_primary")
174 .and_then(|v| v.as_bool())
175 .unwrap_or(false)
176 })
177 .or_else(|| spans.first())?;
178
179 let file = span.get("file_name")?.as_str()?;
180 let line = span.get("line_start")?.as_u64()? as u32;
181 let column = span.get("column_start")?.as_u64()? as u32;
182 let end_line = span
183 .get("line_end")
184 .and_then(|v| v.as_u64())
185 .map(|v| v as u32);
186 let end_column = span
187 .get("column_end")
188 .and_then(|v| v.as_u64())
189 .map(|v| v as u32);
190
191 let code = message
192 .get("code")
193 .and_then(|c| c.get("code"))
194 .and_then(|c| c.as_str())
195 .map(|s| s.to_string());
196
197 Some(Diagnostic {
198 file: file.to_string(),
199 line,
200 column,
201 end_line,
202 end_column,
203 severity,
204 message: msg.to_string(),
205 source: Some("rustc".to_string()),
206 code,
207 })
208 }
209
210 async fn run_npm_lint(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
212 let output = Command::new("npm")
214 .args(["run", "lint", "--", "--format=json"])
215 .current_dir(&self.project_path)
216 .output()
217 .await;
218
219 if let Ok(output) = output
220 && (output.status.success() || !output.stdout.is_empty())
221 {
222 let stdout = String::from_utf8_lossy(&output.stdout);
223 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
224 return Ok(self.parse_eslint_output(&results));
225 }
226 }
227
228 let output = Command::new("npx")
230 .args(["eslint", ".", "--format=json"])
231 .current_dir(&self.project_path)
232 .output()
233 .await
234 .map_err(|e| DiagnosticsError(format!("Failed to run eslint: {}", e)))?;
235
236 let stdout = String::from_utf8_lossy(&output.stdout);
237 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
238 return Ok(self.parse_eslint_output(&results));
239 }
240
241 Ok(DiagnosticsResponse {
243 diagnostics: Vec::new(),
244 total_errors: 0,
245 total_warnings: 0,
246 })
247 }
248
249 fn parse_eslint_output(&self, results: &[serde_json::Value]) -> DiagnosticsResponse {
251 let mut diagnostics = Vec::new();
252
253 for file_result in results {
254 let file = file_result
255 .get("filePath")
256 .and_then(|f| f.as_str())
257 .unwrap_or("");
258
259 if let Some(messages) = file_result.get("messages").and_then(|m| m.as_array()) {
260 for msg in messages {
261 let severity = match msg.get("severity").and_then(|s| s.as_u64()) {
262 Some(2) => DiagnosticSeverity::Error,
263 Some(1) => DiagnosticSeverity::Warning,
264 _ => DiagnosticSeverity::Information,
265 };
266
267 let message = msg
268 .get("message")
269 .and_then(|m| m.as_str())
270 .unwrap_or("")
271 .to_string();
272 let line = msg.get("line").and_then(|l| l.as_u64()).unwrap_or(1) as u32;
273 let column = msg.get("column").and_then(|c| c.as_u64()).unwrap_or(1) as u32;
274 let end_line = msg
275 .get("endLine")
276 .and_then(|l| l.as_u64())
277 .map(|v| v as u32);
278 let end_column = msg
279 .get("endColumn")
280 .and_then(|c| c.as_u64())
281 .map(|v| v as u32);
282 let code = msg
283 .get("ruleId")
284 .and_then(|r| r.as_str())
285 .map(|s| s.to_string());
286
287 diagnostics.push(Diagnostic {
288 file: file.to_string(),
289 line,
290 column,
291 end_line,
292 end_column,
293 severity,
294 message,
295 source: Some("eslint".to_string()),
296 code,
297 });
298 }
299 }
300 }
301
302 let total_errors = diagnostics
303 .iter()
304 .filter(|d| d.severity == DiagnosticSeverity::Error)
305 .count() as u32;
306 let total_warnings = diagnostics
307 .iter()
308 .filter(|d| d.severity == DiagnosticSeverity::Warning)
309 .count() as u32;
310
311 DiagnosticsResponse {
312 diagnostics,
313 total_errors,
314 total_warnings,
315 }
316 }
317
318 async fn run_go_build(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
320 let output = Command::new("go")
321 .args(["build", "-o", "/dev/null", "./..."])
322 .current_dir(&self.project_path)
323 .output()
324 .await
325 .map_err(|e| DiagnosticsError(format!("Failed to run go build: {}", e)))?;
326
327 let stderr = String::from_utf8_lossy(&output.stderr);
328 let mut diagnostics = Vec::new();
329
330 for line in stderr.lines() {
332 if let Some(diag) = self.parse_go_error(line) {
333 diagnostics.push(diag);
334 }
335 }
336
337 let total_errors = diagnostics
338 .iter()
339 .filter(|d| d.severity == DiagnosticSeverity::Error)
340 .count() as u32;
341 let total_warnings = diagnostics
342 .iter()
343 .filter(|d| d.severity == DiagnosticSeverity::Warning)
344 .count() as u32;
345
346 Ok(DiagnosticsResponse {
347 diagnostics,
348 total_errors,
349 total_warnings,
350 })
351 }
352
353 fn parse_go_error(&self, line: &str) -> Option<Diagnostic> {
355 let parts: Vec<&str> = line.splitn(4, ':').collect();
357 if parts.len() < 4 {
358 return None;
359 }
360
361 let file = parts[0].to_string();
362 let line_num = parts[1].parse::<u32>().ok()?;
363 let column = parts[2].parse::<u32>().ok()?;
364 let message = parts[3].trim().to_string();
365
366 Some(Diagnostic {
367 file,
368 line: line_num,
369 column,
370 end_line: None,
371 end_column: None,
372 severity: DiagnosticSeverity::Error,
373 message,
374 source: Some("go".to_string()),
375 code: None,
376 })
377 }
378
379 async fn run_python_check(&self) -> Result<DiagnosticsResponse, DiagnosticsError> {
381 let output = Command::new("pylint")
383 .args(["--output-format=json", "."])
384 .current_dir(&self.project_path)
385 .output()
386 .await;
387
388 if let Ok(output) = output {
389 let stdout = String::from_utf8_lossy(&output.stdout);
390 if let Ok(results) = serde_json::from_str::<Vec<serde_json::Value>>(&stdout) {
391 return Ok(self.parse_pylint_output(&results));
392 }
393 }
394
395 Ok(DiagnosticsResponse {
397 diagnostics: Vec::new(),
398 total_errors: 0,
399 total_warnings: 0,
400 })
401 }
402
403 fn parse_pylint_output(&self, results: &[serde_json::Value]) -> DiagnosticsResponse {
405 let mut diagnostics = Vec::new();
406
407 for msg in results {
408 let msg_type = msg.get("type").and_then(|t| t.as_str()).unwrap_or("");
409 let severity = match msg_type {
410 "error" | "fatal" => DiagnosticSeverity::Error,
411 "warning" => DiagnosticSeverity::Warning,
412 "convention" | "refactor" => DiagnosticSeverity::Hint,
413 _ => DiagnosticSeverity::Information,
414 };
415
416 let file = msg
417 .get("path")
418 .and_then(|p| p.as_str())
419 .unwrap_or("")
420 .to_string();
421 let line = msg.get("line").and_then(|l| l.as_u64()).unwrap_or(1) as u32;
422 let column = msg.get("column").and_then(|c| c.as_u64()).unwrap_or(1) as u32;
423 let message = msg
424 .get("message")
425 .and_then(|m| m.as_str())
426 .unwrap_or("")
427 .to_string();
428 let code = msg
429 .get("message-id")
430 .and_then(|m| m.as_str())
431 .map(|s| s.to_string());
432
433 diagnostics.push(Diagnostic {
434 file,
435 line,
436 column,
437 end_line: None,
438 end_column: None,
439 severity,
440 message,
441 source: Some("pylint".to_string()),
442 code,
443 });
444 }
445
446 let total_errors = diagnostics
447 .iter()
448 .filter(|d| d.severity == DiagnosticSeverity::Error)
449 .count() as u32;
450 let total_warnings = diagnostics
451 .iter()
452 .filter(|d| d.severity == DiagnosticSeverity::Warning)
453 .count() as u32;
454
455 DiagnosticsResponse {
456 diagnostics,
457 total_errors,
458 total_warnings,
459 }
460 }
461
462 fn filter_diagnostics(
464 &self,
465 mut response: DiagnosticsResponse,
466 include_warnings: bool,
467 limit: usize,
468 file_path: Option<&str>,
469 ) -> DiagnosticsResponse {
470 if let Some(path) = file_path {
472 response.diagnostics.retain(|d| d.file.contains(path));
473 }
474
475 if !include_warnings {
477 response
478 .diagnostics
479 .retain(|d| d.severity == DiagnosticSeverity::Error);
480 }
481
482 response.diagnostics.truncate(limit);
484
485 response.total_errors = response
487 .diagnostics
488 .iter()
489 .filter(|d| d.severity == DiagnosticSeverity::Error)
490 .count() as u32;
491 response.total_warnings = response
492 .diagnostics
493 .iter()
494 .filter(|d| d.severity == DiagnosticSeverity::Warning)
495 .count() as u32;
496
497 response
498 }
499}
500
501#[derive(Debug, Clone, Copy)]
502enum ProjectType {
503 Rust,
504 JavaScript,
505 Go,
506 Python,
507 Unknown,
508}
509
510impl Tool for DiagnosticsTool {
511 const NAME: &'static str = "diagnostics";
512
513 type Error = DiagnosticsError;
514 type Args = DiagnosticsArgs;
515 type Output = String;
516
517 async fn definition(&self, _prompt: String) -> ToolDefinition {
518 ToolDefinition {
519 name: Self::NAME.to_string(),
520 description: r#"Check for code errors, warnings, and linting issues.
521
522This tool queries language servers or runs language-specific commands to detect:
523- Compilation errors
524- Type errors
525- Syntax errors
526- Linting warnings
527- Best practice violations
528
529Use this tool after writing or modifying code to verify there are no errors.
530
531The tool automatically detects the project type and uses appropriate checking:
532- Rust: Uses rust-analyzer or cargo check
533- JavaScript/TypeScript: Uses ESLint or TypeScript compiler
534- Go: Uses gopls or go build
535- Python: Uses pylint or pyright
536
537Returns a list of diagnostics with file locations, severity, and messages."#
538 .to_string(),
539 parameters: json!({
540 "type": "object",
541 "properties": {
542 "path": {
543 "type": "string",
544 "description": "Optional file path to check. If not provided, checks all files in the project."
545 },
546 "include_warnings": {
547 "type": "boolean",
548 "description": "Whether to include warnings in addition to errors (default: true)"
549 },
550 "limit": {
551 "type": "integer",
552 "description": "Maximum number of diagnostics to return (default: 50)"
553 }
554 }
555 }),
556 }
557 }
558
559 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
560 let include_warnings = args.include_warnings.unwrap_or(true);
561 let limit = args.limit.unwrap_or(50);
562 let file_path = args.path.as_deref();
563
564 let response = if let Some(ide_response) = self.get_ide_diagnostics(file_path).await {
566 ide_response
567 } else {
568 self.get_command_diagnostics().await?
570 };
571
572 let filtered = self.filter_diagnostics(response, include_warnings, limit, file_path);
574
575 let result = if filtered.diagnostics.is_empty() {
577 json!({
578 "success": true,
579 "message": "No errors or warnings found",
580 "total_errors": 0,
581 "total_warnings": 0,
582 "diagnostics": []
583 })
584 } else {
585 let formatted_diagnostics: Vec<serde_json::Value> = filtered
586 .diagnostics
587 .iter()
588 .map(|d| {
589 json!({
590 "file": d.file,
591 "line": d.line,
592 "column": d.column,
593 "severity": d.severity.as_str(),
594 "message": d.message,
595 "source": d.source,
596 "code": d.code
597 })
598 })
599 .collect();
600
601 json!({
602 "success": filtered.total_errors == 0,
603 "total_errors": filtered.total_errors,
604 "total_warnings": filtered.total_warnings,
605 "diagnostics": formatted_diagnostics
606 })
607 };
608
609 serde_json::to_string_pretty(&result)
610 .map_err(|e| DiagnosticsError(format!("Failed to serialize: {}", e)))
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use std::env;
618
619 #[tokio::test]
620 async fn test_diagnostics_tool_creation() {
621 let tool = DiagnosticsTool::new(PathBuf::from("."));
622 assert_eq!(tool.project_path, PathBuf::from("."));
623 }
624
625 #[test]
626 fn test_project_type_detection() {
627 let tool = DiagnosticsTool::new(env::current_dir().unwrap());
629 let project_type = tool.detect_project_type();
630 assert!(matches!(project_type, ProjectType::Rust));
632 }
633
634 #[test]
635 fn test_parse_go_error() {
636 let tool = DiagnosticsTool::new(PathBuf::from("."));
637 let line = "main.go:10:5: undefined: foo";
638 let diag = tool.parse_go_error(line);
639 assert!(diag.is_some());
640 let diag = diag.unwrap();
641 assert_eq!(diag.file, "main.go");
642 assert_eq!(diag.line, 10);
643 assert_eq!(diag.column, 5);
644 assert_eq!(diag.message, "undefined: foo");
645 }
646
647 #[test]
648 fn test_filter_diagnostics() {
649 let tool = DiagnosticsTool::new(PathBuf::from("."));
650 let response = DiagnosticsResponse {
651 diagnostics: vec![
652 Diagnostic {
653 file: "src/main.rs".to_string(),
654 line: 1,
655 column: 1,
656 end_line: None,
657 end_column: None,
658 severity: DiagnosticSeverity::Error,
659 message: "error".to_string(),
660 source: None,
661 code: None,
662 },
663 Diagnostic {
664 file: "src/lib.rs".to_string(),
665 line: 1,
666 column: 1,
667 end_line: None,
668 end_column: None,
669 severity: DiagnosticSeverity::Warning,
670 message: "warning".to_string(),
671 source: None,
672 code: None,
673 },
674 ],
675 total_errors: 1,
676 total_warnings: 1,
677 };
678
679 let filtered = tool.filter_diagnostics(response.clone(), false, 50, None);
681 assert_eq!(filtered.diagnostics.len(), 1);
682 assert_eq!(filtered.total_errors, 1);
683 assert_eq!(filtered.total_warnings, 0);
684
685 let filtered = tool.filter_diagnostics(response, true, 50, Some("main.rs"));
687 assert_eq!(filtered.diagnostics.len(), 1);
688 assert_eq!(filtered.diagnostics[0].file, "src/main.rs");
689 }
690}